move XSA in a function and using that function.

This commit is contained in:
Rohan Khan 2026-04-13 07:57:54 +06:00
parent 7ecaf86519
commit 0725192e07

View File

@ -101,6 +101,10 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
def XSA(atten, value):
Vn = F.normalize(value, dim=-1)
return atten - (atten * Vn).sum(dim=-1, keepdim=True) * Vn
# =============================================================================
# Public API: Same interface as FA3
# =============================================================================
@ -118,8 +122,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
"""
if USE_FA3:
y = _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
Vn = F.normalize(v, dim=-1)
y = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn
y = XSA(y, v)
return y
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
@ -128,8 +131,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
v = v.transpose(1, 2)
enable_gqa = q.size(1) != k.size(1)
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
Vn = F.normalize(v, dim=-1)
y = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn
y = XSA(y, v)
return y.transpose(1, 2) # back to (B, T, H, D)
@ -156,8 +158,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
causal=causal, window_size=window_size
)
Vn = F.normalize(v, dim=-1)
y = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn
y = XSA(y, v_cache)
return y
# SDPA fallback: manually manage KV cache
@ -181,8 +182,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
Vn = F.normalize(v_sdpa, dim=-1)
y_sdpa = y_sdpa - (y_sdpa * Vn).sum(dim=-1, keepdim=True) * Vn
y_sdpa = XSA(y_sdpa, v_sdpa)
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)