From 275fa4b060e37948688ebea3b910600e960d1898 Mon Sep 17 00:00:00 2001 From: Rohan Khan Date: Tue, 14 Apr 2026 18:43:47 +0600 Subject: [PATCH] using non kv value to fix c_proj shape error. --- nanochat/flash_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index d90998df..000b01b6 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -194,10 +194,11 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N q_sdpa = q.transpose(1, 2) k_sdpa = k_full.transpose(1, 2) v_sdpa = v_full.transpose(1, 2) + v = v.transpose(1, 2) enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) - y_sdpa = XSA(y_sdpa, v_sdpa, True) + y_sdpa = XSA(y_sdpa, v, True) return y_sdpa.transpose(1, 2) # back to (B, T, H, D)