brodcast logid to get speedup.

This commit is contained in:
Rohan Khan 2026-04-15 13:38:27 +06:00
parent 275fa4b060
commit aef1254205

View File

@ -108,17 +108,21 @@ def XSA(atten, value, sdpa=False):
vB, vH, vT, vD = value.shape
rep = aH // vH
assert aH % vH == 0, "query heads and kv heads must be divisible"
value = torch.repeat_interleave(value, rep, dim=1)
Vn = F.normalize(value, dim=-1)
return atten - torch.sum(atten * Vn, dim=-1, keepdim=True) * Vn
atten_r = atten.reshape(aB, vH, rep, aT, aD)
Vn = F.normalize(value, dim=-1).unsqueeze(2)
atten_r = atten_r - torch.sum(atten_r * Vn, dim=-1, keepdim=True) * Vn
return atten_r.reshape(aB, aH, aT, aD)
else:
aB, aT, aH, aD = atten.shape
vB, vT, vH, vD = value.shape
rep = aH // vH
assert aH % vH == 0, "query heads and kv heads must be divisible"
value = torch.repeat_interleave(value, rep, dim=2)
Vn = F.normalize(value, dim=-1)
return atten - torch.sum(atten * Vn, dim=-1, keepdim=True) * Vn
atten_r = atten.reshape(aB, aT, vH, rep, aD)
Vn = F.normalize(value, dim=-1).unsqueeze(-2)
atten_r = atten_r - torch.sum(atten_r * Vn, dim=-1, keepdim=True) * Vn
return atten_r.reshape(aB, aT, aH, aD)
# =============================================================================
# Public API: Same interface as FA3