mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-07 01:40:30 +00:00
Add L3 (Large Lookup Layers) following arXiv:2601.21461v2
L3 generalizes token embeddings by placing per-token lookup tables inside the decoder stack. Unlike MoE, routing is static (determined by token ID), eliminating router training and load-balancing losses. Implementation: - nanochat/l3.py: LZW allocation algorithm and L3Layer module with vectorized gather+pad+mask forward pass, tied/untied KV support - GPT integration: L3 layers sit between decoder blocks, applied residually (x = x + l3_layer(x, token_ids)) - CLI: --l3-after-layers, --l3-n-emb, --l3-d-up, --l3-k-max flags with LZW precomputation from training data sample - 17 tests covering allocation, layer, and GPT integration Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
194c98a5b3
commit
b7629eff5d
|
|
@ -24,6 +24,7 @@ from nanochat.optim import MuonAdamW, DistMuonAdamW
|
|||
|
||||
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
||||
from nanochat.flash_attention import flash_attn
|
||||
from nanochat.l3 import L3Layer
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
|
|
@ -37,6 +38,12 @@ class GPTConfig:
|
|||
# Characters: L=long (full context), S=short (half context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
window_pattern: str = "SSSL"
|
||||
# L3 (Large Lookup Layers) config
|
||||
l3_after_layers: str = "" # comma-separated layer indices (empty = disabled)
|
||||
l3_n_emb: int = 0 # total embeddings (0 = disabled)
|
||||
l3_d_up: int = 0 # up-projection dim (0 = auto: 4 * n_embd)
|
||||
l3_k_max: int = 512 # max embeddings per token
|
||||
l3_tie_kv: bool = True # tie key and value weights
|
||||
|
||||
|
||||
def norm(x):
|
||||
|
|
@ -175,6 +182,13 @@ class GPT(nn.Module):
|
|||
head_dim = config.n_embd // config.n_head
|
||||
kv_dim = config.n_kv_head * head_dim
|
||||
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
|
||||
# L3 layers (placed between decoder blocks)
|
||||
self.l3_layer_indices = set(int(x) for x in config.l3_after_layers.split(",") if x.strip()) if config.l3_after_layers else set()
|
||||
l3_d_up = config.l3_d_up if config.l3_d_up > 0 else 4 * config.n_embd
|
||||
self.l3_layers = nn.ModuleDict({
|
||||
str(i): L3Layer(config.n_embd, config.l3_n_emb, l3_d_up, config.l3_tie_kv)
|
||||
for i in self.l3_layer_indices
|
||||
}) if self.l3_layer_indices and config.l3_n_emb > 0 else nn.ModuleDict()
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
||||
|
|
@ -229,6 +243,16 @@ class GPT(nn.Module):
|
|||
if block.attn.ve_gate is not None:
|
||||
torch.nn.init.zeros_(block.attn.ve_gate.weight)
|
||||
|
||||
# L3 layers
|
||||
for l3_layer in self.l3_layers.values():
|
||||
if l3_layer.tie_kv:
|
||||
torch.nn.init.normal_(l3_layer.kv_weight, mean=0.0, std=1.0)
|
||||
else:
|
||||
torch.nn.init.normal_(l3_layer.k_weight, mean=0.0, std=1.0)
|
||||
torch.nn.init.normal_(l3_layer.v_weight, mean=0.0, std=1.0)
|
||||
torch.nn.init.uniform_(l3_layer.w_up.weight, -s, s)
|
||||
torch.nn.init.zeros_(l3_layer.w_mix.weight) # L3 starts as no-op
|
||||
|
||||
# Rotary embeddings
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
|
|
@ -239,6 +263,12 @@ class GPT(nn.Module):
|
|||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
for ve in self.value_embeds.values():
|
||||
ve.to(dtype=torch.bfloat16)
|
||||
for l3_layer in self.l3_layers.values():
|
||||
if l3_layer.tie_kv:
|
||||
l3_layer.kv_weight.data = l3_layer.kv_weight.data.to(dtype=torch.bfloat16)
|
||||
else:
|
||||
l3_layer.k_weight.data = l3_layer.k_weight.data.to(dtype=torch.bfloat16)
|
||||
l3_layer.v_weight.data = l3_layer.v_weight.data.to(dtype=torch.bfloat16)
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
|
|
@ -304,7 +334,14 @@ class GPT(nn.Module):
|
|||
nparams = sum(p.numel() for p in self.parameters())
|
||||
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
# L3 kv/k/v weights are embeddings (lookup tables), not matmul weights
|
||||
l3_embed_numel = 0
|
||||
for l3_layer in self.l3_layers.values():
|
||||
if l3_layer.tie_kv:
|
||||
l3_embed_numel += l3_layer.kv_weight.numel()
|
||||
else:
|
||||
l3_embed_numel += l3_layer.k_weight.numel() + l3_layer.v_weight.numel()
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + l3_embed_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
|
|
@ -313,7 +350,15 @@ class GPT(nn.Module):
|
|||
window = window_size[0] # (left, right) tuple, we use left
|
||||
effective_seq = t if window < 0 else min(window, t)
|
||||
attn_flops += 12 * h * q * effective_seq
|
||||
num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
|
||||
# L3 FLOPs: per layer, approx 2*avg_k*n_embd (attention) + 2*n_embd*d_up (up) + 2*(d_up+n_embd)*n_embd (mix)
|
||||
l3_flops = 0
|
||||
if self.l3_layers:
|
||||
n_embd = self.config.n_embd
|
||||
l3_d_up = self.config.l3_d_up if self.config.l3_d_up > 0 else 4 * n_embd
|
||||
avg_k = self.config.l3_n_emb / self.config.vocab_size if self.config.vocab_size > 0 else 1
|
||||
per_l3 = 2 * avg_k * n_embd + 2 * n_embd * l3_d_up + 2 * (l3_d_up + n_embd) * n_embd
|
||||
l3_flops = int(per_l3 * len(self.l3_layers))
|
||||
num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops + l3_flops
|
||||
return num_flops_per_token
|
||||
|
||||
def num_scaling_params(self):
|
||||
|
|
@ -334,13 +379,24 @@ class GPT(nn.Module):
|
|||
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
||||
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
|
||||
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
||||
# L3: separate embedding-like params from matrix params
|
||||
l3_embeds = 0
|
||||
l3_matrices = 0
|
||||
for l3_layer in self.l3_layers.values():
|
||||
if l3_layer.tie_kv:
|
||||
l3_embeds += l3_layer.kv_weight.numel()
|
||||
else:
|
||||
l3_embeds += l3_layer.k_weight.numel() + l3_layer.v_weight.numel()
|
||||
l3_matrices += l3_layer.w_up.weight.numel() + l3_layer.w_mix.weight.numel()
|
||||
total = wte + value_embeds + lm_head + transformer_matrices + scalars + l3_embeds + l3_matrices
|
||||
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
||||
return {
|
||||
'wte': wte,
|
||||
'value_embeds': value_embeds,
|
||||
'lm_head': lm_head,
|
||||
'transformer_matrices': transformer_matrices,
|
||||
'l3_embeds': l3_embeds,
|
||||
'l3_matrices': l3_matrices,
|
||||
'scalars': scalars,
|
||||
'total': total,
|
||||
}
|
||||
|
|
@ -356,7 +412,18 @@ class GPT(nn.Module):
|
|||
lm_head_params = list(self.lm_head.parameters())
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
||||
# L3 params: embedding-like (kv/k/v weights) and matrix (w_up, w_mix)
|
||||
l3_embed_params = []
|
||||
l3_matrix_params = []
|
||||
for l3_layer in self.l3_layers.values():
|
||||
if l3_layer.tie_kv:
|
||||
l3_embed_params.append(l3_layer.kv_weight)
|
||||
else:
|
||||
l3_embed_params.append(l3_layer.k_weight)
|
||||
l3_embed_params.append(l3_layer.v_weight)
|
||||
l3_matrix_params.append(l3_layer.w_up.weight)
|
||||
l3_matrix_params.append(l3_layer.w_mix.weight)
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(l3_embed_params) + len(l3_matrix_params)
|
||||
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
|
|
@ -371,9 +438,13 @@ class GPT(nn.Module):
|
|||
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
||||
]
|
||||
# L3 param groups: embeddings use AdamW (like token embeddings), matrices use Muon
|
||||
if l3_embed_params:
|
||||
param_groups.append(dict(kind='adamw', params=l3_embed_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0))
|
||||
# Muon groups (matrix params, grouped by shape for stacking)
|
||||
for shape in sorted({p.shape for p in matrix_params}):
|
||||
group_params = [p for p in matrix_params if p.shape == shape]
|
||||
all_matrix_params = matrix_params + l3_matrix_params
|
||||
for shape in sorted({p.shape for p in all_matrix_params}):
|
||||
group_params = [p for p in all_matrix_params if p.shape == shape]
|
||||
param_groups.append(dict(
|
||||
kind='muon', params=group_params, lr=matrix_lr,
|
||||
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
|
||||
|
|
@ -404,6 +475,9 @@ class GPT(nn.Module):
|
|||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
# L3 layer after this block (if configured)
|
||||
if str(i) in self.l3_layers:
|
||||
x = x + self.l3_layers[str(i)](x, idx)
|
||||
x = norm(x)
|
||||
|
||||
# Forward the lm_head (compute logits)
|
||||
|
|
|
|||
201
nanochat/l3.py
Normal file
201
nanochat/l3.py
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
"""
|
||||
L3: Large Lookup Layers
|
||||
Ref: arXiv:2601.21461v2
|
||||
|
||||
L3 generalizes token embeddings by placing per-token lookup tables inside
|
||||
the decoder stack. Unlike MoE, routing is static (determined by token ID),
|
||||
eliminating router training and load-balancing losses.
|
||||
"""
|
||||
|
||||
from collections import Counter
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def compute_lzw_allocation(token_sequences, vocab_size, n_emb, k_max):
|
||||
"""
|
||||
Compute per-token embedding allocation using LZW-style frequency analysis.
|
||||
|
||||
Scans token sequences LZW-style, counting n-gram frequencies and allocating
|
||||
embeddings to the last token of frequent n-grams. Every token starts with 1
|
||||
embedding, then we greedily add embeddings to tokens involved in the most
|
||||
frequent n-grams until we reach n_emb total.
|
||||
|
||||
Args:
|
||||
token_sequences: list of token ID lists (training data sample)
|
||||
vocab_size: size of vocabulary
|
||||
n_emb: target total embeddings
|
||||
k_max: max embeddings per token
|
||||
Returns:
|
||||
alloc: list[int] of length vocab_size (embeddings per token)
|
||||
"""
|
||||
assert n_emb >= vocab_size, f"n_emb ({n_emb}) must be >= vocab_size ({vocab_size})"
|
||||
|
||||
# Count token frequencies across all sequences
|
||||
token_freq = Counter()
|
||||
for seq in token_sequences:
|
||||
for tok in seq:
|
||||
token_freq[tok] += 1
|
||||
|
||||
# Also count bigram frequencies (LZW-style: last token of n-gram gets credit)
|
||||
bigram_freq = Counter()
|
||||
for seq in token_sequences:
|
||||
for i in range(len(seq) - 1):
|
||||
bigram_freq[seq[i + 1]] += 1 # credit goes to last token
|
||||
|
||||
# Combine unigram and bigram frequencies
|
||||
combined_freq = Counter()
|
||||
for tok in range(vocab_size):
|
||||
combined_freq[tok] = token_freq.get(tok, 0) + bigram_freq.get(tok, 0)
|
||||
|
||||
# Start with 1 embedding per token
|
||||
alloc = [1] * vocab_size
|
||||
remaining = n_emb - vocab_size
|
||||
|
||||
if remaining <= 0:
|
||||
return alloc
|
||||
|
||||
# Sort tokens by frequency (descending) for greedy allocation
|
||||
sorted_tokens = sorted(range(vocab_size), key=lambda t: combined_freq[t], reverse=True)
|
||||
|
||||
# Greedily add embeddings to the most frequent tokens
|
||||
while remaining > 0:
|
||||
added_any = False
|
||||
for tok in sorted_tokens:
|
||||
if remaining <= 0:
|
||||
break
|
||||
if alloc[tok] < k_max:
|
||||
alloc[tok] += 1
|
||||
remaining -= 1
|
||||
added_any = True
|
||||
if not added_any:
|
||||
break # all tokens at k_max, can't allocate more
|
||||
|
||||
return alloc
|
||||
|
||||
|
||||
def allocation_to_bounds(alloc):
|
||||
"""
|
||||
Convert allocation array to cumulative bounds tensor.
|
||||
|
||||
bounds[0] = 0, bounds[i] = bounds[i-1] + alloc[i-1]
|
||||
bounds[-1] = sum(alloc) = n_emb
|
||||
|
||||
Args:
|
||||
alloc: list[int] of per-token allocation counts
|
||||
Returns:
|
||||
bounds: torch.LongTensor of shape [len(alloc) + 1]
|
||||
"""
|
||||
bounds = [0]
|
||||
for a in alloc:
|
||||
bounds.append(bounds[-1] + a)
|
||||
return torch.tensor(bounds, dtype=torch.long)
|
||||
|
||||
|
||||
class L3Layer(nn.Module):
|
||||
"""
|
||||
L3 layer: per-token lookup table with attention-like aggregation.
|
||||
|
||||
Forward pass (vectorized gather+pad approach):
|
||||
1. Look up bounds for each token, gather KV embeddings, pad to k_max
|
||||
2. Compute scores = K @ x_norm, mask invalid positions, softmax
|
||||
3. Aggregate: weighted sum of V embeddings
|
||||
4. Up-project, RMSNorm, concat with x, mix-project
|
||||
Returns the delta (added residually by caller).
|
||||
"""
|
||||
|
||||
def __init__(self, n_embd, n_emb, d_up, tie_kv=True):
|
||||
super().__init__()
|
||||
self.n_embd = n_embd
|
||||
self.n_emb = n_emb
|
||||
self.d_up = d_up
|
||||
self.tie_kv = tie_kv
|
||||
|
||||
if tie_kv:
|
||||
# Single shared weight for both keys and values
|
||||
self.kv_weight = nn.Parameter(torch.empty(n_emb, n_embd))
|
||||
else:
|
||||
# Separate key and value weights
|
||||
self.k_weight = nn.Parameter(torch.empty(n_emb, n_embd))
|
||||
self.v_weight = nn.Parameter(torch.empty(n_emb, n_embd))
|
||||
|
||||
# Up-project from d_emb (= n_embd when tied) to d_up
|
||||
self.w_up = nn.Linear(n_embd, d_up, bias=False)
|
||||
# Mix-project: concat(up_projected, x) -> n_embd
|
||||
self.w_mix = nn.Linear(d_up + n_embd, n_embd, bias=False)
|
||||
|
||||
# Bounds buffer (set after LZW allocation)
|
||||
self.register_buffer("bounds", torch.zeros(1, dtype=torch.long), persistent=True)
|
||||
|
||||
def set_bounds(self, bounds):
|
||||
"""Register the precomputed bounds tensor as a buffer."""
|
||||
self.bounds = bounds
|
||||
|
||||
def forward(self, x, token_ids):
|
||||
"""
|
||||
Args:
|
||||
x: [B, T, n_embd] hidden states
|
||||
token_ids: [B, T] token IDs
|
||||
Returns:
|
||||
delta: [B, T, n_embd] to be added residually by caller
|
||||
"""
|
||||
B, T, C = x.shape
|
||||
device = x.device
|
||||
|
||||
# 1. RMSNorm the input
|
||||
x_norm = F.rms_norm(x, (C,))
|
||||
|
||||
# 2. Gather KV embeddings using bounds + token_ids
|
||||
# Look up bounds for each token
|
||||
flat_ids = token_ids.reshape(-1) # [B*T]
|
||||
starts = self.bounds[flat_ids] # [B*T]
|
||||
ends = self.bounds[flat_ids + 1] # [B*T]
|
||||
lengths = ends - starts # [B*T]
|
||||
k_max = lengths.max().item()
|
||||
|
||||
if k_max == 0:
|
||||
# Edge case: no embeddings for any token
|
||||
return torch.zeros_like(x)
|
||||
|
||||
# Build index tensor [B*T, k_max] with valid indices and padding
|
||||
offsets = torch.arange(k_max, device=device).unsqueeze(0) # [1, k_max]
|
||||
indices = starts.unsqueeze(1) + offsets # [B*T, k_max]
|
||||
mask = offsets < lengths.unsqueeze(1) # [B*T, k_max] True for valid
|
||||
|
||||
# Clamp indices to valid range for gathering (masked ones will be zeroed out)
|
||||
indices = indices.clamp(0, self.n_emb - 1)
|
||||
|
||||
# 3. Gather weights
|
||||
if self.tie_kv:
|
||||
kv = self.kv_weight[indices] # [B*T, k_max, n_embd]
|
||||
k_emb = kv
|
||||
v_emb = kv
|
||||
else:
|
||||
k_emb = self.k_weight[indices] # [B*T, k_max, n_embd]
|
||||
v_emb = self.v_weight[indices] # [B*T, k_max, n_embd]
|
||||
|
||||
# 4. Compute attention scores: K @ x_norm
|
||||
x_flat = x_norm.reshape(B * T, C) # [B*T, C]
|
||||
scores = torch.bmm(k_emb, x_flat.unsqueeze(2)).squeeze(2) # [B*T, k_max]
|
||||
|
||||
# Mask invalid positions
|
||||
scores = scores.masked_fill(~mask, float('-inf'))
|
||||
|
||||
# Softmax over valid positions
|
||||
weights = F.softmax(scores, dim=-1) # [B*T, k_max]
|
||||
# Replace NaN from all-inf rows (shouldn't happen since min alloc=1, but safety)
|
||||
weights = weights.masked_fill(~mask, 0.0)
|
||||
|
||||
# 5. Aggregate: weighted sum of V embeddings
|
||||
agg = torch.bmm(weights.unsqueeze(1), v_emb).squeeze(1) # [B*T, n_embd]
|
||||
agg = agg.view(B, T, C)
|
||||
|
||||
# 6. Up-project, RMSNorm, concat with x, mix-project
|
||||
up = self.w_up(agg) # [B, T, d_up]
|
||||
up = F.rms_norm(up, (self.d_up,)) # normalize
|
||||
cat = torch.cat([up, x], dim=-1) # [B, T, d_up + n_embd]
|
||||
delta = self.w_mix(cat) # [B, T, n_embd]
|
||||
|
||||
return delta
|
||||
|
|
@ -51,6 +51,11 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de
|
|||
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
||||
# L3 (Large Lookup Layers)
|
||||
parser.add_argument("--l3-after-layers", type=str, default="", help="comma-separated layer indices for L3 (empty = disabled)")
|
||||
parser.add_argument("--l3-n-emb", type=int, default=0, help="total L3 embeddings (0 = auto-derive from model size)")
|
||||
parser.add_argument("--l3-d-up", type=int, default=0, help="L3 up-projection dim (0 = 4*n_embd)")
|
||||
parser.add_argument("--l3-k-max", type=int, default=512, help="max embeddings per token for L3")
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
|
|
@ -122,7 +127,7 @@ print0(f"Vocab size: {vocab_size:,}")
|
|||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Model
|
||||
|
||||
def build_model_meta(depth):
|
||||
def build_model_meta(depth, l3_after_layers="", l3_n_emb=0):
|
||||
"""Build a model on meta device for a given depth (shapes/dtypes only, no data)."""
|
||||
# Model dim is nudged up to nearest multiple of head_dim for clean division
|
||||
# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
|
||||
|
|
@ -133,19 +138,53 @@ def build_model_meta(depth):
|
|||
sequence_len=args.max_seq_len, vocab_size=vocab_size,
|
||||
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
|
||||
window_pattern=args.window_pattern,
|
||||
l3_after_layers=l3_after_layers,
|
||||
l3_n_emb=l3_n_emb,
|
||||
l3_d_up=args.l3_d_up,
|
||||
l3_k_max=args.l3_k_max,
|
||||
)
|
||||
with torch.device("meta"):
|
||||
model_meta = GPT(config)
|
||||
return model_meta
|
||||
|
||||
# L3 precomputation: compute LZW allocation from training data sample
|
||||
l3_n_emb = args.l3_n_emb
|
||||
l3_bounds = None
|
||||
if args.l3_after_layers:
|
||||
from nanochat.l3 import compute_lzw_allocation, allocation_to_bounds
|
||||
# Auto-derive n_emb if not specified: scale proportional to model size
|
||||
if l3_n_emb == 0:
|
||||
# Build a temporary model to get param count for auto-derivation
|
||||
tmp_model = build_model_meta(args.depth, l3_after_layers=args.l3_after_layers, l3_n_emb=vocab_size)
|
||||
tmp_params = sum(p.numel() for p in tmp_model.parameters())
|
||||
l3_n_emb = max(vocab_size, tmp_params // 1000)
|
||||
del tmp_model
|
||||
print0(f"Auto-derived L3 n_emb: {l3_n_emb:,}")
|
||||
# Read a sample of training data for LZW allocation
|
||||
sample_loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, 1, args.max_seq_len, split="train", device=device)
|
||||
sample_sequences = []
|
||||
for _ in range(100): # 100 batches should be enough
|
||||
x_sample, _ = next(sample_loader)
|
||||
sample_sequences.append(x_sample[0].tolist())
|
||||
del sample_loader
|
||||
l3_alloc = compute_lzw_allocation(sample_sequences, vocab_size, l3_n_emb, args.l3_k_max)
|
||||
l3_bounds = allocation_to_bounds(l3_alloc).to(device)
|
||||
print0(f"L3 allocation: {l3_n_emb:,} total embeddings, k_max={args.l3_k_max}, avg={l3_n_emb/vocab_size:.1f}/token")
|
||||
|
||||
# Build the model, move to device, init the weights
|
||||
model = build_model_meta(args.depth) # 1) Build on meta device (only shapes/dtypes, no data)
|
||||
model = build_model_meta(args.depth, l3_after_layers=args.l3_after_layers, l3_n_emb=l3_n_emb) # 1) Build on meta device (only shapes/dtypes, no data)
|
||||
model_config = model.config
|
||||
model_config_kwargs = asdict(model_config)
|
||||
print0(f"Model config:\n{json.dumps(model_config_kwargs, indent=2)}")
|
||||
model.to_empty(device=device) # 2) All tensors get storage on target device but with uninitialized (garbage) data
|
||||
model.init_weights() # 3) All tensors get initialized
|
||||
|
||||
# Set L3 bounds after model creation
|
||||
if l3_bounds is not None:
|
||||
for l3_layer in model.l3_layers.values():
|
||||
l3_layer.set_bounds(l3_bounds)
|
||||
print0(f"L3 bounds set for {len(model.l3_layers)} layer(s)")
|
||||
|
||||
# If we are resuming, overwrite the model parameters with those of the checkpoint
|
||||
base_dir = get_base_dir()
|
||||
output_dirname = args.model_tag if args.model_tag else f"d{args.depth}" # e.g. d12
|
||||
|
|
|
|||
263
tests/test_l3.py
Normal file
263
tests/test_l3.py
Normal file
|
|
@ -0,0 +1,263 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nanochat.l3 import compute_lzw_allocation, allocation_to_bounds, L3Layer
|
||||
|
||||
|
||||
# ---- LZW Allocation Tests ----
|
||||
|
||||
def test_lzw_allocation_total():
|
||||
"""Allocation sums to target n_emb."""
|
||||
sequences = [[0, 1, 2, 3, 0, 1, 2, 3, 0, 1]] * 10
|
||||
alloc = compute_lzw_allocation(sequences, vocab_size=8, n_emb=100, k_max=32)
|
||||
assert sum(alloc) == 100
|
||||
|
||||
|
||||
def test_lzw_allocation_min_one():
|
||||
"""Every token gets >= 1 embedding, even unseen tokens."""
|
||||
sequences = [[0, 1, 0, 1, 0, 1]] # only tokens 0 and 1 appear
|
||||
alloc = compute_lzw_allocation(sequences, vocab_size=8, n_emb=100, k_max=32)
|
||||
assert all(a >= 1 for a in alloc)
|
||||
assert len(alloc) == 8
|
||||
|
||||
|
||||
def test_lzw_allocation_max_k():
|
||||
"""No token exceeds k_max embeddings."""
|
||||
sequences = [[0, 0, 0, 0, 0, 0, 0, 0]] * 100 # token 0 is extremely frequent
|
||||
alloc = compute_lzw_allocation(sequences, vocab_size=4, n_emb=50, k_max=8)
|
||||
assert all(a <= 8 for a in alloc)
|
||||
|
||||
|
||||
def test_lzw_allocation_distribution():
|
||||
"""Frequent tokens get more embeddings than rare ones."""
|
||||
# Token 0 appears 100x, token 1 appears 10x, tokens 2-7 appear once each
|
||||
sequences = [[0] * 100 + [1] * 10 + [2, 3, 4, 5, 6, 7]]
|
||||
alloc = compute_lzw_allocation(sequences, vocab_size=8, n_emb=64, k_max=32)
|
||||
assert alloc[0] >= alloc[1], f"Token 0 (freq 100) should get >= Token 1 (freq 10): {alloc[0]} vs {alloc[1]}"
|
||||
assert alloc[1] >= alloc[7], f"Token 1 (freq 10) should get >= Token 7 (freq 1): {alloc[1]} vs {alloc[7]}"
|
||||
|
||||
|
||||
def test_lzw_bounds_from_allocation():
|
||||
"""Bounds array is correct cumulative sum."""
|
||||
alloc = [3, 1, 5, 2]
|
||||
bounds = allocation_to_bounds(alloc)
|
||||
assert bounds.tolist() == [0, 3, 4, 9, 11]
|
||||
assert bounds[-1].item() == sum(alloc)
|
||||
|
||||
|
||||
def test_lzw_allocation_edge_n_emb_equals_vocab():
|
||||
"""When n_emb == vocab_size, every token gets exactly 1."""
|
||||
sequences = [[0, 1, 2, 3]]
|
||||
alloc = compute_lzw_allocation(sequences, vocab_size=4, n_emb=4, k_max=32)
|
||||
assert alloc == [1, 1, 1, 1]
|
||||
|
||||
|
||||
# ---- L3Layer Tests ----
|
||||
|
||||
def _make_l3(n_embd=16, n_emb=32, d_up=64, vocab_size=8, tie_kv=True):
|
||||
"""Helper to create an L3Layer with bounds set and properly initialized."""
|
||||
torch.manual_seed(42)
|
||||
layer = L3Layer(n_embd=n_embd, n_emb=n_emb, d_up=d_up, tie_kv=tie_kv)
|
||||
# Initialize weights to avoid garbage values from torch.empty()
|
||||
for name, p in layer.named_parameters():
|
||||
if p.dim() >= 2:
|
||||
torch.nn.init.normal_(p, std=0.1)
|
||||
else:
|
||||
torch.nn.init.zeros_(p)
|
||||
# Simple uniform allocation: each token gets n_emb // vocab_size embeddings
|
||||
per_token = n_emb // vocab_size
|
||||
alloc = [per_token] * vocab_size
|
||||
# Distribute remainder
|
||||
remainder = n_emb - sum(alloc)
|
||||
for i in range(remainder):
|
||||
alloc[i] += 1
|
||||
bounds = allocation_to_bounds(alloc)
|
||||
layer.set_bounds(bounds)
|
||||
return layer
|
||||
|
||||
|
||||
def test_l3_layer_output_shape():
|
||||
"""Output is [B, T, n_embd]."""
|
||||
n_embd = 16
|
||||
layer = _make_l3(n_embd=n_embd, n_emb=32, d_up=64, vocab_size=8)
|
||||
x = torch.randn(2, 5, n_embd)
|
||||
token_ids = torch.randint(0, 8, (2, 5))
|
||||
out = layer(x, token_ids)
|
||||
assert out.shape == (2, 5, n_embd)
|
||||
|
||||
|
||||
def test_l3_layer_gradient_flow():
|
||||
"""All parameters receive gradients."""
|
||||
layer = _make_l3(n_embd=16, n_emb=32, d_up=64, vocab_size=8)
|
||||
x = torch.randn(2, 5, 16, requires_grad=True)
|
||||
token_ids = torch.randint(0, 8, (2, 5))
|
||||
out = layer(x, token_ids)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
for name, p in layer.named_parameters():
|
||||
assert p.grad is not None, f"No gradient for {name}"
|
||||
assert p.grad.abs().sum() > 0, f"Zero gradient for {name}"
|
||||
assert x.grad is not None, "No gradient for input x"
|
||||
|
||||
|
||||
def test_l3_layer_tied_kv():
|
||||
"""Tied mode uses single weight matrix (kv_weight), no separate k/v."""
|
||||
layer_tied = _make_l3(tie_kv=True)
|
||||
layer_untied = _make_l3(tie_kv=False)
|
||||
tied_params = {n for n, _ in layer_tied.named_parameters()}
|
||||
untied_params = {n for n, _ in layer_untied.named_parameters()}
|
||||
assert "kv_weight" in tied_params
|
||||
assert "k_weight" not in tied_params
|
||||
assert "v_weight" not in tied_params
|
||||
assert "k_weight" in untied_params
|
||||
assert "v_weight" in untied_params
|
||||
assert "kv_weight" not in untied_params
|
||||
|
||||
|
||||
def test_l3_layer_masking():
|
||||
"""Tokens with fewer embeddings are properly masked (no NaN/inf)."""
|
||||
n_embd = 16
|
||||
# Non-uniform allocation: token 0 gets 10, token 1 gets 2
|
||||
layer = L3Layer(n_embd=n_embd, n_emb=12, d_up=64, tie_kv=True)
|
||||
bounds = torch.tensor([0, 10, 12, 12]) # 3 tokens: 10, 2, 0 embeddings
|
||||
# Token 2 has 0 embeddings - adjust to at least 1
|
||||
bounds = torch.tensor([0, 9, 11, 12]) # 3 tokens: 9, 2, 1
|
||||
layer.set_bounds(bounds)
|
||||
x = torch.randn(1, 3, n_embd)
|
||||
token_ids = torch.tensor([[0, 1, 2]])
|
||||
out = layer(x, token_ids)
|
||||
assert not torch.isnan(out).any(), "Output contains NaN"
|
||||
assert not torch.isinf(out).any(), "Output contains inf"
|
||||
assert out.shape == (1, 3, n_embd)
|
||||
|
||||
|
||||
def test_l3_layer_deterministic():
|
||||
"""Same input produces same output."""
|
||||
layer = _make_l3()
|
||||
x = torch.randn(2, 5, 16)
|
||||
token_ids = torch.randint(0, 8, (2, 5))
|
||||
out1 = layer(x, token_ids)
|
||||
out2 = layer(x, token_ids)
|
||||
assert torch.allclose(out1, out2)
|
||||
|
||||
|
||||
def test_l3_layer_untied_output_shape():
|
||||
"""Untied mode also produces correct output shape."""
|
||||
n_embd = 16
|
||||
layer = _make_l3(n_embd=n_embd, n_emb=32, d_up=64, vocab_size=8, tie_kv=False)
|
||||
x = torch.randn(2, 5, n_embd)
|
||||
token_ids = torch.randint(0, 8, (2, 5))
|
||||
out = layer(x, token_ids)
|
||||
assert out.shape == (2, 5, n_embd)
|
||||
|
||||
|
||||
def test_l3_layer_untied_gradient_flow():
|
||||
"""All parameters receive gradients in untied mode."""
|
||||
layer = _make_l3(n_embd=16, n_emb=32, d_up=64, vocab_size=8, tie_kv=False)
|
||||
x = torch.randn(2, 5, 16, requires_grad=True)
|
||||
token_ids = torch.randint(0, 8, (2, 5))
|
||||
out = layer(x, token_ids)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
for name, p in layer.named_parameters():
|
||||
assert p.grad is not None, f"No gradient for {name}"
|
||||
assert p.grad.abs().sum() > 0, f"Zero gradient for {name}"
|
||||
|
||||
|
||||
# ---- GPT Integration Tests ----
|
||||
|
||||
def test_gpt_with_l3_forward():
|
||||
"""Full model with L3 runs forward pass."""
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
config = GPTConfig(
|
||||
sequence_len=8, vocab_size=64, n_layer=4,
|
||||
n_head=2, n_kv_head=2, n_embd=64,
|
||||
l3_after_layers="2", l3_n_emb=128, l3_d_up=32, l3_k_max=16,
|
||||
)
|
||||
model = GPT(config)
|
||||
model.init_weights()
|
||||
# Set bounds for L3 layers
|
||||
alloc = compute_lzw_allocation([[0, 1, 2, 3] * 4], vocab_size=64, n_emb=128, k_max=16)
|
||||
bounds = allocation_to_bounds(alloc)
|
||||
for l3_layer in model.l3_layers.values():
|
||||
l3_layer.set_bounds(bounds)
|
||||
x = torch.randint(0, 64, (2, 8))
|
||||
y = torch.randint(0, 64, (2, 8))
|
||||
loss = model(x, y)
|
||||
assert loss.ndim == 0 # scalar loss
|
||||
assert not torch.isnan(loss)
|
||||
|
||||
|
||||
def test_gpt_with_l3_backward():
|
||||
"""loss.backward() works, L3 params get gradients."""
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
config = GPTConfig(
|
||||
sequence_len=8, vocab_size=64, n_layer=4,
|
||||
n_head=2, n_kv_head=2, n_embd=64,
|
||||
l3_after_layers="2", l3_n_emb=128, l3_d_up=32, l3_k_max=16,
|
||||
)
|
||||
model = GPT(config)
|
||||
model.init_weights()
|
||||
alloc = compute_lzw_allocation([[0, 1, 2, 3] * 4], vocab_size=64, n_emb=128, k_max=16)
|
||||
bounds = allocation_to_bounds(alloc)
|
||||
for l3_layer in model.l3_layers.values():
|
||||
l3_layer.set_bounds(bounds)
|
||||
|
||||
x = torch.randint(0, 64, (2, 8))
|
||||
y = torch.randint(0, 64, (2, 8))
|
||||
|
||||
# Need two forward passes: first to propagate signal through lm_head (init zeros)
|
||||
loss = model(x, y)
|
||||
loss.backward()
|
||||
optimizer = model.setup_optimizer()
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
||||
# Second pass should give gradients to L3 params
|
||||
loss = model(x, y)
|
||||
loss.backward()
|
||||
for name, p in model.l3_layers.named_parameters():
|
||||
if 'bounds' not in name: # bounds is a buffer, not a parameter
|
||||
assert p.grad is not None, f"No gradient for L3 param {name}"
|
||||
|
||||
|
||||
def test_gpt_with_l3_optimizer():
|
||||
"""setup_optimizer includes L3 params."""
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
config = GPTConfig(
|
||||
sequence_len=8, vocab_size=64, n_layer=4,
|
||||
n_head=2, n_kv_head=2, n_embd=64,
|
||||
l3_after_layers="2", l3_n_emb=128, l3_d_up=32, l3_k_max=16,
|
||||
)
|
||||
model = GPT(config)
|
||||
model.init_weights()
|
||||
alloc = compute_lzw_allocation([[0, 1, 2, 3] * 4], vocab_size=64, n_emb=128, k_max=16)
|
||||
bounds = allocation_to_bounds(alloc)
|
||||
for l3_layer in model.l3_layers.values():
|
||||
l3_layer.set_bounds(bounds)
|
||||
optimizer = model.setup_optimizer()
|
||||
# Collect all optimizer param ids
|
||||
opt_param_ids = set()
|
||||
for group in optimizer.param_groups:
|
||||
for p in group["params"]:
|
||||
opt_param_ids.add(id(p))
|
||||
# Check all model params are in optimizer
|
||||
for name, p in model.named_parameters():
|
||||
assert id(p) in opt_param_ids, f"Parameter {name} not in optimizer"
|
||||
|
||||
|
||||
def test_gpt_without_l3_unchanged():
|
||||
"""L3 disabled = identical to current behavior (no L3 layers created)."""
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
config = GPTConfig(
|
||||
sequence_len=8, vocab_size=64, n_layer=4,
|
||||
n_head=2, n_kv_head=2, n_embd=64,
|
||||
)
|
||||
model = GPT(config)
|
||||
assert len(model.l3_layers) == 0
|
||||
model.init_weights()
|
||||
x = torch.randint(0, 64, (2, 8))
|
||||
y = torch.randint(0, 64, (2, 8))
|
||||
loss = model(x, y)
|
||||
assert loss.ndim == 0
|
||||
assert not torch.isnan(loss)
|
||||
Loading…
Reference in New Issue
Block a user