From 7ecaf86519b89a48f721d71026c6c1c3e9e3e737 Mon Sep 17 00:00:00 2001 From: Rohan Khan Date: Sun, 12 Apr 2026 18:32:55 +0600 Subject: [PATCH 1/5] add exclusive self attention. --- nanochat/flash_attention.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index af2aee32..190fa938 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -117,7 +117,10 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): Output tensor of shape (B, T, H, D) """ if USE_FA3: - return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size) + y = _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size) + Vn = F.normalize(v, dim=-1) + y = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn + return y # SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D) q = q.transpose(1, 2) @@ -125,6 +128,8 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): v = v.transpose(1, 2) enable_gqa = q.size(1) != k.size(1) y = _sdpa_attention(q, k, v, window_size, enable_gqa) + Vn = F.normalize(v, dim=-1) + y = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn return y.transpose(1, 2) # back to (B, T, H, D) @@ -147,10 +152,13 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N Output tensor of shape (B, T_new, H, D) """ if USE_FA3: - return _fa3.flash_attn_with_kvcache( + y = _fa3.flash_attn_with_kvcache( q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size ) + Vn = F.normalize(v, dim=-1) + y = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn + return y # SDPA fallback: manually manage KV cache B, T_new, H, D = q.shape @@ -173,6 +181,8 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) + Vn = F.normalize(v_sdpa, dim=-1) + y_sdpa = y_sdpa - (y_sdpa * Vn).sum(dim=-1, keepdim=True) * Vn return y_sdpa.transpose(1, 2) # back to (B, T, H, D) From 0725192e07322bb2dc076996e9faa215782d426b Mon Sep 17 00:00:00 2001 From: Rohan Khan Date: Mon, 13 Apr 2026 07:57:54 +0600 Subject: [PATCH 2/5] move XSA in a function and using that function. --- nanochat/flash_attention.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 190fa938..d9058f0c 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -101,6 +101,10 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) +def XSA(atten, value): + Vn = F.normalize(value, dim=-1) + return atten - (atten * Vn).sum(dim=-1, keepdim=True) * Vn + # ============================================================================= # Public API: Same interface as FA3 # ============================================================================= @@ -118,8 +122,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): """ if USE_FA3: y = _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size) - Vn = F.normalize(v, dim=-1) - y = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn + y = XSA(y, v) return y # SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D) @@ -128,8 +131,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): v = v.transpose(1, 2) enable_gqa = q.size(1) != k.size(1) y = _sdpa_attention(q, k, v, window_size, enable_gqa) - Vn = F.normalize(v, dim=-1) - y = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn + y = XSA(y, v) return y.transpose(1, 2) # back to (B, T, H, D) @@ -156,8 +158,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size ) - Vn = F.normalize(v, dim=-1) - y = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn + y = XSA(y, v_cache) return y # SDPA fallback: manually manage KV cache @@ -181,8 +182,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) - Vn = F.normalize(v_sdpa, dim=-1) - y_sdpa = y_sdpa - (y_sdpa * Vn).sum(dim=-1, keepdim=True) * Vn + y_sdpa = XSA(y_sdpa, v_sdpa) return y_sdpa.transpose(1, 2) # back to (B, T, H, D) From 0b6d93f3c27d9d2dd1a0ba46b19937cb6162b3f3 Mon Sep 17 00:00:00 2001 From: Rohan Khan Date: Mon, 13 Apr 2026 22:37:07 +0600 Subject: [PATCH 3/5] claude suggested fix. --- nanochat/flash_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index d9058f0c..b8be8f4c 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -158,7 +158,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size ) - y = XSA(y, v_cache) + y = XSA(y, v) return y # SDPA fallback: manually manage KV cache From 9d852177a972b04f81a1e1758e08be1d2833f8da Mon Sep 17 00:00:00 2001 From: Rohan Khan Date: Tue, 14 Apr 2026 16:48:56 +0600 Subject: [PATCH 4/5] fix shape mismatch and account for gqa. --- nanochat/flash_attention.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index b8be8f4c..d90998df 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -101,9 +101,24 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) -def XSA(atten, value): - Vn = F.normalize(value, dim=-1) - return atten - (atten * Vn).sum(dim=-1, keepdim=True) * Vn + +def XSA(atten, value, sdpa=False): + if sdpa: + aB, aH, aT, aD = atten.shape + vB, vH, vT, vD = value.shape + rep = aH // vH + assert aH % vH == 0, "query heads and kv heads must be divisible" + value = torch.repeat_interleave(value, rep, dim=1) + Vn = F.normalize(value, dim=-1) + return atten - torch.sum(atten * Vn, dim=-1, keepdim=True) * Vn + else: + aB, aT, aH, aD = atten.shape + vB, vT, vH, vD = value.shape + rep = aH // vH + assert aH % vH == 0, "query heads and kv heads must be divisible" + value = torch.repeat_interleave(value, rep, dim=2) + Vn = F.normalize(value, dim=-1) + return atten - torch.sum(atten * Vn, dim=-1, keepdim=True) * Vn # ============================================================================= # Public API: Same interface as FA3 @@ -131,7 +146,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): v = v.transpose(1, 2) enable_gqa = q.size(1) != k.size(1) y = _sdpa_attention(q, k, v, window_size, enable_gqa) - y = XSA(y, v) + y = XSA(y, v, True) return y.transpose(1, 2) # back to (B, T, H, D) @@ -182,7 +197,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) - y_sdpa = XSA(y_sdpa, v_sdpa) + y_sdpa = XSA(y_sdpa, v_sdpa, True) return y_sdpa.transpose(1, 2) # back to (B, T, H, D) From 275fa4b060e37948688ebea3b910600e960d1898 Mon Sep 17 00:00:00 2001 From: Rohan Khan Date: Tue, 14 Apr 2026 18:43:47 +0600 Subject: [PATCH 5/5] using non kv value to fix c_proj shape error. --- nanochat/flash_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index d90998df..000b01b6 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -194,10 +194,11 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N q_sdpa = q.transpose(1, 2) k_sdpa = k_full.transpose(1, 2) v_sdpa = v_full.transpose(1, 2) + v = v.transpose(1, 2) enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) - y_sdpa = XSA(y_sdpa, v_sdpa, True) + y_sdpa = XSA(y_sdpa, v, True) return y_sdpa.transpose(1, 2) # back to (B, T, H, D)