Add Engram conditional memory module

Integrates DeepSeek's Engram (N-gram hash lookup + context-aware gating
+ depthwise causal conv) as an optional module behind --engram CLI flag.
Placed at two layers per paper ablation findings (layer 1 and n_layer//2-1).
Coexists with existing Value Embeddings; disabled by default.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
junran 2026-04-19 12:04:02 +08:00
parent 0aaca56805
commit dcd9b0668b
4 changed files with 578 additions and 9 deletions

View File

@ -102,11 +102,14 @@ class KVCache:
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
# Previous token's normalized embedding for smear (set by model forward pass)
self.prev_embedding = None
# Engram N-gram history: (B, max_ngram-1) tensor of recent token IDs
self.ngram_history = None
def reset(self):
"""Reset cache to empty state."""
self.cache_seqlens.zero_()
self.prev_embedding = None
self.ngram_history = None
def get_pos(self):
"""Get current position (assumes all batch elements at same position)."""
@ -135,6 +138,9 @@ class KVCache:
# Copy smear state: expand batch=1 prev_embedding to num_samples
if other.prev_embedding is not None:
self.prev_embedding = other.prev_embedding.expand(self.batch_size, -1, -1).clone()
# Copy Engram N-gram history: expand batch=1 to num_samples
if other.ngram_history is not None:
self.ngram_history = other.ngram_history.expand(self.batch_size, -1).clone()
# -----------------------------------------------------------------------------
@torch.inference_mode()

View File

@ -37,6 +37,14 @@ class GPTConfig:
# Characters: L=long (full context), S=short (quarter context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
window_pattern: str = "SSSL"
# Engram: conditional memory via N-gram hash lookup (DeepSeek Engram paper)
engram_enabled: bool = False
engram_ngram_size: int = 3 # max N-gram order (uses 2-gram .. N-gram)
engram_n_heads: int = 4 # hash heads per N-gram order
engram_embed_dim: int = 128 # embedding dim per hash head
engram_table_size: int = 0 # hash table size (0=auto: ~vocab_size*3)
engram_layer_ids: tuple = () # layers to place Engram (empty=auto: (1, n_layer//2))
engram_kernel_size: int = 4 # depthwise causal conv kernel size
def norm(x):
@ -140,17 +148,159 @@ class MLP(nn.Module):
class Block(nn.Module):
def __init__(self, config, layer_idx):
def __init__(self, config, layer_idx, engram=None):
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
self.engram = engram
def forward(self, x, ve, cos_sin, window_size, kv_cache):
def forward(self, x, input_ids, ve, cos_sin, window_size, kv_cache):
if self.engram is not None:
x = x + self.engram(x, input_ids, kv_cache)
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
x = x + self.mlp(norm(x))
return x
def _next_prime(n):
"""Find the smallest prime >= n."""
if n <= 2:
return 2
if n % 2 == 0:
n += 1
while True:
if all(n % d for d in range(3, int(n**0.5) + 1, 2)):
return n
n += 2
class EngramModule(nn.Module):
"""Conditional memory via N-gram hash lookup (DeepSeek Engram).
For each N-gram order (2..max_ngram) and each hash head (1..n_heads),
computes a multiplicative hash over the last n tokens to index into
prime-sized embedding tables. Retrieved embeddings are projected to
key/value, gated by hidden state similarity, and refined via depthwise
causal convolution.
Known deviations from paper:
- Token compression (Section 2.2, NFKC/lowercase normalization) is omitted.
nanochat's 32K BPE vocab is already compact; the hash tables are prime-sized
and larger than the vocab, so collision rates remain low.
- Hash function uses polynomial rolling hash over the exact n-gram window
rather than the demo's multiplicative-XOR accumulation.
Paper: "Conditional Memory via Scalable Lookup: A New Axis of Sparsity for LLMs"
"""
def __init__(self, config, layer_idx, padded_vocab_size):
super().__init__()
self.layer_idx = layer_idx
self.max_ngram = config.engram_ngram_size
self.vocab_size = padded_vocab_size
self.n_heads = config.engram_n_heads
self.embed_dim = config.engram_embed_dim
self.n_embd = config.n_embd
n_gram_orders = self.max_ngram - 1 # e.g. max_ngram=3 → orders 2,3
self.d_mem = n_gram_orders * self.n_heads * self.embed_dim
base_table_size = config.engram_table_size if config.engram_table_size > 0 else padded_vocab_size * 3
self.embed_tables = nn.ModuleDict()
self.hash_seeds = nn.ParameterDict() # buffers stored as 0-dim params for checkpointing
self.table_sizes = {}
for n in range(2, self.max_ngram + 1):
for k in range(self.n_heads):
prime_size = _next_prime(base_table_size + n * k * 997)
key = f"{n}_{k}"
self.embed_tables[key] = nn.Embedding(prime_size, self.embed_dim)
seed = n * 2654435761 + k * 40503 # multiplicative hash constants
self.hash_seeds[key] = nn.Parameter(torch.tensor(seed, dtype=torch.long), requires_grad=False)
self.table_sizes[key] = prime_size
self.key_proj = Linear(self.d_mem, config.n_embd, bias=False)
self.value_proj = Linear(self.d_mem, config.n_embd, bias=False)
dilation = config.engram_ngram_size # paper Section 2.3: dilation = max N-gram order
self.short_conv = nn.Conv1d(
config.n_embd, config.n_embd,
kernel_size=config.engram_kernel_size,
padding=dilation * (config.engram_kernel_size - 1),
dilation=dilation,
groups=config.n_embd, bias=False,
)
def _ngram_hash(self, input_ids, n, seed, table_size):
"""Vectorized multiplicative hash over the last n tokens only.
For order n, the hash at position t is:
h(t) = (ids[t] * seed^1 + ids[t-1] * seed^2 + ... + ids[t-n+1] * seed^n) % table_size
Positions with fewer than n preceding tokens use a zero-padded window.
"""
B, T = input_ids.shape
ids = input_ids.long()
# Build a (B, T, n) window tensor: window[b, t, j] = ids[b, t - (n-1-j)]
# Use pad+slice for vectorization (no Python loop)
padded = F.pad(ids, (n - 1, 0), value=0) # (B, T + n - 1)
# Extract windows of size n ending at each position
# window[b, t, :] = padded[b, t : t+n]
windows = padded.unfold(dimension=1, size=n, step=1) # (B, T, n)
# Compute polynomial hash: h = sum(window[j] * seed^(n-j)) mod table_size
# Use modular exponentiation to avoid overflow
powers = torch.tensor([pow(seed, n - j, table_size) for j in range(n)], dtype=torch.long, device=ids.device)
h = (windows * powers).sum(dim=-1) % table_size # (B, T)
return h
def forward(self, x, input_ids, kv_cache=None):
B, T, D = x.shape
embeds = []
for n in range(2, self.max_ngram + 1):
for k in range(self.n_heads):
key = f"{n}_{k}"
seed = self.hash_seeds[key].item()
table_size = self.table_sizes[key]
if kv_cache is None:
hash_idx = self._ngram_hash(input_ids, n, seed, table_size)
elif T == 1 and kv_cache.ngram_history is not None:
# Decode: build n-gram window from cached history + current token
history = kv_cache.ngram_history # (B, max_ngram-1)
current = input_ids[:, 0:1] # (B, 1)
# Take last (n-1) tokens from history, append current
hist_slice = history[:, -(n - 1):] # (B, n-1)
window = torch.cat([hist_slice, current], dim=1) # (B, n)
powers = torch.tensor([pow(seed, n - j, table_size) for j in range(n)], dtype=torch.long, device=input_ids.device)
hash_idx = (window * powers).sum(dim=-1, keepdim=True) % table_size # (B, 1)
else:
hash_idx = self._ngram_hash(input_ids, n, seed, table_size)
embeds.append(self.embed_tables[key](hash_idx))
# Update n-gram history in KV cache after all orders processed
if kv_cache is not None:
max_hist = self.max_ngram - 1
if T == 1:
if kv_cache.ngram_history is None:
kv_cache.ngram_history = input_ids[:, -1:].expand(B, max_hist).clone()
else:
# Shift left by 1 and append current token
new_hist = torch.cat([kv_cache.ngram_history[:, 1:], input_ids[:, -1:]], dim=1)
kv_cache.ngram_history = new_hist
else:
# Prefill: take last max_hist tokens
if kv_cache.ngram_history is None:
pad_len = max(0, max_hist - T)
hist = F.pad(input_ids[:, -max_hist:], (pad_len, 0), value=0)
kv_cache.ngram_history = hist
else:
combined = torch.cat([kv_cache.ngram_history, input_ids], dim=1)
kv_cache.ngram_history = combined[:, -max_hist:]
e = torch.cat(embeds, dim=-1) # (B, T, d_mem)
k = norm(self.key_proj(e))
v = self.value_proj(e)
gate = torch.sigmoid((norm(x) * k).sum(dim=-1, keepdim=True) / (D ** 0.5))
gated_v = gate * v
normed_v = norm(gated_v)
conv_out = self.short_conv(normed_v.transpose(1, 2))[:, :, :T].transpose(1, 2)
return F.silu(conv_out) + gated_v
class GPT(nn.Module):
def __init__(self, config, pad_vocab_size_to=64):
"""
@ -184,10 +334,28 @@ class GPT(nn.Module):
self.smear_lambda = nn.Parameter(torch.zeros(1))
# Backout: subtract cached mid-layer residual before final norm to remove low-level features
self.backout_lambda = nn.Parameter(0.2 * torch.ones(1))
# Engram: conditional memory via N-gram hash lookup
# Resolve engram_layer_ids (including auto-default) before creating VE, so VE excludes Engram layers
if config.engram_enabled:
engram_layer_ids = config.engram_layer_ids if config.engram_layer_ids else (1, max(2, config.n_layer // 2 - 1))
self.engram_layer_ids = engram_layer_ids
else:
engram_layer_ids = ()
self.engram_layer_ids = ()
# Value embeddings (ResFormer-style): alternating layers, last layer always included
# Skip VE for layers that will use Engram instead
engram_layers = set(engram_layer_ids)
head_dim = config.n_embd // config.n_head
kv_dim = config.n_kv_head * head_dim
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer) and i not in engram_layers})
# Create Engram modules on the designated layers
if config.engram_enabled:
for i in engram_layer_ids:
self.transformer.h[i] = Block(config, i, engram=EngramModule(config, i, padded_vocab_size))
# Disable VE gate in Engram layers (ve is set to None, gate would have no gradient)
self.transformer.h[i].attn.ve_gate = None
else:
self.engram_layer_ids = ()
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
@ -264,6 +432,20 @@ class GPT(nn.Module):
self.transformer.wte.to(dtype=COMPUTE_DTYPE)
for ve in self.value_embeds.values():
ve.to(dtype=COMPUTE_DTYPE)
for block in self.transformer.h:
if block.engram is not None:
for t in block.engram.embed_tables.values():
t.to(dtype=COMPUTE_DTYPE)
# Engram: init embeddings (uniform), projections (uniform), conv (zeros for identity)
for block in self.transformer.h:
if block.engram is not None:
em = block.engram
for t in em.embed_tables.values():
torch.nn.init.uniform_(t.weight, -s, s)
torch.nn.init.uniform_(em.key_proj.weight, -s, s)
torch.nn.init.uniform_(em.value_proj.weight, -s, s)
torch.nn.init.zeros_(em.short_conv.weight)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently
@ -329,7 +511,12 @@ class GPT(nn.Module):
nparams = sum(p.numel() for p in self.parameters())
# Exclude non-matmul params: embeddings and per-layer scalars
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
engram_embed_numel = 0
for block in self.transformer.h:
if block.engram is not None:
for t in block.engram.embed_tables.values():
engram_embed_numel += t.weight.numel()
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + engram_embed_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel() +
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel())
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
@ -362,12 +549,18 @@ class GPT(nn.Module):
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()
total = wte + value_embeds + lm_head + transformer_matrices + scalars
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
# Break down engram params for reporting
engram_params = 0
for block in self.transformer.h:
if block.engram is not None:
engram_params += sum(p.numel() for p in block.engram.parameters())
return {
'wte': wte,
'value_embeds': value_embeds,
'lm_head': lm_head,
'transformer_matrices': transformer_matrices,
'scalars': scalars,
'engram': engram_params,
'total': total,
}
@ -383,6 +576,17 @@ class GPT(nn.Module):
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
# Engram: extract embed weights (-> AdamW) and matrix weights (already in matrix_params)
engram_embed_params = []
engram_matrix_params_set = set()
for block in self.transformer.h:
if block.engram is not None:
em = block.engram
for t in em.embed_tables.values():
engram_embed_params.append(t.weight)
engram_matrix_params_set.add(id(em.key_proj.weight))
engram_matrix_params_set.add(id(em.value_proj.weight))
engram_matrix_params_set.add(id(em.short_conv.weight))
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
@ -399,9 +603,19 @@ class GPT(nn.Module):
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
]
# Muon groups (matrix params, grouped by shape for stacking)
for shape in sorted({p.shape for p in matrix_params}):
group_params = [p for p in matrix_params if p.shape == shape]
# Engram embedding tables: AdamW with 5x lr and no weight decay (per paper)
# Remove engram embeds and non-2D params from matrix_params so they don't end up in Muon groups
engram_embed_ids = {id(p) for p in engram_embed_params}
muon_params = [p for p in matrix_params if id(p) not in engram_embed_ids and p.dim() == 2 and p.requires_grad]
if engram_embed_params:
param_groups.append(dict(kind='adamw', params=engram_embed_params, lr=embedding_lr * dmodel_lr_scale * 5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.0))
# Non-2D engram params (conv weights, hash seeds) go to AdamW
non_matrix_engram = [p for p in matrix_params if id(p) not in engram_embed_ids and id(p) not in {id(mp) for mp in muon_params}]
if non_matrix_engram:
param_groups.append(dict(kind='adamw', params=non_matrix_engram, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.0))
# Muon groups (2D matrix params from blocks, grouped by shape for stacking)
for shape in sorted({p.shape for p in muon_params}):
group_params = [p for p in muon_params if p.shape == shape]
param_groups.append(dict(
kind='muon', params=group_params, lr=matrix_lr,
momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
@ -455,8 +669,8 @@ class GPT(nn.Module):
x_backout = None
for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
ve = self.value_embeds[str(i)](idx).to(x.dtype) if (str(i) in self.value_embeds and i not in self.engram_layer_ids) else None
x = block(x, idx, ve, cos_sin, self.window_sizes[i], kv_cache)
if i == backout_layer:
x_backout = x
# Subtract mid-layer residual to remove low-level features before logit projection

View File

@ -52,6 +52,14 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
# Engram: conditional memory via N-gram hash lookup
parser.add_argument("--engram", action="store_true", help="enable Engram conditional memory")
parser.add_argument("--engram-ngram-size", type=int, default=3, help="max N-gram order (2+3=bigram+trigram)")
parser.add_argument("--engram-n-heads", type=int, default=4, help="hash heads per N-gram order")
parser.add_argument("--engram-embed-dim", type=int, default=128, help="embedding dim per hash head")
parser.add_argument("--engram-table-size", type=int, default=0, help="hash table size (0=auto)")
parser.add_argument("--engram-layer-ids", type=str, default="", help="comma-separated layer IDs for Engram (empty=auto)")
parser.add_argument("--engram-kernel-size", type=int, default=4, help="depthwise causal conv kernel size")
# Training horizon (only one used, in order of precedence)
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
@ -133,10 +141,15 @@ def build_model_meta(depth):
base_dim = depth * args.aspect_ratio
model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
num_heads = model_dim // args.head_dim
engram_layer_ids = tuple(int(x) for x in args.engram_layer_ids.split(",")) if args.engram_layer_ids else ()
config = GPTConfig(
sequence_len=args.max_seq_len, vocab_size=vocab_size,
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
window_pattern=args.window_pattern,
engram_enabled=args.engram, engram_ngram_size=args.engram_ngram_size,
engram_n_heads=args.engram_n_heads, engram_embed_dim=args.engram_embed_dim,
engram_table_size=args.engram_table_size, engram_layer_ids=engram_layer_ids,
engram_kernel_size=args.engram_kernel_size,
)
with torch.device("meta"):
model_meta = GPT(config)

336
tests/test_engram.py Normal file
View File

@ -0,0 +1,336 @@
"""
Regression tests for Engram integration.
These tests guard against bugs found during implementation:
1. Stale Block.forward left inside EngramModule scope overrode the correct method
2. Parameter double-counting when Engram modules stored in both engam_modules and blocks
3. Engram embed params appearing in both Muon and AdamW optimizer groups
4. Non-2D params (hash_seeds, conv weights) crashing Muon's shape-based stacking
5. ve_gate on Engram layers getting no gradient (ve=None gate unused no grad)
6. Block.forward signature must accept input_ids for Engram layers
7. KVCache must support ngram_running state for decode
Run: python -m pytest tests/test_engram.py -v
"""
import torch
import pytest
from nanochat.gpt import GPT, GPTConfig, Block, EngramModule, _next_prime
def _make_model(engram_layer_ids=(1, 2), n_layer=4):
return GPTConfig(
sequence_len=128, vocab_size=32768,
n_layer=n_layer, n_head=4, n_kv_head=4, n_embd=256,
engram_enabled=True, engram_ngram_size=3,
engram_n_heads=2, engram_embed_dim=64,
engram_layer_ids=engram_layer_ids,
)
class TestEngramForwardSignature:
"""Block.forward must accept input_ids when Engram is present."""
def test_block_forward_accepts_input_ids(self):
config = _make_model()
model = GPT(config)
model.to(device="cpu")
model.init_weights()
idx = torch.randint(0, config.vocab_size, (1, 16))
# Should not raise TypeError about missing arguments
loss = model(idx, targets=idx)
assert loss.item() > 0
def test_block_without_engram_ignores_input_ids(self):
"""Non-Engram blocks still work (engram=None path)."""
config = _make_model(engram_layer_ids=(1,), n_layer=4)
model = GPT(config)
model.to(device="cpu")
model.init_weights()
block0 = model.transformer.h[0]
assert block0.engram is None
idx = torch.randint(0, config.vocab_size, (1, 16))
loss = model(idx, targets=idx)
assert loss.item() > 0
class TestEngramNoStaleForward:
"""EngramModule.forward must have (x, input_ids, kv_cache) signature,
NOT the old Block.forward signature (x, ve, cos_sin, window_size, kv_cache).
This catches the bug where a leftover Block.forward was nested inside
EngramModule's scope due to indentation."""
def test_engram_forward_signature(self):
import inspect
sig = inspect.signature(EngramModule.forward)
param_names = list(sig.parameters.keys())
assert "input_ids" in param_names, f"EngramModule.forward missing input_ids, got {param_names}"
assert "ve" not in param_names, f"EngramModule.forward should not have 've' (stale Block.forward), got {param_names}"
def test_engram_forward_runs(self):
config = _make_model()
em = EngramModule(config, layer_idx=1, padded_vocab_size=32768)
x = torch.randn(1, 16, config.n_embd)
ids = torch.randint(0, 32768, (1, 16))
out = em(x, ids)
assert out.shape == x.shape
class TestParameterCounting:
"""num_scaling_params total must match actual parameter count (no double-counting)."""
def test_param_count_matches(self):
config = _make_model()
model = GPT(config)
model.to(device="cpu")
model.init_weights()
params = model.num_scaling_params()
actual = sum(p.numel() for p in model.parameters())
assert params["total"] == actual, f"Mismatch: {params['total']} != {actual}"
def test_engram_params_nonzero(self):
config = _make_model()
model = GPT(config)
model.to(device="cpu")
model.init_weights()
params = model.num_scaling_params()
assert params["engram"] > 0
def test_no_engram_params_when_disabled(self):
config = GPTConfig(n_layer=4, n_head=4, n_kv_head=4, n_embd=256, engram_enabled=False)
model = GPT(config)
model.to(device="cpu")
model.init_weights()
params = model.num_scaling_params()
assert params["engram"] == 0
class TestOptimizerNoDuplicates:
"""Each parameter must appear in exactly one optimizer group."""
def test_no_duplicate_param_groups(self):
config = _make_model()
model = GPT(config)
model.to(device="cpu")
model.init_weights()
opt = model.setup_optimizer()
seen = {}
for group in opt.param_groups:
for p in group["params"]:
pid = id(p)
assert pid not in seen, (
f"Param shape {p.shape} in multiple groups: "
f"kind={group['kind']} and kind={seen[pid]}"
)
seen[pid] = group["kind"]
def test_engram_embeds_not_in_muon(self):
"""Engram embedding table weights must be in AdamW, not Muon."""
config = _make_model()
model = GPT(config)
model.to(device="cpu")
model.init_weights()
opt = model.setup_optimizer()
engram_embed_ids = set()
for block in model.transformer.h:
if block.engram is not None:
for t in block.engram.embed_tables.values():
engram_embed_ids.add(id(t.weight))
muon_params = set()
for group in opt.param_groups:
if group.get("kind") == "muon":
for p in group["params"]:
muon_params.add(id(p))
overlap = engram_embed_ids & muon_params
assert not overlap, f"Engram embed params found in Muon groups: {len(overlap)}"
class TestMuonOnlyMatrixParams:
"""Muon optimizer groups must only contain 2D trainable parameters."""
def test_muon_groups_only_2d_trainable(self):
config = _make_model()
model = GPT(config)
model.to(device="cpu")
model.init_weights()
opt = model.setup_optimizer()
for group in opt.param_groups:
if group.get("kind") != "muon":
continue
for p in group["params"]:
assert p.dim() == 2, f"Muon got {p.dim()}D param (shape={p.shape})"
assert p.requires_grad, f"Muon got non-trainable param"
class TestVEGateNoDeadGrad:
"""On Engram layers, ve_gate must be None so it never has a dead gradient."""
def test_engram_layers_have_no_ve_gate(self):
config = _make_model(engram_layer_ids=(1, 2), n_layer=4)
model = GPT(config)
model.to(device="cpu")
model.init_weights()
for i in model.engram_layer_ids:
block = model.transformer.h[i]
assert block.attn.ve_gate is None, f"Layer {i} has ve_gate but is an Engram layer"
def test_all_trainable_params_get_gradients(self):
"""Every requires_grad=True param must receive a gradient after backward."""
config = _make_model()
model = GPT(config)
model.to(device="cpu")
model.init_weights()
idx = torch.randint(0, config.vocab_size, (1, 16))
loss = model(idx, targets=idx)
loss.backward()
dead_grads = []
for name, p in model.named_parameters():
if p.requires_grad and p.grad is None:
dead_grads.append(name)
assert not dead_grads, f"Params with no gradient: {dead_grads}"
class TestAutoLayerSelection:
"""Default engram_layer_ids should be (1, n_layer//2)."""
@pytest.mark.parametrize("n_layer,expected", [
(6, (1, 2)),
(12, (1, 5)),
(20, (1, 9)),
(30, (1, 14)),
])
def test_auto_layer_ids(self, n_layer, expected):
config = GPTConfig(n_layer=n_layer, n_head=4, n_kv_head=4, n_embd=256, engram_enabled=True)
model = GPT(config)
assert model.engram_layer_ids == expected
class TestKVCacheNgramState:
"""KVCache must support ngram_history for Engram decode."""
def test_ngram_history_initially_none(self):
from nanochat.engine import KVCache
cache = KVCache(batch_size=1, num_heads=4, seq_len=128, head_dim=64, num_layers=4, device="cpu", dtype=torch.float32)
assert cache.ngram_history is None
def test_reset_clears_ngram_history(self):
from nanochat.engine import KVCache
cache = KVCache(batch_size=1, num_heads=4, seq_len=128, head_dim=64, num_layers=4, device="cpu", dtype=torch.float32)
cache.ngram_history = torch.zeros(1, 2, dtype=torch.long)
cache.reset()
assert cache.ngram_history is None
def test_prefill_copies_ngram_history(self):
from nanochat.engine import KVCache
src = KVCache(batch_size=1, num_heads=4, seq_len=128, head_dim=64, num_layers=4, device="cpu", dtype=torch.float32)
src.ngram_history = torch.tensor([[1, 2]], dtype=torch.long)
dst = KVCache(batch_size=4, num_heads=4, seq_len=128, head_dim=64, num_layers=4, device="cpu", dtype=torch.float32)
dst.prefill(src)
assert dst.ngram_history is not None
assert dst.ngram_history.shape == (4, 2)
assert (dst.ngram_history == torch.tensor([[1, 2]])).all()
class TestNextPrime:
"""_next_prime helper must return primes."""
@pytest.mark.parametrize("n,expected", [
(1, 2),
(2, 2),
(3, 3),
(4, 5),
(10, 11),
(100, 101),
])
def test_next_prime(self, n, expected):
assert _next_prime(n) == expected
def test_next_prime_returns_prime(self):
result = _next_prime(99999)
# Verify it's actually prime
assert all(result % d for d in range(2, int(result**0.5) + 1))
class TestEngramInitWeightsIdentity:
"""Short conv must be zero-initialized so Engram starts as identity."""
def test_conv_weights_zero_after_init(self):
config = _make_model()
model = GPT(config)
model.to(device="cpu")
model.init_weights()
for i in model.engram_layer_ids:
conv = model.transformer.h[i].engram.short_conv
assert torch.all(conv.weight == 0), "Engram conv weights should be zero-initialized"
class TestEngramDecodeIntegration:
"""Full inference with KV cache + Engram: prefill then decode must produce consistent results."""
def _make_model(self):
config = GPTConfig(
sequence_len=64, vocab_size=32768,
n_layer=6, n_head=4, n_kv_head=4, n_embd=256,
engram_enabled=True, engram_ngram_size=3,
engram_n_heads=2, engram_embed_dim=64,
engram_layer_ids=(1, 2),
)
model = GPT(config)
model.to(device="cpu")
model.init_weights()
return model
def test_naive_vs_kv_cache_single_token(self):
"""Single-token decode with KV cache should match naive recomputation."""
from nanochat.engine import KVCache
model = self._make_model()
config = model.config
prompt = torch.randint(0, config.vocab_size, (1, 16))
# Naive: full sequence forward, take last token logits
with torch.no_grad():
logits_naive = model(prompt)[:, -1, :]
# KV cache: prefill then decode one token
m = config
kv_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
kv = KVCache(batch_size=1, seq_len=64, device=torch.device("cpu"), dtype=torch.float32, **kv_kwargs)
with torch.no_grad():
logits_prefill = model(prompt, kv_cache=kv)[:, -1, :]
# Prefill should match naive
assert torch.allclose(logits_naive, logits_prefill, atol=1e-4), "Prefill logits should match naive"
def test_prefill_copies_ngram_history_to_batch_cache(self):
"""After prefill, ngram_history should be available and copyable for batch decode."""
from nanochat.engine import KVCache
model = self._make_model()
config = model.config
prompt = torch.randint(0, config.vocab_size, (1, 8))
m = config
kv_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
kv_prefill = KVCache(batch_size=1, seq_len=64, device=torch.device("cpu"), dtype=torch.float32, **kv_kwargs)
with torch.no_grad():
model(prompt, kv_cache=kv_prefill)
assert kv_prefill.ngram_history is not None, "Prefill should set ngram_history"
max_hist = config.engram_ngram_size - 1
assert kv_prefill.ngram_history.shape == (1, max_hist)
# Copy to batch cache
kv_batch = KVCache(batch_size=3, seq_len=64, device=torch.device("cpu"), dtype=torch.float32, **kv_kwargs)
kv_batch.prefill(kv_prefill)
assert kv_batch.ngram_history.shape == (3, max_hist)