fix shape mismatch and account for gqa.

This commit is contained in:
Rohan Khan 2026-04-14 16:48:56 +06:00
parent 6eedfa188c
commit 9d852177a9

View File

@ -101,9 +101,24 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
def XSA(atten, value):
Vn = F.normalize(value, dim=-1)
return atten - (atten * Vn).sum(dim=-1, keepdim=True) * Vn
def XSA(atten, value, sdpa=False):
if sdpa:
aB, aH, aT, aD = atten.shape
vB, vH, vT, vD = value.shape
rep = aH // vH
assert aH % vH == 0, "query heads and kv heads must be divisible"
value = torch.repeat_interleave(value, rep, dim=1)
Vn = F.normalize(value, dim=-1)
return atten - torch.sum(atten * Vn, dim=-1, keepdim=True) * Vn
else:
aB, aT, aH, aD = atten.shape
vB, vT, vH, vD = value.shape
rep = aH // vH
assert aH % vH == 0, "query heads and kv heads must be divisible"
value = torch.repeat_interleave(value, rep, dim=2)
Vn = F.normalize(value, dim=-1)
return atten - torch.sum(atten * Vn, dim=-1, keepdim=True) * Vn
# =============================================================================
# Public API: Same interface as FA3
@ -131,7 +146,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
v = v.transpose(1, 2)
enable_gqa = q.size(1) != k.size(1)
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
y = XSA(y, v)
y = XSA(y, v, True)
return y.transpose(1, 2) # back to (B, T, H, D)
@ -182,7 +197,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
y_sdpa = XSA(y_sdpa, v_sdpa)
y_sdpa = XSA(y_sdpa, v_sdpa, True)
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)