diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index d90998df..000b01b6 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -194,10 +194,11 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N q_sdpa = q.transpose(1, 2) k_sdpa = k_full.transpose(1, 2) v_sdpa = v_full.transpose(1, 2) + v = v.transpose(1, 2) enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) - y_sdpa = XSA(y_sdpa, v_sdpa, True) + y_sdpa = XSA(y_sdpa, v, True) return y_sdpa.transpose(1, 2) # back to (B, T, H, D)