diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 000b01b6..3c8da693 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -108,17 +108,21 @@ def XSA(atten, value, sdpa=False): vB, vH, vT, vD = value.shape rep = aH // vH assert aH % vH == 0, "query heads and kv heads must be divisible" - value = torch.repeat_interleave(value, rep, dim=1) - Vn = F.normalize(value, dim=-1) - return atten - torch.sum(atten * Vn, dim=-1, keepdim=True) * Vn + + atten_r = atten.reshape(aB, vH, rep, aT, aD) + Vn = F.normalize(value, dim=-1).unsqueeze(2) + atten_r = atten_r - torch.sum(atten_r * Vn, dim=-1, keepdim=True) * Vn + return atten_r.reshape(aB, aH, aT, aD) else: aB, aT, aH, aD = atten.shape vB, vT, vH, vD = value.shape rep = aH // vH assert aH % vH == 0, "query heads and kv heads must be divisible" - value = torch.repeat_interleave(value, rep, dim=2) - Vn = F.normalize(value, dim=-1) - return atten - torch.sum(atten * Vn, dim=-1, keepdim=True) * Vn + + atten_r = atten.reshape(aB, aT, vH, rep, aD) + Vn = F.normalize(value, dim=-1).unsqueeze(-2) + atten_r = atten_r - torch.sum(atten_r * Vn, dim=-1, keepdim=True) * Vn + return atten_r.reshape(aB, aT, aH, aD) # ============================================================================= # Public API: Same interface as FA3