Implement Exclusive Self Attention in the forward pass.

Adds Exclusive Self-Attention (XSA) from “Exclusive Self Attention” by Shuangfei Zhai (Apple) , which removes the component of the attention output aligned with its own value vector to eliminate attention similarity bias and improve context modeling. This is a 2 line change in the Attention Block
This commit is contained in:
Mithun Kannaa 2026-04-15 13:54:22 +05:30 committed by GitHub
parent 0aaca56805
commit 637ecf6f12
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -119,7 +119,10 @@ class CausalSelfAttention(nn.Module):
# Advance position after last layer processes
if self.layer_idx == kv_cache.n_layers - 1:
kv_cache.advance(T)
# XSA (Exclusive Self Attention)
Vn = F.normalize(v, dim=-1)
Vn = Vn.repeat_interleave(self.n_head // self.n_kv_head, dim=2)
y = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn
# Re-assemble the heads and project back to residual stream
y = y.contiguous().view(B, T, -1)
y = self.c_proj(y)