feat: restore flash attention

This commit is contained in:
Lantianyou 2026-01-12 11:07:46 +08:00
parent 8d89db3195
commit d47431d87d

View File

@ -177,14 +177,14 @@ class Block(nn.Module):
self.canon_a = CanonLayer(config, kernel_size=3)
self.canon_c = CanonLayer(config, kernel_size=3)
def forward(self, x, cos_sin, kv_cache):
def forward(self, x, cos_sin, window_size, kv_cache):
# Canon-A: local mixing before attention
x = self.canon_a(x)
x = x + self.attn(self.ln_1(x), cos_sin, kv_cache)
x = x + self.attn(norm(x), cos_sin, window_size, kv_cache)
# Canon-C: local mixing before MLP
x = self.canon_c(x)
x = x + self.mlp(self.ln_2(x))
x = x + self.mlp(norm(x))
return x