diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index d9058f0c..b8be8f4c 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -158,7 +158,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size ) - y = XSA(y, v_cache) + y = XSA(y, v) return y # SDPA fallback: manually manage KV cache