using non kv value to fix c_proj shape error.

This commit is contained in:
Rohan Khan 2026-04-14 18:43:47 +06:00
parent 9d852177a9
commit 275fa4b060

View File

@ -194,10 +194,11 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
q_sdpa = q.transpose(1, 2)
k_sdpa = k_full.transpose(1, 2)
v_sdpa = v_full.transpose(1, 2)
v = v.transpose(1, 2)
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
y_sdpa = XSA(y_sdpa, v_sdpa, True)
y_sdpa = XSA(y_sdpa, v, True)
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)