From d47431d87d029f57d4a0f4c0ec01d94bdb4fd2ea Mon Sep 17 00:00:00 2001 From: Lantianyou Date: Mon, 12 Jan 2026 11:07:46 +0800 Subject: [PATCH] feat: restore flash attention --- nanochat/gpt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 35ae2c2..3b5677d 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -177,14 +177,14 @@ class Block(nn.Module): self.canon_a = CanonLayer(config, kernel_size=3) self.canon_c = CanonLayer(config, kernel_size=3) - def forward(self, x, cos_sin, kv_cache): + def forward(self, x, cos_sin, window_size, kv_cache): # Canon-A: local mixing before attention x = self.canon_a(x) - x = x + self.attn(self.ln_1(x), cos_sin, kv_cache) + x = x + self.attn(norm(x), cos_sin, window_size, kv_cache) # Canon-C: local mixing before MLP x = self.canon_c(x) - x = x + self.mlp(self.ln_2(x)) + x = x + self.mlp(norm(x)) return x