mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-09 12:09:49 +00:00
feat: restore flash attention
This commit is contained in:
parent
8d89db3195
commit
d47431d87d
|
|
@ -177,14 +177,14 @@ class Block(nn.Module):
|
|||
self.canon_a = CanonLayer(config, kernel_size=3)
|
||||
self.canon_c = CanonLayer(config, kernel_size=3)
|
||||
|
||||
def forward(self, x, cos_sin, kv_cache):
|
||||
def forward(self, x, cos_sin, window_size, kv_cache):
|
||||
# Canon-A: local mixing before attention
|
||||
x = self.canon_a(x)
|
||||
x = x + self.attn(self.ln_1(x), cos_sin, kv_cache)
|
||||
x = x + self.attn(norm(x), cos_sin, window_size, kv_cache)
|
||||
|
||||
# Canon-C: local mixing before MLP
|
||||
x = self.canon_c(x)
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
x = x + self.mlp(norm(x))
|
||||
|
||||
return x
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user