mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-05 18:19:49 +00:00
feat: add canon layer
This commit is contained in:
parent
f5a0ea4d3f
commit
d515407deb
|
|
@ -9,6 +9,7 @@ Notable features:
|
|||
- no learnable params in rmsnorm
|
||||
- no bias in linear layers
|
||||
- Group-Query Attention (GQA) support for more efficient inference
|
||||
- Canon layers
|
||||
"""
|
||||
|
||||
import math
|
||||
|
|
@ -120,15 +121,66 @@ class MLP(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class CanonLayer(nn.Module):
|
||||
"""
|
||||
Canon layer using depth-wise 1D convolution with residual connection.
|
||||
This is an efficient implementation where groups=n_embd means each feature
|
||||
dimension is convolved independently (depth-wise convolution).
|
||||
"""
|
||||
def __init__(self, config, kernel_size=3):
|
||||
super().__init__()
|
||||
self.n_embd = config.n_embd
|
||||
# Depth-wise 1D convolution
|
||||
# groups=n_embd means each input channel has its own filter (depth-wise)
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels=config.n_embd,
|
||||
out_channels=config.n_embd,
|
||||
kernel_size=kernel_size,
|
||||
groups=config.n_embd,
|
||||
padding=kernel_size - 1 # Extra padding for causality, will be trimmed
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x shape: (Batch, Time, Channels) -> (B, T, C)
|
||||
# Conv1d expects: (Batch, Channels, Time) -> (B, C, T)
|
||||
|
||||
# 1. Transpose to match Conv1d input format
|
||||
x_transposed = x.transpose(1, 2)
|
||||
|
||||
# 2. Apply convolution
|
||||
out = self.conv(x_transposed)
|
||||
|
||||
# 3. Trim to original sequence length for causal convolution
|
||||
out = out[:, :, :x.size(1)]
|
||||
|
||||
# 4. Transpose back to original format
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
# 5. Residual connection: h' = h + canon
|
||||
return x + out
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.ln_1 = nn.LayerNorm(config.n_embd)
|
||||
self.attn = CausalSelfAttention(config, layer_idx)
|
||||
self.ln_2 = nn.LayerNorm(config.n_embd)
|
||||
self.mlp = MLP(config)
|
||||
|
||||
# Canon Layers: local mixing before attention and MLP
|
||||
self.canon_a = CanonLayer(config, kernel_size=3)
|
||||
self.canon_c = CanonLayer(config, kernel_size=3)
|
||||
|
||||
def forward(self, x, cos_sin, kv_cache):
|
||||
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
||||
x = x + self.mlp(norm(x))
|
||||
# Canon-A: local mixing before attention
|
||||
x = self.canon_a(x)
|
||||
x = x + self.attn(self.ln_1(x), cos_sin, kv_cache)
|
||||
|
||||
# Canon-C: local mixing before MLP
|
||||
x = self.canon_c(x)
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
|
@ -169,6 +221,8 @@ class GPT(nn.Module):
|
|||
attn.c_proj: zeros
|
||||
mlp.c_fc: uniform, std=1/sqrt(n_embd)
|
||||
mlp.c_proj: zeros
|
||||
canon_a.conv: uniform, std=1/sqrt(n_embd)
|
||||
canon_c.conv: uniform, std=1/sqrt(n_embd)
|
||||
"""
|
||||
|
||||
# Embedding and unembedding
|
||||
|
|
@ -185,6 +239,9 @@ class GPT(nn.Module):
|
|||
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
||||
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
# Canon layers: initialize conv weights
|
||||
torch.nn.init.uniform_(block.canon_a.conv.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.canon_c.conv.weight, -s, s)
|
||||
|
||||
# Rotary embeddings
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user