Add L3 (Large Lookup Layers) following arXiv:2601.21461v2

L3 generalizes token embeddings by placing per-token lookup tables inside
the decoder stack. Unlike MoE, routing is static (determined by token ID),
eliminating router training and load-balancing losses.

Implementation:
- nanochat/l3.py: LZW allocation algorithm and L3Layer module with
  vectorized gather+pad+mask forward pass, tied/untied KV support
- GPT integration: L3 layers sit between decoder blocks, applied
  residually (x = x + l3_layer(x, token_ids))
- CLI: --l3-after-layers, --l3-n-emb, --l3-d-up, --l3-k-max flags
  with LZW precomputation from training data sample
- 17 tests covering allocation, layer, and GPT integration

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
William Thurston 2026-02-22 15:49:15 -08:00
parent 194c98a5b3
commit b7629eff5d
4 changed files with 585 additions and 8 deletions

View File

@ -24,6 +24,7 @@ from nanochat.optim import MuonAdamW, DistMuonAdamW
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
from nanochat.flash_attention import flash_attn
from nanochat.l3 import L3Layer
@dataclass
class GPTConfig:
@ -37,6 +38,12 @@ class GPTConfig:
# Characters: L=long (full context), S=short (half context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
window_pattern: str = "SSSL"
# L3 (Large Lookup Layers) config
l3_after_layers: str = "" # comma-separated layer indices (empty = disabled)
l3_n_emb: int = 0 # total embeddings (0 = disabled)
l3_d_up: int = 0 # up-projection dim (0 = auto: 4 * n_embd)
l3_k_max: int = 512 # max embeddings per token
l3_tie_kv: bool = True # tie key and value weights
def norm(x):
@ -175,6 +182,13 @@ class GPT(nn.Module):
head_dim = config.n_embd // config.n_head
kv_dim = config.n_kv_head * head_dim
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
# L3 layers (placed between decoder blocks)
self.l3_layer_indices = set(int(x) for x in config.l3_after_layers.split(",") if x.strip()) if config.l3_after_layers else set()
l3_d_up = config.l3_d_up if config.l3_d_up > 0 else 4 * config.n_embd
self.l3_layers = nn.ModuleDict({
str(i): L3Layer(config.n_embd, config.l3_n_emb, l3_d_up, config.l3_tie_kv)
for i in self.l3_layer_indices
}) if self.l3_layer_indices and config.l3_n_emb > 0 else nn.ModuleDict()
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
@ -229,6 +243,16 @@ class GPT(nn.Module):
if block.attn.ve_gate is not None:
torch.nn.init.zeros_(block.attn.ve_gate.weight)
# L3 layers
for l3_layer in self.l3_layers.values():
if l3_layer.tie_kv:
torch.nn.init.normal_(l3_layer.kv_weight, mean=0.0, std=1.0)
else:
torch.nn.init.normal_(l3_layer.k_weight, mean=0.0, std=1.0)
torch.nn.init.normal_(l3_layer.v_weight, mean=0.0, std=1.0)
torch.nn.init.uniform_(l3_layer.w_up.weight, -s, s)
torch.nn.init.zeros_(l3_layer.w_mix.weight) # L3 starts as no-op
# Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
@ -239,6 +263,12 @@ class GPT(nn.Module):
self.transformer.wte.to(dtype=torch.bfloat16)
for ve in self.value_embeds.values():
ve.to(dtype=torch.bfloat16)
for l3_layer in self.l3_layers.values():
if l3_layer.tie_kv:
l3_layer.kv_weight.data = l3_layer.kv_weight.data.to(dtype=torch.bfloat16)
else:
l3_layer.k_weight.data = l3_layer.k_weight.data.to(dtype=torch.bfloat16)
l3_layer.v_weight.data = l3_layer.v_weight.data.to(dtype=torch.bfloat16)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently
@ -304,7 +334,14 @@ class GPT(nn.Module):
nparams = sum(p.numel() for p in self.parameters())
# Exclude non-matmul params: embeddings and per-layer scalars
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
# L3 kv/k/v weights are embeddings (lookup tables), not matmul weights
l3_embed_numel = 0
for l3_layer in self.l3_layers.values():
if l3_layer.tie_kv:
l3_embed_numel += l3_layer.kv_weight.numel()
else:
l3_embed_numel += l3_layer.k_weight.numel() + l3_layer.v_weight.numel()
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + l3_embed_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel())
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window
@ -313,7 +350,15 @@ class GPT(nn.Module):
window = window_size[0] # (left, right) tuple, we use left
effective_seq = t if window < 0 else min(window, t)
attn_flops += 12 * h * q * effective_seq
num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
# L3 FLOPs: per layer, approx 2*avg_k*n_embd (attention) + 2*n_embd*d_up (up) + 2*(d_up+n_embd)*n_embd (mix)
l3_flops = 0
if self.l3_layers:
n_embd = self.config.n_embd
l3_d_up = self.config.l3_d_up if self.config.l3_d_up > 0 else 4 * n_embd
avg_k = self.config.l3_n_emb / self.config.vocab_size if self.config.vocab_size > 0 else 1
per_l3 = 2 * avg_k * n_embd + 2 * n_embd * l3_d_up + 2 * (l3_d_up + n_embd) * n_embd
l3_flops = int(per_l3 * len(self.l3_layers))
num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops + l3_flops
return num_flops_per_token
def num_scaling_params(self):
@ -334,13 +379,24 @@ class GPT(nn.Module):
lm_head = sum(p.numel() for p in self.lm_head.parameters())
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
total = wte + value_embeds + lm_head + transformer_matrices + scalars
# L3: separate embedding-like params from matrix params
l3_embeds = 0
l3_matrices = 0
for l3_layer in self.l3_layers.values():
if l3_layer.tie_kv:
l3_embeds += l3_layer.kv_weight.numel()
else:
l3_embeds += l3_layer.k_weight.numel() + l3_layer.v_weight.numel()
l3_matrices += l3_layer.w_up.weight.numel() + l3_layer.w_mix.weight.numel()
total = wte + value_embeds + lm_head + transformer_matrices + scalars + l3_embeds + l3_matrices
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
return {
'wte': wte,
'value_embeds': value_embeds,
'lm_head': lm_head,
'transformer_matrices': transformer_matrices,
'l3_embeds': l3_embeds,
'l3_matrices': l3_matrices,
'scalars': scalars,
'total': total,
}
@ -356,7 +412,18 @@ class GPT(nn.Module):
lm_head_params = list(self.lm_head.parameters())
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
# L3 params: embedding-like (kv/k/v weights) and matrix (w_up, w_mix)
l3_embed_params = []
l3_matrix_params = []
for l3_layer in self.l3_layers.values():
if l3_layer.tie_kv:
l3_embed_params.append(l3_layer.kv_weight)
else:
l3_embed_params.append(l3_layer.k_weight)
l3_embed_params.append(l3_layer.v_weight)
l3_matrix_params.append(l3_layer.w_up.weight)
l3_matrix_params.append(l3_layer.w_mix.weight)
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(l3_embed_params) + len(l3_matrix_params)
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
@ -371,9 +438,13 @@ class GPT(nn.Module):
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
]
# L3 param groups: embeddings use AdamW (like token embeddings), matrices use Muon
if l3_embed_params:
param_groups.append(dict(kind='adamw', params=l3_embed_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0))
# Muon groups (matrix params, grouped by shape for stacking)
for shape in sorted({p.shape for p in matrix_params}):
group_params = [p for p in matrix_params if p.shape == shape]
all_matrix_params = matrix_params + l3_matrix_params
for shape in sorted({p.shape for p in all_matrix_params}):
group_params = [p for p in all_matrix_params if p.shape == shape]
param_groups.append(dict(
kind='muon', params=group_params, lr=matrix_lr,
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
@ -404,6 +475,9 @@ class GPT(nn.Module):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
# L3 layer after this block (if configured)
if str(i) in self.l3_layers:
x = x + self.l3_layers[str(i)](x, idx)
x = norm(x)
# Forward the lm_head (compute logits)

201
nanochat/l3.py Normal file
View File

@ -0,0 +1,201 @@
"""
L3: Large Lookup Layers
Ref: arXiv:2601.21461v2
L3 generalizes token embeddings by placing per-token lookup tables inside
the decoder stack. Unlike MoE, routing is static (determined by token ID),
eliminating router training and load-balancing losses.
"""
from collections import Counter
import torch
import torch.nn as nn
import torch.nn.functional as F
def compute_lzw_allocation(token_sequences, vocab_size, n_emb, k_max):
"""
Compute per-token embedding allocation using LZW-style frequency analysis.
Scans token sequences LZW-style, counting n-gram frequencies and allocating
embeddings to the last token of frequent n-grams. Every token starts with 1
embedding, then we greedily add embeddings to tokens involved in the most
frequent n-grams until we reach n_emb total.
Args:
token_sequences: list of token ID lists (training data sample)
vocab_size: size of vocabulary
n_emb: target total embeddings
k_max: max embeddings per token
Returns:
alloc: list[int] of length vocab_size (embeddings per token)
"""
assert n_emb >= vocab_size, f"n_emb ({n_emb}) must be >= vocab_size ({vocab_size})"
# Count token frequencies across all sequences
token_freq = Counter()
for seq in token_sequences:
for tok in seq:
token_freq[tok] += 1
# Also count bigram frequencies (LZW-style: last token of n-gram gets credit)
bigram_freq = Counter()
for seq in token_sequences:
for i in range(len(seq) - 1):
bigram_freq[seq[i + 1]] += 1 # credit goes to last token
# Combine unigram and bigram frequencies
combined_freq = Counter()
for tok in range(vocab_size):
combined_freq[tok] = token_freq.get(tok, 0) + bigram_freq.get(tok, 0)
# Start with 1 embedding per token
alloc = [1] * vocab_size
remaining = n_emb - vocab_size
if remaining <= 0:
return alloc
# Sort tokens by frequency (descending) for greedy allocation
sorted_tokens = sorted(range(vocab_size), key=lambda t: combined_freq[t], reverse=True)
# Greedily add embeddings to the most frequent tokens
while remaining > 0:
added_any = False
for tok in sorted_tokens:
if remaining <= 0:
break
if alloc[tok] < k_max:
alloc[tok] += 1
remaining -= 1
added_any = True
if not added_any:
break # all tokens at k_max, can't allocate more
return alloc
def allocation_to_bounds(alloc):
"""
Convert allocation array to cumulative bounds tensor.
bounds[0] = 0, bounds[i] = bounds[i-1] + alloc[i-1]
bounds[-1] = sum(alloc) = n_emb
Args:
alloc: list[int] of per-token allocation counts
Returns:
bounds: torch.LongTensor of shape [len(alloc) + 1]
"""
bounds = [0]
for a in alloc:
bounds.append(bounds[-1] + a)
return torch.tensor(bounds, dtype=torch.long)
class L3Layer(nn.Module):
"""
L3 layer: per-token lookup table with attention-like aggregation.
Forward pass (vectorized gather+pad approach):
1. Look up bounds for each token, gather KV embeddings, pad to k_max
2. Compute scores = K @ x_norm, mask invalid positions, softmax
3. Aggregate: weighted sum of V embeddings
4. Up-project, RMSNorm, concat with x, mix-project
Returns the delta (added residually by caller).
"""
def __init__(self, n_embd, n_emb, d_up, tie_kv=True):
super().__init__()
self.n_embd = n_embd
self.n_emb = n_emb
self.d_up = d_up
self.tie_kv = tie_kv
if tie_kv:
# Single shared weight for both keys and values
self.kv_weight = nn.Parameter(torch.empty(n_emb, n_embd))
else:
# Separate key and value weights
self.k_weight = nn.Parameter(torch.empty(n_emb, n_embd))
self.v_weight = nn.Parameter(torch.empty(n_emb, n_embd))
# Up-project from d_emb (= n_embd when tied) to d_up
self.w_up = nn.Linear(n_embd, d_up, bias=False)
# Mix-project: concat(up_projected, x) -> n_embd
self.w_mix = nn.Linear(d_up + n_embd, n_embd, bias=False)
# Bounds buffer (set after LZW allocation)
self.register_buffer("bounds", torch.zeros(1, dtype=torch.long), persistent=True)
def set_bounds(self, bounds):
"""Register the precomputed bounds tensor as a buffer."""
self.bounds = bounds
def forward(self, x, token_ids):
"""
Args:
x: [B, T, n_embd] hidden states
token_ids: [B, T] token IDs
Returns:
delta: [B, T, n_embd] to be added residually by caller
"""
B, T, C = x.shape
device = x.device
# 1. RMSNorm the input
x_norm = F.rms_norm(x, (C,))
# 2. Gather KV embeddings using bounds + token_ids
# Look up bounds for each token
flat_ids = token_ids.reshape(-1) # [B*T]
starts = self.bounds[flat_ids] # [B*T]
ends = self.bounds[flat_ids + 1] # [B*T]
lengths = ends - starts # [B*T]
k_max = lengths.max().item()
if k_max == 0:
# Edge case: no embeddings for any token
return torch.zeros_like(x)
# Build index tensor [B*T, k_max] with valid indices and padding
offsets = torch.arange(k_max, device=device).unsqueeze(0) # [1, k_max]
indices = starts.unsqueeze(1) + offsets # [B*T, k_max]
mask = offsets < lengths.unsqueeze(1) # [B*T, k_max] True for valid
# Clamp indices to valid range for gathering (masked ones will be zeroed out)
indices = indices.clamp(0, self.n_emb - 1)
# 3. Gather weights
if self.tie_kv:
kv = self.kv_weight[indices] # [B*T, k_max, n_embd]
k_emb = kv
v_emb = kv
else:
k_emb = self.k_weight[indices] # [B*T, k_max, n_embd]
v_emb = self.v_weight[indices] # [B*T, k_max, n_embd]
# 4. Compute attention scores: K @ x_norm
x_flat = x_norm.reshape(B * T, C) # [B*T, C]
scores = torch.bmm(k_emb, x_flat.unsqueeze(2)).squeeze(2) # [B*T, k_max]
# Mask invalid positions
scores = scores.masked_fill(~mask, float('-inf'))
# Softmax over valid positions
weights = F.softmax(scores, dim=-1) # [B*T, k_max]
# Replace NaN from all-inf rows (shouldn't happen since min alloc=1, but safety)
weights = weights.masked_fill(~mask, 0.0)
# 5. Aggregate: weighted sum of V embeddings
agg = torch.bmm(weights.unsqueeze(1), v_emb).squeeze(1) # [B*T, n_embd]
agg = agg.view(B, T, C)
# 6. Up-project, RMSNorm, concat with x, mix-project
up = self.w_up(agg) # [B, T, d_up]
up = F.rms_norm(up, (self.d_up,)) # normalize
cat = torch.cat([up, x], dim=-1) # [B, T, d_up + n_embd]
delta = self.w_mix(cat) # [B, T, n_embd]
return delta

View File

@ -51,6 +51,11 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
# L3 (Large Lookup Layers)
parser.add_argument("--l3-after-layers", type=str, default="", help="comma-separated layer indices for L3 (empty = disabled)")
parser.add_argument("--l3-n-emb", type=int, default=0, help="total L3 embeddings (0 = auto-derive from model size)")
parser.add_argument("--l3-d-up", type=int, default=0, help="L3 up-projection dim (0 = 4*n_embd)")
parser.add_argument("--l3-k-max", type=int, default=512, help="max embeddings per token for L3")
# Training horizon (only one used, in order of precedence)
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
@ -122,7 +127,7 @@ print0(f"Vocab size: {vocab_size:,}")
# -----------------------------------------------------------------------------
# Initialize the Model
def build_model_meta(depth):
def build_model_meta(depth, l3_after_layers="", l3_n_emb=0):
"""Build a model on meta device for a given depth (shapes/dtypes only, no data)."""
# Model dim is nudged up to nearest multiple of head_dim for clean division
# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
@ -133,19 +138,53 @@ def build_model_meta(depth):
sequence_len=args.max_seq_len, vocab_size=vocab_size,
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
window_pattern=args.window_pattern,
l3_after_layers=l3_after_layers,
l3_n_emb=l3_n_emb,
l3_d_up=args.l3_d_up,
l3_k_max=args.l3_k_max,
)
with torch.device("meta"):
model_meta = GPT(config)
return model_meta
# L3 precomputation: compute LZW allocation from training data sample
l3_n_emb = args.l3_n_emb
l3_bounds = None
if args.l3_after_layers:
from nanochat.l3 import compute_lzw_allocation, allocation_to_bounds
# Auto-derive n_emb if not specified: scale proportional to model size
if l3_n_emb == 0:
# Build a temporary model to get param count for auto-derivation
tmp_model = build_model_meta(args.depth, l3_after_layers=args.l3_after_layers, l3_n_emb=vocab_size)
tmp_params = sum(p.numel() for p in tmp_model.parameters())
l3_n_emb = max(vocab_size, tmp_params // 1000)
del tmp_model
print0(f"Auto-derived L3 n_emb: {l3_n_emb:,}")
# Read a sample of training data for LZW allocation
sample_loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, 1, args.max_seq_len, split="train", device=device)
sample_sequences = []
for _ in range(100): # 100 batches should be enough
x_sample, _ = next(sample_loader)
sample_sequences.append(x_sample[0].tolist())
del sample_loader
l3_alloc = compute_lzw_allocation(sample_sequences, vocab_size, l3_n_emb, args.l3_k_max)
l3_bounds = allocation_to_bounds(l3_alloc).to(device)
print0(f"L3 allocation: {l3_n_emb:,} total embeddings, k_max={args.l3_k_max}, avg={l3_n_emb/vocab_size:.1f}/token")
# Build the model, move to device, init the weights
model = build_model_meta(args.depth) # 1) Build on meta device (only shapes/dtypes, no data)
model = build_model_meta(args.depth, l3_after_layers=args.l3_after_layers, l3_n_emb=l3_n_emb) # 1) Build on meta device (only shapes/dtypes, no data)
model_config = model.config
model_config_kwargs = asdict(model_config)
print0(f"Model config:\n{json.dumps(model_config_kwargs, indent=2)}")
model.to_empty(device=device) # 2) All tensors get storage on target device but with uninitialized (garbage) data
model.init_weights() # 3) All tensors get initialized
# Set L3 bounds after model creation
if l3_bounds is not None:
for l3_layer in model.l3_layers.values():
l3_layer.set_bounds(l3_bounds)
print0(f"L3 bounds set for {len(model.l3_layers)} layer(s)")
# If we are resuming, overwrite the model parameters with those of the checkpoint
base_dir = get_base_dir()
output_dirname = args.model_tag if args.model_tag else f"d{args.depth}" # e.g. d12

263
tests/test_l3.py Normal file
View File

@ -0,0 +1,263 @@
import torch
import torch.nn.functional as F
from nanochat.l3 import compute_lzw_allocation, allocation_to_bounds, L3Layer
# ---- LZW Allocation Tests ----
def test_lzw_allocation_total():
"""Allocation sums to target n_emb."""
sequences = [[0, 1, 2, 3, 0, 1, 2, 3, 0, 1]] * 10
alloc = compute_lzw_allocation(sequences, vocab_size=8, n_emb=100, k_max=32)
assert sum(alloc) == 100
def test_lzw_allocation_min_one():
"""Every token gets >= 1 embedding, even unseen tokens."""
sequences = [[0, 1, 0, 1, 0, 1]] # only tokens 0 and 1 appear
alloc = compute_lzw_allocation(sequences, vocab_size=8, n_emb=100, k_max=32)
assert all(a >= 1 for a in alloc)
assert len(alloc) == 8
def test_lzw_allocation_max_k():
"""No token exceeds k_max embeddings."""
sequences = [[0, 0, 0, 0, 0, 0, 0, 0]] * 100 # token 0 is extremely frequent
alloc = compute_lzw_allocation(sequences, vocab_size=4, n_emb=50, k_max=8)
assert all(a <= 8 for a in alloc)
def test_lzw_allocation_distribution():
"""Frequent tokens get more embeddings than rare ones."""
# Token 0 appears 100x, token 1 appears 10x, tokens 2-7 appear once each
sequences = [[0] * 100 + [1] * 10 + [2, 3, 4, 5, 6, 7]]
alloc = compute_lzw_allocation(sequences, vocab_size=8, n_emb=64, k_max=32)
assert alloc[0] >= alloc[1], f"Token 0 (freq 100) should get >= Token 1 (freq 10): {alloc[0]} vs {alloc[1]}"
assert alloc[1] >= alloc[7], f"Token 1 (freq 10) should get >= Token 7 (freq 1): {alloc[1]} vs {alloc[7]}"
def test_lzw_bounds_from_allocation():
"""Bounds array is correct cumulative sum."""
alloc = [3, 1, 5, 2]
bounds = allocation_to_bounds(alloc)
assert bounds.tolist() == [0, 3, 4, 9, 11]
assert bounds[-1].item() == sum(alloc)
def test_lzw_allocation_edge_n_emb_equals_vocab():
"""When n_emb == vocab_size, every token gets exactly 1."""
sequences = [[0, 1, 2, 3]]
alloc = compute_lzw_allocation(sequences, vocab_size=4, n_emb=4, k_max=32)
assert alloc == [1, 1, 1, 1]
# ---- L3Layer Tests ----
def _make_l3(n_embd=16, n_emb=32, d_up=64, vocab_size=8, tie_kv=True):
"""Helper to create an L3Layer with bounds set and properly initialized."""
torch.manual_seed(42)
layer = L3Layer(n_embd=n_embd, n_emb=n_emb, d_up=d_up, tie_kv=tie_kv)
# Initialize weights to avoid garbage values from torch.empty()
for name, p in layer.named_parameters():
if p.dim() >= 2:
torch.nn.init.normal_(p, std=0.1)
else:
torch.nn.init.zeros_(p)
# Simple uniform allocation: each token gets n_emb // vocab_size embeddings
per_token = n_emb // vocab_size
alloc = [per_token] * vocab_size
# Distribute remainder
remainder = n_emb - sum(alloc)
for i in range(remainder):
alloc[i] += 1
bounds = allocation_to_bounds(alloc)
layer.set_bounds(bounds)
return layer
def test_l3_layer_output_shape():
"""Output is [B, T, n_embd]."""
n_embd = 16
layer = _make_l3(n_embd=n_embd, n_emb=32, d_up=64, vocab_size=8)
x = torch.randn(2, 5, n_embd)
token_ids = torch.randint(0, 8, (2, 5))
out = layer(x, token_ids)
assert out.shape == (2, 5, n_embd)
def test_l3_layer_gradient_flow():
"""All parameters receive gradients."""
layer = _make_l3(n_embd=16, n_emb=32, d_up=64, vocab_size=8)
x = torch.randn(2, 5, 16, requires_grad=True)
token_ids = torch.randint(0, 8, (2, 5))
out = layer(x, token_ids)
loss = out.sum()
loss.backward()
for name, p in layer.named_parameters():
assert p.grad is not None, f"No gradient for {name}"
assert p.grad.abs().sum() > 0, f"Zero gradient for {name}"
assert x.grad is not None, "No gradient for input x"
def test_l3_layer_tied_kv():
"""Tied mode uses single weight matrix (kv_weight), no separate k/v."""
layer_tied = _make_l3(tie_kv=True)
layer_untied = _make_l3(tie_kv=False)
tied_params = {n for n, _ in layer_tied.named_parameters()}
untied_params = {n for n, _ in layer_untied.named_parameters()}
assert "kv_weight" in tied_params
assert "k_weight" not in tied_params
assert "v_weight" not in tied_params
assert "k_weight" in untied_params
assert "v_weight" in untied_params
assert "kv_weight" not in untied_params
def test_l3_layer_masking():
"""Tokens with fewer embeddings are properly masked (no NaN/inf)."""
n_embd = 16
# Non-uniform allocation: token 0 gets 10, token 1 gets 2
layer = L3Layer(n_embd=n_embd, n_emb=12, d_up=64, tie_kv=True)
bounds = torch.tensor([0, 10, 12, 12]) # 3 tokens: 10, 2, 0 embeddings
# Token 2 has 0 embeddings - adjust to at least 1
bounds = torch.tensor([0, 9, 11, 12]) # 3 tokens: 9, 2, 1
layer.set_bounds(bounds)
x = torch.randn(1, 3, n_embd)
token_ids = torch.tensor([[0, 1, 2]])
out = layer(x, token_ids)
assert not torch.isnan(out).any(), "Output contains NaN"
assert not torch.isinf(out).any(), "Output contains inf"
assert out.shape == (1, 3, n_embd)
def test_l3_layer_deterministic():
"""Same input produces same output."""
layer = _make_l3()
x = torch.randn(2, 5, 16)
token_ids = torch.randint(0, 8, (2, 5))
out1 = layer(x, token_ids)
out2 = layer(x, token_ids)
assert torch.allclose(out1, out2)
def test_l3_layer_untied_output_shape():
"""Untied mode also produces correct output shape."""
n_embd = 16
layer = _make_l3(n_embd=n_embd, n_emb=32, d_up=64, vocab_size=8, tie_kv=False)
x = torch.randn(2, 5, n_embd)
token_ids = torch.randint(0, 8, (2, 5))
out = layer(x, token_ids)
assert out.shape == (2, 5, n_embd)
def test_l3_layer_untied_gradient_flow():
"""All parameters receive gradients in untied mode."""
layer = _make_l3(n_embd=16, n_emb=32, d_up=64, vocab_size=8, tie_kv=False)
x = torch.randn(2, 5, 16, requires_grad=True)
token_ids = torch.randint(0, 8, (2, 5))
out = layer(x, token_ids)
loss = out.sum()
loss.backward()
for name, p in layer.named_parameters():
assert p.grad is not None, f"No gradient for {name}"
assert p.grad.abs().sum() > 0, f"Zero gradient for {name}"
# ---- GPT Integration Tests ----
def test_gpt_with_l3_forward():
"""Full model with L3 runs forward pass."""
from nanochat.gpt import GPT, GPTConfig
config = GPTConfig(
sequence_len=8, vocab_size=64, n_layer=4,
n_head=2, n_kv_head=2, n_embd=64,
l3_after_layers="2", l3_n_emb=128, l3_d_up=32, l3_k_max=16,
)
model = GPT(config)
model.init_weights()
# Set bounds for L3 layers
alloc = compute_lzw_allocation([[0, 1, 2, 3] * 4], vocab_size=64, n_emb=128, k_max=16)
bounds = allocation_to_bounds(alloc)
for l3_layer in model.l3_layers.values():
l3_layer.set_bounds(bounds)
x = torch.randint(0, 64, (2, 8))
y = torch.randint(0, 64, (2, 8))
loss = model(x, y)
assert loss.ndim == 0 # scalar loss
assert not torch.isnan(loss)
def test_gpt_with_l3_backward():
"""loss.backward() works, L3 params get gradients."""
from nanochat.gpt import GPT, GPTConfig
config = GPTConfig(
sequence_len=8, vocab_size=64, n_layer=4,
n_head=2, n_kv_head=2, n_embd=64,
l3_after_layers="2", l3_n_emb=128, l3_d_up=32, l3_k_max=16,
)
model = GPT(config)
model.init_weights()
alloc = compute_lzw_allocation([[0, 1, 2, 3] * 4], vocab_size=64, n_emb=128, k_max=16)
bounds = allocation_to_bounds(alloc)
for l3_layer in model.l3_layers.values():
l3_layer.set_bounds(bounds)
x = torch.randint(0, 64, (2, 8))
y = torch.randint(0, 64, (2, 8))
# Need two forward passes: first to propagate signal through lm_head (init zeros)
loss = model(x, y)
loss.backward()
optimizer = model.setup_optimizer()
optimizer.step()
model.zero_grad(set_to_none=True)
# Second pass should give gradients to L3 params
loss = model(x, y)
loss.backward()
for name, p in model.l3_layers.named_parameters():
if 'bounds' not in name: # bounds is a buffer, not a parameter
assert p.grad is not None, f"No gradient for L3 param {name}"
def test_gpt_with_l3_optimizer():
"""setup_optimizer includes L3 params."""
from nanochat.gpt import GPT, GPTConfig
config = GPTConfig(
sequence_len=8, vocab_size=64, n_layer=4,
n_head=2, n_kv_head=2, n_embd=64,
l3_after_layers="2", l3_n_emb=128, l3_d_up=32, l3_k_max=16,
)
model = GPT(config)
model.init_weights()
alloc = compute_lzw_allocation([[0, 1, 2, 3] * 4], vocab_size=64, n_emb=128, k_max=16)
bounds = allocation_to_bounds(alloc)
for l3_layer in model.l3_layers.values():
l3_layer.set_bounds(bounds)
optimizer = model.setup_optimizer()
# Collect all optimizer param ids
opt_param_ids = set()
for group in optimizer.param_groups:
for p in group["params"]:
opt_param_ids.add(id(p))
# Check all model params are in optimizer
for name, p in model.named_parameters():
assert id(p) in opt_param_ids, f"Parameter {name} not in optimizer"
def test_gpt_without_l3_unchanged():
"""L3 disabled = identical to current behavior (no L3 layers created)."""
from nanochat.gpt import GPT, GPTConfig
config = GPTConfig(
sequence_len=8, vocab_size=64, n_layer=4,
n_head=2, n_kv_head=2, n_embd=64,
)
model = GPT(config)
assert len(model.l3_layers) == 0
model.init_weights()
x = torch.randint(0, 64, (2, 8))
y = torch.randint(0, 64, (2, 8))
loss = model(x, y)
assert loss.ndim == 0
assert not torch.isnan(loss)