mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-25 14:15:21 +00:00
Autoresearch round 2: smear, backout, and hyperparameter tuning
New architectural features: - Smear: mix previous token embedding into current position via learned gate, providing cheap bigram-like info (works in training + KV cache) - Backout: subtract learned fraction of mid-layer residual before logit projection to remove low-level features Hyperparameter tuning: - Muon momentum warmdown 0.97→0.90 during LR warmdown phase - Non-uniform per-layer init: resid_lambdas 1.15→1.05, x0_lambdas 0.20→0.05 - c_fc init scale 0.4x, QK norm scale 1.2, sliding window seq_len/4 - Speedrun data:params ratio reduced to 8 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
f068604948
commit
a825e63f81
|
|
@ -100,10 +100,13 @@ class KVCache:
|
|||
self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
||||
# Current sequence length per batch element (FA3 needs int32)
|
||||
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
# Previous token's normalized embedding for smear (set by model forward pass)
|
||||
self.prev_embedding = None
|
||||
|
||||
def reset(self):
|
||||
"""Reset cache to empty state."""
|
||||
self.cache_seqlens.zero_()
|
||||
self.prev_embedding = None
|
||||
|
||||
def get_pos(self):
|
||||
"""Get current position (assumes all batch elements at same position)."""
|
||||
|
|
@ -129,6 +132,9 @@ class KVCache:
|
|||
self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
|
||||
self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
|
||||
self.cache_seqlens.fill_(other_pos)
|
||||
# Copy smear state: expand batch=1 prev_embedding to num_samples
|
||||
if other.prev_embedding is not None:
|
||||
self.prev_embedding = other.prev_embedding.expand(self.batch_size, -1, -1).clone()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@torch.inference_mode()
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class GPTConfig:
|
|||
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||
n_embd: int = 768
|
||||
# Sliding window attention pattern string, tiled across layers. Final layer always L.
|
||||
# Characters: L=long (full context), S=short (half context)
|
||||
# Characters: L=long (full context), S=short (quarter context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
window_pattern: str = "SSSL"
|
||||
|
||||
|
|
@ -98,8 +98,8 @@ class CausalSelfAttention(nn.Module):
|
|||
cos, sin = cos_sin
|
||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
||||
q, k = norm(q), norm(k) # QK norm
|
||||
q = q * 1.15 # sharper attention (split scale between Q and K), TODO think through better
|
||||
k = k * 1.15
|
||||
q = q * 1.2 # sharper attention (split scale between Q and K), TODO think through better
|
||||
k = k * 1.2
|
||||
|
||||
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
|
||||
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
|
||||
|
|
@ -179,6 +179,11 @@ class GPT(nn.Module):
|
|||
# Separate parameters so they can have different optimizer treatment
|
||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
# Smear: mix previous token's embedding into current token (cheap bigram-like info)
|
||||
self.smear_gate = Linear(24, 1, bias=False)
|
||||
self.smear_lambda = nn.Parameter(torch.zeros(1))
|
||||
# Backout: subtract cached mid-layer residual before final norm to remove low-level features
|
||||
self.backout_lambda = nn.Parameter(0.2 * torch.ones(1))
|
||||
# Value embeddings (ResFormer-style): alternating layers, last layer always included
|
||||
head_dim = config.n_embd // config.n_head
|
||||
kv_dim = config.n_kv_head * head_dim
|
||||
|
|
@ -221,12 +226,17 @@ class GPT(nn.Module):
|
|||
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
||||
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.5, s * 0.5) # 0.5x init scale for c_fc
|
||||
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4) # 0.4x init scale for c_fc
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
|
||||
# Per-layer scalars
|
||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||
self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
|
||||
# Per-layer resid init: stronger residual at early layers, weaker at deep layers
|
||||
n_layer = self.config.n_layer
|
||||
for i in range(n_layer):
|
||||
self.resid_lambdas.data[i] = 1.15 - (0.10 * i / max(n_layer - 1, 1))
|
||||
# Decaying x0 init: earlier layers get more input embedding blending
|
||||
for i in range(n_layer):
|
||||
self.x0_lambdas.data[i] = 0.20 - (0.15 * i / max(n_layer - 1, 1))
|
||||
|
||||
# Value embeddings (init like c_v: uniform with same std)
|
||||
for ve in self.value_embeds.values():
|
||||
|
|
@ -276,13 +286,13 @@ class GPT(nn.Module):
|
|||
- right: how many tokens after current position to attend to (0 for causal)
|
||||
|
||||
Pattern string is tiled across layers. Final layer always gets L (full context).
|
||||
Characters: L=long (full context), S=short (half context)
|
||||
Characters: L=long (full context), S=short (quarter context)
|
||||
"""
|
||||
pattern = config.window_pattern.upper()
|
||||
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
|
||||
# Map characters to window sizes
|
||||
long_window = config.sequence_len
|
||||
short_window = -(-long_window // 3 // 128) * 128 # ceil to FA3 tile size (2048 -> 768)
|
||||
short_window = -(-long_window // 4 // 128) * 128 # ceil to FA3 tile size (2048 -> 768)
|
||||
char_to_window = {
|
||||
"L": (long_window, 0),
|
||||
"S": (short_window, 0),
|
||||
|
|
@ -315,7 +325,8 @@ class GPT(nn.Module):
|
|||
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel() +
|
||||
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel())
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
attn_flops = 0
|
||||
|
|
@ -343,7 +354,7 @@ class GPT(nn.Module):
|
|||
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
||||
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
||||
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()
|
||||
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
||||
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
||||
return {
|
||||
|
|
@ -366,7 +377,8 @@ class GPT(nn.Module):
|
|||
lm_head_params = list(self.lm_head.parameters())
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
||||
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
|
||||
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
|
|
@ -380,6 +392,7 @@ class GPT(nn.Module):
|
|||
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
|
||||
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
|
||||
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
||||
dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
|
||||
]
|
||||
# Muon groups (matrix params, grouped by shape for stacking)
|
||||
for shape in sorted({p.shape for p in matrix_params}):
|
||||
|
|
@ -406,15 +419,44 @@ class GPT(nn.Module):
|
|||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
# Embed the tokens
|
||||
x = self.transformer.wte(idx) # embed current token
|
||||
x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
|
||||
x = norm(x)
|
||||
|
||||
# Smear: mix previous token's embedding into current position (cheap bigram info)
|
||||
if kv_cache is None:
|
||||
# Training / naive generate: full sequence available, use fast slice
|
||||
assert T > 1, "Training forward pass should have T > 1"
|
||||
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
|
||||
x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
|
||||
else:
|
||||
# KV cache inference: read prev embedding from cache, store current for next step
|
||||
x_pre_smear = kv_cache.prev_embedding
|
||||
kv_cache.prev_embedding = x[:, -1:, :]
|
||||
if T > 1:
|
||||
# Prefill: apply smear to positions 1+, same as training
|
||||
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
|
||||
x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
|
||||
elif x_pre_smear is not None:
|
||||
# Decode: single token, use cached prev embedding
|
||||
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24]))
|
||||
x = x + gate * x_pre_smear
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
n_layer = self.config.n_layer
|
||||
backout_layer = n_layer // 2 # cache at halfway point
|
||||
x_backout = None
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
if i == backout_layer:
|
||||
x_backout = x
|
||||
# Subtract mid-layer residual to remove low-level features before logit projection
|
||||
if x_backout is not None:
|
||||
x = x - self.backout_lambda.to(x.dtype) * x_backout
|
||||
x = norm(x)
|
||||
|
||||
# Forward the lm_head (compute logits)
|
||||
|
|
|
|||
|
|
@ -69,8 +69,8 @@ python -m scripts.tok_eval
|
|||
echo "Waiting for dataset download to complete..."
|
||||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
# d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 9.5)
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=9.5 --device-batch-size=16 --fp8 --run=$WANDB_RUN
|
||||
# d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 8)
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=8 --device-batch-size=16 --fp8 --run=$WANDB_RUN
|
||||
# evaluate the model: CORE metric, BPB on train/val, and draw samples
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16
|
||||
|
||||
|
|
|
|||
|
|
@ -367,11 +367,18 @@ def get_lr_multiplier(it):
|
|||
progress = (num_iterations - it) / warmdown_iters
|
||||
return progress * 1.0 + (1 - progress) * args.final_lr_frac
|
||||
|
||||
# Momentum scheduler for Muon optimizer (warms up to 0.97 over the first 400 steps)
|
||||
# Momentum scheduler for Muon optimizer (warms up to 0.97, warms down to 0.90 during LR warmdown)
|
||||
def get_muon_momentum(it):
|
||||
frac = min(it / 400, 1)
|
||||
momentum = (1 - frac) * 0.85 + frac * 0.97
|
||||
return momentum
|
||||
warmdown_iters = round(args.warmdown_ratio * num_iterations)
|
||||
warmdown_start = num_iterations - warmdown_iters
|
||||
if it < 400:
|
||||
frac = it / 400
|
||||
return (1 - frac) * 0.85 + frac * 0.97
|
||||
elif it >= warmdown_start:
|
||||
progress = (it - warmdown_start) / warmdown_iters
|
||||
return 0.97 * (1 - progress) + 0.90 * progress
|
||||
else:
|
||||
return 0.97
|
||||
|
||||
# Weight decay scheduler for Muon optimizer (cosine decay to zero over the course of training)
|
||||
def get_weight_decay(it):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user