mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-15 10:39:08 +00:00
using non kv value to fix c_proj shape error.
This commit is contained in:
parent
9d852177a9
commit
275fa4b060
|
|
@ -194,10 +194,11 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
|
|||
q_sdpa = q.transpose(1, 2)
|
||||
k_sdpa = k_full.transpose(1, 2)
|
||||
v_sdpa = v_full.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
|
||||
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
|
||||
y_sdpa = XSA(y_sdpa, v_sdpa, True)
|
||||
y_sdpa = XSA(y_sdpa, v, True)
|
||||
|
||||
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user