This commit is contained in:
Rohan Khan 2026-04-13 21:08:17 -05:00 committed by GitHub
commit a0f7bfea16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -101,6 +101,10 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
def XSA(atten, value):
Vn = F.normalize(value, dim=-1)
return atten - (atten * Vn).sum(dim=-1, keepdim=True) * Vn
# =============================================================================
# Public API: Same interface as FA3
# =============================================================================
@ -117,7 +121,9 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
Output tensor of shape (B, T, H, D)
"""
if USE_FA3:
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
y = _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
y = XSA(y, v)
return y
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
q = q.transpose(1, 2)
@ -125,6 +131,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
v = v.transpose(1, 2)
enable_gqa = q.size(1) != k.size(1)
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
y = XSA(y, v)
return y.transpose(1, 2) # back to (B, T, H, D)
@ -147,10 +154,12 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
Output tensor of shape (B, T_new, H, D)
"""
if USE_FA3:
return _fa3.flash_attn_with_kvcache(
y = _fa3.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
causal=causal, window_size=window_size
)
y = XSA(y, v)
return y
# SDPA fallback: manually manage KV cache
B, T_new, H, D = q.shape
@ -173,6 +182,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
y_sdpa = XSA(y_sdpa, v_sdpa)
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)