diff --git a/nanochat/engine.py b/nanochat/engine.py index aa2e6a98..b6884614 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -102,11 +102,14 @@ class KVCache: self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device) # Previous token's normalized embedding for smear (set by model forward pass) self.prev_embedding = None + # Engram N-gram history: (B, max_ngram-1) tensor of recent token IDs + self.ngram_history = None def reset(self): """Reset cache to empty state.""" self.cache_seqlens.zero_() self.prev_embedding = None + self.ngram_history = None def get_pos(self): """Get current position (assumes all batch elements at same position).""" @@ -135,6 +138,9 @@ class KVCache: # Copy smear state: expand batch=1 prev_embedding to num_samples if other.prev_embedding is not None: self.prev_embedding = other.prev_embedding.expand(self.batch_size, -1, -1).clone() + # Copy Engram N-gram history: expand batch=1 to num_samples + if other.ngram_history is not None: + self.ngram_history = other.ngram_history.expand(self.batch_size, -1).clone() # ----------------------------------------------------------------------------- @torch.inference_mode() diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 07a1eae8..2c7105df 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -37,6 +37,14 @@ class GPTConfig: # Characters: L=long (full context), S=short (quarter context) # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long window_pattern: str = "SSSL" + # Engram: conditional memory via N-gram hash lookup (DeepSeek Engram paper) + engram_enabled: bool = False + engram_ngram_size: int = 3 # max N-gram order (uses 2-gram .. N-gram) + engram_n_heads: int = 4 # hash heads per N-gram order + engram_embed_dim: int = 128 # embedding dim per hash head + engram_table_size: int = 0 # hash table size (0=auto: ~vocab_size*3) + engram_layer_ids: tuple = () # layers to place Engram (empty=auto: (1, n_layer//2)) + engram_kernel_size: int = 4 # depthwise causal conv kernel size def norm(x): @@ -140,17 +148,159 @@ class MLP(nn.Module): class Block(nn.Module): - def __init__(self, config, layer_idx): + def __init__(self, config, layer_idx, engram=None): super().__init__() self.attn = CausalSelfAttention(config, layer_idx) self.mlp = MLP(config) + self.engram = engram - def forward(self, x, ve, cos_sin, window_size, kv_cache): + def forward(self, x, input_ids, ve, cos_sin, window_size, kv_cache): + if self.engram is not None: + x = x + self.engram(x, input_ids, kv_cache) x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache) x = x + self.mlp(norm(x)) return x +def _next_prime(n): + """Find the smallest prime >= n.""" + if n <= 2: + return 2 + if n % 2 == 0: + n += 1 + while True: + if all(n % d for d in range(3, int(n**0.5) + 1, 2)): + return n + n += 2 + + +class EngramModule(nn.Module): + """Conditional memory via N-gram hash lookup (DeepSeek Engram). + + For each N-gram order (2..max_ngram) and each hash head (1..n_heads), + computes a multiplicative hash over the last n tokens to index into + prime-sized embedding tables. Retrieved embeddings are projected to + key/value, gated by hidden state similarity, and refined via depthwise + causal convolution. + + Known deviations from paper: + - Token compression (Section 2.2, NFKC/lowercase normalization) is omitted. + nanochat's 32K BPE vocab is already compact; the hash tables are prime-sized + and larger than the vocab, so collision rates remain low. + - Hash function uses polynomial rolling hash over the exact n-gram window + rather than the demo's multiplicative-XOR accumulation. + + Paper: "Conditional Memory via Scalable Lookup: A New Axis of Sparsity for LLMs" + """ + + def __init__(self, config, layer_idx, padded_vocab_size): + super().__init__() + self.layer_idx = layer_idx + self.max_ngram = config.engram_ngram_size + self.vocab_size = padded_vocab_size + self.n_heads = config.engram_n_heads + self.embed_dim = config.engram_embed_dim + self.n_embd = config.n_embd + n_gram_orders = self.max_ngram - 1 # e.g. max_ngram=3 → orders 2,3 + self.d_mem = n_gram_orders * self.n_heads * self.embed_dim + + base_table_size = config.engram_table_size if config.engram_table_size > 0 else padded_vocab_size * 3 + self.embed_tables = nn.ModuleDict() + self.hash_seeds = nn.ParameterDict() # buffers stored as 0-dim params for checkpointing + self.table_sizes = {} + for n in range(2, self.max_ngram + 1): + for k in range(self.n_heads): + prime_size = _next_prime(base_table_size + n * k * 997) + key = f"{n}_{k}" + self.embed_tables[key] = nn.Embedding(prime_size, self.embed_dim) + seed = n * 2654435761 + k * 40503 # multiplicative hash constants + self.hash_seeds[key] = nn.Parameter(torch.tensor(seed, dtype=torch.long), requires_grad=False) + self.table_sizes[key] = prime_size + + self.key_proj = Linear(self.d_mem, config.n_embd, bias=False) + self.value_proj = Linear(self.d_mem, config.n_embd, bias=False) + dilation = config.engram_ngram_size # paper Section 2.3: dilation = max N-gram order + self.short_conv = nn.Conv1d( + config.n_embd, config.n_embd, + kernel_size=config.engram_kernel_size, + padding=dilation * (config.engram_kernel_size - 1), + dilation=dilation, + groups=config.n_embd, bias=False, + ) + + def _ngram_hash(self, input_ids, n, seed, table_size): + """Vectorized multiplicative hash over the last n tokens only. + + For order n, the hash at position t is: + h(t) = (ids[t] * seed^1 + ids[t-1] * seed^2 + ... + ids[t-n+1] * seed^n) % table_size + Positions with fewer than n preceding tokens use a zero-padded window. + """ + B, T = input_ids.shape + ids = input_ids.long() + # Build a (B, T, n) window tensor: window[b, t, j] = ids[b, t - (n-1-j)] + # Use pad+slice for vectorization (no Python loop) + padded = F.pad(ids, (n - 1, 0), value=0) # (B, T + n - 1) + # Extract windows of size n ending at each position + # window[b, t, :] = padded[b, t : t+n] + windows = padded.unfold(dimension=1, size=n, step=1) # (B, T, n) + # Compute polynomial hash: h = sum(window[j] * seed^(n-j)) mod table_size + # Use modular exponentiation to avoid overflow + powers = torch.tensor([pow(seed, n - j, table_size) for j in range(n)], dtype=torch.long, device=ids.device) + h = (windows * powers).sum(dim=-1) % table_size # (B, T) + return h + + def forward(self, x, input_ids, kv_cache=None): + B, T, D = x.shape + embeds = [] + for n in range(2, self.max_ngram + 1): + for k in range(self.n_heads): + key = f"{n}_{k}" + seed = self.hash_seeds[key].item() + table_size = self.table_sizes[key] + if kv_cache is None: + hash_idx = self._ngram_hash(input_ids, n, seed, table_size) + elif T == 1 and kv_cache.ngram_history is not None: + # Decode: build n-gram window from cached history + current token + history = kv_cache.ngram_history # (B, max_ngram-1) + current = input_ids[:, 0:1] # (B, 1) + # Take last (n-1) tokens from history, append current + hist_slice = history[:, -(n - 1):] # (B, n-1) + window = torch.cat([hist_slice, current], dim=1) # (B, n) + powers = torch.tensor([pow(seed, n - j, table_size) for j in range(n)], dtype=torch.long, device=input_ids.device) + hash_idx = (window * powers).sum(dim=-1, keepdim=True) % table_size # (B, 1) + else: + hash_idx = self._ngram_hash(input_ids, n, seed, table_size) + embeds.append(self.embed_tables[key](hash_idx)) + # Update n-gram history in KV cache after all orders processed + if kv_cache is not None: + max_hist = self.max_ngram - 1 + if T == 1: + if kv_cache.ngram_history is None: + kv_cache.ngram_history = input_ids[:, -1:].expand(B, max_hist).clone() + else: + # Shift left by 1 and append current token + new_hist = torch.cat([kv_cache.ngram_history[:, 1:], input_ids[:, -1:]], dim=1) + kv_cache.ngram_history = new_hist + else: + # Prefill: take last max_hist tokens + if kv_cache.ngram_history is None: + pad_len = max(0, max_hist - T) + hist = F.pad(input_ids[:, -max_hist:], (pad_len, 0), value=0) + kv_cache.ngram_history = hist + else: + combined = torch.cat([kv_cache.ngram_history, input_ids], dim=1) + kv_cache.ngram_history = combined[:, -max_hist:] + e = torch.cat(embeds, dim=-1) # (B, T, d_mem) + + k = norm(self.key_proj(e)) + v = self.value_proj(e) + gate = torch.sigmoid((norm(x) * k).sum(dim=-1, keepdim=True) / (D ** 0.5)) + gated_v = gate * v + normed_v = norm(gated_v) + conv_out = self.short_conv(normed_v.transpose(1, 2))[:, :, :T].transpose(1, 2) + return F.silu(conv_out) + gated_v + + class GPT(nn.Module): def __init__(self, config, pad_vocab_size_to=64): """ @@ -184,10 +334,28 @@ class GPT(nn.Module): self.smear_lambda = nn.Parameter(torch.zeros(1)) # Backout: subtract cached mid-layer residual before final norm to remove low-level features self.backout_lambda = nn.Parameter(0.2 * torch.ones(1)) + # Engram: conditional memory via N-gram hash lookup + # Resolve engram_layer_ids (including auto-default) before creating VE, so VE excludes Engram layers + if config.engram_enabled: + engram_layer_ids = config.engram_layer_ids if config.engram_layer_ids else (1, max(2, config.n_layer // 2 - 1)) + self.engram_layer_ids = engram_layer_ids + else: + engram_layer_ids = () + self.engram_layer_ids = () # Value embeddings (ResFormer-style): alternating layers, last layer always included + # Skip VE for layers that will use Engram instead + engram_layers = set(engram_layer_ids) head_dim = config.n_embd // config.n_head kv_dim = config.n_kv_head * head_dim - self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)}) + self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer) and i not in engram_layers}) + # Create Engram modules on the designated layers + if config.engram_enabled: + for i in engram_layer_ids: + self.transformer.h[i] = Block(config, i, engram=EngramModule(config, i, padded_vocab_size)) + # Disable VE gate in Engram layers (ve is set to None, gate would have no gradient) + self.transformer.h[i].attn.ve_gate = None + else: + self.engram_layer_ids = () # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them by 10X, but assert fail if we ever reach that amount. @@ -264,6 +432,20 @@ class GPT(nn.Module): self.transformer.wte.to(dtype=COMPUTE_DTYPE) for ve in self.value_embeds.values(): ve.to(dtype=COMPUTE_DTYPE) + for block in self.transformer.h: + if block.engram is not None: + for t in block.engram.embed_tables.values(): + t.to(dtype=COMPUTE_DTYPE) + + # Engram: init embeddings (uniform), projections (uniform), conv (zeros for identity) + for block in self.transformer.h: + if block.engram is not None: + em = block.engram + for t in em.embed_tables.values(): + torch.nn.init.uniform_(t.weight, -s, s) + torch.nn.init.uniform_(em.key_proj.weight, -s, s) + torch.nn.init.uniform_(em.value_proj.weight, -s, s) + torch.nn.init.zeros_(em.short_conv.weight) def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None): # TODO: bump base theta more? e.g. 100K is more common more recently @@ -329,7 +511,12 @@ class GPT(nn.Module): nparams = sum(p.numel() for p in self.parameters()) # Exclude non-matmul params: embeddings and per-layer scalars value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) - nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + + engram_embed_numel = 0 + for block in self.transformer.h: + if block.engram is not None: + for t in block.engram.embed_tables.values(): + engram_embed_numel += t.weight.numel() + nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + engram_embed_numel + self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len @@ -362,12 +549,18 @@ class GPT(nn.Module): scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() total = wte + value_embeds + lm_head + transformer_matrices + scalars assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch" + # Break down engram params for reporting + engram_params = 0 + for block in self.transformer.h: + if block.engram is not None: + engram_params += sum(p.numel() for p in block.engram.parameters()) return { 'wte': wte, 'value_embeds': value_embeds, 'lm_head': lm_head, 'transformer_matrices': transformer_matrices, 'scalars': scalars, + 'engram': engram_params, 'total': total, } @@ -383,6 +576,17 @@ class GPT(nn.Module): resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda] + # Engram: extract embed weights (-> AdamW) and matrix weights (already in matrix_params) + engram_embed_params = [] + engram_matrix_params_set = set() + for block in self.transformer.h: + if block.engram is not None: + em = block.engram + for t in em.embed_tables.values(): + engram_embed_params.append(t.weight) + engram_matrix_params_set.add(id(em.key_proj.weight)) + engram_matrix_params_set.add(id(em.value_proj.weight)) + engram_matrix_params_set.add(id(em.short_conv.weight)) assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params) # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model) @@ -399,9 +603,19 @@ class GPT(nn.Module): dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0 dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0), ] - # Muon groups (matrix params, grouped by shape for stacking) - for shape in sorted({p.shape for p in matrix_params}): - group_params = [p for p in matrix_params if p.shape == shape] + # Engram embedding tables: AdamW with 5x lr and no weight decay (per paper) + # Remove engram embeds and non-2D params from matrix_params so they don't end up in Muon groups + engram_embed_ids = {id(p) for p in engram_embed_params} + muon_params = [p for p in matrix_params if id(p) not in engram_embed_ids and p.dim() == 2 and p.requires_grad] + if engram_embed_params: + param_groups.append(dict(kind='adamw', params=engram_embed_params, lr=embedding_lr * dmodel_lr_scale * 5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.0)) + # Non-2D engram params (conv weights, hash seeds) go to AdamW + non_matrix_engram = [p for p in matrix_params if id(p) not in engram_embed_ids and id(p) not in {id(mp) for mp in muon_params}] + if non_matrix_engram: + param_groups.append(dict(kind='adamw', params=non_matrix_engram, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.0)) + # Muon groups (2D matrix params from blocks, grouped by shape for stacking) + for shape in sorted({p.shape for p in muon_params}): + group_params = [p for p in muon_params if p.shape == shape] param_groups.append(dict( kind='muon', params=group_params, lr=matrix_lr, momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay, @@ -455,8 +669,8 @@ class GPT(nn.Module): x_backout = None for i, block in enumerate(self.transformer.h): x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 - ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None - x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) + ve = self.value_embeds[str(i)](idx).to(x.dtype) if (str(i) in self.value_embeds and i not in self.engram_layer_ids) else None + x = block(x, idx, ve, cos_sin, self.window_sizes[i], kv_cache) if i == backout_layer: x_backout = x # Subtract mid-layer residual to remove low-level features before logit projection diff --git a/scripts/base_train.py b/scripts/base_train.py index a161c477..ae3b2bf1 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -52,6 +52,14 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") +# Engram: conditional memory via N-gram hash lookup +parser.add_argument("--engram", action="store_true", help="enable Engram conditional memory") +parser.add_argument("--engram-ngram-size", type=int, default=3, help="max N-gram order (2+3=bigram+trigram)") +parser.add_argument("--engram-n-heads", type=int, default=4, help="hash heads per N-gram order") +parser.add_argument("--engram-embed-dim", type=int, default=128, help="embedding dim per hash head") +parser.add_argument("--engram-table-size", type=int, default=0, help="hash table size (0=auto)") +parser.add_argument("--engram-layer-ids", type=str, default="", help="comma-separated layer IDs for Engram (empty=auto)") +parser.add_argument("--engram-kernel-size", type=int, default=4, help="depthwise causal conv kernel size") # Training horizon (only one used, in order of precedence) parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") @@ -133,10 +141,15 @@ def build_model_meta(depth): base_dim = depth * args.aspect_ratio model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim num_heads = model_dim // args.head_dim + engram_layer_ids = tuple(int(x) for x in args.engram_layer_ids.split(",")) if args.engram_layer_ids else () config = GPTConfig( sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, window_pattern=args.window_pattern, + engram_enabled=args.engram, engram_ngram_size=args.engram_ngram_size, + engram_n_heads=args.engram_n_heads, engram_embed_dim=args.engram_embed_dim, + engram_table_size=args.engram_table_size, engram_layer_ids=engram_layer_ids, + engram_kernel_size=args.engram_kernel_size, ) with torch.device("meta"): model_meta = GPT(config) diff --git a/tests/test_engram.py b/tests/test_engram.py new file mode 100644 index 00000000..576c6654 --- /dev/null +++ b/tests/test_engram.py @@ -0,0 +1,336 @@ +""" +Regression tests for Engram integration. + +These tests guard against bugs found during implementation: + +1. Stale Block.forward left inside EngramModule scope overrode the correct method +2. Parameter double-counting when Engram modules stored in both engam_modules and blocks +3. Engram embed params appearing in both Muon and AdamW optimizer groups +4. Non-2D params (hash_seeds, conv weights) crashing Muon's shape-based stacking +5. ve_gate on Engram layers getting no gradient (ve=None → gate unused → no grad) +6. Block.forward signature must accept input_ids for Engram layers +7. KVCache must support ngram_running state for decode + +Run: python -m pytest tests/test_engram.py -v +""" + +import torch +import pytest +from nanochat.gpt import GPT, GPTConfig, Block, EngramModule, _next_prime + + +def _make_model(engram_layer_ids=(1, 2), n_layer=4): + return GPTConfig( + sequence_len=128, vocab_size=32768, + n_layer=n_layer, n_head=4, n_kv_head=4, n_embd=256, + engram_enabled=True, engram_ngram_size=3, + engram_n_heads=2, engram_embed_dim=64, + engram_layer_ids=engram_layer_ids, + ) + + +class TestEngramForwardSignature: + """Block.forward must accept input_ids when Engram is present.""" + + def test_block_forward_accepts_input_ids(self): + config = _make_model() + model = GPT(config) + model.to(device="cpu") + model.init_weights() + + idx = torch.randint(0, config.vocab_size, (1, 16)) + # Should not raise TypeError about missing arguments + loss = model(idx, targets=idx) + assert loss.item() > 0 + + def test_block_without_engram_ignores_input_ids(self): + """Non-Engram blocks still work (engram=None path).""" + config = _make_model(engram_layer_ids=(1,), n_layer=4) + model = GPT(config) + model.to(device="cpu") + model.init_weights() + block0 = model.transformer.h[0] + assert block0.engram is None + + idx = torch.randint(0, config.vocab_size, (1, 16)) + loss = model(idx, targets=idx) + assert loss.item() > 0 + + +class TestEngramNoStaleForward: + """EngramModule.forward must have (x, input_ids, kv_cache) signature, + NOT the old Block.forward signature (x, ve, cos_sin, window_size, kv_cache). + This catches the bug where a leftover Block.forward was nested inside + EngramModule's scope due to indentation.""" + + def test_engram_forward_signature(self): + import inspect + sig = inspect.signature(EngramModule.forward) + param_names = list(sig.parameters.keys()) + assert "input_ids" in param_names, f"EngramModule.forward missing input_ids, got {param_names}" + assert "ve" not in param_names, f"EngramModule.forward should not have 've' (stale Block.forward), got {param_names}" + + def test_engram_forward_runs(self): + config = _make_model() + em = EngramModule(config, layer_idx=1, padded_vocab_size=32768) + x = torch.randn(1, 16, config.n_embd) + ids = torch.randint(0, 32768, (1, 16)) + out = em(x, ids) + assert out.shape == x.shape + + +class TestParameterCounting: + """num_scaling_params total must match actual parameter count (no double-counting).""" + + def test_param_count_matches(self): + config = _make_model() + model = GPT(config) + model.to(device="cpu") + model.init_weights() + params = model.num_scaling_params() + actual = sum(p.numel() for p in model.parameters()) + assert params["total"] == actual, f"Mismatch: {params['total']} != {actual}" + + def test_engram_params_nonzero(self): + config = _make_model() + model = GPT(config) + model.to(device="cpu") + model.init_weights() + params = model.num_scaling_params() + assert params["engram"] > 0 + + def test_no_engram_params_when_disabled(self): + config = GPTConfig(n_layer=4, n_head=4, n_kv_head=4, n_embd=256, engram_enabled=False) + model = GPT(config) + model.to(device="cpu") + model.init_weights() + params = model.num_scaling_params() + assert params["engram"] == 0 + + +class TestOptimizerNoDuplicates: + """Each parameter must appear in exactly one optimizer group.""" + + def test_no_duplicate_param_groups(self): + config = _make_model() + model = GPT(config) + model.to(device="cpu") + model.init_weights() + opt = model.setup_optimizer() + + seen = {} + for group in opt.param_groups: + for p in group["params"]: + pid = id(p) + assert pid not in seen, ( + f"Param shape {p.shape} in multiple groups: " + f"kind={group['kind']} and kind={seen[pid]}" + ) + seen[pid] = group["kind"] + + def test_engram_embeds_not_in_muon(self): + """Engram embedding table weights must be in AdamW, not Muon.""" + config = _make_model() + model = GPT(config) + model.to(device="cpu") + model.init_weights() + opt = model.setup_optimizer() + + engram_embed_ids = set() + for block in model.transformer.h: + if block.engram is not None: + for t in block.engram.embed_tables.values(): + engram_embed_ids.add(id(t.weight)) + + muon_params = set() + for group in opt.param_groups: + if group.get("kind") == "muon": + for p in group["params"]: + muon_params.add(id(p)) + + overlap = engram_embed_ids & muon_params + assert not overlap, f"Engram embed params found in Muon groups: {len(overlap)}" + + +class TestMuonOnlyMatrixParams: + """Muon optimizer groups must only contain 2D trainable parameters.""" + + def test_muon_groups_only_2d_trainable(self): + config = _make_model() + model = GPT(config) + model.to(device="cpu") + model.init_weights() + opt = model.setup_optimizer() + + for group in opt.param_groups: + if group.get("kind") != "muon": + continue + for p in group["params"]: + assert p.dim() == 2, f"Muon got {p.dim()}D param (shape={p.shape})" + assert p.requires_grad, f"Muon got non-trainable param" + + +class TestVEGateNoDeadGrad: + """On Engram layers, ve_gate must be None so it never has a dead gradient.""" + + def test_engram_layers_have_no_ve_gate(self): + config = _make_model(engram_layer_ids=(1, 2), n_layer=4) + model = GPT(config) + model.to(device="cpu") + model.init_weights() + + for i in model.engram_layer_ids: + block = model.transformer.h[i] + assert block.attn.ve_gate is None, f"Layer {i} has ve_gate but is an Engram layer" + + def test_all_trainable_params_get_gradients(self): + """Every requires_grad=True param must receive a gradient after backward.""" + config = _make_model() + model = GPT(config) + model.to(device="cpu") + model.init_weights() + + idx = torch.randint(0, config.vocab_size, (1, 16)) + loss = model(idx, targets=idx) + loss.backward() + + dead_grads = [] + for name, p in model.named_parameters(): + if p.requires_grad and p.grad is None: + dead_grads.append(name) + assert not dead_grads, f"Params with no gradient: {dead_grads}" + + +class TestAutoLayerSelection: + """Default engram_layer_ids should be (1, n_layer//2).""" + + @pytest.mark.parametrize("n_layer,expected", [ + (6, (1, 2)), + (12, (1, 5)), + (20, (1, 9)), + (30, (1, 14)), + ]) + def test_auto_layer_ids(self, n_layer, expected): + config = GPTConfig(n_layer=n_layer, n_head=4, n_kv_head=4, n_embd=256, engram_enabled=True) + model = GPT(config) + assert model.engram_layer_ids == expected + + +class TestKVCacheNgramState: + """KVCache must support ngram_history for Engram decode.""" + + def test_ngram_history_initially_none(self): + from nanochat.engine import KVCache + cache = KVCache(batch_size=1, num_heads=4, seq_len=128, head_dim=64, num_layers=4, device="cpu", dtype=torch.float32) + assert cache.ngram_history is None + + def test_reset_clears_ngram_history(self): + from nanochat.engine import KVCache + cache = KVCache(batch_size=1, num_heads=4, seq_len=128, head_dim=64, num_layers=4, device="cpu", dtype=torch.float32) + cache.ngram_history = torch.zeros(1, 2, dtype=torch.long) + cache.reset() + assert cache.ngram_history is None + + def test_prefill_copies_ngram_history(self): + from nanochat.engine import KVCache + src = KVCache(batch_size=1, num_heads=4, seq_len=128, head_dim=64, num_layers=4, device="cpu", dtype=torch.float32) + src.ngram_history = torch.tensor([[1, 2]], dtype=torch.long) + dst = KVCache(batch_size=4, num_heads=4, seq_len=128, head_dim=64, num_layers=4, device="cpu", dtype=torch.float32) + dst.prefill(src) + assert dst.ngram_history is not None + assert dst.ngram_history.shape == (4, 2) + assert (dst.ngram_history == torch.tensor([[1, 2]])).all() + + +class TestNextPrime: + """_next_prime helper must return primes.""" + + @pytest.mark.parametrize("n,expected", [ + (1, 2), + (2, 2), + (3, 3), + (4, 5), + (10, 11), + (100, 101), + ]) + def test_next_prime(self, n, expected): + assert _next_prime(n) == expected + + def test_next_prime_returns_prime(self): + result = _next_prime(99999) + # Verify it's actually prime + assert all(result % d for d in range(2, int(result**0.5) + 1)) + + +class TestEngramInitWeightsIdentity: + """Short conv must be zero-initialized so Engram starts as identity.""" + + def test_conv_weights_zero_after_init(self): + config = _make_model() + model = GPT(config) + model.to(device="cpu") + model.init_weights() + + for i in model.engram_layer_ids: + conv = model.transformer.h[i].engram.short_conv + assert torch.all(conv.weight == 0), "Engram conv weights should be zero-initialized" + + +class TestEngramDecodeIntegration: + """Full inference with KV cache + Engram: prefill then decode must produce consistent results.""" + + def _make_model(self): + config = GPTConfig( + sequence_len=64, vocab_size=32768, + n_layer=6, n_head=4, n_kv_head=4, n_embd=256, + engram_enabled=True, engram_ngram_size=3, + engram_n_heads=2, engram_embed_dim=64, + engram_layer_ids=(1, 2), + ) + model = GPT(config) + model.to(device="cpu") + model.init_weights() + return model + + def test_naive_vs_kv_cache_single_token(self): + """Single-token decode with KV cache should match naive recomputation.""" + from nanochat.engine import KVCache + model = self._make_model() + config = model.config + prompt = torch.randint(0, config.vocab_size, (1, 16)) + + # Naive: full sequence forward, take last token logits + with torch.no_grad(): + logits_naive = model(prompt)[:, -1, :] + + # KV cache: prefill then decode one token + m = config + kv_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer} + kv = KVCache(batch_size=1, seq_len=64, device=torch.device("cpu"), dtype=torch.float32, **kv_kwargs) + with torch.no_grad(): + logits_prefill = model(prompt, kv_cache=kv)[:, -1, :] + + # Prefill should match naive + assert torch.allclose(logits_naive, logits_prefill, atol=1e-4), "Prefill logits should match naive" + + def test_prefill_copies_ngram_history_to_batch_cache(self): + """After prefill, ngram_history should be available and copyable for batch decode.""" + from nanochat.engine import KVCache + model = self._make_model() + config = model.config + prompt = torch.randint(0, config.vocab_size, (1, 8)) + + m = config + kv_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer} + kv_prefill = KVCache(batch_size=1, seq_len=64, device=torch.device("cpu"), dtype=torch.float32, **kv_kwargs) + with torch.no_grad(): + model(prompt, kv_cache=kv_prefill) + + assert kv_prefill.ngram_history is not None, "Prefill should set ngram_history" + max_hist = config.engram_ngram_size - 1 + assert kv_prefill.ngram_history.shape == (1, max_hist) + + # Copy to batch cache + kv_batch = KVCache(batch_size=3, seq_len=64, device=torch.device("cpu"), dtype=torch.float32, **kv_kwargs) + kv_batch.prefill(kv_prefill) + assert kv_batch.ngram_history.shape == (3, max_hist)