diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 35ae2c2..3b5677d 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -177,14 +177,14 @@ class Block(nn.Module): self.canon_a = CanonLayer(config, kernel_size=3) self.canon_c = CanonLayer(config, kernel_size=3) - def forward(self, x, cos_sin, kv_cache): + def forward(self, x, cos_sin, window_size, kv_cache): # Canon-A: local mixing before attention x = self.canon_a(x) - x = x + self.attn(self.ln_1(x), cos_sin, kv_cache) + x = x + self.attn(norm(x), cos_sin, window_size, kv_cache) # Canon-C: local mixing before MLP x = self.canon_c(x) - x = x + self.mlp(self.ln_2(x)) + x = x + self.mlp(norm(x)) return x