diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index b8be8f4c..d90998df 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -101,9 +101,24 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) -def XSA(atten, value): - Vn = F.normalize(value, dim=-1) - return atten - (atten * Vn).sum(dim=-1, keepdim=True) * Vn + +def XSA(atten, value, sdpa=False): + if sdpa: + aB, aH, aT, aD = atten.shape + vB, vH, vT, vD = value.shape + rep = aH // vH + assert aH % vH == 0, "query heads and kv heads must be divisible" + value = torch.repeat_interleave(value, rep, dim=1) + Vn = F.normalize(value, dim=-1) + return atten - torch.sum(atten * Vn, dim=-1, keepdim=True) * Vn + else: + aB, aT, aH, aD = atten.shape + vB, vT, vH, vD = value.shape + rep = aH // vH + assert aH % vH == 0, "query heads and kv heads must be divisible" + value = torch.repeat_interleave(value, rep, dim=2) + Vn = F.normalize(value, dim=-1) + return atten - torch.sum(atten * Vn, dim=-1, keepdim=True) * Vn # ============================================================================= # Public API: Same interface as FA3 @@ -131,7 +146,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): v = v.transpose(1, 2) enable_gqa = q.size(1) != k.size(1) y = _sdpa_attention(q, k, v, window_size, enable_gqa) - y = XSA(y, v) + y = XSA(y, v, True) return y.transpose(1, 2) # back to (B, T, H, D) @@ -182,7 +197,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) - y_sdpa = XSA(y_sdpa, v_sdpa) + y_sdpa = XSA(y_sdpa, v_sdpa, True) return y_sdpa.transpose(1, 2) # back to (B, T, H, D)