From e85db6b4a4351eb562bec220b3bbcaad28be6722 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 16 Jan 2026 23:52:12 +0000 Subject: [PATCH] alternating design --- nanochat/gpt.py | 60 +++++++++++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index ea7a4d8..a077256 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -45,6 +45,10 @@ def norm(x): return F.rms_norm(x, (x.size(-1),)) +def has_ve(layer_idx, n_layer): + """Returns True if GPT layer should have Value Embedding (alternating, last layer always included).""" + return layer_idx % 2 == (n_layer - 1) % 2 + def apply_rotary_emb(x, cos, sin): assert x.ndim == 4 # multihead attention d = x.shape[3] // 2 @@ -67,8 +71,10 @@ class CausalSelfAttention(nn.Module): self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) + self.ve_gate_channels = 32 + self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None - def forward(self, x, cos_sin, window_size, kv_cache, v0, v0_lambda): + def forward(self, x, ve, cos_sin, window_size, kv_cache): B, T, C = x.size() # Project the input to get queries, keys, and values @@ -77,10 +83,11 @@ class CausalSelfAttention(nn.Module): k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) - # Value residual (ResFormer): mix in projected initial embedding for later layers - if v0 is not None: - v0_reshaped = v0.view(B, T, self.n_kv_head, self.head_dim) - v = v + v0_lambda * v0_reshaped + # Value residual (ResFormer): mix in value embedding with input-dependent gate per head + if ve is not None: + ve = ve.view(B, T, self.n_kv_head, self.head_dim) + gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 2) + v = v + gate.unsqueeze(-1) * ve # Apply Rotary Embeddings to queries and keys to get relative positional encoding cos, sin = cos_sin @@ -131,8 +138,8 @@ class Block(nn.Module): self.attn = CausalSelfAttention(config, layer_idx) self.mlp = MLP(config) - def forward(self, x, cos_sin, window_size, kv_cache, v0, v0_lambda): - x = x + self.attn(norm(x), cos_sin, window_size, kv_cache, v0, v0_lambda) + def forward(self, x, ve, cos_sin, window_size, kv_cache): + x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache) x = x + self.mlp(norm(x)) return x @@ -165,12 +172,10 @@ class GPT(nn.Module): # Separate parameters so they can have different optimizer treatment self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() - # Value residual (ResFormer-style): every layer gets its own value embedding - # Paper: "Value Residual Learning" (arXiv:2410.17897) shows this improves information flow + # Value embeddings (ResFormer-style): alternating layers, last layer always included head_dim = config.n_embd // config.n_head kv_dim = config.n_kv_head * head_dim - self.value_embeds = nn.ModuleList([nn.Embedding(padded_vocab_size, kv_dim) for _ in range(config.n_layer)]) - self.v0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() + self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)}) # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them by 10X, but assert fail if we ever reach that amount. @@ -181,6 +186,7 @@ class GPT(nn.Module): self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("sin", sin, persistent=False) + @torch.no_grad() def init_weights(self): """ Initialize the full model in this one function for maximum clarity. @@ -212,15 +218,18 @@ class GPT(nn.Module): torch.nn.init.zeros_(block.mlp.c_proj.weight) # Per-layer scalars - with torch.no_grad(): - self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init - self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init - self.v0_lambdas.fill_(0.0) # 0.0 => value residual is disabled at init + self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init + self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init # Value embeddings (init like c_v: uniform with same std) - for ve in self.value_embeds: + for ve in self.value_embeds.values(): torch.nn.init.uniform_(ve.weight, -s, s) + # Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral) + for block in self.transformer.h: + if block.attn.ve_gate is not None: + torch.nn.init.zeros_(block.attn.ve_gate.weight) + # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) @@ -229,7 +238,7 @@ class GPT(nn.Module): # Cast embeddings to bf16: optimizer can tolerate it and it saves memory if self.transformer.wte.weight.device.type == "cuda": self.transformer.wte.to(dtype=torch.bfloat16) - for ve in self.value_embeds: + for ve in self.value_embeds.values(): ve.to(dtype=torch.bfloat16) def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): @@ -295,9 +304,9 @@ class GPT(nn.Module): """ nparams = sum(p.numel() for p in self.parameters()) # Exclude non-matmul params: embeddings and per-layer scalars - value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds) + value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + - self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.v0_lambdas.numel()) + self.resid_lambdas.numel() + self.x0_lambdas.numel()) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window attn_flops = 0 @@ -323,15 +332,14 @@ class GPT(nn.Module): def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() - # Separate out all parameters into groups (matrix, embedding, lm_head, value_embeds, resid_lambdas, x0_lambdas, v0_lambdas) + # Separate out all parameters into groups matrix_params = list(self.transformer.h.parameters()) + value_embeds_params = list(self.value_embeds.parameters()) embedding_params = list(self.transformer.wte.parameters()) lm_head_params = list(self.lm_head.parameters()) - value_embeds_params = list(self.value_embeds.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] - v0_params = [self.v0_lambdas] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(v0_params) + assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) # Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 @@ -342,7 +350,6 @@ class GPT(nn.Module): dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream dict(params=x0_params, lr=scalar_lr), - dict(params=v0_params, lr=scalar_lr), ] adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True) @@ -373,11 +380,10 @@ class GPT(nn.Module): x = self.transformer.wte(idx) x = norm(x) x0 = x # save initial normalized embedding for x0 residual - # Value residual (ResFormer): every layer gets its own value embedding - v0s = [ve(idx) for ve in self.value_embeds] # n_layer x (B, T, kv_dim) for i, block in enumerate(self.transformer.h): x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 - x = block(x, cos_sin, self.window_sizes[i], kv_cache, v0s[i], self.v0_lambdas[i]) + ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None + x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) x = norm(x) # Forward the lm_head (compute logits)