diff --git a/README.md b/README.md index 1fed675..79b12df 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ Presently, the main focus of development is on tuning the pretraining stage, whi | 3 | 2.76 | 0.74645 | 0.2602 | bump total batch size to 1M tokens | Feb 5 2026 | 2c062aa | @karpathy | | 4 | 2.02 | 0.71854 | 0.2571 | change dataset to NVIDIA ClimbMix | Mar 4 2026 | 324e69c | @ddudek @karpathy | | 5 | 1.80 | 0.71808 | 0.2690 | autoresearch [round 1](https://x.com/karpathy/status/2031135152349524125) | Mar 9 2026 | 6ed7d1d | @karpathy | +| 5 | 1.65 | 0.71800 | 0.2626 | autoresearch round 2 | Mar 14 2026 | a825e63 | @karpathy | The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $43,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 2 hours is ~$48). diff --git a/dev/LEADERBOARD.md b/dev/LEADERBOARD.md index f20d455..65c0809 100644 --- a/dev/LEADERBOARD.md +++ b/dev/LEADERBOARD.md @@ -196,3 +196,7 @@ NOTE: The `val_bpb` is as of this run *NOT* comparable due to the data distribut Achieved Mar 9, 2026 on commit `6ed7d1d`. Exactly the same launch command as Run 4 except `--target-param-data-ratio=8.7`. I ran 5 identical runs, the average CORE was 0.2690, which is quite a bit above the needed threshold of 0.2565. But the reason I didn't decrease the ratio further (i.e. train shorter) is that while the CORE "safety gap" is large, the val_loss safety gap is smaller - 0.71808, which we want to be below the Run 4 val loss of 0.71854. It's likely that we could have reduced the ratio even lower, possibly to 8.6, but it's not worth splitting hairs at this point. This commit is special because all of the improvements that went into [this commit](https://github.com/karpathy/nanochat/commit/6ed7d1d82cee16c2e26f45d559ad3338447a6c1b) came from fully autonomous "research" done by a private version of [autoresearch](https://github.com/karpathy/autoresearch) run on a d12 model. I wrote more about this in [this tweet](https://x.com/karpathy/status/2031135152349524125). The changes easily translated from d12 to d24, hence new leaderboard record, taking us from 2.02 hours "time to GPT-2" to 1.80 hours. + +## Run 6 + +Achieved Mar 14, 2026 on commit `a825e63`. Exactly the same launch command as Run 4 except `--target-param-data-ratio=8`. Improvements in the architecture are allowing us to train shorter and shorter time. Instead of an undertrained d24 I attempted to train an overtrained d22 but it was worse. This set of changes came from autoresearch round 2, where I asked it to reference the modded-nanogpt repo for inspiration. So the exploration tried out a number of ideas and in particular found a way to incorporate the backout and smear in such a way that they are helpful (I had previously tried them manually a long time ago and they caused regressions). The smear idea in particular is a little bit heavier and bloaty because it is essentially an "early fusion" of context across tokens, producing a kind of a bigram input into the network and allowing it to focus on higher ngrams earlier. But for this reason the code gets a bit more complex and required some changes to inference. I verified with a unit test that the Engine inference is correct compared to the naive inference of `GPT.generate()`. The average of 5 runs was CORE 0.262634 and each of them lasted 1.65 hours (99 minutes). diff --git a/nanochat/engine.py b/nanochat/engine.py index 4724c8f..aa2e6a9 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -100,10 +100,13 @@ class KVCache: self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) # Current sequence length per batch element (FA3 needs int32) self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device) + # Previous token's normalized embedding for smear (set by model forward pass) + self.prev_embedding = None def reset(self): """Reset cache to empty state.""" self.cache_seqlens.zero_() + self.prev_embedding = None def get_pos(self): """Get current position (assumes all batch elements at same position).""" @@ -129,6 +132,9 @@ class KVCache: self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :] self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :] self.cache_seqlens.fill_(other_pos) + # Copy smear state: expand batch=1 prev_embedding to num_samples + if other.prev_embedding is not None: + self.prev_embedding = other.prev_embedding.expand(self.batch_size, -1, -1).clone() # ----------------------------------------------------------------------------- @torch.inference_mode() diff --git a/nanochat/gpt.py b/nanochat/gpt.py index a8b4f1c..6ab5eee 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -34,7 +34,7 @@ class GPTConfig: n_kv_head: int = 6 # number of key/value heads (GQA) n_embd: int = 768 # Sliding window attention pattern string, tiled across layers. Final layer always L. - # Characters: L=long (full context), S=short (half context) + # Characters: L=long (full context), S=short (quarter context) # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long window_pattern: str = "SSSL" # muP (Maximal Update Parametrization): set > 0 to enable. Value is the base/proxy width. @@ -101,8 +101,8 @@ class CausalSelfAttention(nn.Module): cos, sin = cos_sin q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) q, k = norm(q), norm(k) # QK norm - q = q * 1.15 # sharper attention (split scale between Q and K), TODO think through better - k = k * 1.15 + q = q * 1.2 # sharper attention (split scale between Q and K), TODO think through better + k = k * 1.2 # Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere) # window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context @@ -182,6 +182,11 @@ class GPT(nn.Module): # Separate parameters so they can have different optimizer treatment self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() + # Smear: mix previous token's embedding into current token (cheap bigram-like info) + self.smear_gate = Linear(24, 1, bias=False) + self.smear_lambda = nn.Parameter(torch.zeros(1)) + # Backout: subtract cached mid-layer residual before final norm to remove low-level features + self.backout_lambda = nn.Parameter(0.2 * torch.ones(1)) # Value embeddings (ResFormer-style): alternating layers, last layer always included head_dim = config.n_embd // config.n_head kv_dim = config.n_kv_head * head_dim @@ -228,7 +233,7 @@ class GPT(nn.Module): torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers torch.nn.init.uniform_(block.attn.c_k.weight, -s, s) torch.nn.init.uniform_(block.attn.c_v.weight, -s, s) - torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.5, s * 0.5) # 0.5x init scale for c_fc + torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4) # 0.4x init scale for c_fc if self.config.mup_base_width > 0: # muP: output projections use same scale as hidden weights (std = sigma_base/sqrt(m_d)) # Zero init causes attn/FFN outputs to vanish as width increases with muP LR scaling @@ -239,8 +244,13 @@ class GPT(nn.Module): torch.nn.init.zeros_(block.mlp.c_proj.weight) # Per-layer scalars - self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init - self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding + # Per-layer resid init: stronger residual at early layers, weaker at deep layers + n_layer = self.config.n_layer + for i in range(n_layer): + self.resid_lambdas.data[i] = 1.15 - (0.10 * i / max(n_layer - 1, 1)) + # Decaying x0 init: earlier layers get more input embedding blending + for i in range(n_layer): + self.x0_lambdas.data[i] = 0.20 - (0.15 * i / max(n_layer - 1, 1)) # Value embeddings (init like c_v: uniform with same std) for ve in self.value_embeds.values(): @@ -289,13 +299,13 @@ class GPT(nn.Module): - right: how many tokens after current position to attend to (0 for causal) Pattern string is tiled across layers. Final layer always gets L (full context). - Characters: L=long (full context), S=short (half context) + Characters: L=long (full context), S=short (quarter context) """ pattern = config.window_pattern.upper() assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L." # Map characters to window sizes long_window = config.sequence_len - short_window = -(-long_window // 3 // 128) * 128 # ceil to FA3 tile size (2048 -> 768) + short_window = -(-long_window // 4 // 128) * 128 # ceil to FA3 tile size (2048 -> 768) char_to_window = { "L": (long_window, 0), "S": (short_window, 0), @@ -328,7 +338,8 @@ class GPT(nn.Module): # Exclude non-matmul params: embeddings and per-layer scalars value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + - self.resid_lambdas.numel() + self.x0_lambdas.numel()) + self.resid_lambdas.numel() + self.x0_lambdas.numel() + + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window attn_flops = 0 @@ -356,7 +367,7 @@ class GPT(nn.Module): value_embeds = sum(p.numel() for p in self.value_embeds.parameters()) lm_head = sum(p.numel() for p in self.lm_head.parameters()) transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters()) - scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() total = wte + value_embeds + lm_head + transformer_matrices + scalars assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch" return { @@ -379,7 +390,8 @@ class GPT(nn.Module): lm_head_params = list(self.lm_head.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda] + assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params) # Compute LR scaling factors based on mode if use_mup: @@ -422,6 +434,7 @@ class GPT(nn.Module): dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * emb_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01), dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05), dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0 + dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0), ] # Muon groups (matrix params, grouped by shape for stacking) for shape in sorted({p.shape for p in matrix_params}): @@ -448,15 +461,44 @@ class GPT(nn.Module): T0 = 0 if kv_cache is None else kv_cache.get_pos() cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length - # Forward the trunk of the Transformer + # Embed the tokens x = self.transformer.wte(idx) # embed current token x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path) x = norm(x) + + # Smear: mix previous token's embedding into current position (cheap bigram info) + if kv_cache is None: + # Training / naive generate: full sequence available, use fast slice + assert T > 1, "Training forward pass should have T > 1" + gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24])) + x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1) + else: + # KV cache inference: read prev embedding from cache, store current for next step + x_pre_smear = kv_cache.prev_embedding + kv_cache.prev_embedding = x[:, -1:, :] + if T > 1: + # Prefill: apply smear to positions 1+, same as training + gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24])) + x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1) + elif x_pre_smear is not None: + # Decode: single token, use cached prev embedding + gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24])) + x = x + gate * x_pre_smear + + # Forward the trunk of the Transformer x0 = x # save initial normalized embedding for x0 residual + n_layer = self.config.n_layer + backout_layer = n_layer // 2 # cache at halfway point + x_backout = None for i, block in enumerate(self.transformer.h): x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) + if i == backout_layer: + x_backout = x + # Subtract mid-layer residual to remove low-level features before logit projection + if x_backout is not None: + x = x - self.backout_lambda.to(x.dtype) * x_backout x = norm(x) # Forward the lm_head (compute logits) diff --git a/runs/speedrun.sh b/runs/speedrun.sh index fa50694..48fcc68 100644 --- a/runs/speedrun.sh +++ b/runs/speedrun.sh @@ -69,8 +69,8 @@ python -m scripts.tok_eval echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID -# d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 9.5) -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=9.5 --device-batch-size=16 --fp8 --run=$WANDB_RUN +# d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 8) +torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=8 --device-batch-size=16 --fp8 --run=$WANDB_RUN # evaluate the model: CORE metric, BPB on train/val, and draw samples torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16 diff --git a/scripts/base_train.py b/scripts/base_train.py index 67db962..8232854 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -373,11 +373,18 @@ def get_lr_multiplier(it): progress = (num_iterations - it) / warmdown_iters return progress * 1.0 + (1 - progress) * args.final_lr_frac -# Momentum scheduler for Muon optimizer (warms up to 0.97 over the first 400 steps) +# Momentum scheduler for Muon optimizer (warms up to 0.97, warms down to 0.90 during LR warmdown) def get_muon_momentum(it): - frac = min(it / 400, 1) - momentum = (1 - frac) * 0.85 + frac * 0.97 - return momentum + warmdown_iters = round(args.warmdown_ratio * num_iterations) + warmdown_start = num_iterations - warmdown_iters + if it < 400: + frac = it / 400 + return (1 - frac) * 0.85 + frac * 0.97 + elif it >= warmdown_start: + progress = (it - warmdown_start) / warmdown_iters + return 0.97 * (1 - progress) + 0.90 * progress + else: + return 0.97 # Weight decay scheduler for Muon optimizer (cosine decay to zero over the course of training) def get_weight_decay(it):