From 7ecaf86519b89a48f721d71026c6c1c3e9e3e737 Mon Sep 17 00:00:00 2001 From: Rohan Khan Date: Sun, 12 Apr 2026 18:32:55 +0600 Subject: [PATCH 1/3] add exclusive self attention. --- nanochat/flash_attention.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index af2aee32..190fa938 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -117,7 +117,10 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): Output tensor of shape (B, T, H, D) """ if USE_FA3: - return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size) + y = _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size) + Vn = F.normalize(v, dim=-1) + y = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn + return y # SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D) q = q.transpose(1, 2) @@ -125,6 +128,8 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): v = v.transpose(1, 2) enable_gqa = q.size(1) != k.size(1) y = _sdpa_attention(q, k, v, window_size, enable_gqa) + Vn = F.normalize(v, dim=-1) + y = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn return y.transpose(1, 2) # back to (B, T, H, D) @@ -147,10 +152,13 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N Output tensor of shape (B, T_new, H, D) """ if USE_FA3: - return _fa3.flash_attn_with_kvcache( + y = _fa3.flash_attn_with_kvcache( q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size ) + Vn = F.normalize(v, dim=-1) + y = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn + return y # SDPA fallback: manually manage KV cache B, T_new, H, D = q.shape @@ -173,6 +181,8 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) + Vn = F.normalize(v_sdpa, dim=-1) + y_sdpa = y_sdpa - (y_sdpa * Vn).sum(dim=-1, keepdim=True) * Vn return y_sdpa.transpose(1, 2) # back to (B, T, H, D) From 0725192e07322bb2dc076996e9faa215782d426b Mon Sep 17 00:00:00 2001 From: Rohan Khan Date: Mon, 13 Apr 2026 07:57:54 +0600 Subject: [PATCH 2/3] move XSA in a function and using that function. --- nanochat/flash_attention.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 190fa938..d9058f0c 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -101,6 +101,10 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) +def XSA(atten, value): + Vn = F.normalize(value, dim=-1) + return atten - (atten * Vn).sum(dim=-1, keepdim=True) * Vn + # ============================================================================= # Public API: Same interface as FA3 # ============================================================================= @@ -118,8 +122,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): """ if USE_FA3: y = _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size) - Vn = F.normalize(v, dim=-1) - y = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn + y = XSA(y, v) return y # SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D) @@ -128,8 +131,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): v = v.transpose(1, 2) enable_gqa = q.size(1) != k.size(1) y = _sdpa_attention(q, k, v, window_size, enable_gqa) - Vn = F.normalize(v, dim=-1) - y = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn + y = XSA(y, v) return y.transpose(1, 2) # back to (B, T, H, D) @@ -156,8 +158,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size ) - Vn = F.normalize(v, dim=-1) - y = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn + y = XSA(y, v_cache) return y # SDPA fallback: manually manage KV cache @@ -181,8 +182,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) - Vn = F.normalize(v_sdpa, dim=-1) - y_sdpa = y_sdpa - (y_sdpa * Vn).sum(dim=-1, keepdim=True) * Vn + y_sdpa = XSA(y_sdpa, v_sdpa) return y_sdpa.transpose(1, 2) # back to (B, T, H, D) From 0b6d93f3c27d9d2dd1a0ba46b19937cb6162b3f3 Mon Sep 17 00:00:00 2001 From: Rohan Khan Date: Mon, 13 Apr 2026 22:37:07 +0600 Subject: [PATCH 3/3] claude suggested fix. --- nanochat/flash_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index d9058f0c..b8be8f4c 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -158,7 +158,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size ) - y = XSA(y, v_cache) + y = XSA(y, v) return y # SDPA fallback: manually manage KV cache