Build Recurrent-Depth Transformers with OpenMythos for MLA, GQA, Sparse MoE, and Loop-Scaled Reasoning
In this tutorial, we explore OpenMythos by building an advanced recurrent-depth transformer workflow that runs end-to-end in Google Colab. We create both MLA and GQA model variants, compare their parameter counts, and check the stability of the recurrent injection matrix through its spectral radius. We then move from simple forward and generation tests into a synthetic compositional reasoning task, where the model learns to predict the sum of digit chains modulo a fixed value. Through this setup, we study how recurrent loops enable a single model to reuse its parameters for deeper computation.
import subprocess, sys
def pip(*args):
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *args], check=False)
try:
import open_mythos # noqa: F401
except Exception:
pip("open-mythos")
try:
import open_mythos # noqa: F401
except Exception:
pip("git+https://github.com/kyegomez/OpenMythos.git")
import math, random, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from open_mythos.main import OpenMythos, MythosConfig
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
print(f"Device: {device} | Torch: {torch.__version__}")We install OpenMythos and fall back to the GitHub source if installing from PyPI fails. We import the required Python, PyTorch, NumPy, and plotting libraries for model building, training, and visualization. We also set a fixed random seed and use CUDA when available, so the tutorial runs efficiently in Colab.
def build_model(attn_type: str = "mla", max_loop_iters: int = 8) -> tuple:
"""Build a small OpenMythos model. Two attention variants supported.
MLA — Multi-Latent Attention (compressed KV cache, DeepSeek-V2 style)
GQA — Grouped-Query Attention (fewer KV heads than Q heads)
"""
base = dict(
vocab_size = 64,
dim = 128,
n_heads = 4,
max_seq_len = 32,
max_loop_iters = max_loop_iters,
prelude_layers = 1,
coda_layers = 1,
n_experts = 4,
n_shared_experts = 1,
n_experts_per_tok= 2,
expert_dim = 64,
lora_rank = 8,
attn_type = attn_type,
)
if attn_type == "gqa":
cfg = MythosConfig(**base, n_kv_heads=2)
else:
cfg = MythosConfig(
**base, n_kv_heads=4,
kv_lora_rank=32, q_lora_rank=32,
qk_rope_head_dim=16, qk_nope_head_dim=16, v_head_dim=16,
)
model = OpenMythos(cfg).to(device)
return model, cfg
model_mla, cfg_mla = build_model("mla")
model_gqa, cfg_gqa = build_model("gqa")
def n_params(m): return sum(p.numel() for p in m.parameters())
print(f"\n[MLA] params: {n_params(model_mla):>10,}")
print(f"[GQA] params: {n_params(model_gqa):>10,}")
def spectral_radius(model):
A = model.recurrent.injection.get_A().detach().cpu()
if A.dim() == 1:
rho = A.abs().max().item()
else:
rho = torch.linalg.eigvals(A.float()).abs().max().item()
return rho
print(f"\nρ(A) MLA: {spectral_radius(model_mla):.4f} (must be < 1)")
print(f"ρ(A) GQA: {spectral_radius(model_gqa):.4f} (must be < 1)")
ids = torch.randint(0, cfg_mla.vocab_size, (2, 16), device=device)
with torch.no_grad():
logits = model_mla(ids, n_loops=4)
gen = model_mla.generate(ids, max_new_tokens=4, n_loops=8)
print(f"\nForward logits shape: {tuple(logits.shape)}")
print(f"Generation shape: {tuple(gen.shape)}")We define a reusable model factory that builds small OpenMythos models with either MLA or GQA attention. We compare both variants by checking their parameter counts and the spectral radius of the recurrent injection matrix. We then run a quick forward pass and generation test to confirm that the MLA model produces logits and generated tokens correctly.
PAD, START, EQ = 0, 1, 2
DIGIT_BASE = 10
M = 7
SEQ_LEN = cfg_mla.max_seq_len
MIN_LEN, MAX_LEN = 2, 5
def make_example(chain_len: int):
digits = [random.randint(0, M-1) for _ in range(chain_len)]
target = sum(digits) % M
toks = [START] + [DIGIT_BASE + d for d in digits] + [EQ]
toks = toks + [PAD] * (SEQ_LEN - len(toks))
return toks[:SEQ_LEN], DIGIT_BASE + target
class ChainDataset(Dataset):
def __init__(self, n, lo, hi):
self.items = [make_example(random.randint(lo, hi)) for _ in range(n)]
def __len__(self): return len(self.items)
def __getitem__(self, i):
x, y = self.items[i]
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
train_loader = DataLoader(ChainDataset(3000, MIN_LEN, MAX_LEN), batch_size=64, shuffle=True)
test_loader = DataLoader(ChainDataset(400, MIN_LEN, MAX_LEN), batch_size=64)
ood_loader = DataLoader(ChainDataset(400, MAX_LEN+1, MAX_LEN+3), batch_size=64)We create a synthetic compositional task in which the model predicts the sum of digit tokens modulo 7. We define the token scheme, sequence structure, and dataset class that generates random digit-chain examples. We then build training, test, and out-of-distribution loaders to evaluate both normal performance and depth extrapolation.
model = model_mla
TRAIN_LOOPS = 4
EPOCHS = 6
opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)
def loss_at_eq(logits, x, y):
"""Predict the answer at the position immediately after the EQ token."""
eq_pos = (x == EQ).int().argmax(dim=1)
pred = logits[torch.arange(x.size(0)), eq_pos]
return F.cross_entropy(pred, y), pred
train_losses = []
print("\n--- Training ---")
t0 = time.time()
for ep in range(EPOCHS):
model.train(); running = 0.0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
logits = model(x, n_loops=TRAIN_LOOPS)
loss, _ = loss_at_eq(logits, x, y)
opt.zero_grad(); loss.backward()
opt.step()
running += loss.item()
sched.step()
train_losses.append(running / len(train_loader))
print(f" epoch {ep+1}/{EPOCHS} loss={train_losses[-1]:.4f} ρ(A)={spectral_radius(model):.3f}")
print(f"Train time: {time.time()-t0:.1f}s")
@torch.no_grad()
def accuracy(loader, n_loops):
model.eval(); correct = total = 0
for x, y in loader:
x, y = x.to(device), y.to(device)
logits = model(x, n_loops=n_loops)
_, pred = loss_at_eq(logits, x, y)
correct += (pred.argmax(-1) == y).sum().item()
total += y.size(0)
return correct / total
LOOP_GRID = [1, 2, 4, 6, 8]
print("\n--- Loop-count scaling (same weights, varying compute) ---")
in_dist_acc = [accuracy(test_loader, L) for L in LOOP_GRID]
ood_acc = [accuracy(ood_loader, L) for L in LOOP_GRID]
for L, a, o in zip(LOOP_GRID, in_dist_acc, ood_acc):
print(f" n_loops={L}: in-dist acc={a:.3f} OOD (longer chains) acc={o:.3f}")We train the MLA model with a fixed number of recurrent loops and optimize it with AdamW and a cosine learning rate schedule. We compute the loss at the EQ token position, clip gradients, track training loss, and monitor recurrent stability after each epoch. We then evaluate inference-time loop scaling by testing the same trained model with different loop counts.
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(range(1, EPOCHS+1), train_losses, marker="o")
axes[0].set_title("Training loss"); axes[0].set_xlabel("epoch"); axes[0].set_ylabel("CE loss")
axes[0].grid(alpha=0.3)
axes[1].plot(LOOP_GRID, in_dist_acc, marker="s", label="in-distribution")
axes[1].plot(LOOP_GRID, ood_acc, marker="^", label="longer chains (OOD depth)")
axes[1].set_title("Inference-time loop scaling")
axes[1].set_xlabel("# recurrent loops at inference"); axes[1].set_ylabel("test accuracy")
axes[1].legend(); axes[1].grid(alpha=0.3)
plt.tight_layout(); plt.show()
chain_len = 4
toks, true_tok = make_example(chain_len)
digits = [t - DIGIT_BASE for t in toks if t >= DIGIT_BASE]
prompt = torch.tensor([toks], device=device)
with torch.no_grad():
gen = model.generate(prompt, max_new_tokens=1, n_loops=8)
predicted = gen[0, -1].item()
print(f"\nDemo: digits={digits}, target=({'+'.join(map(str, digits))}) % {M} = {sum(digits)%M}")
print(f" true token={true_tok} (digit {true_tok-DIGIT_BASE}) | "
f"predicted token={predicted} (digit {predicted-DIGIT_BASE if predicted>=DIGIT_BASE else '?'})")
print("\nDone. Key takeaway: at inference, increasing n_loops trades compute for")
print("reasoning depth on the same fixed-parameter model — that's the RDT premise.")We visualize the training loss curve and compare in-distribution accuracy with longer-chain out-of-distribution accuracy across loop counts. We also run a small qualitative generation example to inspect whether the trained model predicts the correct modulo-sum digit. We conclude by showing that increasing the number of recurrent loops gives the same fixed-parameter model more reasoning depth at inference time.
In conclusion, we understood how OpenMythos combines recurrent-depth transformer design, attention variants, sparse MoE components, and inference-time loop scaling into a compact experimental pipeline. We trained the model on a controlled reasoning task, evaluated it on both in-distribution and longer out-of-distribution chains, and visualized how accuracy changes as we increase the number of recurrent loops. It helped us see how recurrent depth can trade additional inference computation for stronger reasoning behavior without changing the model’s learned parameters.
Check out the Full Codes with Notebook. Also, feel free to follow us on Twitter and don’t forget to join our 150k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.
Need to partner with us for promoting your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar etc.? Connect with us
The post Build Recurrent-Depth Transformers with OpenMythos for MLA, GQA, Sparse MoE, and Loop-Scaled Reasoning appeared first on MarkTechPost.
from MarkTechPost https://ift.tt/TD7ZlaI
via IFTTT

Comments
Post a Comment