From 4011f5e3d55a4a6b1005b5e66cde91609fcb0a1b Mon Sep 17 00:00:00 2001 From: junran Date: Tue, 28 Apr 2026 08:18:50 +0800 Subject: [PATCH] Fix torch.compile graph breaks with Engram module Refactor Engram data structures for torch.compile compatibility: - Replace nn.ModuleDict (string-keyed) with nested nn.ModuleList (integer-indexed) - Precompute hash powers in __init__ and re-materialize in init_weights to eliminate torch.tensor() construction during forward - Remove hash_seeds and table_sizes dicts (info embedded in _powers / embed_tables) - Add _patch_removed_state_keys for old engram checkpoint backward compat Co-Authored-By: Claude Opus 4.7 (1M context) --- nanochat/checkpoint_manager.py | 27 ++++++++++++ nanochat/gpt.py | 77 +++++++++++++++++++--------------- tests/test_engram.py | 5 ++- 3 files changed, 73 insertions(+), 36 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index f71524ed..f73d2c8e 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -39,6 +39,32 @@ def _patch_missing_keys(model_data, model_config): model_data["x0_lambdas"] = torch.zeros(n_layer) log0(f"Patching missing x0_lambdas in model data to 0.0") +def _patch_removed_state_keys(model_data): + """Remove/remap keys that no longer exist in the model (backward compat).""" + import re + keys_to_remove = [] + keys_to_remap = {} + for key in list(model_data.keys()): + # Strip old engram.hash_seeds keys + if key.startswith("transformer.h.") and "engram.hash_seeds" in key: + keys_to_remove.append(key) + continue + # Remap old engram.embed_tables._.* -> embed_tables...* + m = re.match(r"(transformer\.h\.\d+\.engram\.embed_tables)\.(\d+)_(\d+)\.(.*)", key) + if m: + prefix, n_str, k_str, rest = m.groups() + n_idx = int(n_str) - 2 # old n=2,3 -> new n_idx=0,1 + new_key = f"{prefix}.{n_idx}.{k_str}.{rest}" + keys_to_remap[key] = new_key + for old_key in keys_to_remove: + del model_data[old_key] + for old_key, new_key in keys_to_remap.items(): + model_data[new_key] = model_data.pop(old_key) + if keys_to_remove: + log0(f"Removed {len(keys_to_remove)} stale engram.hash_seeds keys from checkpoint") + if keys_to_remap: + log0(f"Remapped {len(keys_to_remap)} engram.embed_tables keys from old string-keyed format") + def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0): if rank == 0: os.makedirs(checkpoint_dir, exist_ok=True) @@ -97,6 +123,7 @@ def build_model(checkpoint_dir, step, device, phase): log0(f"Building model with config: {model_config_kwargs}") model_config = GPTConfig(**model_config_kwargs) _patch_missing_keys(model_data, model_config) + _patch_removed_state_keys(model_data) with torch.device("meta"): model = GPT(model_config) # Load the model state diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 2c7105df..00be51dc 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -205,17 +205,20 @@ class EngramModule(nn.Module): self.d_mem = n_gram_orders * self.n_heads * self.embed_dim base_table_size = config.engram_table_size if config.engram_table_size > 0 else padded_vocab_size * 3 - self.embed_tables = nn.ModuleDict() - self.hash_seeds = nn.ParameterDict() # buffers stored as 0-dim params for checkpointing - self.table_sizes = {} + # Nested nn.ModuleList[nn.ModuleList[Embedding]]: index as [n_idx][k] + self.embed_tables = nn.ModuleList() + self._powers = [] # list of LongTensor: precomputed hash powers per (n, k); may be meta + self._power_configs = [] # list of (prime_size, seed, n) for reinit on real device for n in range(2, self.max_ngram + 1): + tables_for_n = nn.ModuleList() for k in range(self.n_heads): prime_size = _next_prime(base_table_size + n * k * 997) - key = f"{n}_{k}" - self.embed_tables[key] = nn.Embedding(prime_size, self.embed_dim) + tables_for_n.append(nn.Embedding(prime_size, self.embed_dim)) seed = n * 2654435761 + k * 40503 # multiplicative hash constants - self.hash_seeds[key] = nn.Parameter(torch.tensor(seed, dtype=torch.long), requires_grad=False) - self.table_sizes[key] = prime_size + powers = torch.tensor([pow(seed, n - j, prime_size) for j in range(n)], dtype=torch.long) + self._powers.append(powers) + self._power_configs.append((prime_size, seed, n)) + self.embed_tables.append(tables_for_n) self.key_proj = Linear(self.d_mem, config.n_embd, bias=False) self.value_proj = Linear(self.d_mem, config.n_embd, bias=False) @@ -228,37 +231,33 @@ class EngramModule(nn.Module): groups=config.n_embd, bias=False, ) - def _ngram_hash(self, input_ids, n, seed, table_size): + def _ngram_hash(self, input_ids, n, powers, table_size): """Vectorized multiplicative hash over the last n tokens only. For order n, the hash at position t is: h(t) = (ids[t] * seed^1 + ids[t-1] * seed^2 + ... + ids[t-n+1] * seed^n) % table_size Positions with fewer than n preceding tokens use a zero-padded window. + + powers: LongTensor of shape (n,) — precomputed powers on the correct device. """ B, T = input_ids.shape ids = input_ids.long() - # Build a (B, T, n) window tensor: window[b, t, j] = ids[b, t - (n-1-j)] - # Use pad+slice for vectorization (no Python loop) - padded = F.pad(ids, (n - 1, 0), value=0) # (B, T + n - 1) - # Extract windows of size n ending at each position - # window[b, t, :] = padded[b, t : t+n] + padded = F.pad(ids, (n - 1, 0), value=0) windows = padded.unfold(dimension=1, size=n, step=1) # (B, T, n) - # Compute polynomial hash: h = sum(window[j] * seed^(n-j)) mod table_size - # Use modular exponentiation to avoid overflow - powers = torch.tensor([pow(seed, n - j, table_size) for j in range(n)], dtype=torch.long, device=ids.device) - h = (windows * powers).sum(dim=-1) % table_size # (B, T) + h = (windows * powers).sum(dim=-1) % table_size return h def forward(self, x, input_ids, kv_cache=None): B, T, D = x.shape embeds = [] - for n in range(2, self.max_ngram + 1): + idx = 0 + for n_idx, n in enumerate(range(2, self.max_ngram + 1)): for k in range(self.n_heads): - key = f"{n}_{k}" - seed = self.hash_seeds[key].item() - table_size = self.table_sizes[key] + table = self.embed_tables[n_idx][k] + table_size = table.num_embeddings + powers = self._powers[idx].to(input_ids.device) if kv_cache is None: - hash_idx = self._ngram_hash(input_ids, n, seed, table_size) + hash_idx = self._ngram_hash(input_ids, n, powers, table_size) elif T == 1 and kv_cache.ngram_history is not None: # Decode: build n-gram window from cached history + current token history = kv_cache.ngram_history # (B, max_ngram-1) @@ -266,11 +265,11 @@ class EngramModule(nn.Module): # Take last (n-1) tokens from history, append current hist_slice = history[:, -(n - 1):] # (B, n-1) window = torch.cat([hist_slice, current], dim=1) # (B, n) - powers = torch.tensor([pow(seed, n - j, table_size) for j in range(n)], dtype=torch.long, device=input_ids.device) - hash_idx = (window * powers).sum(dim=-1, keepdim=True) % table_size # (B, 1) + hash_idx = (window * powers.unsqueeze(0)).sum(dim=-1, keepdim=True) % table_size # (B, 1) else: - hash_idx = self._ngram_hash(input_ids, n, seed, table_size) - embeds.append(self.embed_tables[key](hash_idx)) + hash_idx = self._ngram_hash(input_ids, n, powers, table_size) + embeds.append(table(hash_idx)) + idx += 1 # Update n-gram history in KV cache after all orders processed if kv_cache is not None: max_hist = self.max_ngram - 1 @@ -434,18 +433,26 @@ class GPT(nn.Module): ve.to(dtype=COMPUTE_DTYPE) for block in self.transformer.h: if block.engram is not None: - for t in block.engram.embed_tables.values(): - t.to(dtype=COMPUTE_DTYPE) + for tables_n in block.engram.embed_tables: + for t in tables_n: + t.to(dtype=COMPUTE_DTYPE) + block.engram.short_conv.to(dtype=COMPUTE_DTYPE) # Engram: init embeddings (uniform), projections (uniform), conv (zeros for identity) for block in self.transformer.h: if block.engram is not None: em = block.engram - for t in em.embed_tables.values(): - torch.nn.init.uniform_(t.weight, -s, s) + for tables_n in em.embed_tables: + for t in tables_n: + torch.nn.init.uniform_(t.weight, -s, s) torch.nn.init.uniform_(em.key_proj.weight, -s, s) torch.nn.init.uniform_(em.value_proj.weight, -s, s) torch.nn.init.zeros_(em.short_conv.weight) + # Recompute _powers on the correct device (needed when __init__ ran on meta device) + em._powers = [ + torch.tensor([pow(seed, n - j, prime_size) for j in range(n)], dtype=torch.long, device=em.embed_tables[0][0].weight.device) + for prime_size, seed, n in em._power_configs + ] def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None): # TODO: bump base theta more? e.g. 100K is more common more recently @@ -514,8 +521,9 @@ class GPT(nn.Module): engram_embed_numel = 0 for block in self.transformer.h: if block.engram is not None: - for t in block.engram.embed_tables.values(): - engram_embed_numel += t.weight.numel() + for tables_n in block.engram.embed_tables: + for t in tables_n: + engram_embed_numel += t.weight.numel() nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + engram_embed_numel + self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()) @@ -582,8 +590,9 @@ class GPT(nn.Module): for block in self.transformer.h: if block.engram is not None: em = block.engram - for t in em.embed_tables.values(): - engram_embed_params.append(t.weight) + for tables_n in em.embed_tables: + for t in tables_n: + engram_embed_params.append(t.weight) engram_matrix_params_set.add(id(em.key_proj.weight)) engram_matrix_params_set.add(id(em.value_proj.weight)) engram_matrix_params_set.add(id(em.short_conv.weight)) diff --git a/tests/test_engram.py b/tests/test_engram.py index 576c6654..c525ee8f 100644 --- a/tests/test_engram.py +++ b/tests/test_engram.py @@ -139,8 +139,9 @@ class TestOptimizerNoDuplicates: engram_embed_ids = set() for block in model.transformer.h: if block.engram is not None: - for t in block.engram.embed_tables.values(): - engram_embed_ids.add(id(t.weight)) + for tables_n in block.engram.embed_tables: + for t in tables_n: + engram_embed_ids.add(id(t.weight)) muon_params = set() for group in opt.param_groups: