Fix torch.compile graph breaks with Engram module

Refactor Engram data structures for torch.compile compatibility:
- Replace nn.ModuleDict (string-keyed) with nested nn.ModuleList (integer-indexed)
- Precompute hash powers in __init__ and re-materialize in init_weights to
  eliminate torch.tensor() construction during forward
- Remove hash_seeds and table_sizes dicts (info embedded in _powers / embed_tables)
- Add _patch_removed_state_keys for old engram checkpoint backward compat

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
junran 2026-04-28 08:18:50 +08:00
parent dcd9b0668b
commit 4011f5e3d5
3 changed files with 73 additions and 36 deletions

View File

@ -39,6 +39,32 @@ def _patch_missing_keys(model_data, model_config):
model_data["x0_lambdas"] = torch.zeros(n_layer)
log0(f"Patching missing x0_lambdas in model data to 0.0")
def _patch_removed_state_keys(model_data):
"""Remove/remap keys that no longer exist in the model (backward compat)."""
import re
keys_to_remove = []
keys_to_remap = {}
for key in list(model_data.keys()):
# Strip old engram.hash_seeds keys
if key.startswith("transformer.h.") and "engram.hash_seeds" in key:
keys_to_remove.append(key)
continue
# Remap old engram.embed_tables.<n>_<k>.* -> embed_tables.<n_idx>.<k>.*
m = re.match(r"(transformer\.h\.\d+\.engram\.embed_tables)\.(\d+)_(\d+)\.(.*)", key)
if m:
prefix, n_str, k_str, rest = m.groups()
n_idx = int(n_str) - 2 # old n=2,3 -> new n_idx=0,1
new_key = f"{prefix}.{n_idx}.{k_str}.{rest}"
keys_to_remap[key] = new_key
for old_key in keys_to_remove:
del model_data[old_key]
for old_key, new_key in keys_to_remap.items():
model_data[new_key] = model_data.pop(old_key)
if keys_to_remove:
log0(f"Removed {len(keys_to_remove)} stale engram.hash_seeds keys from checkpoint")
if keys_to_remap:
log0(f"Remapped {len(keys_to_remap)} engram.embed_tables keys from old string-keyed format")
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
if rank == 0:
os.makedirs(checkpoint_dir, exist_ok=True)
@ -97,6 +123,7 @@ def build_model(checkpoint_dir, step, device, phase):
log0(f"Building model with config: {model_config_kwargs}")
model_config = GPTConfig(**model_config_kwargs)
_patch_missing_keys(model_data, model_config)
_patch_removed_state_keys(model_data)
with torch.device("meta"):
model = GPT(model_config)
# Load the model state

View File

@ -205,17 +205,20 @@ class EngramModule(nn.Module):
self.d_mem = n_gram_orders * self.n_heads * self.embed_dim
base_table_size = config.engram_table_size if config.engram_table_size > 0 else padded_vocab_size * 3
self.embed_tables = nn.ModuleDict()
self.hash_seeds = nn.ParameterDict() # buffers stored as 0-dim params for checkpointing
self.table_sizes = {}
# Nested nn.ModuleList[nn.ModuleList[Embedding]]: index as [n_idx][k]
self.embed_tables = nn.ModuleList()
self._powers = [] # list of LongTensor: precomputed hash powers per (n, k); may be meta
self._power_configs = [] # list of (prime_size, seed, n) for reinit on real device
for n in range(2, self.max_ngram + 1):
tables_for_n = nn.ModuleList()
for k in range(self.n_heads):
prime_size = _next_prime(base_table_size + n * k * 997)
key = f"{n}_{k}"
self.embed_tables[key] = nn.Embedding(prime_size, self.embed_dim)
tables_for_n.append(nn.Embedding(prime_size, self.embed_dim))
seed = n * 2654435761 + k * 40503 # multiplicative hash constants
self.hash_seeds[key] = nn.Parameter(torch.tensor(seed, dtype=torch.long), requires_grad=False)
self.table_sizes[key] = prime_size
powers = torch.tensor([pow(seed, n - j, prime_size) for j in range(n)], dtype=torch.long)
self._powers.append(powers)
self._power_configs.append((prime_size, seed, n))
self.embed_tables.append(tables_for_n)
self.key_proj = Linear(self.d_mem, config.n_embd, bias=False)
self.value_proj = Linear(self.d_mem, config.n_embd, bias=False)
@ -228,37 +231,33 @@ class EngramModule(nn.Module):
groups=config.n_embd, bias=False,
)
def _ngram_hash(self, input_ids, n, seed, table_size):
def _ngram_hash(self, input_ids, n, powers, table_size):
"""Vectorized multiplicative hash over the last n tokens only.
For order n, the hash at position t is:
h(t) = (ids[t] * seed^1 + ids[t-1] * seed^2 + ... + ids[t-n+1] * seed^n) % table_size
Positions with fewer than n preceding tokens use a zero-padded window.
powers: LongTensor of shape (n,) precomputed powers on the correct device.
"""
B, T = input_ids.shape
ids = input_ids.long()
# Build a (B, T, n) window tensor: window[b, t, j] = ids[b, t - (n-1-j)]
# Use pad+slice for vectorization (no Python loop)
padded = F.pad(ids, (n - 1, 0), value=0) # (B, T + n - 1)
# Extract windows of size n ending at each position
# window[b, t, :] = padded[b, t : t+n]
padded = F.pad(ids, (n - 1, 0), value=0)
windows = padded.unfold(dimension=1, size=n, step=1) # (B, T, n)
# Compute polynomial hash: h = sum(window[j] * seed^(n-j)) mod table_size
# Use modular exponentiation to avoid overflow
powers = torch.tensor([pow(seed, n - j, table_size) for j in range(n)], dtype=torch.long, device=ids.device)
h = (windows * powers).sum(dim=-1) % table_size # (B, T)
h = (windows * powers).sum(dim=-1) % table_size
return h
def forward(self, x, input_ids, kv_cache=None):
B, T, D = x.shape
embeds = []
for n in range(2, self.max_ngram + 1):
idx = 0
for n_idx, n in enumerate(range(2, self.max_ngram + 1)):
for k in range(self.n_heads):
key = f"{n}_{k}"
seed = self.hash_seeds[key].item()
table_size = self.table_sizes[key]
table = self.embed_tables[n_idx][k]
table_size = table.num_embeddings
powers = self._powers[idx].to(input_ids.device)
if kv_cache is None:
hash_idx = self._ngram_hash(input_ids, n, seed, table_size)
hash_idx = self._ngram_hash(input_ids, n, powers, table_size)
elif T == 1 and kv_cache.ngram_history is not None:
# Decode: build n-gram window from cached history + current token
history = kv_cache.ngram_history # (B, max_ngram-1)
@ -266,11 +265,11 @@ class EngramModule(nn.Module):
# Take last (n-1) tokens from history, append current
hist_slice = history[:, -(n - 1):] # (B, n-1)
window = torch.cat([hist_slice, current], dim=1) # (B, n)
powers = torch.tensor([pow(seed, n - j, table_size) for j in range(n)], dtype=torch.long, device=input_ids.device)
hash_idx = (window * powers).sum(dim=-1, keepdim=True) % table_size # (B, 1)
hash_idx = (window * powers.unsqueeze(0)).sum(dim=-1, keepdim=True) % table_size # (B, 1)
else:
hash_idx = self._ngram_hash(input_ids, n, seed, table_size)
embeds.append(self.embed_tables[key](hash_idx))
hash_idx = self._ngram_hash(input_ids, n, powers, table_size)
embeds.append(table(hash_idx))
idx += 1
# Update n-gram history in KV cache after all orders processed
if kv_cache is not None:
max_hist = self.max_ngram - 1
@ -434,18 +433,26 @@ class GPT(nn.Module):
ve.to(dtype=COMPUTE_DTYPE)
for block in self.transformer.h:
if block.engram is not None:
for t in block.engram.embed_tables.values():
t.to(dtype=COMPUTE_DTYPE)
for tables_n in block.engram.embed_tables:
for t in tables_n:
t.to(dtype=COMPUTE_DTYPE)
block.engram.short_conv.to(dtype=COMPUTE_DTYPE)
# Engram: init embeddings (uniform), projections (uniform), conv (zeros for identity)
for block in self.transformer.h:
if block.engram is not None:
em = block.engram
for t in em.embed_tables.values():
torch.nn.init.uniform_(t.weight, -s, s)
for tables_n in em.embed_tables:
for t in tables_n:
torch.nn.init.uniform_(t.weight, -s, s)
torch.nn.init.uniform_(em.key_proj.weight, -s, s)
torch.nn.init.uniform_(em.value_proj.weight, -s, s)
torch.nn.init.zeros_(em.short_conv.weight)
# Recompute _powers on the correct device (needed when __init__ ran on meta device)
em._powers = [
torch.tensor([pow(seed, n - j, prime_size) for j in range(n)], dtype=torch.long, device=em.embed_tables[0][0].weight.device)
for prime_size, seed, n in em._power_configs
]
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently
@ -514,8 +521,9 @@ class GPT(nn.Module):
engram_embed_numel = 0
for block in self.transformer.h:
if block.engram is not None:
for t in block.engram.embed_tables.values():
engram_embed_numel += t.weight.numel()
for tables_n in block.engram.embed_tables:
for t in tables_n:
engram_embed_numel += t.weight.numel()
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + engram_embed_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel() +
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel())
@ -582,8 +590,9 @@ class GPT(nn.Module):
for block in self.transformer.h:
if block.engram is not None:
em = block.engram
for t in em.embed_tables.values():
engram_embed_params.append(t.weight)
for tables_n in em.embed_tables:
for t in tables_n:
engram_embed_params.append(t.weight)
engram_matrix_params_set.add(id(em.key_proj.weight))
engram_matrix_params_set.add(id(em.value_proj.weight))
engram_matrix_params_set.add(id(em.short_conv.weight))

View File

@ -139,8 +139,9 @@ class TestOptimizerNoDuplicates:
engram_embed_ids = set()
for block in model.transformer.h:
if block.engram is not None:
for t in block.engram.embed_tables.values():
engram_embed_ids.add(id(t.weight))
for tables_n in block.engram.embed_tables:
for t in tables_n:
engram_embed_ids.add(id(t.weight))
muon_params = set()
for group in opt.param_groups: