mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-25 17:18:01 +00:00
Add Engram conditional memory module
Integrates DeepSeek's Engram (N-gram hash lookup + context-aware gating + depthwise causal conv) as an optional module behind --engram CLI flag. Placed at two layers per paper ablation findings (layer 1 and n_layer//2-1). Coexists with existing Value Embeddings; disabled by default. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
0aaca56805
commit
dcd9b0668b
|
|
@ -102,11 +102,14 @@ class KVCache:
|
|||
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
# Previous token's normalized embedding for smear (set by model forward pass)
|
||||
self.prev_embedding = None
|
||||
# Engram N-gram history: (B, max_ngram-1) tensor of recent token IDs
|
||||
self.ngram_history = None
|
||||
|
||||
def reset(self):
|
||||
"""Reset cache to empty state."""
|
||||
self.cache_seqlens.zero_()
|
||||
self.prev_embedding = None
|
||||
self.ngram_history = None
|
||||
|
||||
def get_pos(self):
|
||||
"""Get current position (assumes all batch elements at same position)."""
|
||||
|
|
@ -135,6 +138,9 @@ class KVCache:
|
|||
# Copy smear state: expand batch=1 prev_embedding to num_samples
|
||||
if other.prev_embedding is not None:
|
||||
self.prev_embedding = other.prev_embedding.expand(self.batch_size, -1, -1).clone()
|
||||
# Copy Engram N-gram history: expand batch=1 to num_samples
|
||||
if other.ngram_history is not None:
|
||||
self.ngram_history = other.ngram_history.expand(self.batch_size, -1).clone()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@torch.inference_mode()
|
||||
|
|
|
|||
232
nanochat/gpt.py
232
nanochat/gpt.py
|
|
@ -37,6 +37,14 @@ class GPTConfig:
|
|||
# Characters: L=long (full context), S=short (quarter context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
window_pattern: str = "SSSL"
|
||||
# Engram: conditional memory via N-gram hash lookup (DeepSeek Engram paper)
|
||||
engram_enabled: bool = False
|
||||
engram_ngram_size: int = 3 # max N-gram order (uses 2-gram .. N-gram)
|
||||
engram_n_heads: int = 4 # hash heads per N-gram order
|
||||
engram_embed_dim: int = 128 # embedding dim per hash head
|
||||
engram_table_size: int = 0 # hash table size (0=auto: ~vocab_size*3)
|
||||
engram_layer_ids: tuple = () # layers to place Engram (empty=auto: (1, n_layer//2))
|
||||
engram_kernel_size: int = 4 # depthwise causal conv kernel size
|
||||
|
||||
|
||||
def norm(x):
|
||||
|
|
@ -140,17 +148,159 @@ class MLP(nn.Module):
|
|||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
def __init__(self, config, layer_idx, engram=None):
|
||||
super().__init__()
|
||||
self.attn = CausalSelfAttention(config, layer_idx)
|
||||
self.mlp = MLP(config)
|
||||
self.engram = engram
|
||||
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
def forward(self, x, input_ids, ve, cos_sin, window_size, kv_cache):
|
||||
if self.engram is not None:
|
||||
x = x + self.engram(x, input_ids, kv_cache)
|
||||
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
|
||||
x = x + self.mlp(norm(x))
|
||||
return x
|
||||
|
||||
|
||||
def _next_prime(n):
|
||||
"""Find the smallest prime >= n."""
|
||||
if n <= 2:
|
||||
return 2
|
||||
if n % 2 == 0:
|
||||
n += 1
|
||||
while True:
|
||||
if all(n % d for d in range(3, int(n**0.5) + 1, 2)):
|
||||
return n
|
||||
n += 2
|
||||
|
||||
|
||||
class EngramModule(nn.Module):
|
||||
"""Conditional memory via N-gram hash lookup (DeepSeek Engram).
|
||||
|
||||
For each N-gram order (2..max_ngram) and each hash head (1..n_heads),
|
||||
computes a multiplicative hash over the last n tokens to index into
|
||||
prime-sized embedding tables. Retrieved embeddings are projected to
|
||||
key/value, gated by hidden state similarity, and refined via depthwise
|
||||
causal convolution.
|
||||
|
||||
Known deviations from paper:
|
||||
- Token compression (Section 2.2, NFKC/lowercase normalization) is omitted.
|
||||
nanochat's 32K BPE vocab is already compact; the hash tables are prime-sized
|
||||
and larger than the vocab, so collision rates remain low.
|
||||
- Hash function uses polynomial rolling hash over the exact n-gram window
|
||||
rather than the demo's multiplicative-XOR accumulation.
|
||||
|
||||
Paper: "Conditional Memory via Scalable Lookup: A New Axis of Sparsity for LLMs"
|
||||
"""
|
||||
|
||||
def __init__(self, config, layer_idx, padded_vocab_size):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.max_ngram = config.engram_ngram_size
|
||||
self.vocab_size = padded_vocab_size
|
||||
self.n_heads = config.engram_n_heads
|
||||
self.embed_dim = config.engram_embed_dim
|
||||
self.n_embd = config.n_embd
|
||||
n_gram_orders = self.max_ngram - 1 # e.g. max_ngram=3 → orders 2,3
|
||||
self.d_mem = n_gram_orders * self.n_heads * self.embed_dim
|
||||
|
||||
base_table_size = config.engram_table_size if config.engram_table_size > 0 else padded_vocab_size * 3
|
||||
self.embed_tables = nn.ModuleDict()
|
||||
self.hash_seeds = nn.ParameterDict() # buffers stored as 0-dim params for checkpointing
|
||||
self.table_sizes = {}
|
||||
for n in range(2, self.max_ngram + 1):
|
||||
for k in range(self.n_heads):
|
||||
prime_size = _next_prime(base_table_size + n * k * 997)
|
||||
key = f"{n}_{k}"
|
||||
self.embed_tables[key] = nn.Embedding(prime_size, self.embed_dim)
|
||||
seed = n * 2654435761 + k * 40503 # multiplicative hash constants
|
||||
self.hash_seeds[key] = nn.Parameter(torch.tensor(seed, dtype=torch.long), requires_grad=False)
|
||||
self.table_sizes[key] = prime_size
|
||||
|
||||
self.key_proj = Linear(self.d_mem, config.n_embd, bias=False)
|
||||
self.value_proj = Linear(self.d_mem, config.n_embd, bias=False)
|
||||
dilation = config.engram_ngram_size # paper Section 2.3: dilation = max N-gram order
|
||||
self.short_conv = nn.Conv1d(
|
||||
config.n_embd, config.n_embd,
|
||||
kernel_size=config.engram_kernel_size,
|
||||
padding=dilation * (config.engram_kernel_size - 1),
|
||||
dilation=dilation,
|
||||
groups=config.n_embd, bias=False,
|
||||
)
|
||||
|
||||
def _ngram_hash(self, input_ids, n, seed, table_size):
|
||||
"""Vectorized multiplicative hash over the last n tokens only.
|
||||
|
||||
For order n, the hash at position t is:
|
||||
h(t) = (ids[t] * seed^1 + ids[t-1] * seed^2 + ... + ids[t-n+1] * seed^n) % table_size
|
||||
Positions with fewer than n preceding tokens use a zero-padded window.
|
||||
"""
|
||||
B, T = input_ids.shape
|
||||
ids = input_ids.long()
|
||||
# Build a (B, T, n) window tensor: window[b, t, j] = ids[b, t - (n-1-j)]
|
||||
# Use pad+slice for vectorization (no Python loop)
|
||||
padded = F.pad(ids, (n - 1, 0), value=0) # (B, T + n - 1)
|
||||
# Extract windows of size n ending at each position
|
||||
# window[b, t, :] = padded[b, t : t+n]
|
||||
windows = padded.unfold(dimension=1, size=n, step=1) # (B, T, n)
|
||||
# Compute polynomial hash: h = sum(window[j] * seed^(n-j)) mod table_size
|
||||
# Use modular exponentiation to avoid overflow
|
||||
powers = torch.tensor([pow(seed, n - j, table_size) for j in range(n)], dtype=torch.long, device=ids.device)
|
||||
h = (windows * powers).sum(dim=-1) % table_size # (B, T)
|
||||
return h
|
||||
|
||||
def forward(self, x, input_ids, kv_cache=None):
|
||||
B, T, D = x.shape
|
||||
embeds = []
|
||||
for n in range(2, self.max_ngram + 1):
|
||||
for k in range(self.n_heads):
|
||||
key = f"{n}_{k}"
|
||||
seed = self.hash_seeds[key].item()
|
||||
table_size = self.table_sizes[key]
|
||||
if kv_cache is None:
|
||||
hash_idx = self._ngram_hash(input_ids, n, seed, table_size)
|
||||
elif T == 1 and kv_cache.ngram_history is not None:
|
||||
# Decode: build n-gram window from cached history + current token
|
||||
history = kv_cache.ngram_history # (B, max_ngram-1)
|
||||
current = input_ids[:, 0:1] # (B, 1)
|
||||
# Take last (n-1) tokens from history, append current
|
||||
hist_slice = history[:, -(n - 1):] # (B, n-1)
|
||||
window = torch.cat([hist_slice, current], dim=1) # (B, n)
|
||||
powers = torch.tensor([pow(seed, n - j, table_size) for j in range(n)], dtype=torch.long, device=input_ids.device)
|
||||
hash_idx = (window * powers).sum(dim=-1, keepdim=True) % table_size # (B, 1)
|
||||
else:
|
||||
hash_idx = self._ngram_hash(input_ids, n, seed, table_size)
|
||||
embeds.append(self.embed_tables[key](hash_idx))
|
||||
# Update n-gram history in KV cache after all orders processed
|
||||
if kv_cache is not None:
|
||||
max_hist = self.max_ngram - 1
|
||||
if T == 1:
|
||||
if kv_cache.ngram_history is None:
|
||||
kv_cache.ngram_history = input_ids[:, -1:].expand(B, max_hist).clone()
|
||||
else:
|
||||
# Shift left by 1 and append current token
|
||||
new_hist = torch.cat([kv_cache.ngram_history[:, 1:], input_ids[:, -1:]], dim=1)
|
||||
kv_cache.ngram_history = new_hist
|
||||
else:
|
||||
# Prefill: take last max_hist tokens
|
||||
if kv_cache.ngram_history is None:
|
||||
pad_len = max(0, max_hist - T)
|
||||
hist = F.pad(input_ids[:, -max_hist:], (pad_len, 0), value=0)
|
||||
kv_cache.ngram_history = hist
|
||||
else:
|
||||
combined = torch.cat([kv_cache.ngram_history, input_ids], dim=1)
|
||||
kv_cache.ngram_history = combined[:, -max_hist:]
|
||||
e = torch.cat(embeds, dim=-1) # (B, T, d_mem)
|
||||
|
||||
k = norm(self.key_proj(e))
|
||||
v = self.value_proj(e)
|
||||
gate = torch.sigmoid((norm(x) * k).sum(dim=-1, keepdim=True) / (D ** 0.5))
|
||||
gated_v = gate * v
|
||||
normed_v = norm(gated_v)
|
||||
conv_out = self.short_conv(normed_v.transpose(1, 2))[:, :, :T].transpose(1, 2)
|
||||
return F.silu(conv_out) + gated_v
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
def __init__(self, config, pad_vocab_size_to=64):
|
||||
"""
|
||||
|
|
@ -184,10 +334,28 @@ class GPT(nn.Module):
|
|||
self.smear_lambda = nn.Parameter(torch.zeros(1))
|
||||
# Backout: subtract cached mid-layer residual before final norm to remove low-level features
|
||||
self.backout_lambda = nn.Parameter(0.2 * torch.ones(1))
|
||||
# Engram: conditional memory via N-gram hash lookup
|
||||
# Resolve engram_layer_ids (including auto-default) before creating VE, so VE excludes Engram layers
|
||||
if config.engram_enabled:
|
||||
engram_layer_ids = config.engram_layer_ids if config.engram_layer_ids else (1, max(2, config.n_layer // 2 - 1))
|
||||
self.engram_layer_ids = engram_layer_ids
|
||||
else:
|
||||
engram_layer_ids = ()
|
||||
self.engram_layer_ids = ()
|
||||
# Value embeddings (ResFormer-style): alternating layers, last layer always included
|
||||
# Skip VE for layers that will use Engram instead
|
||||
engram_layers = set(engram_layer_ids)
|
||||
head_dim = config.n_embd // config.n_head
|
||||
kv_dim = config.n_kv_head * head_dim
|
||||
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
|
||||
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer) and i not in engram_layers})
|
||||
# Create Engram modules on the designated layers
|
||||
if config.engram_enabled:
|
||||
for i in engram_layer_ids:
|
||||
self.transformer.h[i] = Block(config, i, engram=EngramModule(config, i, padded_vocab_size))
|
||||
# Disable VE gate in Engram layers (ve is set to None, gate would have no gradient)
|
||||
self.transformer.h[i].attn.ve_gate = None
|
||||
else:
|
||||
self.engram_layer_ids = ()
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
||||
|
|
@ -264,6 +432,20 @@ class GPT(nn.Module):
|
|||
self.transformer.wte.to(dtype=COMPUTE_DTYPE)
|
||||
for ve in self.value_embeds.values():
|
||||
ve.to(dtype=COMPUTE_DTYPE)
|
||||
for block in self.transformer.h:
|
||||
if block.engram is not None:
|
||||
for t in block.engram.embed_tables.values():
|
||||
t.to(dtype=COMPUTE_DTYPE)
|
||||
|
||||
# Engram: init embeddings (uniform), projections (uniform), conv (zeros for identity)
|
||||
for block in self.transformer.h:
|
||||
if block.engram is not None:
|
||||
em = block.engram
|
||||
for t in em.embed_tables.values():
|
||||
torch.nn.init.uniform_(t.weight, -s, s)
|
||||
torch.nn.init.uniform_(em.key_proj.weight, -s, s)
|
||||
torch.nn.init.uniform_(em.value_proj.weight, -s, s)
|
||||
torch.nn.init.zeros_(em.short_conv.weight)
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
|
|
@ -329,7 +511,12 @@ class GPT(nn.Module):
|
|||
nparams = sum(p.numel() for p in self.parameters())
|
||||
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
engram_embed_numel = 0
|
||||
for block in self.transformer.h:
|
||||
if block.engram is not None:
|
||||
for t in block.engram.embed_tables.values():
|
||||
engram_embed_numel += t.weight.numel()
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + engram_embed_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel() +
|
||||
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel())
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
|
|
@ -362,12 +549,18 @@ class GPT(nn.Module):
|
|||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()
|
||||
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
||||
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
||||
# Break down engram params for reporting
|
||||
engram_params = 0
|
||||
for block in self.transformer.h:
|
||||
if block.engram is not None:
|
||||
engram_params += sum(p.numel() for p in block.engram.parameters())
|
||||
return {
|
||||
'wte': wte,
|
||||
'value_embeds': value_embeds,
|
||||
'lm_head': lm_head,
|
||||
'transformer_matrices': transformer_matrices,
|
||||
'scalars': scalars,
|
||||
'engram': engram_params,
|
||||
'total': total,
|
||||
}
|
||||
|
||||
|
|
@ -383,6 +576,17 @@ class GPT(nn.Module):
|
|||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
|
||||
# Engram: extract embed weights (-> AdamW) and matrix weights (already in matrix_params)
|
||||
engram_embed_params = []
|
||||
engram_matrix_params_set = set()
|
||||
for block in self.transformer.h:
|
||||
if block.engram is not None:
|
||||
em = block.engram
|
||||
for t in em.embed_tables.values():
|
||||
engram_embed_params.append(t.weight)
|
||||
engram_matrix_params_set.add(id(em.key_proj.weight))
|
||||
engram_matrix_params_set.add(id(em.value_proj.weight))
|
||||
engram_matrix_params_set.add(id(em.short_conv.weight))
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
|
||||
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
||||
|
|
@ -399,9 +603,19 @@ class GPT(nn.Module):
|
|||
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
||||
dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
|
||||
]
|
||||
# Muon groups (matrix params, grouped by shape for stacking)
|
||||
for shape in sorted({p.shape for p in matrix_params}):
|
||||
group_params = [p for p in matrix_params if p.shape == shape]
|
||||
# Engram embedding tables: AdamW with 5x lr and no weight decay (per paper)
|
||||
# Remove engram embeds and non-2D params from matrix_params so they don't end up in Muon groups
|
||||
engram_embed_ids = {id(p) for p in engram_embed_params}
|
||||
muon_params = [p for p in matrix_params if id(p) not in engram_embed_ids and p.dim() == 2 and p.requires_grad]
|
||||
if engram_embed_params:
|
||||
param_groups.append(dict(kind='adamw', params=engram_embed_params, lr=embedding_lr * dmodel_lr_scale * 5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.0))
|
||||
# Non-2D engram params (conv weights, hash seeds) go to AdamW
|
||||
non_matrix_engram = [p for p in matrix_params if id(p) not in engram_embed_ids and id(p) not in {id(mp) for mp in muon_params}]
|
||||
if non_matrix_engram:
|
||||
param_groups.append(dict(kind='adamw', params=non_matrix_engram, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.0))
|
||||
# Muon groups (2D matrix params from blocks, grouped by shape for stacking)
|
||||
for shape in sorted({p.shape for p in muon_params}):
|
||||
group_params = [p for p in muon_params if p.shape == shape]
|
||||
param_groups.append(dict(
|
||||
kind='muon', params=group_params, lr=matrix_lr,
|
||||
momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
|
||||
|
|
@ -455,8 +669,8 @@ class GPT(nn.Module):
|
|||
x_backout = None
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
ve = self.value_embeds[str(i)](idx).to(x.dtype) if (str(i) in self.value_embeds and i not in self.engram_layer_ids) else None
|
||||
x = block(x, idx, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
if i == backout_layer:
|
||||
x_backout = x
|
||||
# Subtract mid-layer residual to remove low-level features before logit projection
|
||||
|
|
|
|||
|
|
@ -52,6 +52,14 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de
|
|||
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
||||
# Engram: conditional memory via N-gram hash lookup
|
||||
parser.add_argument("--engram", action="store_true", help="enable Engram conditional memory")
|
||||
parser.add_argument("--engram-ngram-size", type=int, default=3, help="max N-gram order (2+3=bigram+trigram)")
|
||||
parser.add_argument("--engram-n-heads", type=int, default=4, help="hash heads per N-gram order")
|
||||
parser.add_argument("--engram-embed-dim", type=int, default=128, help="embedding dim per hash head")
|
||||
parser.add_argument("--engram-table-size", type=int, default=0, help="hash table size (0=auto)")
|
||||
parser.add_argument("--engram-layer-ids", type=str, default="", help="comma-separated layer IDs for Engram (empty=auto)")
|
||||
parser.add_argument("--engram-kernel-size", type=int, default=4, help="depthwise causal conv kernel size")
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
|
|
@ -133,10 +141,15 @@ def build_model_meta(depth):
|
|||
base_dim = depth * args.aspect_ratio
|
||||
model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
|
||||
num_heads = model_dim // args.head_dim
|
||||
engram_layer_ids = tuple(int(x) for x in args.engram_layer_ids.split(",")) if args.engram_layer_ids else ()
|
||||
config = GPTConfig(
|
||||
sequence_len=args.max_seq_len, vocab_size=vocab_size,
|
||||
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
|
||||
window_pattern=args.window_pattern,
|
||||
engram_enabled=args.engram, engram_ngram_size=args.engram_ngram_size,
|
||||
engram_n_heads=args.engram_n_heads, engram_embed_dim=args.engram_embed_dim,
|
||||
engram_table_size=args.engram_table_size, engram_layer_ids=engram_layer_ids,
|
||||
engram_kernel_size=args.engram_kernel_size,
|
||||
)
|
||||
with torch.device("meta"):
|
||||
model_meta = GPT(config)
|
||||
|
|
|
|||
336
tests/test_engram.py
Normal file
336
tests/test_engram.py
Normal file
|
|
@ -0,0 +1,336 @@
|
|||
"""
|
||||
Regression tests for Engram integration.
|
||||
|
||||
These tests guard against bugs found during implementation:
|
||||
|
||||
1. Stale Block.forward left inside EngramModule scope overrode the correct method
|
||||
2. Parameter double-counting when Engram modules stored in both engam_modules and blocks
|
||||
3. Engram embed params appearing in both Muon and AdamW optimizer groups
|
||||
4. Non-2D params (hash_seeds, conv weights) crashing Muon's shape-based stacking
|
||||
5. ve_gate on Engram layers getting no gradient (ve=None → gate unused → no grad)
|
||||
6. Block.forward signature must accept input_ids for Engram layers
|
||||
7. KVCache must support ngram_running state for decode
|
||||
|
||||
Run: python -m pytest tests/test_engram.py -v
|
||||
"""
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
from nanochat.gpt import GPT, GPTConfig, Block, EngramModule, _next_prime
|
||||
|
||||
|
||||
def _make_model(engram_layer_ids=(1, 2), n_layer=4):
|
||||
return GPTConfig(
|
||||
sequence_len=128, vocab_size=32768,
|
||||
n_layer=n_layer, n_head=4, n_kv_head=4, n_embd=256,
|
||||
engram_enabled=True, engram_ngram_size=3,
|
||||
engram_n_heads=2, engram_embed_dim=64,
|
||||
engram_layer_ids=engram_layer_ids,
|
||||
)
|
||||
|
||||
|
||||
class TestEngramForwardSignature:
|
||||
"""Block.forward must accept input_ids when Engram is present."""
|
||||
|
||||
def test_block_forward_accepts_input_ids(self):
|
||||
config = _make_model()
|
||||
model = GPT(config)
|
||||
model.to(device="cpu")
|
||||
model.init_weights()
|
||||
|
||||
idx = torch.randint(0, config.vocab_size, (1, 16))
|
||||
# Should not raise TypeError about missing arguments
|
||||
loss = model(idx, targets=idx)
|
||||
assert loss.item() > 0
|
||||
|
||||
def test_block_without_engram_ignores_input_ids(self):
|
||||
"""Non-Engram blocks still work (engram=None path)."""
|
||||
config = _make_model(engram_layer_ids=(1,), n_layer=4)
|
||||
model = GPT(config)
|
||||
model.to(device="cpu")
|
||||
model.init_weights()
|
||||
block0 = model.transformer.h[0]
|
||||
assert block0.engram is None
|
||||
|
||||
idx = torch.randint(0, config.vocab_size, (1, 16))
|
||||
loss = model(idx, targets=idx)
|
||||
assert loss.item() > 0
|
||||
|
||||
|
||||
class TestEngramNoStaleForward:
|
||||
"""EngramModule.forward must have (x, input_ids, kv_cache) signature,
|
||||
NOT the old Block.forward signature (x, ve, cos_sin, window_size, kv_cache).
|
||||
This catches the bug where a leftover Block.forward was nested inside
|
||||
EngramModule's scope due to indentation."""
|
||||
|
||||
def test_engram_forward_signature(self):
|
||||
import inspect
|
||||
sig = inspect.signature(EngramModule.forward)
|
||||
param_names = list(sig.parameters.keys())
|
||||
assert "input_ids" in param_names, f"EngramModule.forward missing input_ids, got {param_names}"
|
||||
assert "ve" not in param_names, f"EngramModule.forward should not have 've' (stale Block.forward), got {param_names}"
|
||||
|
||||
def test_engram_forward_runs(self):
|
||||
config = _make_model()
|
||||
em = EngramModule(config, layer_idx=1, padded_vocab_size=32768)
|
||||
x = torch.randn(1, 16, config.n_embd)
|
||||
ids = torch.randint(0, 32768, (1, 16))
|
||||
out = em(x, ids)
|
||||
assert out.shape == x.shape
|
||||
|
||||
|
||||
class TestParameterCounting:
|
||||
"""num_scaling_params total must match actual parameter count (no double-counting)."""
|
||||
|
||||
def test_param_count_matches(self):
|
||||
config = _make_model()
|
||||
model = GPT(config)
|
||||
model.to(device="cpu")
|
||||
model.init_weights()
|
||||
params = model.num_scaling_params()
|
||||
actual = sum(p.numel() for p in model.parameters())
|
||||
assert params["total"] == actual, f"Mismatch: {params['total']} != {actual}"
|
||||
|
||||
def test_engram_params_nonzero(self):
|
||||
config = _make_model()
|
||||
model = GPT(config)
|
||||
model.to(device="cpu")
|
||||
model.init_weights()
|
||||
params = model.num_scaling_params()
|
||||
assert params["engram"] > 0
|
||||
|
||||
def test_no_engram_params_when_disabled(self):
|
||||
config = GPTConfig(n_layer=4, n_head=4, n_kv_head=4, n_embd=256, engram_enabled=False)
|
||||
model = GPT(config)
|
||||
model.to(device="cpu")
|
||||
model.init_weights()
|
||||
params = model.num_scaling_params()
|
||||
assert params["engram"] == 0
|
||||
|
||||
|
||||
class TestOptimizerNoDuplicates:
|
||||
"""Each parameter must appear in exactly one optimizer group."""
|
||||
|
||||
def test_no_duplicate_param_groups(self):
|
||||
config = _make_model()
|
||||
model = GPT(config)
|
||||
model.to(device="cpu")
|
||||
model.init_weights()
|
||||
opt = model.setup_optimizer()
|
||||
|
||||
seen = {}
|
||||
for group in opt.param_groups:
|
||||
for p in group["params"]:
|
||||
pid = id(p)
|
||||
assert pid not in seen, (
|
||||
f"Param shape {p.shape} in multiple groups: "
|
||||
f"kind={group['kind']} and kind={seen[pid]}"
|
||||
)
|
||||
seen[pid] = group["kind"]
|
||||
|
||||
def test_engram_embeds_not_in_muon(self):
|
||||
"""Engram embedding table weights must be in AdamW, not Muon."""
|
||||
config = _make_model()
|
||||
model = GPT(config)
|
||||
model.to(device="cpu")
|
||||
model.init_weights()
|
||||
opt = model.setup_optimizer()
|
||||
|
||||
engram_embed_ids = set()
|
||||
for block in model.transformer.h:
|
||||
if block.engram is not None:
|
||||
for t in block.engram.embed_tables.values():
|
||||
engram_embed_ids.add(id(t.weight))
|
||||
|
||||
muon_params = set()
|
||||
for group in opt.param_groups:
|
||||
if group.get("kind") == "muon":
|
||||
for p in group["params"]:
|
||||
muon_params.add(id(p))
|
||||
|
||||
overlap = engram_embed_ids & muon_params
|
||||
assert not overlap, f"Engram embed params found in Muon groups: {len(overlap)}"
|
||||
|
||||
|
||||
class TestMuonOnlyMatrixParams:
|
||||
"""Muon optimizer groups must only contain 2D trainable parameters."""
|
||||
|
||||
def test_muon_groups_only_2d_trainable(self):
|
||||
config = _make_model()
|
||||
model = GPT(config)
|
||||
model.to(device="cpu")
|
||||
model.init_weights()
|
||||
opt = model.setup_optimizer()
|
||||
|
||||
for group in opt.param_groups:
|
||||
if group.get("kind") != "muon":
|
||||
continue
|
||||
for p in group["params"]:
|
||||
assert p.dim() == 2, f"Muon got {p.dim()}D param (shape={p.shape})"
|
||||
assert p.requires_grad, f"Muon got non-trainable param"
|
||||
|
||||
|
||||
class TestVEGateNoDeadGrad:
|
||||
"""On Engram layers, ve_gate must be None so it never has a dead gradient."""
|
||||
|
||||
def test_engram_layers_have_no_ve_gate(self):
|
||||
config = _make_model(engram_layer_ids=(1, 2), n_layer=4)
|
||||
model = GPT(config)
|
||||
model.to(device="cpu")
|
||||
model.init_weights()
|
||||
|
||||
for i in model.engram_layer_ids:
|
||||
block = model.transformer.h[i]
|
||||
assert block.attn.ve_gate is None, f"Layer {i} has ve_gate but is an Engram layer"
|
||||
|
||||
def test_all_trainable_params_get_gradients(self):
|
||||
"""Every requires_grad=True param must receive a gradient after backward."""
|
||||
config = _make_model()
|
||||
model = GPT(config)
|
||||
model.to(device="cpu")
|
||||
model.init_weights()
|
||||
|
||||
idx = torch.randint(0, config.vocab_size, (1, 16))
|
||||
loss = model(idx, targets=idx)
|
||||
loss.backward()
|
||||
|
||||
dead_grads = []
|
||||
for name, p in model.named_parameters():
|
||||
if p.requires_grad and p.grad is None:
|
||||
dead_grads.append(name)
|
||||
assert not dead_grads, f"Params with no gradient: {dead_grads}"
|
||||
|
||||
|
||||
class TestAutoLayerSelection:
|
||||
"""Default engram_layer_ids should be (1, n_layer//2)."""
|
||||
|
||||
@pytest.mark.parametrize("n_layer,expected", [
|
||||
(6, (1, 2)),
|
||||
(12, (1, 5)),
|
||||
(20, (1, 9)),
|
||||
(30, (1, 14)),
|
||||
])
|
||||
def test_auto_layer_ids(self, n_layer, expected):
|
||||
config = GPTConfig(n_layer=n_layer, n_head=4, n_kv_head=4, n_embd=256, engram_enabled=True)
|
||||
model = GPT(config)
|
||||
assert model.engram_layer_ids == expected
|
||||
|
||||
|
||||
class TestKVCacheNgramState:
|
||||
"""KVCache must support ngram_history for Engram decode."""
|
||||
|
||||
def test_ngram_history_initially_none(self):
|
||||
from nanochat.engine import KVCache
|
||||
cache = KVCache(batch_size=1, num_heads=4, seq_len=128, head_dim=64, num_layers=4, device="cpu", dtype=torch.float32)
|
||||
assert cache.ngram_history is None
|
||||
|
||||
def test_reset_clears_ngram_history(self):
|
||||
from nanochat.engine import KVCache
|
||||
cache = KVCache(batch_size=1, num_heads=4, seq_len=128, head_dim=64, num_layers=4, device="cpu", dtype=torch.float32)
|
||||
cache.ngram_history = torch.zeros(1, 2, dtype=torch.long)
|
||||
cache.reset()
|
||||
assert cache.ngram_history is None
|
||||
|
||||
def test_prefill_copies_ngram_history(self):
|
||||
from nanochat.engine import KVCache
|
||||
src = KVCache(batch_size=1, num_heads=4, seq_len=128, head_dim=64, num_layers=4, device="cpu", dtype=torch.float32)
|
||||
src.ngram_history = torch.tensor([[1, 2]], dtype=torch.long)
|
||||
dst = KVCache(batch_size=4, num_heads=4, seq_len=128, head_dim=64, num_layers=4, device="cpu", dtype=torch.float32)
|
||||
dst.prefill(src)
|
||||
assert dst.ngram_history is not None
|
||||
assert dst.ngram_history.shape == (4, 2)
|
||||
assert (dst.ngram_history == torch.tensor([[1, 2]])).all()
|
||||
|
||||
|
||||
class TestNextPrime:
|
||||
"""_next_prime helper must return primes."""
|
||||
|
||||
@pytest.mark.parametrize("n,expected", [
|
||||
(1, 2),
|
||||
(2, 2),
|
||||
(3, 3),
|
||||
(4, 5),
|
||||
(10, 11),
|
||||
(100, 101),
|
||||
])
|
||||
def test_next_prime(self, n, expected):
|
||||
assert _next_prime(n) == expected
|
||||
|
||||
def test_next_prime_returns_prime(self):
|
||||
result = _next_prime(99999)
|
||||
# Verify it's actually prime
|
||||
assert all(result % d for d in range(2, int(result**0.5) + 1))
|
||||
|
||||
|
||||
class TestEngramInitWeightsIdentity:
|
||||
"""Short conv must be zero-initialized so Engram starts as identity."""
|
||||
|
||||
def test_conv_weights_zero_after_init(self):
|
||||
config = _make_model()
|
||||
model = GPT(config)
|
||||
model.to(device="cpu")
|
||||
model.init_weights()
|
||||
|
||||
for i in model.engram_layer_ids:
|
||||
conv = model.transformer.h[i].engram.short_conv
|
||||
assert torch.all(conv.weight == 0), "Engram conv weights should be zero-initialized"
|
||||
|
||||
|
||||
class TestEngramDecodeIntegration:
|
||||
"""Full inference with KV cache + Engram: prefill then decode must produce consistent results."""
|
||||
|
||||
def _make_model(self):
|
||||
config = GPTConfig(
|
||||
sequence_len=64, vocab_size=32768,
|
||||
n_layer=6, n_head=4, n_kv_head=4, n_embd=256,
|
||||
engram_enabled=True, engram_ngram_size=3,
|
||||
engram_n_heads=2, engram_embed_dim=64,
|
||||
engram_layer_ids=(1, 2),
|
||||
)
|
||||
model = GPT(config)
|
||||
model.to(device="cpu")
|
||||
model.init_weights()
|
||||
return model
|
||||
|
||||
def test_naive_vs_kv_cache_single_token(self):
|
||||
"""Single-token decode with KV cache should match naive recomputation."""
|
||||
from nanochat.engine import KVCache
|
||||
model = self._make_model()
|
||||
config = model.config
|
||||
prompt = torch.randint(0, config.vocab_size, (1, 16))
|
||||
|
||||
# Naive: full sequence forward, take last token logits
|
||||
with torch.no_grad():
|
||||
logits_naive = model(prompt)[:, -1, :]
|
||||
|
||||
# KV cache: prefill then decode one token
|
||||
m = config
|
||||
kv_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
|
||||
kv = KVCache(batch_size=1, seq_len=64, device=torch.device("cpu"), dtype=torch.float32, **kv_kwargs)
|
||||
with torch.no_grad():
|
||||
logits_prefill = model(prompt, kv_cache=kv)[:, -1, :]
|
||||
|
||||
# Prefill should match naive
|
||||
assert torch.allclose(logits_naive, logits_prefill, atol=1e-4), "Prefill logits should match naive"
|
||||
|
||||
def test_prefill_copies_ngram_history_to_batch_cache(self):
|
||||
"""After prefill, ngram_history should be available and copyable for batch decode."""
|
||||
from nanochat.engine import KVCache
|
||||
model = self._make_model()
|
||||
config = model.config
|
||||
prompt = torch.randint(0, config.vocab_size, (1, 8))
|
||||
|
||||
m = config
|
||||
kv_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
|
||||
kv_prefill = KVCache(batch_size=1, seq_len=64, device=torch.device("cpu"), dtype=torch.float32, **kv_kwargs)
|
||||
with torch.no_grad():
|
||||
model(prompt, kv_cache=kv_prefill)
|
||||
|
||||
assert kv_prefill.ngram_history is not None, "Prefill should set ngram_history"
|
||||
max_hist = config.engram_ngram_size - 1
|
||||
assert kv_prefill.ngram_history.shape == (1, max_hist)
|
||||
|
||||
# Copy to batch cache
|
||||
kv_batch = KVCache(batch_size=3, seq_len=64, device=torch.device("cpu"), dtype=torch.float32, **kv_kwargs)
|
||||
kv_batch.prefill(kv_prefill)
|
||||
assert kv_batch.ngram_history.shape == (3, max_hist)
|
||||
Loading…
Reference in New Issue
Block a user