Autoresearch round 2: smear, backout, and hyperparameter tuning

New architectural features:
- Smear: mix previous token embedding into current position via learned
  gate, providing cheap bigram-like info (works in training + KV cache)
- Backout: subtract learned fraction of mid-layer residual before logit
  projection to remove low-level features

Hyperparameter tuning:
- Muon momentum warmdown 0.97→0.90 during LR warmdown phase
- Non-uniform per-layer init: resid_lambdas 1.15→1.05, x0_lambdas 0.20→0.05
- c_fc init scale 0.4x, QK norm scale 1.2, sliding window seq_len/4
- Speedrun data:params ratio reduced to 8

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Andrej Karpathy 2026-03-14 17:03:06 +00:00
parent f068604948
commit a825e63f81
4 changed files with 73 additions and 18 deletions

View File

@ -100,10 +100,13 @@ class KVCache:
self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
# Current sequence length per batch element (FA3 needs int32)
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
# Previous token's normalized embedding for smear (set by model forward pass)
self.prev_embedding = None
def reset(self):
"""Reset cache to empty state."""
self.cache_seqlens.zero_()
self.prev_embedding = None
def get_pos(self):
"""Get current position (assumes all batch elements at same position)."""
@ -129,6 +132,9 @@ class KVCache:
self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
self.cache_seqlens.fill_(other_pos)
# Copy smear state: expand batch=1 prev_embedding to num_samples
if other.prev_embedding is not None:
self.prev_embedding = other.prev_embedding.expand(self.batch_size, -1, -1).clone()
# -----------------------------------------------------------------------------
@torch.inference_mode()

View File

@ -34,7 +34,7 @@ class GPTConfig:
n_kv_head: int = 6 # number of key/value heads (GQA)
n_embd: int = 768
# Sliding window attention pattern string, tiled across layers. Final layer always L.
# Characters: L=long (full context), S=short (half context)
# Characters: L=long (full context), S=short (quarter context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
window_pattern: str = "SSSL"
@ -98,8 +98,8 @@ class CausalSelfAttention(nn.Module):
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k) # QK norm
q = q * 1.15 # sharper attention (split scale between Q and K), TODO think through better
k = k * 1.15
q = q * 1.2 # sharper attention (split scale between Q and K), TODO think through better
k = k * 1.2
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
@ -179,6 +179,11 @@ class GPT(nn.Module):
# Separate parameters so they can have different optimizer treatment
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
# Smear: mix previous token's embedding into current token (cheap bigram-like info)
self.smear_gate = Linear(24, 1, bias=False)
self.smear_lambda = nn.Parameter(torch.zeros(1))
# Backout: subtract cached mid-layer residual before final norm to remove low-level features
self.backout_lambda = nn.Parameter(0.2 * torch.ones(1))
# Value embeddings (ResFormer-style): alternating layers, last layer always included
head_dim = config.n_embd // config.n_head
kv_dim = config.n_kv_head * head_dim
@ -221,12 +226,17 @@ class GPT(nn.Module):
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.5, s * 0.5) # 0.5x init scale for c_fc
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4) # 0.4x init scale for c_fc
torch.nn.init.zeros_(block.mlp.c_proj.weight)
# Per-layer scalars
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
# Per-layer resid init: stronger residual at early layers, weaker at deep layers
n_layer = self.config.n_layer
for i in range(n_layer):
self.resid_lambdas.data[i] = 1.15 - (0.10 * i / max(n_layer - 1, 1))
# Decaying x0 init: earlier layers get more input embedding blending
for i in range(n_layer):
self.x0_lambdas.data[i] = 0.20 - (0.15 * i / max(n_layer - 1, 1))
# Value embeddings (init like c_v: uniform with same std)
for ve in self.value_embeds.values():
@ -276,13 +286,13 @@ class GPT(nn.Module):
- right: how many tokens after current position to attend to (0 for causal)
Pattern string is tiled across layers. Final layer always gets L (full context).
Characters: L=long (full context), S=short (half context)
Characters: L=long (full context), S=short (quarter context)
"""
pattern = config.window_pattern.upper()
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
# Map characters to window sizes
long_window = config.sequence_len
short_window = -(-long_window // 3 // 128) * 128 # ceil to FA3 tile size (2048 -> 768)
short_window = -(-long_window // 4 // 128) * 128 # ceil to FA3 tile size (2048 -> 768)
char_to_window = {
"L": (long_window, 0),
"S": (short_window, 0),
@ -315,7 +325,8 @@ class GPT(nn.Module):
# Exclude non-matmul params: embeddings and per-layer scalars
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel())
self.resid_lambdas.numel() + self.x0_lambdas.numel() +
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel())
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window
attn_flops = 0
@ -343,7 +354,7 @@ class GPT(nn.Module):
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
lm_head = sum(p.numel() for p in self.lm_head.parameters())
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()
total = wte + value_embeds + lm_head + transformer_matrices + scalars
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
return {
@ -366,7 +377,8 @@ class GPT(nn.Module):
lm_head_params = list(self.lm_head.parameters())
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
@ -380,6 +392,7 @@ class GPT(nn.Module):
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
]
# Muon groups (matrix params, grouped by shape for stacking)
for shape in sorted({p.shape for p in matrix_params}):
@ -406,15 +419,44 @@ class GPT(nn.Module):
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
# Forward the trunk of the Transformer
# Embed the tokens
x = self.transformer.wte(idx) # embed current token
x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
x = norm(x)
# Smear: mix previous token's embedding into current position (cheap bigram info)
if kv_cache is None:
# Training / naive generate: full sequence available, use fast slice
assert T > 1, "Training forward pass should have T > 1"
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
else:
# KV cache inference: read prev embedding from cache, store current for next step
x_pre_smear = kv_cache.prev_embedding
kv_cache.prev_embedding = x[:, -1:, :]
if T > 1:
# Prefill: apply smear to positions 1+, same as training
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
elif x_pre_smear is not None:
# Decode: single token, use cached prev embedding
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24]))
x = x + gate * x_pre_smear
# Forward the trunk of the Transformer
x0 = x # save initial normalized embedding for x0 residual
n_layer = self.config.n_layer
backout_layer = n_layer // 2 # cache at halfway point
x_backout = None
for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
if i == backout_layer:
x_backout = x
# Subtract mid-layer residual to remove low-level features before logit projection
if x_backout is not None:
x = x - self.backout_lambda.to(x.dtype) * x_backout
x = norm(x)
# Forward the lm_head (compute logits)

View File

@ -69,8 +69,8 @@ python -m scripts.tok_eval
echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID
# d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 9.5)
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=9.5 --device-batch-size=16 --fp8 --run=$WANDB_RUN
# d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 8)
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=8 --device-batch-size=16 --fp8 --run=$WANDB_RUN
# evaluate the model: CORE metric, BPB on train/val, and draw samples
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16

View File

@ -367,11 +367,18 @@ def get_lr_multiplier(it):
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * args.final_lr_frac
# Momentum scheduler for Muon optimizer (warms up to 0.97 over the first 400 steps)
# Momentum scheduler for Muon optimizer (warms up to 0.97, warms down to 0.90 during LR warmdown)
def get_muon_momentum(it):
frac = min(it / 400, 1)
momentum = (1 - frac) * 0.85 + frac * 0.97
return momentum
warmdown_iters = round(args.warmdown_ratio * num_iterations)
warmdown_start = num_iterations - warmdown_iters
if it < 400:
frac = it / 400
return (1 - frac) * 0.85 + frac * 0.97
elif it >= warmdown_start:
progress = (it - warmdown_start) / warmdown_iters
return 0.97 * (1 - progress) + 0.90 * progress
else:
return 0.97
# Weight decay scheduler for Muon optimizer (cosine decay to zero over the course of training)
def get_weight_decay(it):