mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 13:15:21 +00:00
Merge branch 'master' into fix/typo
This commit is contained in:
commit
02e865c2ab
|
|
@ -19,6 +19,7 @@ Presently, the main focus of development is on tuning the pretraining stage, whi
|
|||
| 3 | 2.76 | 0.74645 | 0.2602 | bump total batch size to 1M tokens | Feb 5 2026 | 2c062aa | @karpathy |
|
||||
| 4 | 2.02 | 0.71854 | 0.2571 | change dataset to NVIDIA ClimbMix | Mar 4 2026 | 324e69c | @ddudek @karpathy |
|
||||
| 5 | 1.80 | 0.71808 | 0.2690 | autoresearch [round 1](https://x.com/karpathy/status/2031135152349524125) | Mar 9 2026 | 6ed7d1d | @karpathy |
|
||||
| 5 | 1.65 | 0.71800 | 0.2626 | autoresearch round 2 | Mar 14 2026 | a825e63 | @karpathy |
|
||||
|
||||
The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $43,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 2 hours is ~$48).
|
||||
|
||||
|
|
|
|||
|
|
@ -196,3 +196,7 @@ NOTE: The `val_bpb` is as of this run *NOT* comparable due to the data distribut
|
|||
Achieved Mar 9, 2026 on commit `6ed7d1d`. Exactly the same launch command as Run 4 except `--target-param-data-ratio=8.7`. I ran 5 identical runs, the average CORE was 0.2690, which is quite a bit above the needed threshold of 0.2565. But the reason I didn't decrease the ratio further (i.e. train shorter) is that while the CORE "safety gap" is large, the val_loss safety gap is smaller - 0.71808, which we want to be below the Run 4 val loss of 0.71854. It's likely that we could have reduced the ratio even lower, possibly to 8.6, but it's not worth splitting hairs at this point.
|
||||
|
||||
This commit is special because all of the improvements that went into [this commit](https://github.com/karpathy/nanochat/commit/6ed7d1d82cee16c2e26f45d559ad3338447a6c1b) came from fully autonomous "research" done by a private version of [autoresearch](https://github.com/karpathy/autoresearch) run on a d12 model. I wrote more about this in [this tweet](https://x.com/karpathy/status/2031135152349524125). The changes easily translated from d12 to d24, hence new leaderboard record, taking us from 2.02 hours "time to GPT-2" to 1.80 hours.
|
||||
|
||||
## Run 6
|
||||
|
||||
Achieved Mar 14, 2026 on commit `a825e63`. Exactly the same launch command as Run 4 except `--target-param-data-ratio=8`. Improvements in the architecture are allowing us to train shorter and shorter time. Instead of an undertrained d24 I attempted to train an overtrained d22 but it was worse. This set of changes came from autoresearch round 2, where I asked it to reference the modded-nanogpt repo for inspiration. So the exploration tried out a number of ideas and in particular found a way to incorporate the backout and smear in such a way that they are helpful (I had previously tried them manually a long time ago and they caused regressions). The smear idea in particular is a little bit heavier and bloaty because it is essentially an "early fusion" of context across tokens, producing a kind of a bigram input into the network and allowing it to focus on higher ngrams earlier. But for this reason the code gets a bit more complex and required some changes to inference. I verified with a unit test that the Engine inference is correct compared to the naive inference of `GPT.generate()`. The average of 5 runs was CORE 0.262634 and each of them lasted 1.65 hours (99 minutes).
|
||||
|
|
|
|||
|
|
@ -100,10 +100,13 @@ class KVCache:
|
|||
self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
||||
# Current sequence length per batch element (FA3 needs int32)
|
||||
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
# Previous token's normalized embedding for smear (set by model forward pass)
|
||||
self.prev_embedding = None
|
||||
|
||||
def reset(self):
|
||||
"""Reset cache to empty state."""
|
||||
self.cache_seqlens.zero_()
|
||||
self.prev_embedding = None
|
||||
|
||||
def get_pos(self):
|
||||
"""Get current position (assumes all batch elements at same position)."""
|
||||
|
|
@ -129,6 +132,9 @@ class KVCache:
|
|||
self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
|
||||
self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
|
||||
self.cache_seqlens.fill_(other_pos)
|
||||
# Copy smear state: expand batch=1 prev_embedding to num_samples
|
||||
if other.prev_embedding is not None:
|
||||
self.prev_embedding = other.prev_embedding.expand(self.batch_size, -1, -1).clone()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@torch.inference_mode()
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class GPTConfig:
|
|||
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||
n_embd: int = 768
|
||||
# Sliding window attention pattern string, tiled across layers. Final layer always L.
|
||||
# Characters: L=long (full context), S=short (half context)
|
||||
# Characters: L=long (full context), S=short (quarter context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
window_pattern: str = "SSSL"
|
||||
|
||||
|
|
@ -98,8 +98,8 @@ class CausalSelfAttention(nn.Module):
|
|||
cos, sin = cos_sin
|
||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
||||
q, k = norm(q), norm(k) # QK norm
|
||||
q = q * 1.15 # sharper attention (split scale between Q and K), TODO think through better
|
||||
k = k * 1.15
|
||||
q = q * 1.2 # sharper attention (split scale between Q and K), TODO think through better
|
||||
k = k * 1.2
|
||||
|
||||
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
|
||||
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
|
||||
|
|
@ -179,6 +179,11 @@ class GPT(nn.Module):
|
|||
# Separate parameters so they can have different optimizer treatment
|
||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
# Smear: mix previous token's embedding into current token (cheap bigram-like info)
|
||||
self.smear_gate = Linear(24, 1, bias=False)
|
||||
self.smear_lambda = nn.Parameter(torch.zeros(1))
|
||||
# Backout: subtract cached mid-layer residual before final norm to remove low-level features
|
||||
self.backout_lambda = nn.Parameter(0.2 * torch.ones(1))
|
||||
# Value embeddings (ResFormer-style): alternating layers, last layer always included
|
||||
head_dim = config.n_embd // config.n_head
|
||||
kv_dim = config.n_kv_head * head_dim
|
||||
|
|
@ -221,12 +226,17 @@ class GPT(nn.Module):
|
|||
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
||||
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.5, s * 0.5) # 0.5x init scale for c_fc
|
||||
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4) # 0.4x init scale for c_fc
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
|
||||
# Per-layer scalars
|
||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||
self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
|
||||
# Per-layer resid init: stronger residual at early layers, weaker at deep layers
|
||||
n_layer = self.config.n_layer
|
||||
for i in range(n_layer):
|
||||
self.resid_lambdas.data[i] = 1.15 - (0.10 * i / max(n_layer - 1, 1))
|
||||
# Decaying x0 init: earlier layers get more input embedding blending
|
||||
for i in range(n_layer):
|
||||
self.x0_lambdas.data[i] = 0.20 - (0.15 * i / max(n_layer - 1, 1))
|
||||
|
||||
# Value embeddings (init like c_v: uniform with same std)
|
||||
for ve in self.value_embeds.values():
|
||||
|
|
@ -276,13 +286,13 @@ class GPT(nn.Module):
|
|||
- right: how many tokens after current position to attend to (0 for causal)
|
||||
|
||||
Pattern string is tiled across layers. Final layer always gets L (full context).
|
||||
Characters: L=long (full context), S=short (half context)
|
||||
Characters: L=long (full context), S=short (quarter context)
|
||||
"""
|
||||
pattern = config.window_pattern.upper()
|
||||
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
|
||||
# Map characters to window sizes
|
||||
long_window = config.sequence_len
|
||||
short_window = -(-long_window // 3 // 128) * 128 # ceil to FA3 tile size (2048 -> 768)
|
||||
short_window = -(-long_window // 4 // 128) * 128 # ceil to FA3 tile size (2048 -> 768)
|
||||
char_to_window = {
|
||||
"L": (long_window, 0),
|
||||
"S": (short_window, 0),
|
||||
|
|
@ -315,7 +325,8 @@ class GPT(nn.Module):
|
|||
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel() +
|
||||
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel())
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
attn_flops = 0
|
||||
|
|
@ -343,7 +354,7 @@ class GPT(nn.Module):
|
|||
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
||||
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
||||
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()
|
||||
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
||||
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
||||
return {
|
||||
|
|
@ -366,7 +377,8 @@ class GPT(nn.Module):
|
|||
lm_head_params = list(self.lm_head.parameters())
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
||||
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
|
||||
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
|
|
@ -380,6 +392,7 @@ class GPT(nn.Module):
|
|||
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
|
||||
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
|
||||
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
||||
dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
|
||||
]
|
||||
# Muon groups (matrix params, grouped by shape for stacking)
|
||||
for shape in sorted({p.shape for p in matrix_params}):
|
||||
|
|
@ -406,15 +419,44 @@ class GPT(nn.Module):
|
|||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
# Embed the tokens
|
||||
x = self.transformer.wte(idx) # embed current token
|
||||
x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
|
||||
x = norm(x)
|
||||
|
||||
# Smear: mix previous token's embedding into current position (cheap bigram info)
|
||||
if kv_cache is None:
|
||||
# Training / naive generate: full sequence available, use fast slice
|
||||
assert T > 1, "Training forward pass should have T > 1"
|
||||
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
|
||||
x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
|
||||
else:
|
||||
# KV cache inference: read prev embedding from cache, store current for next step
|
||||
x_pre_smear = kv_cache.prev_embedding
|
||||
kv_cache.prev_embedding = x[:, -1:, :]
|
||||
if T > 1:
|
||||
# Prefill: apply smear to positions 1+, same as training
|
||||
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
|
||||
x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
|
||||
elif x_pre_smear is not None:
|
||||
# Decode: single token, use cached prev embedding
|
||||
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24]))
|
||||
x = x + gate * x_pre_smear
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
n_layer = self.config.n_layer
|
||||
backout_layer = n_layer // 2 # cache at halfway point
|
||||
x_backout = None
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
if i == backout_layer:
|
||||
x_backout = x
|
||||
# Subtract mid-layer residual to remove low-level features before logit projection
|
||||
if x_backout is not None:
|
||||
x = x - self.backout_lambda.to(x.dtype) * x_backout
|
||||
x = norm(x)
|
||||
|
||||
# Forward the lm_head (compute logits)
|
||||
|
|
|
|||
|
|
@ -69,8 +69,8 @@ python -m scripts.tok_eval
|
|||
echo "Waiting for dataset download to complete..."
|
||||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
# d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 9.5)
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=9.5 --device-batch-size=16 --fp8 --run=$WANDB_RUN
|
||||
# d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 8)
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=8 --device-batch-size=16 --fp8 --run=$WANDB_RUN
|
||||
# evaluate the model: CORE metric, BPB on train/val, and draw samples
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16
|
||||
|
||||
|
|
|
|||
|
|
@ -367,11 +367,18 @@ def get_lr_multiplier(it):
|
|||
progress = (num_iterations - it) / warmdown_iters
|
||||
return progress * 1.0 + (1 - progress) * args.final_lr_frac
|
||||
|
||||
# Momentum scheduler for Muon optimizer (warms up to 0.97 over the first 400 steps)
|
||||
# Momentum scheduler for Muon optimizer (warms up to 0.97, warms down to 0.90 during LR warmdown)
|
||||
def get_muon_momentum(it):
|
||||
frac = min(it / 400, 1)
|
||||
momentum = (1 - frac) * 0.85 + frac * 0.97
|
||||
return momentum
|
||||
warmdown_iters = round(args.warmdown_ratio * num_iterations)
|
||||
warmdown_start = num_iterations - warmdown_iters
|
||||
if it < 400:
|
||||
frac = it / 400
|
||||
return (1 - frac) * 0.85 + frac * 0.97
|
||||
elif it >= warmdown_start:
|
||||
progress = (it - warmdown_start) / warmdown_iters
|
||||
return 0.97 * (1 - progress) + 0.90 * progress
|
||||
else:
|
||||
return 0.97
|
||||
|
||||
# Weight decay scheduler for Muon optimizer (cosine decay to zero over the course of training)
|
||||
def get_weight_decay(it):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user