This commit is contained in:
Mithun Kannaa 2026-05-05 06:39:08 +03:00 committed by GitHub
commit cee0b39642
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -119,7 +119,10 @@ class CausalSelfAttention(nn.Module):
# Advance position after last layer processes
if self.layer_idx == kv_cache.n_layers - 1:
kv_cache.advance(T)
# XSA (Exclusive Self Attention)
Vn = F.normalize(v, dim=-1)
Vn = Vn.repeat_interleave(self.n_head // self.n_kv_head, dim=2)
y = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn
# Re-assemble the heads and project back to residual stream
y = y.contiguous().view(B, T, -1)
y = self.c_proj(y)