mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-15 10:39:08 +00:00
Merge 275fa4b060 into b9b6ce137b
This commit is contained in:
commit
bfe68455e0
|
|
@ -101,6 +101,25 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
|||
|
||||
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
|
||||
|
||||
|
||||
def XSA(atten, value, sdpa=False):
|
||||
if sdpa:
|
||||
aB, aH, aT, aD = atten.shape
|
||||
vB, vH, vT, vD = value.shape
|
||||
rep = aH // vH
|
||||
assert aH % vH == 0, "query heads and kv heads must be divisible"
|
||||
value = torch.repeat_interleave(value, rep, dim=1)
|
||||
Vn = F.normalize(value, dim=-1)
|
||||
return atten - torch.sum(atten * Vn, dim=-1, keepdim=True) * Vn
|
||||
else:
|
||||
aB, aT, aH, aD = atten.shape
|
||||
vB, vT, vH, vD = value.shape
|
||||
rep = aH // vH
|
||||
assert aH % vH == 0, "query heads and kv heads must be divisible"
|
||||
value = torch.repeat_interleave(value, rep, dim=2)
|
||||
Vn = F.normalize(value, dim=-1)
|
||||
return atten - torch.sum(atten * Vn, dim=-1, keepdim=True) * Vn
|
||||
|
||||
# =============================================================================
|
||||
# Public API: Same interface as FA3
|
||||
# =============================================================================
|
||||
|
|
@ -117,7 +136,9 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
|||
Output tensor of shape (B, T, H, D)
|
||||
"""
|
||||
if USE_FA3:
|
||||
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
||||
y = _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
||||
y = XSA(y, v)
|
||||
return y
|
||||
|
||||
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
|
||||
q = q.transpose(1, 2)
|
||||
|
|
@ -125,6 +146,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
|||
v = v.transpose(1, 2)
|
||||
enable_gqa = q.size(1) != k.size(1)
|
||||
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
|
||||
y = XSA(y, v, True)
|
||||
return y.transpose(1, 2) # back to (B, T, H, D)
|
||||
|
||||
|
||||
|
|
@ -147,10 +169,12 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
|
|||
Output tensor of shape (B, T_new, H, D)
|
||||
"""
|
||||
if USE_FA3:
|
||||
return _fa3.flash_attn_with_kvcache(
|
||||
y = _fa3.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
|
||||
causal=causal, window_size=window_size
|
||||
)
|
||||
y = XSA(y, v)
|
||||
return y
|
||||
|
||||
# SDPA fallback: manually manage KV cache
|
||||
B, T_new, H, D = q.shape
|
||||
|
|
@ -170,9 +194,11 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
|
|||
q_sdpa = q.transpose(1, 2)
|
||||
k_sdpa = k_full.transpose(1, 2)
|
||||
v_sdpa = v_full.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
|
||||
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
|
||||
y_sdpa = XSA(y_sdpa, v, True)
|
||||
|
||||
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user