mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
use enable_gqa of pytorch sdpa, allows us to delete some code, didnt realize it's available
This commit is contained in:
parent
94ee507054
commit
a088b7a6ec
|
|
@ -48,19 +48,6 @@ def apply_rotary_emb(x, cos, sin):
|
|||
out = out.to(x.dtype) # ensure input/output dtypes match
|
||||
return out
|
||||
|
||||
|
||||
def repeat_kv(x, n_rep):
|
||||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||
if n_rep == 1:
|
||||
return x
|
||||
bs, n_kv_heads, slen, head_dim = x.shape
|
||||
return (
|
||||
x[:, :, None, :, :]
|
||||
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
|
||||
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
|
||||
)
|
||||
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
|
|
@ -96,19 +83,16 @@ class CausalSelfAttention(nn.Module):
|
|||
Tq = q.size(2) # number of queries in this forward pass
|
||||
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
|
||||
|
||||
# Apply MQA: replicate the key/value heads for each query head
|
||||
nrep = self.n_head // self.n_kv_head
|
||||
k, v = repeat_kv(k, nrep), repeat_kv(v, nrep)
|
||||
|
||||
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
|
||||
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
|
||||
if kv_cache is None or Tq == Tk:
|
||||
# During training (no KV cache), attend as usual with causal attention
|
||||
# And even if there is KV cache, we can still use this simple version when Tq == Tk
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
||||
elif Tq == 1:
|
||||
# During inference but with a single query in this forward pass:
|
||||
# The query has to attend to all the keys/values in the cache
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
||||
else:
|
||||
# During inference AND we have a chunk of queries in this forward pass:
|
||||
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
||||
|
|
@ -118,7 +102,7 @@ class CausalSelfAttention(nn.Module):
|
|||
attn_mask[:, :prefix_len] = True
|
||||
# Then, causal attention within this chunk
|
||||
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
|
||||
|
||||
# Re-assemble the heads side by side and project back to residual stream
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ print0(f"Vocab size: {vocab_size:,}")
|
|||
num_layers = depth
|
||||
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
|
||||
num_kv_heads = num_heads # 1:1 MQA ratio
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim}")
|
||||
print0(f"num_heads: {num_heads}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user