mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-25 17:18:01 +00:00
Fix torch.compile graph breaks with Engram module
Refactor Engram data structures for torch.compile compatibility: - Replace nn.ModuleDict (string-keyed) with nested nn.ModuleList (integer-indexed) - Precompute hash powers in __init__ and re-materialize in init_weights to eliminate torch.tensor() construction during forward - Remove hash_seeds and table_sizes dicts (info embedded in _powers / embed_tables) - Add _patch_removed_state_keys for old engram checkpoint backward compat Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
dcd9b0668b
commit
4011f5e3d5
|
|
@ -39,6 +39,32 @@ def _patch_missing_keys(model_data, model_config):
|
|||
model_data["x0_lambdas"] = torch.zeros(n_layer)
|
||||
log0(f"Patching missing x0_lambdas in model data to 0.0")
|
||||
|
||||
def _patch_removed_state_keys(model_data):
|
||||
"""Remove/remap keys that no longer exist in the model (backward compat)."""
|
||||
import re
|
||||
keys_to_remove = []
|
||||
keys_to_remap = {}
|
||||
for key in list(model_data.keys()):
|
||||
# Strip old engram.hash_seeds keys
|
||||
if key.startswith("transformer.h.") and "engram.hash_seeds" in key:
|
||||
keys_to_remove.append(key)
|
||||
continue
|
||||
# Remap old engram.embed_tables.<n>_<k>.* -> embed_tables.<n_idx>.<k>.*
|
||||
m = re.match(r"(transformer\.h\.\d+\.engram\.embed_tables)\.(\d+)_(\d+)\.(.*)", key)
|
||||
if m:
|
||||
prefix, n_str, k_str, rest = m.groups()
|
||||
n_idx = int(n_str) - 2 # old n=2,3 -> new n_idx=0,1
|
||||
new_key = f"{prefix}.{n_idx}.{k_str}.{rest}"
|
||||
keys_to_remap[key] = new_key
|
||||
for old_key in keys_to_remove:
|
||||
del model_data[old_key]
|
||||
for old_key, new_key in keys_to_remap.items():
|
||||
model_data[new_key] = model_data.pop(old_key)
|
||||
if keys_to_remove:
|
||||
log0(f"Removed {len(keys_to_remove)} stale engram.hash_seeds keys from checkpoint")
|
||||
if keys_to_remap:
|
||||
log0(f"Remapped {len(keys_to_remap)} engram.embed_tables keys from old string-keyed format")
|
||||
|
||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
||||
if rank == 0:
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
|
@ -97,6 +123,7 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
log0(f"Building model with config: {model_config_kwargs}")
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
_patch_missing_keys(model_data, model_config)
|
||||
_patch_removed_state_keys(model_data)
|
||||
with torch.device("meta"):
|
||||
model = GPT(model_config)
|
||||
# Load the model state
|
||||
|
|
|
|||
|
|
@ -205,17 +205,20 @@ class EngramModule(nn.Module):
|
|||
self.d_mem = n_gram_orders * self.n_heads * self.embed_dim
|
||||
|
||||
base_table_size = config.engram_table_size if config.engram_table_size > 0 else padded_vocab_size * 3
|
||||
self.embed_tables = nn.ModuleDict()
|
||||
self.hash_seeds = nn.ParameterDict() # buffers stored as 0-dim params for checkpointing
|
||||
self.table_sizes = {}
|
||||
# Nested nn.ModuleList[nn.ModuleList[Embedding]]: index as [n_idx][k]
|
||||
self.embed_tables = nn.ModuleList()
|
||||
self._powers = [] # list of LongTensor: precomputed hash powers per (n, k); may be meta
|
||||
self._power_configs = [] # list of (prime_size, seed, n) for reinit on real device
|
||||
for n in range(2, self.max_ngram + 1):
|
||||
tables_for_n = nn.ModuleList()
|
||||
for k in range(self.n_heads):
|
||||
prime_size = _next_prime(base_table_size + n * k * 997)
|
||||
key = f"{n}_{k}"
|
||||
self.embed_tables[key] = nn.Embedding(prime_size, self.embed_dim)
|
||||
tables_for_n.append(nn.Embedding(prime_size, self.embed_dim))
|
||||
seed = n * 2654435761 + k * 40503 # multiplicative hash constants
|
||||
self.hash_seeds[key] = nn.Parameter(torch.tensor(seed, dtype=torch.long), requires_grad=False)
|
||||
self.table_sizes[key] = prime_size
|
||||
powers = torch.tensor([pow(seed, n - j, prime_size) for j in range(n)], dtype=torch.long)
|
||||
self._powers.append(powers)
|
||||
self._power_configs.append((prime_size, seed, n))
|
||||
self.embed_tables.append(tables_for_n)
|
||||
|
||||
self.key_proj = Linear(self.d_mem, config.n_embd, bias=False)
|
||||
self.value_proj = Linear(self.d_mem, config.n_embd, bias=False)
|
||||
|
|
@ -228,37 +231,33 @@ class EngramModule(nn.Module):
|
|||
groups=config.n_embd, bias=False,
|
||||
)
|
||||
|
||||
def _ngram_hash(self, input_ids, n, seed, table_size):
|
||||
def _ngram_hash(self, input_ids, n, powers, table_size):
|
||||
"""Vectorized multiplicative hash over the last n tokens only.
|
||||
|
||||
For order n, the hash at position t is:
|
||||
h(t) = (ids[t] * seed^1 + ids[t-1] * seed^2 + ... + ids[t-n+1] * seed^n) % table_size
|
||||
Positions with fewer than n preceding tokens use a zero-padded window.
|
||||
|
||||
powers: LongTensor of shape (n,) — precomputed powers on the correct device.
|
||||
"""
|
||||
B, T = input_ids.shape
|
||||
ids = input_ids.long()
|
||||
# Build a (B, T, n) window tensor: window[b, t, j] = ids[b, t - (n-1-j)]
|
||||
# Use pad+slice for vectorization (no Python loop)
|
||||
padded = F.pad(ids, (n - 1, 0), value=0) # (B, T + n - 1)
|
||||
# Extract windows of size n ending at each position
|
||||
# window[b, t, :] = padded[b, t : t+n]
|
||||
padded = F.pad(ids, (n - 1, 0), value=0)
|
||||
windows = padded.unfold(dimension=1, size=n, step=1) # (B, T, n)
|
||||
# Compute polynomial hash: h = sum(window[j] * seed^(n-j)) mod table_size
|
||||
# Use modular exponentiation to avoid overflow
|
||||
powers = torch.tensor([pow(seed, n - j, table_size) for j in range(n)], dtype=torch.long, device=ids.device)
|
||||
h = (windows * powers).sum(dim=-1) % table_size # (B, T)
|
||||
h = (windows * powers).sum(dim=-1) % table_size
|
||||
return h
|
||||
|
||||
def forward(self, x, input_ids, kv_cache=None):
|
||||
B, T, D = x.shape
|
||||
embeds = []
|
||||
for n in range(2, self.max_ngram + 1):
|
||||
idx = 0
|
||||
for n_idx, n in enumerate(range(2, self.max_ngram + 1)):
|
||||
for k in range(self.n_heads):
|
||||
key = f"{n}_{k}"
|
||||
seed = self.hash_seeds[key].item()
|
||||
table_size = self.table_sizes[key]
|
||||
table = self.embed_tables[n_idx][k]
|
||||
table_size = table.num_embeddings
|
||||
powers = self._powers[idx].to(input_ids.device)
|
||||
if kv_cache is None:
|
||||
hash_idx = self._ngram_hash(input_ids, n, seed, table_size)
|
||||
hash_idx = self._ngram_hash(input_ids, n, powers, table_size)
|
||||
elif T == 1 and kv_cache.ngram_history is not None:
|
||||
# Decode: build n-gram window from cached history + current token
|
||||
history = kv_cache.ngram_history # (B, max_ngram-1)
|
||||
|
|
@ -266,11 +265,11 @@ class EngramModule(nn.Module):
|
|||
# Take last (n-1) tokens from history, append current
|
||||
hist_slice = history[:, -(n - 1):] # (B, n-1)
|
||||
window = torch.cat([hist_slice, current], dim=1) # (B, n)
|
||||
powers = torch.tensor([pow(seed, n - j, table_size) for j in range(n)], dtype=torch.long, device=input_ids.device)
|
||||
hash_idx = (window * powers).sum(dim=-1, keepdim=True) % table_size # (B, 1)
|
||||
hash_idx = (window * powers.unsqueeze(0)).sum(dim=-1, keepdim=True) % table_size # (B, 1)
|
||||
else:
|
||||
hash_idx = self._ngram_hash(input_ids, n, seed, table_size)
|
||||
embeds.append(self.embed_tables[key](hash_idx))
|
||||
hash_idx = self._ngram_hash(input_ids, n, powers, table_size)
|
||||
embeds.append(table(hash_idx))
|
||||
idx += 1
|
||||
# Update n-gram history in KV cache after all orders processed
|
||||
if kv_cache is not None:
|
||||
max_hist = self.max_ngram - 1
|
||||
|
|
@ -434,18 +433,26 @@ class GPT(nn.Module):
|
|||
ve.to(dtype=COMPUTE_DTYPE)
|
||||
for block in self.transformer.h:
|
||||
if block.engram is not None:
|
||||
for t in block.engram.embed_tables.values():
|
||||
t.to(dtype=COMPUTE_DTYPE)
|
||||
for tables_n in block.engram.embed_tables:
|
||||
for t in tables_n:
|
||||
t.to(dtype=COMPUTE_DTYPE)
|
||||
block.engram.short_conv.to(dtype=COMPUTE_DTYPE)
|
||||
|
||||
# Engram: init embeddings (uniform), projections (uniform), conv (zeros for identity)
|
||||
for block in self.transformer.h:
|
||||
if block.engram is not None:
|
||||
em = block.engram
|
||||
for t in em.embed_tables.values():
|
||||
torch.nn.init.uniform_(t.weight, -s, s)
|
||||
for tables_n in em.embed_tables:
|
||||
for t in tables_n:
|
||||
torch.nn.init.uniform_(t.weight, -s, s)
|
||||
torch.nn.init.uniform_(em.key_proj.weight, -s, s)
|
||||
torch.nn.init.uniform_(em.value_proj.weight, -s, s)
|
||||
torch.nn.init.zeros_(em.short_conv.weight)
|
||||
# Recompute _powers on the correct device (needed when __init__ ran on meta device)
|
||||
em._powers = [
|
||||
torch.tensor([pow(seed, n - j, prime_size) for j in range(n)], dtype=torch.long, device=em.embed_tables[0][0].weight.device)
|
||||
for prime_size, seed, n in em._power_configs
|
||||
]
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
|
|
@ -514,8 +521,9 @@ class GPT(nn.Module):
|
|||
engram_embed_numel = 0
|
||||
for block in self.transformer.h:
|
||||
if block.engram is not None:
|
||||
for t in block.engram.embed_tables.values():
|
||||
engram_embed_numel += t.weight.numel()
|
||||
for tables_n in block.engram.embed_tables:
|
||||
for t in tables_n:
|
||||
engram_embed_numel += t.weight.numel()
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + engram_embed_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel() +
|
||||
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel())
|
||||
|
|
@ -582,8 +590,9 @@ class GPT(nn.Module):
|
|||
for block in self.transformer.h:
|
||||
if block.engram is not None:
|
||||
em = block.engram
|
||||
for t in em.embed_tables.values():
|
||||
engram_embed_params.append(t.weight)
|
||||
for tables_n in em.embed_tables:
|
||||
for t in tables_n:
|
||||
engram_embed_params.append(t.weight)
|
||||
engram_matrix_params_set.add(id(em.key_proj.weight))
|
||||
engram_matrix_params_set.add(id(em.value_proj.weight))
|
||||
engram_matrix_params_set.add(id(em.short_conv.weight))
|
||||
|
|
|
|||
|
|
@ -139,8 +139,9 @@ class TestOptimizerNoDuplicates:
|
|||
engram_embed_ids = set()
|
||||
for block in model.transformer.h:
|
||||
if block.engram is not None:
|
||||
for t in block.engram.embed_tables.values():
|
||||
engram_embed_ids.add(id(t.weight))
|
||||
for tables_n in block.engram.embed_tables:
|
||||
for t in tables_n:
|
||||
engram_embed_ids.add(id(t.weight))
|
||||
|
||||
muon_params = set()
|
||||
for group in opt.param_groups:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user