diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 208acd1..030a39d 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -24,6 +24,7 @@ from nanochat.optim import MuonAdamW, DistMuonAdamW # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere from nanochat.flash_attention import flash_attn +from nanochat.l3 import L3Layer @dataclass class GPTConfig: @@ -37,6 +38,12 @@ class GPTConfig: # Characters: L=long (full context), S=short (half context) # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long window_pattern: str = "SSSL" + # L3 (Large Lookup Layers) config + l3_after_layers: str = "" # comma-separated layer indices (empty = disabled) + l3_n_emb: int = 0 # total embeddings (0 = disabled) + l3_d_up: int = 0 # up-projection dim (0 = auto: 4 * n_embd) + l3_k_max: int = 512 # max embeddings per token + l3_tie_kv: bool = True # tie key and value weights def norm(x): @@ -175,6 +182,13 @@ class GPT(nn.Module): head_dim = config.n_embd // config.n_head kv_dim = config.n_kv_head * head_dim self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)}) + # L3 layers (placed between decoder blocks) + self.l3_layer_indices = set(int(x) for x in config.l3_after_layers.split(",") if x.strip()) if config.l3_after_layers else set() + l3_d_up = config.l3_d_up if config.l3_d_up > 0 else 4 * config.n_embd + self.l3_layers = nn.ModuleDict({ + str(i): L3Layer(config.n_embd, config.l3_n_emb, l3_d_up, config.l3_tie_kv) + for i in self.l3_layer_indices + }) if self.l3_layer_indices and config.l3_n_emb > 0 else nn.ModuleDict() # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them by 10X, but assert fail if we ever reach that amount. @@ -229,6 +243,16 @@ class GPT(nn.Module): if block.attn.ve_gate is not None: torch.nn.init.zeros_(block.attn.ve_gate.weight) + # L3 layers + for l3_layer in self.l3_layers.values(): + if l3_layer.tie_kv: + torch.nn.init.normal_(l3_layer.kv_weight, mean=0.0, std=1.0) + else: + torch.nn.init.normal_(l3_layer.k_weight, mean=0.0, std=1.0) + torch.nn.init.normal_(l3_layer.v_weight, mean=0.0, std=1.0) + torch.nn.init.uniform_(l3_layer.w_up.weight, -s, s) + torch.nn.init.zeros_(l3_layer.w_mix.weight) # L3 starts as no-op + # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) @@ -239,6 +263,12 @@ class GPT(nn.Module): self.transformer.wte.to(dtype=torch.bfloat16) for ve in self.value_embeds.values(): ve.to(dtype=torch.bfloat16) + for l3_layer in self.l3_layers.values(): + if l3_layer.tie_kv: + l3_layer.kv_weight.data = l3_layer.kv_weight.data.to(dtype=torch.bfloat16) + else: + l3_layer.k_weight.data = l3_layer.k_weight.data.to(dtype=torch.bfloat16) + l3_layer.v_weight.data = l3_layer.v_weight.data.to(dtype=torch.bfloat16) def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): # TODO: bump base theta more? e.g. 100K is more common more recently @@ -304,7 +334,14 @@ class GPT(nn.Module): nparams = sum(p.numel() for p in self.parameters()) # Exclude non-matmul params: embeddings and per-layer scalars value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) - nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + + # L3 kv/k/v weights are embeddings (lookup tables), not matmul weights + l3_embed_numel = 0 + for l3_layer in self.l3_layers.values(): + if l3_layer.tie_kv: + l3_embed_numel += l3_layer.kv_weight.numel() + else: + l3_embed_numel += l3_layer.k_weight.numel() + l3_layer.v_weight.numel() + nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + l3_embed_numel + self.resid_lambdas.numel() + self.x0_lambdas.numel()) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window @@ -313,7 +350,15 @@ class GPT(nn.Module): window = window_size[0] # (left, right) tuple, we use left effective_seq = t if window < 0 else min(window, t) attn_flops += 12 * h * q * effective_seq - num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops + # L3 FLOPs: per layer, approx 2*avg_k*n_embd (attention) + 2*n_embd*d_up (up) + 2*(d_up+n_embd)*n_embd (mix) + l3_flops = 0 + if self.l3_layers: + n_embd = self.config.n_embd + l3_d_up = self.config.l3_d_up if self.config.l3_d_up > 0 else 4 * n_embd + avg_k = self.config.l3_n_emb / self.config.vocab_size if self.config.vocab_size > 0 else 1 + per_l3 = 2 * avg_k * n_embd + 2 * n_embd * l3_d_up + 2 * (l3_d_up + n_embd) * n_embd + l3_flops = int(per_l3 * len(self.l3_layers)) + num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops + l3_flops return num_flops_per_token def num_scaling_params(self): @@ -334,13 +379,24 @@ class GPT(nn.Module): lm_head = sum(p.numel() for p in self.lm_head.parameters()) transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters()) scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() - total = wte + value_embeds + lm_head + transformer_matrices + scalars + # L3: separate embedding-like params from matrix params + l3_embeds = 0 + l3_matrices = 0 + for l3_layer in self.l3_layers.values(): + if l3_layer.tie_kv: + l3_embeds += l3_layer.kv_weight.numel() + else: + l3_embeds += l3_layer.k_weight.numel() + l3_layer.v_weight.numel() + l3_matrices += l3_layer.w_up.weight.numel() + l3_layer.w_mix.weight.numel() + total = wte + value_embeds + lm_head + transformer_matrices + scalars + l3_embeds + l3_matrices assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch" return { 'wte': wte, 'value_embeds': value_embeds, 'lm_head': lm_head, 'transformer_matrices': transformer_matrices, + 'l3_embeds': l3_embeds, + 'l3_matrices': l3_matrices, 'scalars': scalars, 'total': total, } @@ -356,7 +412,18 @@ class GPT(nn.Module): lm_head_params = list(self.lm_head.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + # L3 params: embedding-like (kv/k/v weights) and matrix (w_up, w_mix) + l3_embed_params = [] + l3_matrix_params = [] + for l3_layer in self.l3_layers.values(): + if l3_layer.tie_kv: + l3_embed_params.append(l3_layer.kv_weight) + else: + l3_embed_params.append(l3_layer.k_weight) + l3_embed_params.append(l3_layer.v_weight) + l3_matrix_params.append(l3_layer.w_up.weight) + l3_matrix_params.append(l3_layer.w_mix.weight) + assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(l3_embed_params) + len(l3_matrix_params) # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 @@ -371,9 +438,13 @@ class GPT(nn.Module): dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0), dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0 ] + # L3 param groups: embeddings use AdamW (like token embeddings), matrices use Muon + if l3_embed_params: + param_groups.append(dict(kind='adamw', params=l3_embed_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0)) # Muon groups (matrix params, grouped by shape for stacking) - for shape in sorted({p.shape for p in matrix_params}): - group_params = [p for p in matrix_params if p.shape == shape] + all_matrix_params = matrix_params + l3_matrix_params + for shape in sorted({p.shape for p in all_matrix_params}): + group_params = [p for p in all_matrix_params if p.shape == shape] param_groups.append(dict( kind='muon', params=group_params, lr=matrix_lr, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay, @@ -404,6 +475,9 @@ class GPT(nn.Module): x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) + # L3 layer after this block (if configured) + if str(i) in self.l3_layers: + x = x + self.l3_layers[str(i)](x, idx) x = norm(x) # Forward the lm_head (compute logits) diff --git a/nanochat/l3.py b/nanochat/l3.py new file mode 100644 index 0000000..6f16008 --- /dev/null +++ b/nanochat/l3.py @@ -0,0 +1,201 @@ +""" +L3: Large Lookup Layers +Ref: arXiv:2601.21461v2 + +L3 generalizes token embeddings by placing per-token lookup tables inside +the decoder stack. Unlike MoE, routing is static (determined by token ID), +eliminating router training and load-balancing losses. +""" + +from collections import Counter + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def compute_lzw_allocation(token_sequences, vocab_size, n_emb, k_max): + """ + Compute per-token embedding allocation using LZW-style frequency analysis. + + Scans token sequences LZW-style, counting n-gram frequencies and allocating + embeddings to the last token of frequent n-grams. Every token starts with 1 + embedding, then we greedily add embeddings to tokens involved in the most + frequent n-grams until we reach n_emb total. + + Args: + token_sequences: list of token ID lists (training data sample) + vocab_size: size of vocabulary + n_emb: target total embeddings + k_max: max embeddings per token + Returns: + alloc: list[int] of length vocab_size (embeddings per token) + """ + assert n_emb >= vocab_size, f"n_emb ({n_emb}) must be >= vocab_size ({vocab_size})" + + # Count token frequencies across all sequences + token_freq = Counter() + for seq in token_sequences: + for tok in seq: + token_freq[tok] += 1 + + # Also count bigram frequencies (LZW-style: last token of n-gram gets credit) + bigram_freq = Counter() + for seq in token_sequences: + for i in range(len(seq) - 1): + bigram_freq[seq[i + 1]] += 1 # credit goes to last token + + # Combine unigram and bigram frequencies + combined_freq = Counter() + for tok in range(vocab_size): + combined_freq[tok] = token_freq.get(tok, 0) + bigram_freq.get(tok, 0) + + # Start with 1 embedding per token + alloc = [1] * vocab_size + remaining = n_emb - vocab_size + + if remaining <= 0: + return alloc + + # Sort tokens by frequency (descending) for greedy allocation + sorted_tokens = sorted(range(vocab_size), key=lambda t: combined_freq[t], reverse=True) + + # Greedily add embeddings to the most frequent tokens + while remaining > 0: + added_any = False + for tok in sorted_tokens: + if remaining <= 0: + break + if alloc[tok] < k_max: + alloc[tok] += 1 + remaining -= 1 + added_any = True + if not added_any: + break # all tokens at k_max, can't allocate more + + return alloc + + +def allocation_to_bounds(alloc): + """ + Convert allocation array to cumulative bounds tensor. + + bounds[0] = 0, bounds[i] = bounds[i-1] + alloc[i-1] + bounds[-1] = sum(alloc) = n_emb + + Args: + alloc: list[int] of per-token allocation counts + Returns: + bounds: torch.LongTensor of shape [len(alloc) + 1] + """ + bounds = [0] + for a in alloc: + bounds.append(bounds[-1] + a) + return torch.tensor(bounds, dtype=torch.long) + + +class L3Layer(nn.Module): + """ + L3 layer: per-token lookup table with attention-like aggregation. + + Forward pass (vectorized gather+pad approach): + 1. Look up bounds for each token, gather KV embeddings, pad to k_max + 2. Compute scores = K @ x_norm, mask invalid positions, softmax + 3. Aggregate: weighted sum of V embeddings + 4. Up-project, RMSNorm, concat with x, mix-project + Returns the delta (added residually by caller). + """ + + def __init__(self, n_embd, n_emb, d_up, tie_kv=True): + super().__init__() + self.n_embd = n_embd + self.n_emb = n_emb + self.d_up = d_up + self.tie_kv = tie_kv + + if tie_kv: + # Single shared weight for both keys and values + self.kv_weight = nn.Parameter(torch.empty(n_emb, n_embd)) + else: + # Separate key and value weights + self.k_weight = nn.Parameter(torch.empty(n_emb, n_embd)) + self.v_weight = nn.Parameter(torch.empty(n_emb, n_embd)) + + # Up-project from d_emb (= n_embd when tied) to d_up + self.w_up = nn.Linear(n_embd, d_up, bias=False) + # Mix-project: concat(up_projected, x) -> n_embd + self.w_mix = nn.Linear(d_up + n_embd, n_embd, bias=False) + + # Bounds buffer (set after LZW allocation) + self.register_buffer("bounds", torch.zeros(1, dtype=torch.long), persistent=True) + + def set_bounds(self, bounds): + """Register the precomputed bounds tensor as a buffer.""" + self.bounds = bounds + + def forward(self, x, token_ids): + """ + Args: + x: [B, T, n_embd] hidden states + token_ids: [B, T] token IDs + Returns: + delta: [B, T, n_embd] to be added residually by caller + """ + B, T, C = x.shape + device = x.device + + # 1. RMSNorm the input + x_norm = F.rms_norm(x, (C,)) + + # 2. Gather KV embeddings using bounds + token_ids + # Look up bounds for each token + flat_ids = token_ids.reshape(-1) # [B*T] + starts = self.bounds[flat_ids] # [B*T] + ends = self.bounds[flat_ids + 1] # [B*T] + lengths = ends - starts # [B*T] + k_max = lengths.max().item() + + if k_max == 0: + # Edge case: no embeddings for any token + return torch.zeros_like(x) + + # Build index tensor [B*T, k_max] with valid indices and padding + offsets = torch.arange(k_max, device=device).unsqueeze(0) # [1, k_max] + indices = starts.unsqueeze(1) + offsets # [B*T, k_max] + mask = offsets < lengths.unsqueeze(1) # [B*T, k_max] True for valid + + # Clamp indices to valid range for gathering (masked ones will be zeroed out) + indices = indices.clamp(0, self.n_emb - 1) + + # 3. Gather weights + if self.tie_kv: + kv = self.kv_weight[indices] # [B*T, k_max, n_embd] + k_emb = kv + v_emb = kv + else: + k_emb = self.k_weight[indices] # [B*T, k_max, n_embd] + v_emb = self.v_weight[indices] # [B*T, k_max, n_embd] + + # 4. Compute attention scores: K @ x_norm + x_flat = x_norm.reshape(B * T, C) # [B*T, C] + scores = torch.bmm(k_emb, x_flat.unsqueeze(2)).squeeze(2) # [B*T, k_max] + + # Mask invalid positions + scores = scores.masked_fill(~mask, float('-inf')) + + # Softmax over valid positions + weights = F.softmax(scores, dim=-1) # [B*T, k_max] + # Replace NaN from all-inf rows (shouldn't happen since min alloc=1, but safety) + weights = weights.masked_fill(~mask, 0.0) + + # 5. Aggregate: weighted sum of V embeddings + agg = torch.bmm(weights.unsqueeze(1), v_emb).squeeze(1) # [B*T, n_embd] + agg = agg.view(B, T, C) + + # 6. Up-project, RMSNorm, concat with x, mix-project + up = self.w_up(agg) # [B, T, d_up] + up = F.rms_norm(up, (self.d_up,)) # normalize + cat = torch.cat([up, x], dim=-1) # [B, T, d_up + n_embd] + delta = self.w_mix(cat) # [B, T, n_embd] + + return delta diff --git a/scripts/base_train.py b/scripts/base_train.py index 24091b6..a6b6211 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -51,6 +51,11 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") +# L3 (Large Lookup Layers) +parser.add_argument("--l3-after-layers", type=str, default="", help="comma-separated layer indices for L3 (empty = disabled)") +parser.add_argument("--l3-n-emb", type=int, default=0, help="total L3 embeddings (0 = auto-derive from model size)") +parser.add_argument("--l3-d-up", type=int, default=0, help="L3 up-projection dim (0 = 4*n_embd)") +parser.add_argument("--l3-k-max", type=int, default=512, help="max embeddings per token for L3") # Training horizon (only one used, in order of precedence) parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") @@ -122,7 +127,7 @@ print0(f"Vocab size: {vocab_size:,}") # ----------------------------------------------------------------------------- # Initialize the Model -def build_model_meta(depth): +def build_model_meta(depth, l3_after_layers="", l3_n_emb=0): """Build a model on meta device for a given depth (shapes/dtypes only, no data).""" # Model dim is nudged up to nearest multiple of head_dim for clean division # (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly) @@ -133,19 +138,53 @@ def build_model_meta(depth): sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, window_pattern=args.window_pattern, + l3_after_layers=l3_after_layers, + l3_n_emb=l3_n_emb, + l3_d_up=args.l3_d_up, + l3_k_max=args.l3_k_max, ) with torch.device("meta"): model_meta = GPT(config) return model_meta +# L3 precomputation: compute LZW allocation from training data sample +l3_n_emb = args.l3_n_emb +l3_bounds = None +if args.l3_after_layers: + from nanochat.l3 import compute_lzw_allocation, allocation_to_bounds + # Auto-derive n_emb if not specified: scale proportional to model size + if l3_n_emb == 0: + # Build a temporary model to get param count for auto-derivation + tmp_model = build_model_meta(args.depth, l3_after_layers=args.l3_after_layers, l3_n_emb=vocab_size) + tmp_params = sum(p.numel() for p in tmp_model.parameters()) + l3_n_emb = max(vocab_size, tmp_params // 1000) + del tmp_model + print0(f"Auto-derived L3 n_emb: {l3_n_emb:,}") + # Read a sample of training data for LZW allocation + sample_loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, 1, args.max_seq_len, split="train", device=device) + sample_sequences = [] + for _ in range(100): # 100 batches should be enough + x_sample, _ = next(sample_loader) + sample_sequences.append(x_sample[0].tolist()) + del sample_loader + l3_alloc = compute_lzw_allocation(sample_sequences, vocab_size, l3_n_emb, args.l3_k_max) + l3_bounds = allocation_to_bounds(l3_alloc).to(device) + print0(f"L3 allocation: {l3_n_emb:,} total embeddings, k_max={args.l3_k_max}, avg={l3_n_emb/vocab_size:.1f}/token") + # Build the model, move to device, init the weights -model = build_model_meta(args.depth) # 1) Build on meta device (only shapes/dtypes, no data) +model = build_model_meta(args.depth, l3_after_layers=args.l3_after_layers, l3_n_emb=l3_n_emb) # 1) Build on meta device (only shapes/dtypes, no data) model_config = model.config model_config_kwargs = asdict(model_config) print0(f"Model config:\n{json.dumps(model_config_kwargs, indent=2)}") model.to_empty(device=device) # 2) All tensors get storage on target device but with uninitialized (garbage) data model.init_weights() # 3) All tensors get initialized +# Set L3 bounds after model creation +if l3_bounds is not None: + for l3_layer in model.l3_layers.values(): + l3_layer.set_bounds(l3_bounds) + print0(f"L3 bounds set for {len(model.l3_layers)} layer(s)") + # If we are resuming, overwrite the model parameters with those of the checkpoint base_dir = get_base_dir() output_dirname = args.model_tag if args.model_tag else f"d{args.depth}" # e.g. d12 diff --git a/tests/test_l3.py b/tests/test_l3.py new file mode 100644 index 0000000..368a5c8 --- /dev/null +++ b/tests/test_l3.py @@ -0,0 +1,263 @@ +import torch +import torch.nn.functional as F + +from nanochat.l3 import compute_lzw_allocation, allocation_to_bounds, L3Layer + + +# ---- LZW Allocation Tests ---- + +def test_lzw_allocation_total(): + """Allocation sums to target n_emb.""" + sequences = [[0, 1, 2, 3, 0, 1, 2, 3, 0, 1]] * 10 + alloc = compute_lzw_allocation(sequences, vocab_size=8, n_emb=100, k_max=32) + assert sum(alloc) == 100 + + +def test_lzw_allocation_min_one(): + """Every token gets >= 1 embedding, even unseen tokens.""" + sequences = [[0, 1, 0, 1, 0, 1]] # only tokens 0 and 1 appear + alloc = compute_lzw_allocation(sequences, vocab_size=8, n_emb=100, k_max=32) + assert all(a >= 1 for a in alloc) + assert len(alloc) == 8 + + +def test_lzw_allocation_max_k(): + """No token exceeds k_max embeddings.""" + sequences = [[0, 0, 0, 0, 0, 0, 0, 0]] * 100 # token 0 is extremely frequent + alloc = compute_lzw_allocation(sequences, vocab_size=4, n_emb=50, k_max=8) + assert all(a <= 8 for a in alloc) + + +def test_lzw_allocation_distribution(): + """Frequent tokens get more embeddings than rare ones.""" + # Token 0 appears 100x, token 1 appears 10x, tokens 2-7 appear once each + sequences = [[0] * 100 + [1] * 10 + [2, 3, 4, 5, 6, 7]] + alloc = compute_lzw_allocation(sequences, vocab_size=8, n_emb=64, k_max=32) + assert alloc[0] >= alloc[1], f"Token 0 (freq 100) should get >= Token 1 (freq 10): {alloc[0]} vs {alloc[1]}" + assert alloc[1] >= alloc[7], f"Token 1 (freq 10) should get >= Token 7 (freq 1): {alloc[1]} vs {alloc[7]}" + + +def test_lzw_bounds_from_allocation(): + """Bounds array is correct cumulative sum.""" + alloc = [3, 1, 5, 2] + bounds = allocation_to_bounds(alloc) + assert bounds.tolist() == [0, 3, 4, 9, 11] + assert bounds[-1].item() == sum(alloc) + + +def test_lzw_allocation_edge_n_emb_equals_vocab(): + """When n_emb == vocab_size, every token gets exactly 1.""" + sequences = [[0, 1, 2, 3]] + alloc = compute_lzw_allocation(sequences, vocab_size=4, n_emb=4, k_max=32) + assert alloc == [1, 1, 1, 1] + + +# ---- L3Layer Tests ---- + +def _make_l3(n_embd=16, n_emb=32, d_up=64, vocab_size=8, tie_kv=True): + """Helper to create an L3Layer with bounds set and properly initialized.""" + torch.manual_seed(42) + layer = L3Layer(n_embd=n_embd, n_emb=n_emb, d_up=d_up, tie_kv=tie_kv) + # Initialize weights to avoid garbage values from torch.empty() + for name, p in layer.named_parameters(): + if p.dim() >= 2: + torch.nn.init.normal_(p, std=0.1) + else: + torch.nn.init.zeros_(p) + # Simple uniform allocation: each token gets n_emb // vocab_size embeddings + per_token = n_emb // vocab_size + alloc = [per_token] * vocab_size + # Distribute remainder + remainder = n_emb - sum(alloc) + for i in range(remainder): + alloc[i] += 1 + bounds = allocation_to_bounds(alloc) + layer.set_bounds(bounds) + return layer + + +def test_l3_layer_output_shape(): + """Output is [B, T, n_embd].""" + n_embd = 16 + layer = _make_l3(n_embd=n_embd, n_emb=32, d_up=64, vocab_size=8) + x = torch.randn(2, 5, n_embd) + token_ids = torch.randint(0, 8, (2, 5)) + out = layer(x, token_ids) + assert out.shape == (2, 5, n_embd) + + +def test_l3_layer_gradient_flow(): + """All parameters receive gradients.""" + layer = _make_l3(n_embd=16, n_emb=32, d_up=64, vocab_size=8) + x = torch.randn(2, 5, 16, requires_grad=True) + token_ids = torch.randint(0, 8, (2, 5)) + out = layer(x, token_ids) + loss = out.sum() + loss.backward() + for name, p in layer.named_parameters(): + assert p.grad is not None, f"No gradient for {name}" + assert p.grad.abs().sum() > 0, f"Zero gradient for {name}" + assert x.grad is not None, "No gradient for input x" + + +def test_l3_layer_tied_kv(): + """Tied mode uses single weight matrix (kv_weight), no separate k/v.""" + layer_tied = _make_l3(tie_kv=True) + layer_untied = _make_l3(tie_kv=False) + tied_params = {n for n, _ in layer_tied.named_parameters()} + untied_params = {n for n, _ in layer_untied.named_parameters()} + assert "kv_weight" in tied_params + assert "k_weight" not in tied_params + assert "v_weight" not in tied_params + assert "k_weight" in untied_params + assert "v_weight" in untied_params + assert "kv_weight" not in untied_params + + +def test_l3_layer_masking(): + """Tokens with fewer embeddings are properly masked (no NaN/inf).""" + n_embd = 16 + # Non-uniform allocation: token 0 gets 10, token 1 gets 2 + layer = L3Layer(n_embd=n_embd, n_emb=12, d_up=64, tie_kv=True) + bounds = torch.tensor([0, 10, 12, 12]) # 3 tokens: 10, 2, 0 embeddings + # Token 2 has 0 embeddings - adjust to at least 1 + bounds = torch.tensor([0, 9, 11, 12]) # 3 tokens: 9, 2, 1 + layer.set_bounds(bounds) + x = torch.randn(1, 3, n_embd) + token_ids = torch.tensor([[0, 1, 2]]) + out = layer(x, token_ids) + assert not torch.isnan(out).any(), "Output contains NaN" + assert not torch.isinf(out).any(), "Output contains inf" + assert out.shape == (1, 3, n_embd) + + +def test_l3_layer_deterministic(): + """Same input produces same output.""" + layer = _make_l3() + x = torch.randn(2, 5, 16) + token_ids = torch.randint(0, 8, (2, 5)) + out1 = layer(x, token_ids) + out2 = layer(x, token_ids) + assert torch.allclose(out1, out2) + + +def test_l3_layer_untied_output_shape(): + """Untied mode also produces correct output shape.""" + n_embd = 16 + layer = _make_l3(n_embd=n_embd, n_emb=32, d_up=64, vocab_size=8, tie_kv=False) + x = torch.randn(2, 5, n_embd) + token_ids = torch.randint(0, 8, (2, 5)) + out = layer(x, token_ids) + assert out.shape == (2, 5, n_embd) + + +def test_l3_layer_untied_gradient_flow(): + """All parameters receive gradients in untied mode.""" + layer = _make_l3(n_embd=16, n_emb=32, d_up=64, vocab_size=8, tie_kv=False) + x = torch.randn(2, 5, 16, requires_grad=True) + token_ids = torch.randint(0, 8, (2, 5)) + out = layer(x, token_ids) + loss = out.sum() + loss.backward() + for name, p in layer.named_parameters(): + assert p.grad is not None, f"No gradient for {name}" + assert p.grad.abs().sum() > 0, f"Zero gradient for {name}" + + +# ---- GPT Integration Tests ---- + +def test_gpt_with_l3_forward(): + """Full model with L3 runs forward pass.""" + from nanochat.gpt import GPT, GPTConfig + config = GPTConfig( + sequence_len=8, vocab_size=64, n_layer=4, + n_head=2, n_kv_head=2, n_embd=64, + l3_after_layers="2", l3_n_emb=128, l3_d_up=32, l3_k_max=16, + ) + model = GPT(config) + model.init_weights() + # Set bounds for L3 layers + alloc = compute_lzw_allocation([[0, 1, 2, 3] * 4], vocab_size=64, n_emb=128, k_max=16) + bounds = allocation_to_bounds(alloc) + for l3_layer in model.l3_layers.values(): + l3_layer.set_bounds(bounds) + x = torch.randint(0, 64, (2, 8)) + y = torch.randint(0, 64, (2, 8)) + loss = model(x, y) + assert loss.ndim == 0 # scalar loss + assert not torch.isnan(loss) + + +def test_gpt_with_l3_backward(): + """loss.backward() works, L3 params get gradients.""" + from nanochat.gpt import GPT, GPTConfig + config = GPTConfig( + sequence_len=8, vocab_size=64, n_layer=4, + n_head=2, n_kv_head=2, n_embd=64, + l3_after_layers="2", l3_n_emb=128, l3_d_up=32, l3_k_max=16, + ) + model = GPT(config) + model.init_weights() + alloc = compute_lzw_allocation([[0, 1, 2, 3] * 4], vocab_size=64, n_emb=128, k_max=16) + bounds = allocation_to_bounds(alloc) + for l3_layer in model.l3_layers.values(): + l3_layer.set_bounds(bounds) + + x = torch.randint(0, 64, (2, 8)) + y = torch.randint(0, 64, (2, 8)) + + # Need two forward passes: first to propagate signal through lm_head (init zeros) + loss = model(x, y) + loss.backward() + optimizer = model.setup_optimizer() + optimizer.step() + model.zero_grad(set_to_none=True) + + # Second pass should give gradients to L3 params + loss = model(x, y) + loss.backward() + for name, p in model.l3_layers.named_parameters(): + if 'bounds' not in name: # bounds is a buffer, not a parameter + assert p.grad is not None, f"No gradient for L3 param {name}" + + +def test_gpt_with_l3_optimizer(): + """setup_optimizer includes L3 params.""" + from nanochat.gpt import GPT, GPTConfig + config = GPTConfig( + sequence_len=8, vocab_size=64, n_layer=4, + n_head=2, n_kv_head=2, n_embd=64, + l3_after_layers="2", l3_n_emb=128, l3_d_up=32, l3_k_max=16, + ) + model = GPT(config) + model.init_weights() + alloc = compute_lzw_allocation([[0, 1, 2, 3] * 4], vocab_size=64, n_emb=128, k_max=16) + bounds = allocation_to_bounds(alloc) + for l3_layer in model.l3_layers.values(): + l3_layer.set_bounds(bounds) + optimizer = model.setup_optimizer() + # Collect all optimizer param ids + opt_param_ids = set() + for group in optimizer.param_groups: + for p in group["params"]: + opt_param_ids.add(id(p)) + # Check all model params are in optimizer + for name, p in model.named_parameters(): + assert id(p) in opt_param_ids, f"Parameter {name} not in optimizer" + + +def test_gpt_without_l3_unchanged(): + """L3 disabled = identical to current behavior (no L3 layers created).""" + from nanochat.gpt import GPT, GPTConfig + config = GPTConfig( + sequence_len=8, vocab_size=64, n_layer=4, + n_head=2, n_kv_head=2, n_embd=64, + ) + model = GPT(config) + assert len(model.l3_layers) == 0 + model.init_weights() + x = torch.randint(0, 64, (2, 8)) + y = torch.randint(0, 64, (2, 8)) + loss = model(x, y) + assert loss.ndim == 0 + assert not torch.isnan(loss)