diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index af2aee32..b8be8f4c 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -101,6 +101,10 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) +def XSA(atten, value): + Vn = F.normalize(value, dim=-1) + return atten - (atten * Vn).sum(dim=-1, keepdim=True) * Vn + # ============================================================================= # Public API: Same interface as FA3 # ============================================================================= @@ -117,7 +121,9 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): Output tensor of shape (B, T, H, D) """ if USE_FA3: - return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size) + y = _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size) + y = XSA(y, v) + return y # SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D) q = q.transpose(1, 2) @@ -125,6 +131,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): v = v.transpose(1, 2) enable_gqa = q.size(1) != k.size(1) y = _sdpa_attention(q, k, v, window_size, enable_gqa) + y = XSA(y, v) return y.transpose(1, 2) # back to (B, T, H, D) @@ -147,10 +154,12 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N Output tensor of shape (B, T_new, H, D) """ if USE_FA3: - return _fa3.flash_attn_with_kvcache( + y = _fa3.flash_attn_with_kvcache( q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size ) + y = XSA(y, v) + return y # SDPA fallback: manually manage KV cache B, T_new, H, D = q.shape @@ -173,6 +182,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) + y_sdpa = XSA(y_sdpa, v_sdpa) return y_sdpa.transpose(1, 2) # back to (B, T, H, D)