diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 208acd1..6c4dbd5 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -14,6 +14,7 @@ Notable features: from functools import partial from dataclasses import dataclass +import math import torch import torch.nn as nn @@ -23,7 +24,7 @@ from nanochat.common import get_dist_info, print0 from nanochat.optim import MuonAdamW, DistMuonAdamW # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere -from nanochat.flash_attention import flash_attn +from nanochat.flash_attention import flash_attn, _sdpa_attention @dataclass class GPTConfig: @@ -37,6 +38,11 @@ class GPTConfig: # Characters: L=long (full context), S=short (half context) # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long window_pattern: str = "SSSL" + # Differential attention (Differential Transformer, arxiv 2410.05258). + # Splits each head's Q and K into two halves (Q1/Q2, K1/K2) and computes + # DiffAttn = (softmax(Q1 K1^T/√d) - λ·softmax(Q2 K2^T/√d)) V, + # cancelling attention noise and promoting sparse patterns. + use_diff_attn: bool = False def norm(x): @@ -72,6 +78,19 @@ class CausalSelfAttention(nn.Module): self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) self.ve_gate_channels = 32 self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None + # Differential attention parameters (paper: arxiv 2410.05258) + self.use_diff_attn = config.use_diff_attn + if self.use_diff_attn: + assert self.head_dim % 2 == 0, "head_dim must be even for differential attention" + self.half_head_dim = self.head_dim // 2 + # Per-layer constant λ_init = 0.8 - 0.6·exp(-0.3·(l-1)), l=1-indexed layer + self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * layer_idx) + # Learnable vectors for the re-parameterized λ per head: + # λ = exp(λq1·λk1) - exp(λq2·λk2) + λ_init + self.lambda_q1 = nn.Parameter(torch.empty(self.n_head, self.half_head_dim)) + self.lambda_k1 = nn.Parameter(torch.empty(self.n_head, self.half_head_dim)) + self.lambda_q2 = nn.Parameter(torch.empty(self.n_head, self.half_head_dim)) + self.lambda_k2 = nn.Parameter(torch.empty(self.n_head, self.half_head_dim)) def forward(self, x, ve, cos_sin, window_size, kv_cache): B, T, C = x.size() @@ -91,32 +110,120 @@ class CausalSelfAttention(nn.Module): # Apply Rotary Embeddings to queries and keys to get relative positional encoding cos, sin = cos_sin q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) - q, k = norm(q), norm(k) # QK norm - # Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere) - # window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context - if kv_cache is None: - # Training: causal attention with optional sliding window - y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size) + if self.use_diff_attn: + y = self._diff_attn_forward(q, k, v, window_size, kv_cache, B, T) else: - # Inference: use flash_attn_with_kvcache which handles cache management - k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx) - y = flash_attn.flash_attn_with_kvcache( - q, k_cache, v_cache, - k=k, v=v, - cache_seqlens=kv_cache.cache_seqlens, - causal=True, - window_size=window_size, - ) - # Advance position after last layer processes - if self.layer_idx == kv_cache.n_layers - 1: - kv_cache.advance(T) + q, k = norm(q), norm(k) # QK norm + # Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere) + # window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context + if kv_cache is None: + # Training: causal attention with optional sliding window + y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size) + else: + # Inference: use flash_attn_with_kvcache which handles cache management + k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx) + y = flash_attn.flash_attn_with_kvcache( + q, k_cache, v_cache, + k=k, v=v, + cache_seqlens=kv_cache.cache_seqlens, + causal=True, + window_size=window_size, + ) + # Advance position after last layer processes + if self.layer_idx == kv_cache.n_layers - 1: + kv_cache.advance(T) # Re-assemble the heads and project back to residual stream y = y.contiguous().view(B, T, -1) y = self.c_proj(y) return y + def _diff_attn_forward(self, q, k, v, window_size, kv_cache, B, T): + """ + Differential attention forward pass. + + Splits each head's Q and K into two halves (Q1/Q2, K1/K2): + DiffAttn = (softmax(Q1 K1^T/√d) - λ·softmax(Q2 K2^T/√d)) V + where λ = exp(λq1·λk1) - exp(λq2·λk2) + λ_init (per-head scalar). + + After computing all heads, applies per-head RMSNorm and scales by (1 - λ_init). + + V retains the full head_dim so output shape matches standard attention. + """ + # Split Q and K each into two halves along the head_dim axis + # q: (B, T, n_head, head_dim) -> q1, q2: (B, T, n_head, half_head_dim) each + # k: (B, T, n_kv_head, head_dim) -> k1, k2: (B, T, n_kv_head, half_head_dim) each + q1 = q[..., :self.half_head_dim] + q2 = q[..., self.half_head_dim:] + k1 = k[..., :self.half_head_dim] + k2 = k[..., self.half_head_dim:] + + # QK norm applied independently to each sub-head + q1, q2 = norm(q1), norm(q2) + k1, k2 = norm(k1), norm(k2) + + if kv_cache is None: + # Training: two separate causal attention computations (FA3 or SDPA) + # attn1, attn2: (B, T, n_head, head_dim) — V dimension is head_dim + attn1 = flash_attn.flash_attn_func(q1, k1, v, causal=True, window_size=window_size) + attn2 = flash_attn.flash_attn_func(q2, k2, v, causal=True, window_size=window_size) + else: + # Inference: manually manage the KV cache. + # The cache stores full K (= K1||K2, head_dim) and V; we split K after reading. + k_full_new = k # (B, T, n_kv_head, head_dim) — stores K1 and K2 concatenated + k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx) + pos = kv_cache.cache_seqlens[0].item() + + # Write new K (full head_dim) and V into the cache + k_cache[:, pos:pos + T] = k_full_new + v_cache[:, pos:pos + T] = v + end_pos = pos + T + + # Advance position after the last layer has written its cache + if self.layer_idx == kv_cache.n_layers - 1: + kv_cache.advance(T) + + # Read the full cached context + k_ctx = k_cache[:, :end_pos] # (B, end_pos, n_kv_head, head_dim) + v_ctx = v_cache[:, :end_pos] # (B, end_pos, n_kv_head, head_dim) + + # Split cached K into K1 and K2 + k1_ctx = k_ctx[..., :self.half_head_dim] + k2_ctx = k_ctx[..., self.half_head_dim:] + + # Compute two attention operations via SDPA (handles decode/prefill/window cases) + enable_gqa = self.n_head != self.n_kv_head + # Transpose to SDPA layout (B, H, T, D) + q1_t = q1.transpose(1, 2) + q2_t = q2.transpose(1, 2) + k1_t = k1_ctx.transpose(1, 2) + k2_t = k2_ctx.transpose(1, 2) + v_t = v_ctx.transpose(1, 2) + attn1 = _sdpa_attention(q1_t, k1_t, v_t, window_size, enable_gqa).transpose(1, 2) + attn2 = _sdpa_attention(q2_t, k2_t, v_t, window_size, enable_gqa).transpose(1, 2) + + # Per-head scalar λ = exp(λq1·λk1) - exp(λq2·λk2) + λ_init + # lambda_q1, lambda_k1: (n_head, half_head_dim) -> dot product -> (n_head,) + lambda_val = ( + torch.exp((self.lambda_q1 * self.lambda_k1).sum(-1)) - + torch.exp((self.lambda_q2 * self.lambda_k2).sum(-1)) + + self.lambda_init + ) # shape: (n_head,) + + # Differential combination: subtract the noise attention scaled by λ + # attn1, attn2: (B, T, n_head, head_dim); lambda_val broadcast over (1, 1, n_head, 1) + y = attn1 - lambda_val[None, None, :, None] * attn2 + + # Per-head RMSNorm (equivalent to GroupNorm with n_head groups). + # F.rms_norm normalizes over the last dim (head_dim) for each (B, T, head) element. + y = norm(y) + + # Fixed scaling by (1 - λ_init) as specified in the paper + y = y * (1.0 - self.lambda_init) + + return y + class MLP(nn.Module): def __init__(self, config): @@ -229,6 +336,14 @@ class GPT(nn.Module): if block.attn.ve_gate is not None: torch.nn.init.zeros_(block.attn.ve_gate.weight) + # Differential attention lambda vectors: Normal(0, 0.1) as per the paper + for block in self.transformer.h: + if block.attn.use_diff_attn: + torch.nn.init.normal_(block.attn.lambda_q1, mean=0.0, std=0.1) + torch.nn.init.normal_(block.attn.lambda_k1, mean=0.0, std=0.1) + torch.nn.init.normal_(block.attn.lambda_q2, mean=0.0, std=0.1) + torch.nn.init.normal_(block.attn.lambda_k2, mean=0.0, std=0.1) + # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) @@ -349,14 +464,23 @@ class GPT(nn.Module): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() + # Differential attention lambda vectors live inside transformer.h but must be + # excluded from Muon (which expects square weight matrices) and trained with AdamW. + _diff_lambda_names = {'lambda_q1', 'lambda_k1', 'lambda_q2', 'lambda_k2'} + diff_lambda_params = [ + p for name, p in self.transformer.h.named_parameters() + if name.split('.')[-1] in _diff_lambda_names + ] + _diff_lambda_ids = {id(p) for p in diff_lambda_params} + # Separate out all parameters into groups - matrix_params = list(self.transformer.h.parameters()) + matrix_params = [p for p in self.transformer.h.parameters() if id(p) not in _diff_lambda_ids] value_embeds_params = list(self.value_embeds.parameters()) embedding_params = list(self.transformer.wte.parameters()) lm_head_params = list(self.lm_head.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(diff_lambda_params) # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 @@ -370,6 +494,8 @@ class GPT(nn.Module): dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0), dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0 + # Differential attention lambda vectors trained with AdamW (not Muon) + dict(kind='adamw', params=diff_lambda_params, lr=matrix_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), ] # Muon groups (matrix params, grouped by shape for stacking) for shape in sorted({p.shape for p in matrix_params}): diff --git a/scripts/base_train.py b/scripts/base_train.py index 24091b6..4882cfb 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -51,6 +51,7 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") +parser.add_argument("--use-diff-attn", action="store_true", help="enable Differential Attention (arxiv 2410.05258)") # Training horizon (only one used, in order of precedence) parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") @@ -133,6 +134,7 @@ def build_model_meta(depth): sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, window_pattern=args.window_pattern, + use_diff_attn=args.use_diff_attn, ) with torch.device("meta"): model_meta = GPT(config)