diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 07a1eae8..536720a6 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -119,7 +119,10 @@ class CausalSelfAttention(nn.Module): # Advance position after last layer processes if self.layer_idx == kv_cache.n_layers - 1: kv_cache.advance(T) - + # XSA (Exclusive Self Attention) + Vn = F.normalize(v, dim=-1) + Vn = Vn.repeat_interleave(self.n_head // self.n_kv_head, dim=2) + y = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn # Re-assemble the heads and project back to residual stream y = y.contiguous().view(B, T, -1) y = self.c_proj(y)