feat: add canon layer

This commit is contained in:
Lantianyou 2026-01-10 23:12:27 +08:00
parent f5a0ea4d3f
commit d515407deb

View File

@ -9,6 +9,7 @@ Notable features:
- no learnable params in rmsnorm
- no bias in linear layers
- Group-Query Attention (GQA) support for more efficient inference
- Canon layers
"""
import math
@ -120,15 +121,66 @@ class MLP(nn.Module):
return x
class CanonLayer(nn.Module):
"""
Canon layer using depth-wise 1D convolution with residual connection.
This is an efficient implementation where groups=n_embd means each feature
dimension is convolved independently (depth-wise convolution).
"""
def __init__(self, config, kernel_size=3):
super().__init__()
self.n_embd = config.n_embd
# Depth-wise 1D convolution
# groups=n_embd means each input channel has its own filter (depth-wise)
self.conv = nn.Conv1d(
in_channels=config.n_embd,
out_channels=config.n_embd,
kernel_size=kernel_size,
groups=config.n_embd,
padding=kernel_size - 1 # Extra padding for causality, will be trimmed
)
def forward(self, x):
# x shape: (Batch, Time, Channels) -> (B, T, C)
# Conv1d expects: (Batch, Channels, Time) -> (B, C, T)
# 1. Transpose to match Conv1d input format
x_transposed = x.transpose(1, 2)
# 2. Apply convolution
out = self.conv(x_transposed)
# 3. Trim to original sequence length for causal convolution
out = out[:, :, :x.size(1)]
# 4. Transpose back to original format
out = out.transpose(1, 2)
# 5. Residual connection: h' = h + canon
return x + out
class Block(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config, layer_idx)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
# Canon Layers: local mixing before attention and MLP
self.canon_a = CanonLayer(config, kernel_size=3)
self.canon_c = CanonLayer(config, kernel_size=3)
def forward(self, x, cos_sin, kv_cache):
x = x + self.attn(norm(x), cos_sin, kv_cache)
x = x + self.mlp(norm(x))
# Canon-A: local mixing before attention
x = self.canon_a(x)
x = x + self.attn(self.ln_1(x), cos_sin, kv_cache)
# Canon-C: local mixing before MLP
x = self.canon_c(x)
x = x + self.mlp(self.ln_2(x))
return x
@ -169,6 +221,8 @@ class GPT(nn.Module):
attn.c_proj: zeros
mlp.c_fc: uniform, std=1/sqrt(n_embd)
mlp.c_proj: zeros
canon_a.conv: uniform, std=1/sqrt(n_embd)
canon_c.conv: uniform, std=1/sqrt(n_embd)
"""
# Embedding and unembedding
@ -185,6 +239,9 @@ class GPT(nn.Module):
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
torch.nn.init.zeros_(block.mlp.c_proj.weight)
# Canon layers: initialize conv weights
torch.nn.init.uniform_(block.canon_a.conv.weight, -s, s)
torch.nn.init.uniform_(block.canon_c.conv.weight, -s, s)
# Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head