From e3f58b838e98a5ea013a3c1773fde9d4a3c5d090 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 16 Jan 2026 20:59:42 +0000 Subject: [PATCH 1/9] ranked version --- nanochat/gpt.py | 48 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 86f440bf..ffb7862a 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -68,7 +68,7 @@ class CausalSelfAttention(nn.Module): self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) - def forward(self, x, cos_sin, window_size, kv_cache): + def forward(self, x, cos_sin, window_size, kv_cache, v0, v0_lambda): B, T, C = x.size() # Project the input to get queries, keys, and values @@ -77,6 +77,11 @@ class CausalSelfAttention(nn.Module): k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) + # Value residual (ResFormer): mix in projected initial embedding for later layers + if v0 is not None: + v0_reshaped = v0.view(B, T, self.n_kv_head, self.head_dim) + v = v + v0_lambda * v0_reshaped + # Apply Rotary Embeddings to queries and keys to get relative positional encoding cos, sin = cos_sin q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) @@ -126,8 +131,8 @@ class Block(nn.Module): self.attn = CausalSelfAttention(config, layer_idx) self.mlp = MLP(config) - def forward(self, x, cos_sin, window_size, kv_cache): - x = x + self.attn(norm(x), cos_sin, window_size, kv_cache) + def forward(self, x, cos_sin, window_size, kv_cache, v0, v0_lambda): + x = x + self.attn(norm(x), cos_sin, window_size, kv_cache, v0, v0_lambda) x = x + self.mlp(norm(x)) return x @@ -160,6 +165,17 @@ class GPT(nn.Module): # Separate parameters so they can have different optimizer treatment self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() + # Value residual (ResFormer-style): low-rank factorized embedding for value residual + # Paper: "Value Residual Learning" (arXiv:2410.17897) shows this improves information flow + # We apply to last 1/4 of layers as the paper shows later layers benefit most + # Low-rank factorization: (vocab, r) @ (r, kv_dim) instead of full (vocab, kv_dim) + head_dim = config.n_embd // config.n_head + kv_dim = config.n_kv_head * head_dim + value_rank = 32 # low-rank bottleneck dimension + self.value_embed_A = nn.Embedding(padded_vocab_size, value_rank) # token -> low-rank + self.value_embed_B = nn.Linear(value_rank, kv_dim, bias=False) # low-rank -> kv_dim + self.v0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() + self.value_residual_start = config.n_layer - config.n_layer // 4 # last 1/4 of layers # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them by 10X, but assert fail if we ever reach that amount. @@ -204,15 +220,21 @@ class GPT(nn.Module): with torch.no_grad(): self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init + self.v0_lambdas.fill_(0.0) # 0.0 => value residual is disabled at init + + # Value embedding low-rank factors (init like embeddings/projections) + torch.nn.init.normal_(self.value_embed_A.weight, mean=0.0, std=1.0) # like wte + torch.nn.init.uniform_(self.value_embed_B.weight, -s, s) # like c_v # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.cos, self.sin = cos, sin - # Cast token embeddings to bf16: optimizer can tolerate it and it saves memory + # Cast embeddings to bf16: optimizer can tolerate it and it saves memory if self.transformer.wte.weight.device.type == "cuda": self.transformer.wte.to(dtype=torch.bfloat16) + self.value_embed_A.to(dtype=torch.bfloat16) def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): # TODO: bump base theta more? e.g. 100K is more common more recently @@ -277,7 +299,8 @@ class GPT(nn.Module): """ nparams = sum(p.numel() for p in self.parameters()) # Exclude non-matmul params: embeddings and per-layer scalars - nparams_exclude = self.transformer.wte.weight.numel() + self.resid_lambdas.numel() + self.x0_lambdas.numel() + nparams_exclude = (self.transformer.wte.weight.numel() + self.value_embed_A.weight.numel() + + self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.v0_lambdas.numel()) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window attn_flops = 0 @@ -303,13 +326,16 @@ class GPT(nn.Module): def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() - # Separate out all parameters into 5 groups (matrix, embedding, lm_head, resid_lambdas, x0_lambdas) + # Separate out all parameters into groups (matrix, embedding, lm_head, value_embed, resid_lambdas, x0_lambdas, v0_lambdas) matrix_params = list(self.transformer.h.parameters()) embedding_params = list(self.transformer.wte.parameters()) lm_head_params = list(self.lm_head.parameters()) + value_embed_A_params = list(self.value_embed_A.parameters()) + value_embed_B_params = list(self.value_embed_B.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(resid_params) + len(x0_params) + v0_params = [self.v0_lambdas] + assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embed_A_params) + len(value_embed_B_params) + len(resid_params) + len(x0_params) + len(v0_params) # Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 @@ -317,8 +343,11 @@ class GPT(nn.Module): adam_groups = [ dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), + dict(params=value_embed_A_params, lr=embedding_lr * dmodel_lr_scale), # low-rank embedding + dict(params=value_embed_B_params, lr=embedding_lr * dmodel_lr_scale), # low-rank projection dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream dict(params=x0_params, lr=scalar_lr), + dict(params=v0_params, lr=scalar_lr), ] adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True) @@ -349,9 +378,12 @@ class GPT(nn.Module): x = self.transformer.wte(idx) x = norm(x) x0 = x # save initial normalized embedding for x0 residual + # Value residual (ResFormer): low-rank factorized embedding for later layers + v0 = self.value_embed_B(self.value_embed_A(idx)) # (B, T, kv_dim) for i, block in enumerate(self.transformer.h): x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 - x = block(x, cos_sin, self.window_sizes[i], kv_cache) + v0_for_layer = v0 if i >= self.value_residual_start else None + x = block(x, cos_sin, self.window_sizes[i], kv_cache, v0_for_layer, self.v0_lambdas[i]) x = norm(x) # Forward the lm_head (compute logits) From 0b58d70e9975d42b4357dfb33f321f764759af9f Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 16 Jan 2026 21:16:47 +0000 Subject: [PATCH 2/9] full ve version works very well --- nanochat/gpt.py | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index ffb7862a..0356413d 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -165,15 +165,12 @@ class GPT(nn.Module): # Separate parameters so they can have different optimizer treatment self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() - # Value residual (ResFormer-style): low-rank factorized embedding for value residual + # Value residual (ResFormer-style): separate embedding for values, mixed into later layers # Paper: "Value Residual Learning" (arXiv:2410.17897) shows this improves information flow # We apply to last 1/4 of layers as the paper shows later layers benefit most - # Low-rank factorization: (vocab, r) @ (r, kv_dim) instead of full (vocab, kv_dim) head_dim = config.n_embd // config.n_head kv_dim = config.n_kv_head * head_dim - value_rank = 32 # low-rank bottleneck dimension - self.value_embed_A = nn.Embedding(padded_vocab_size, value_rank) # token -> low-rank - self.value_embed_B = nn.Linear(value_rank, kv_dim, bias=False) # low-rank -> kv_dim + self.value_embed = nn.Embedding(padded_vocab_size, kv_dim) self.v0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() self.value_residual_start = config.n_layer - config.n_layer // 4 # last 1/4 of layers # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. @@ -222,9 +219,8 @@ class GPT(nn.Module): self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init self.v0_lambdas.fill_(0.0) # 0.0 => value residual is disabled at init - # Value embedding low-rank factors (init like embeddings/projections) - torch.nn.init.normal_(self.value_embed_A.weight, mean=0.0, std=1.0) # like wte - torch.nn.init.uniform_(self.value_embed_B.weight, -s, s) # like c_v + # Value embedding (init like c_v: uniform with same std) + torch.nn.init.uniform_(self.value_embed.weight, -s, s) # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head @@ -234,7 +230,7 @@ class GPT(nn.Module): # Cast embeddings to bf16: optimizer can tolerate it and it saves memory if self.transformer.wte.weight.device.type == "cuda": self.transformer.wte.to(dtype=torch.bfloat16) - self.value_embed_A.to(dtype=torch.bfloat16) + self.value_embed.to(dtype=torch.bfloat16) def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): # TODO: bump base theta more? e.g. 100K is more common more recently @@ -299,7 +295,7 @@ class GPT(nn.Module): """ nparams = sum(p.numel() for p in self.parameters()) # Exclude non-matmul params: embeddings and per-layer scalars - nparams_exclude = (self.transformer.wte.weight.numel() + self.value_embed_A.weight.numel() + + nparams_exclude = (self.transformer.wte.weight.numel() + self.value_embed.weight.numel() + self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.v0_lambdas.numel()) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window @@ -330,12 +326,11 @@ class GPT(nn.Module): matrix_params = list(self.transformer.h.parameters()) embedding_params = list(self.transformer.wte.parameters()) lm_head_params = list(self.lm_head.parameters()) - value_embed_A_params = list(self.value_embed_A.parameters()) - value_embed_B_params = list(self.value_embed_B.parameters()) + value_embed_params = list(self.value_embed.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] v0_params = [self.v0_lambdas] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embed_A_params) + len(value_embed_B_params) + len(resid_params) + len(x0_params) + len(v0_params) + assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embed_params) + len(resid_params) + len(x0_params) + len(v0_params) # Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 @@ -343,8 +338,7 @@ class GPT(nn.Module): adam_groups = [ dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), - dict(params=value_embed_A_params, lr=embedding_lr * dmodel_lr_scale), # low-rank embedding - dict(params=value_embed_B_params, lr=embedding_lr * dmodel_lr_scale), # low-rank projection + dict(params=value_embed_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream dict(params=x0_params, lr=scalar_lr), dict(params=v0_params, lr=scalar_lr), @@ -378,8 +372,8 @@ class GPT(nn.Module): x = self.transformer.wte(idx) x = norm(x) x0 = x # save initial normalized embedding for x0 residual - # Value residual (ResFormer): low-rank factorized embedding for later layers - v0 = self.value_embed_B(self.value_embed_A(idx)) # (B, T, kv_dim) + # Value residual (ResFormer): separate value embedding for later layers + v0 = self.value_embed(idx) # (B, T, kv_dim) for i, block in enumerate(self.transformer.h): x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 v0_for_layer = v0 if i >= self.value_residual_start else None From 9a88194c3f684a3418c0c0f4069e6f3b3af10736 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 16 Jan 2026 22:08:52 +0000 Subject: [PATCH 3/9] simply one VE per layer, works best --- nanochat/gpt.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 0356413d..ea7a4d86 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -165,14 +165,12 @@ class GPT(nn.Module): # Separate parameters so they can have different optimizer treatment self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() - # Value residual (ResFormer-style): separate embedding for values, mixed into later layers + # Value residual (ResFormer-style): every layer gets its own value embedding # Paper: "Value Residual Learning" (arXiv:2410.17897) shows this improves information flow - # We apply to last 1/4 of layers as the paper shows later layers benefit most head_dim = config.n_embd // config.n_head kv_dim = config.n_kv_head * head_dim - self.value_embed = nn.Embedding(padded_vocab_size, kv_dim) - self.v0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() - self.value_residual_start = config.n_layer - config.n_layer // 4 # last 1/4 of layers + self.value_embeds = nn.ModuleList([nn.Embedding(padded_vocab_size, kv_dim) for _ in range(config.n_layer)]) + self.v0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them by 10X, but assert fail if we ever reach that amount. @@ -219,8 +217,9 @@ class GPT(nn.Module): self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init self.v0_lambdas.fill_(0.0) # 0.0 => value residual is disabled at init - # Value embedding (init like c_v: uniform with same std) - torch.nn.init.uniform_(self.value_embed.weight, -s, s) + # Value embeddings (init like c_v: uniform with same std) + for ve in self.value_embeds: + torch.nn.init.uniform_(ve.weight, -s, s) # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head @@ -230,7 +229,8 @@ class GPT(nn.Module): # Cast embeddings to bf16: optimizer can tolerate it and it saves memory if self.transformer.wte.weight.device.type == "cuda": self.transformer.wte.to(dtype=torch.bfloat16) - self.value_embed.to(dtype=torch.bfloat16) + for ve in self.value_embeds: + ve.to(dtype=torch.bfloat16) def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): # TODO: bump base theta more? e.g. 100K is more common more recently @@ -295,7 +295,8 @@ class GPT(nn.Module): """ nparams = sum(p.numel() for p in self.parameters()) # Exclude non-matmul params: embeddings and per-layer scalars - nparams_exclude = (self.transformer.wte.weight.numel() + self.value_embed.weight.numel() + + value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds) + nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.v0_lambdas.numel()) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window @@ -322,15 +323,15 @@ class GPT(nn.Module): def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() - # Separate out all parameters into groups (matrix, embedding, lm_head, value_embed, resid_lambdas, x0_lambdas, v0_lambdas) + # Separate out all parameters into groups (matrix, embedding, lm_head, value_embeds, resid_lambdas, x0_lambdas, v0_lambdas) matrix_params = list(self.transformer.h.parameters()) embedding_params = list(self.transformer.wte.parameters()) lm_head_params = list(self.lm_head.parameters()) - value_embed_params = list(self.value_embed.parameters()) + value_embeds_params = list(self.value_embeds.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] v0_params = [self.v0_lambdas] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embed_params) + len(resid_params) + len(x0_params) + len(v0_params) + assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(v0_params) # Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 @@ -338,7 +339,7 @@ class GPT(nn.Module): adam_groups = [ dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), - dict(params=value_embed_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding + dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream dict(params=x0_params, lr=scalar_lr), dict(params=v0_params, lr=scalar_lr), @@ -372,12 +373,11 @@ class GPT(nn.Module): x = self.transformer.wte(idx) x = norm(x) x0 = x # save initial normalized embedding for x0 residual - # Value residual (ResFormer): separate value embedding for later layers - v0 = self.value_embed(idx) # (B, T, kv_dim) + # Value residual (ResFormer): every layer gets its own value embedding + v0s = [ve(idx) for ve in self.value_embeds] # n_layer x (B, T, kv_dim) for i, block in enumerate(self.transformer.h): x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 - v0_for_layer = v0 if i >= self.value_residual_start else None - x = block(x, cos_sin, self.window_sizes[i], kv_cache, v0_for_layer, self.v0_lambdas[i]) + x = block(x, cos_sin, self.window_sizes[i], kv_cache, v0s[i], self.v0_lambdas[i]) x = norm(x) # Forward the lm_head (compute logits) From e85db6b4a4351eb562bec220b3bbcaad28be6722 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 16 Jan 2026 23:52:12 +0000 Subject: [PATCH 4/9] alternating design --- nanochat/gpt.py | 60 +++++++++++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index ea7a4d86..a077256e 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -45,6 +45,10 @@ def norm(x): return F.rms_norm(x, (x.size(-1),)) +def has_ve(layer_idx, n_layer): + """Returns True if GPT layer should have Value Embedding (alternating, last layer always included).""" + return layer_idx % 2 == (n_layer - 1) % 2 + def apply_rotary_emb(x, cos, sin): assert x.ndim == 4 # multihead attention d = x.shape[3] // 2 @@ -67,8 +71,10 @@ class CausalSelfAttention(nn.Module): self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) + self.ve_gate_channels = 32 + self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None - def forward(self, x, cos_sin, window_size, kv_cache, v0, v0_lambda): + def forward(self, x, ve, cos_sin, window_size, kv_cache): B, T, C = x.size() # Project the input to get queries, keys, and values @@ -77,10 +83,11 @@ class CausalSelfAttention(nn.Module): k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) - # Value residual (ResFormer): mix in projected initial embedding for later layers - if v0 is not None: - v0_reshaped = v0.view(B, T, self.n_kv_head, self.head_dim) - v = v + v0_lambda * v0_reshaped + # Value residual (ResFormer): mix in value embedding with input-dependent gate per head + if ve is not None: + ve = ve.view(B, T, self.n_kv_head, self.head_dim) + gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 2) + v = v + gate.unsqueeze(-1) * ve # Apply Rotary Embeddings to queries and keys to get relative positional encoding cos, sin = cos_sin @@ -131,8 +138,8 @@ class Block(nn.Module): self.attn = CausalSelfAttention(config, layer_idx) self.mlp = MLP(config) - def forward(self, x, cos_sin, window_size, kv_cache, v0, v0_lambda): - x = x + self.attn(norm(x), cos_sin, window_size, kv_cache, v0, v0_lambda) + def forward(self, x, ve, cos_sin, window_size, kv_cache): + x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache) x = x + self.mlp(norm(x)) return x @@ -165,12 +172,10 @@ class GPT(nn.Module): # Separate parameters so they can have different optimizer treatment self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() - # Value residual (ResFormer-style): every layer gets its own value embedding - # Paper: "Value Residual Learning" (arXiv:2410.17897) shows this improves information flow + # Value embeddings (ResFormer-style): alternating layers, last layer always included head_dim = config.n_embd // config.n_head kv_dim = config.n_kv_head * head_dim - self.value_embeds = nn.ModuleList([nn.Embedding(padded_vocab_size, kv_dim) for _ in range(config.n_layer)]) - self.v0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() + self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)}) # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them by 10X, but assert fail if we ever reach that amount. @@ -181,6 +186,7 @@ class GPT(nn.Module): self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("sin", sin, persistent=False) + @torch.no_grad() def init_weights(self): """ Initialize the full model in this one function for maximum clarity. @@ -212,15 +218,18 @@ class GPT(nn.Module): torch.nn.init.zeros_(block.mlp.c_proj.weight) # Per-layer scalars - with torch.no_grad(): - self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init - self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init - self.v0_lambdas.fill_(0.0) # 0.0 => value residual is disabled at init + self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init + self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init # Value embeddings (init like c_v: uniform with same std) - for ve in self.value_embeds: + for ve in self.value_embeds.values(): torch.nn.init.uniform_(ve.weight, -s, s) + # Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral) + for block in self.transformer.h: + if block.attn.ve_gate is not None: + torch.nn.init.zeros_(block.attn.ve_gate.weight) + # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) @@ -229,7 +238,7 @@ class GPT(nn.Module): # Cast embeddings to bf16: optimizer can tolerate it and it saves memory if self.transformer.wte.weight.device.type == "cuda": self.transformer.wte.to(dtype=torch.bfloat16) - for ve in self.value_embeds: + for ve in self.value_embeds.values(): ve.to(dtype=torch.bfloat16) def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): @@ -295,9 +304,9 @@ class GPT(nn.Module): """ nparams = sum(p.numel() for p in self.parameters()) # Exclude non-matmul params: embeddings and per-layer scalars - value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds) + value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + - self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.v0_lambdas.numel()) + self.resid_lambdas.numel() + self.x0_lambdas.numel()) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window attn_flops = 0 @@ -323,15 +332,14 @@ class GPT(nn.Module): def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() - # Separate out all parameters into groups (matrix, embedding, lm_head, value_embeds, resid_lambdas, x0_lambdas, v0_lambdas) + # Separate out all parameters into groups matrix_params = list(self.transformer.h.parameters()) + value_embeds_params = list(self.value_embeds.parameters()) embedding_params = list(self.transformer.wte.parameters()) lm_head_params = list(self.lm_head.parameters()) - value_embeds_params = list(self.value_embeds.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] - v0_params = [self.v0_lambdas] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(v0_params) + assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) # Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 @@ -342,7 +350,6 @@ class GPT(nn.Module): dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream dict(params=x0_params, lr=scalar_lr), - dict(params=v0_params, lr=scalar_lr), ] adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True) @@ -373,11 +380,10 @@ class GPT(nn.Module): x = self.transformer.wte(idx) x = norm(x) x0 = x # save initial normalized embedding for x0 residual - # Value residual (ResFormer): every layer gets its own value embedding - v0s = [ve(idx) for ve in self.value_embeds] # n_layer x (B, T, kv_dim) for i, block in enumerate(self.transformer.h): x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 - x = block(x, cos_sin, self.window_sizes[i], kv_cache, v0s[i], self.v0_lambdas[i]) + ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None + x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) x = norm(x) # Forward the lm_head (compute logits) From 3b95d4fd392fb4d593adb80530e80c8009d06f75 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sat, 17 Jan 2026 00:23:30 +0000 Subject: [PATCH 5/9] allow label for scaling laws script --- scaling_laws.sh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scaling_laws.sh b/scaling_laws.sh index 321b286a..7c269c6a 100644 --- a/scaling_laws.sh +++ b/scaling_laws.sh @@ -1,5 +1,7 @@ #!/bin/bash +LABEL="jan16" + FLOPS_BUDGETS=( 1e18 3e18 @@ -7,14 +9,14 @@ FLOPS_BUDGETS=( ) DEPTHS=(8 10 12 14 16 18 20) NPROC_PER_NODE="${NPROC_PER_NODE:-8}" -WANDB_RUN="${WANDB_RUN:-scaling}" +WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}" EVAL_TOKENS=$((100 * 524288)) # ~100M tokens for final eval (default is ~10M) export OMP_NUM_THREADS=1 export NANOCHAT_BASE_DIR="${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat}" source .venv/bin/activate -RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results" +RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results_${LABEL}" mkdir -p "$RESULTS_DIR" RESULTS_FILE="$RESULTS_DIR/results.csv" From e7ed2082b836ac21e45020759e799c3bf1d511fe Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sat, 17 Jan 2026 21:16:46 +0000 Subject: [PATCH 6/9] update the default GPTConfig kwargs otherwise they are confusing --- nanochat/gpt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index a077256e..cb4bd05b 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -28,8 +28,8 @@ from nanochat.flash_attention import flash_attn @dataclass class GPTConfig: - sequence_len: int = 1024 - vocab_size: int = 50304 + sequence_len: int = 2048 + vocab_size: int = 32768 n_layer: int = 12 n_head: int = 6 # number of query heads n_kv_head: int = 6 # number of key/value heads (GQA) @@ -37,7 +37,7 @@ class GPTConfig: # Sliding window attention pattern string, tiled across layers. Final layer always L. # Characters: L=long (full context), S=short (half context) # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long - window_pattern: str = "L" + window_pattern: str = "SSSL" def norm(x): From 413e91aa0f5f3f841dbdc0009e64811cf75c5a9d Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sat, 17 Jan 2026 23:51:09 +0000 Subject: [PATCH 7/9] optimal ratio is now around 4 --- scripts/base_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index c61986e6..bb8d8a68 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -47,7 +47,7 @@ parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding # Training horizon (only one used, in order of precedence) parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") -parser.add_argument("--target-param-data-ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") +parser.add_argument("--target-param-data-ratio", type=int, default=4, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") # Optimization parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") From cf5c9e5b8eb2e06c7c2c1c4a280ed95a7f4aa68d Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 18 Jan 2026 00:07:08 +0000 Subject: [PATCH 8/9] resolve a crash for odd depths because FA3 needs head_dim % 8 == 0 --- scripts/base_train.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index bb8d8a68..bcbd4841 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -106,21 +106,19 @@ vocab_size = tokenizer.get_vocab_size() print0(f"Vocab size: {vocab_size:,}") # Model kwargs are derived from the desired depth of the model +# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division +# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly) +# (For very small depths, this gives a slight "unfair" advantage to models with odd depths) num_layers = args.depth -model_dim = args.depth * args.aspect_ratio -def find_num_heads(model_dim, target_head_dim): - # Find num_heads that divides model_dim evenly, with head_dim closest to target. - ideal = max(1, round(model_dim / target_head_dim)) - for offset in range(model_dim): - for candidate in [ideal + offset, ideal - offset]: - if candidate > 0 and model_dim % candidate == 0: - return candidate - return 1 -num_heads = find_num_heads(model_dim, args.head_dim) +base_dim = args.depth * args.aspect_ratio +model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim +num_heads = model_dim // args.head_dim num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled) +head_dim = model_dim // num_heads print0(f"num_layers: {num_layers}") -print0(f"model_dim: {model_dim}") +print0(f"model_dim: {model_dim} (base: {base_dim}, nudge: {model_dim - base_dim:+d})") print0(f"num_heads: {num_heads}") +print0(f"head_dim: {head_dim}") print0(f"num_kv_heads: {num_kv_heads}") # Optimizer / data / training length related hyperparameters From babde18ce1cb59cb3d36f8874d1248983c7ba9c3 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 18 Jan 2026 03:00:38 +0000 Subject: [PATCH 9/9] small tweaks --- miniseries.sh | 1 - scaling_laws.sh | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/miniseries.sh b/miniseries.sh index 9a4512b6..c42544e3 100644 --- a/miniseries.sh +++ b/miniseries.sh @@ -61,7 +61,6 @@ for d in "${DEPTHS[@]}"; do # No --target-flops, let it use the default ratio from base_train torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \ --depth=$d \ - --target-param-data-ratio=8 \ --run="${WANDB_RUN}_d${d}" \ --model-tag="${TAG}" \ --core-metric-every=999999 \ diff --git a/scaling_laws.sh b/scaling_laws.sh index 7c269c6a..1f9dab87 100644 --- a/scaling_laws.sh +++ b/scaling_laws.sh @@ -7,7 +7,8 @@ FLOPS_BUDGETS=( 3e18 6e18 ) -DEPTHS=(8 10 12 14 16 18 20) +DEPTHS=(6 7 8 9 10 11 12 13 14) + NPROC_PER_NODE="${NPROC_PER_NODE:-8}" WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}" EVAL_TOKENS=$((100 * 524288)) # ~100M tokens for final eval (default is ~10M)