From d515407deb4b77d1b3aa9bee5b2ada16b1f6531b Mon Sep 17 00:00:00 2001 From: Lantianyou Date: Sat, 10 Jan 2026 23:12:27 +0800 Subject: [PATCH] feat: add canon layer --- nanochat/gpt.py | 61 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 478f687..582f278 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -9,6 +9,7 @@ Notable features: - no learnable params in rmsnorm - no bias in linear layers - Group-Query Attention (GQA) support for more efficient inference +- Canon layers """ import math @@ -120,15 +121,66 @@ class MLP(nn.Module): return x +class CanonLayer(nn.Module): + """ + Canon layer using depth-wise 1D convolution with residual connection. + This is an efficient implementation where groups=n_embd means each feature + dimension is convolved independently (depth-wise convolution). + """ + def __init__(self, config, kernel_size=3): + super().__init__() + self.n_embd = config.n_embd + # Depth-wise 1D convolution + # groups=n_embd means each input channel has its own filter (depth-wise) + self.conv = nn.Conv1d( + in_channels=config.n_embd, + out_channels=config.n_embd, + kernel_size=kernel_size, + groups=config.n_embd, + padding=kernel_size - 1 # Extra padding for causality, will be trimmed + ) + + def forward(self, x): + # x shape: (Batch, Time, Channels) -> (B, T, C) + # Conv1d expects: (Batch, Channels, Time) -> (B, C, T) + + # 1. Transpose to match Conv1d input format + x_transposed = x.transpose(1, 2) + + # 2. Apply convolution + out = self.conv(x_transposed) + + # 3. Trim to original sequence length for causal convolution + out = out[:, :, :x.size(1)] + + # 4. Transpose back to original format + out = out.transpose(1, 2) + + # 5. Residual connection: h' = h + canon + return x + out + + class Block(nn.Module): def __init__(self, config, layer_idx): super().__init__() + self.ln_1 = nn.LayerNorm(config.n_embd) self.attn = CausalSelfAttention(config, layer_idx) + self.ln_2 = nn.LayerNorm(config.n_embd) self.mlp = MLP(config) + # Canon Layers: local mixing before attention and MLP + self.canon_a = CanonLayer(config, kernel_size=3) + self.canon_c = CanonLayer(config, kernel_size=3) + def forward(self, x, cos_sin, kv_cache): - x = x + self.attn(norm(x), cos_sin, kv_cache) - x = x + self.mlp(norm(x)) + # Canon-A: local mixing before attention + x = self.canon_a(x) + x = x + self.attn(self.ln_1(x), cos_sin, kv_cache) + + # Canon-C: local mixing before MLP + x = self.canon_c(x) + x = x + self.mlp(self.ln_2(x)) + return x @@ -169,6 +221,8 @@ class GPT(nn.Module): attn.c_proj: zeros mlp.c_fc: uniform, std=1/sqrt(n_embd) mlp.c_proj: zeros + canon_a.conv: uniform, std=1/sqrt(n_embd) + canon_c.conv: uniform, std=1/sqrt(n_embd) """ # Embedding and unembedding @@ -185,6 +239,9 @@ class GPT(nn.Module): torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s) torch.nn.init.zeros_(block.mlp.c_proj.weight) + # Canon layers: initialize conv weights + torch.nn.init.uniform_(block.canon_a.conv.weight, -s, s) + torch.nn.init.uniform_(block.canon_c.conv.weight, -s, s) # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head