From a088b7a6ec6e144b99a2a89d0b1f772198abcb97 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 21 Oct 2025 18:07:33 +0000 Subject: [PATCH] use enable_gqa of pytorch sdpa, allows us to delete some code, didnt realize it's available --- nanochat/gpt.py | 24 ++++-------------------- scripts/base_train.py | 2 +- 2 files changed, 5 insertions(+), 21 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index d744550..b640f1e 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -48,19 +48,6 @@ def apply_rotary_emb(x, cos, sin): out = out.to(x.dtype) # ensure input/output dtypes match return out - -def repeat_kv(x, n_rep): - """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" - if n_rep == 1: - return x - bs, n_kv_heads, slen, head_dim = x.shape - return ( - x[:, :, None, :, :] - .expand(bs, n_kv_heads, n_rep, slen, head_dim) - .reshape(bs, n_kv_heads * n_rep, slen, head_dim) - ) - - class CausalSelfAttention(nn.Module): def __init__(self, config, layer_idx): super().__init__() @@ -96,19 +83,16 @@ class CausalSelfAttention(nn.Module): Tq = q.size(2) # number of queries in this forward pass Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass) - # Apply MQA: replicate the key/value heads for each query head - nrep = self.n_head // self.n_kv_head - k, v = repeat_kv(k, nrep), repeat_kv(v, nrep) - # Attention: queries attend to keys/values autoregressively. A few cases to handle: + enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired if kv_cache is None or Tq == Tk: # During training (no KV cache), attend as usual with causal attention # And even if there is KV cache, we can still use this simple version when Tq == Tk - y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa) elif Tq == 1: # During inference but with a single query in this forward pass: # The query has to attend to all the keys/values in the cache - y = F.scaled_dot_product_attention(q, k, v, is_causal=False) + y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) else: # During inference AND we have a chunk of queries in this forward pass: # First, each query attends to all the cached keys/values (i.e. full prefix) @@ -118,7 +102,7 @@ class CausalSelfAttention(nn.Module): attn_mask[:, :prefix_len] = True # Then, causal attention within this chunk attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) # Re-assemble the heads side by side and project back to residual stream y = y.transpose(1, 2).contiguous().view(B, T, -1) diff --git a/scripts/base_train.py b/scripts/base_train.py index ef7db17..4ca8cdc 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -85,7 +85,7 @@ print0(f"Vocab size: {vocab_size:,}") num_layers = depth model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases) num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div) -num_kv_heads = num_heads # 1:1 MQA ratio +num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled) print0(f"num_layers: {num_layers}") print0(f"model_dim: {model_dim}") print0(f"num_heads: {num_heads}")