From aab331dfd4c0282aa38af5c72e7d0fa1f7506331 Mon Sep 17 00:00:00 2001 From: Codex Date: Wed, 6 May 2026 12:19:07 +0000 Subject: [PATCH 1/8] Add minimal bigram speedrun recipe --- dev/bigram_speedrun_results.md | 66 +++++++++++++++ nanochat/engine.py | 5 ++ nanochat/gpt.py | 85 +++++++++++++++++--- nanochat/optim.py | 18 ++++- runs/speedrun.sh | 14 +++- scripts/base_train.py | 141 +++++++++++++++++++++++---------- tests/test_engine.py | 44 ++++++++++ 7 files changed, 321 insertions(+), 52 deletions(-) create mode 100644 dev/bigram_speedrun_results.md diff --git a/dev/bigram_speedrun_results.md b/dev/bigram_speedrun_results.md new file mode 100644 index 00000000..61f6ec75 --- /dev/null +++ b/dev/bigram_speedrun_results.md @@ -0,0 +1,66 @@ +# Bigram Speedrun Verification Notes + +This branch is based on upstream nanochat master at `dc54a1a` and keeps the +submission implementation focused on the winning recipe: + +- per-layer hashed bigram residual embeddings +- Muon+ post-orthogonalization normalization +- row equilibration before Muon orthogonalization +- lower scalar LR (`--scalar-lr=0.3`) +- batched training logging (`--train-log-every=50`) +- `torch.compile(..., mode="max-autotune-no-cudagraphs")` for the speedrun script + +It intentionally excludes the experimental branches that were not part of the +final candidate: sparse layers, MoE/TOP losses, train-time logit bias losses, +post-hoc fitting, NorMuon, and checkpoint merging. + +## Reproduction Sanity Check + +Minimal branch d4/20 matched the prior experimental branch: + +| Run | Step 0 BPB | Step 10 BPB | Final BPB | +| --- | ---: | ---: | ---: | +| Prior candidate branch | `3.237224` | `3.234722` | `3.223259` | +| Minimal PR branch | `3.237224` | `3.234722` | `3.223286` | + +The final difference is `0.000027` BPB on a tiny run, consistent with small +compile/graph differences after removing unused experimental code. + +## Full d16 Verification + +Both runs used d16, FP8, target param/data ratio 8, total batch `524288`, and +device batch `32` on the same machine. + +| Run | Final BPB | Train time | Avg logged tok/s, excluding first | Avg logged step time, excluding first | +| --- | ---: | ---: | ---: | ---: | +| Upstream master dense | `0.800673` | `94.64m` | `329,904` | `1589.232ms` | +| Bigram/Muon+ candidate | `0.798000` | `93.61m` | `333,507` | `1572.058ms` | + +Candidate delta versus upstream master dense: + +- BPB: `-0.002673` +- train time: `-1.03m` (`1.09%` faster) +- logged throughput: `+3,603 tok/s` (`1.09%` higher) + +Important caveat: this is a full recipe comparison, not an architecture-only +comparison. The candidate also uses `--train-log-every=50` and +`--compile-mode=max-autotune-no-cudagraphs`, while upstream master logs every +step and uses the default compile mode. + +## Compile Mode Probe + +Short d16/40 throughput probes on the minimal branch: + +| Compile mode | Avg logged tok/s, excluding first | Avg logged step time, excluding first | Total time | +| --- | ---: | ---: | ---: | +| default `torch.compile` | `324,995` | `1613.250ms` | `0.78m` | +| `max-autotune-no-cudagraphs` | `333,261` | `1573.250ms` | `0.76m` | + +On this d16 probe, `max-autotune-no-cudagraphs` was about `2.5%` faster than +the default compile mode. The speedrun script keeps this compile mode for that +reason. + +## Test Status + +- `python -m pytest tests/test_engine.py -q`: `9 passed` +- `python -m py_compile nanochat/gpt.py nanochat/optim.py scripts/base_train.py nanochat/engine.py`: passed diff --git a/nanochat/engine.py b/nanochat/engine.py index aa2e6a98..2a7d12d2 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -102,11 +102,14 @@ class KVCache: self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device) # Previous token's normalized embedding for smear (set by model forward pass) self.prev_embedding = None + # Previous token id for hashed bigram embeddings (set by model forward pass) + self.prev_token = None def reset(self): """Reset cache to empty state.""" self.cache_seqlens.zero_() self.prev_embedding = None + self.prev_token = None def get_pos(self): """Get current position (assumes all batch elements at same position).""" @@ -135,6 +138,8 @@ class KVCache: # Copy smear state: expand batch=1 prev_embedding to num_samples if other.prev_embedding is not None: self.prev_embedding = other.prev_embedding.expand(self.batch_size, -1, -1).clone() + if other.prev_token is not None: + self.prev_token = other.prev_token.expand(self.batch_size, -1).clone() # ----------------------------------------------------------------------------- @torch.inference_mode() diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 07a1eae8..24a01c64 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -37,6 +37,8 @@ class GPTConfig: # Characters: L=long (full context), S=short (quarter context) # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long window_pattern: str = "SSSL" + bigram_embed_factor: int = 0 + bigram_lambda_init: float = 0.05 def norm(x): @@ -172,6 +174,8 @@ class GPT(nn.Module): "wte": nn.Embedding(padded_vocab_size, config.n_embd), "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), }) + self.bigram_vocab_size = int(config.vocab_size * max(0, int(config.bigram_embed_factor))) + self.bigram_embed = nn.Embedding(self.bigram_vocab_size, config.n_embd) if self.bigram_vocab_size > 0 else None self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False) # Per-layer learnable scalars (inspired by modded-nanogpt) # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral) @@ -179,6 +183,10 @@ class GPT(nn.Module): # Separate parameters so they can have different optimizer treatment self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() + if self.bigram_embed is not None: + self.bigram_lambdas = nn.Parameter(torch.zeros(config.n_layer)) + else: + self.register_buffer("bigram_lambdas", torch.zeros(0), persistent=False) # Smear: mix previous token's embedding into current token (cheap bigram-like info) self.smear_gate = Linear(24, 1, bias=False) self.smear_lambda = nn.Parameter(torch.zeros(1)) @@ -216,6 +224,8 @@ class GPT(nn.Module): # Embedding and unembedding torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8) + if self.bigram_embed is not None: + torch.nn.init.zeros_(self.bigram_embed.weight) torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001) # Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal) @@ -237,6 +247,8 @@ class GPT(nn.Module): # Decaying x0 init: earlier layers get more input embedding blending for i in range(n_layer): self.x0_lambdas.data[i] = 0.20 - (0.15 * i / max(n_layer - 1, 1)) + if self.bigram_embed is not None: + torch.nn.init.constant_(self.bigram_lambdas, self.config.bigram_lambda_init) # Smear/backout scalars and smear gate must be explicitly initialized torch.nn.init.zeros_(self.smear_lambda) @@ -262,9 +274,25 @@ class GPT(nn.Module): # because GradScaler cannot unscale fp16 gradients. if COMPUTE_DTYPE != torch.float16: self.transformer.wte.to(dtype=COMPUTE_DTYPE) + if self.bigram_embed is not None: + self.bigram_embed.to(dtype=COMPUTE_DTYPE) for ve in self.value_embeds.values(): ve.to(dtype=COMPUTE_DTYPE) + def _bigram_hash(self, idx, prev_idx=None): + mod = self.bigram_vocab_size - 1 + if mod <= 0: + raise RuntimeError("bigram hash requested with disabled bigram embedding") + idx_i32 = idx.to(torch.int32) + out = torch.empty_like(idx_i32) + if prev_idx is None: + out[:, :1].fill_(mod) + out[:, 1:] = torch.bitwise_xor(36313 * idx_i32[:, 1:], 27191 * idx_i32[:, :-1]) % mod + else: + prev_i32 = prev_idx.to(torch.int32) + out[:] = torch.bitwise_xor(36313 * idx_i32, 27191 * prev_i32) % mod + return out.to(torch.long) + def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None): # TODO: bump base theta more? e.g. 100K is more common more recently # autodetect the device from model embeddings @@ -329,8 +357,9 @@ class GPT(nn.Module): nparams = sum(p.numel() for p in self.parameters()) # Exclude non-matmul params: embeddings and per-layer scalars value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) - nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + - self.resid_lambdas.numel() + self.x0_lambdas.numel() + + bigram_embed_numel = self.bigram_embed.weight.numel() if self.bigram_embed is not None else 0 + nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + bigram_embed_numel + + self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window @@ -356,14 +385,17 @@ class GPT(nn.Module): """ # Count each group separately (mirrors the grouping in setup_optimizers) wte = sum(p.numel() for p in self.transformer.wte.parameters()) + bigram_embed = self.bigram_embed.weight.numel() if self.bigram_embed is not None else 0 value_embeds = sum(p.numel() for p in self.value_embeds.parameters()) lm_head = sum(p.numel() for p in self.lm_head.parameters()) transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters()) - scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() - total = wte + value_embeds + lm_head + transformer_matrices + scalars + bigram_lambdas = self.bigram_lambdas.numel() if isinstance(self.bigram_lambdas, nn.Parameter) else 0 + scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + bigram_lambdas + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() + total = wte + bigram_embed + value_embeds + lm_head + transformer_matrices + scalars assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch" return { 'wte': wte, + 'bigram_embed': bigram_embed, 'value_embeds': value_embeds, 'lm_head': lm_head, 'transformer_matrices': transformer_matrices, @@ -371,40 +403,60 @@ class GPT(nn.Module): 'total': total, } - def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5): + def setup_optimizer( + self, + unembedding_lr=0.004, + embedding_lr=0.2, + bigram_embedding_lr_mult=1.0, + bigram_lambda_lr=0.004, + matrix_lr=0.02, + weight_decay=0.0, + scalar_lr=0.5, + muon_plus=False, + muon_eq_axis=0, + ): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() # Separate out all parameters into groups matrix_params = list(self.transformer.h.parameters()) value_embeds_params = list(self.value_embeds.parameters()) + bigram_embed_params = list(self.bigram_embed.parameters()) if self.bigram_embed is not None else [] embedding_params = list(self.transformer.wte.parameters()) lm_head_params = list(self.lm_head.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] + bigram_lambda_params = [self.bigram_lambdas] if isinstance(self.bigram_lambdas, nn.Parameter) else [] smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params) + assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(bigram_embed_params) + len(resid_params) + len(x0_params) + len(bigram_lambda_params) + len(smear_params) # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}") # Build param_groups with all required fields explicit + # AdamW groups (embeddings, lm_head, scalars) param_groups = [ - # AdamW groups (embeddings, lm_head, scalars) dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=(0.8, 0.96), eps=1e-10, weight_decay=0.01), dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.001), dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01), + ] + if bigram_embed_params: + param_groups.append(dict(kind='adamw', params=bigram_embed_params, lr=embedding_lr * dmodel_lr_scale * bigram_embedding_lr_mult, betas=(0.75, 0.95), eps=1e-10, weight_decay=0.01)) + param_groups.extend([ dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05), dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0 - dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0), - ] + ]) + if bigram_embed_params: + param_groups.append(dict(kind='adamw', params=bigram_lambda_params, lr=bigram_lambda_lr * dmodel_lr_scale, betas=(0.9, 0.95), eps=1e-10, weight_decay=0.0)) + param_groups.append(dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0)) # Muon groups (matrix params, grouped by shape for stacking) for shape in sorted({p.shape for p in matrix_params}): group_params = [p for p in matrix_params if p.shape == shape] param_groups.append(dict( kind='muon', params=group_params, lr=matrix_lr, momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay, + muon_plus=muon_plus, muon_eq_axis=muon_eq_axis, )) Factory = DistMuonAdamW if ddp else MuonAdamW @@ -448,6 +500,19 @@ class GPT(nn.Module): gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24])) x = x + gate * x_pre_smear + # Optional hashed bigram embedding residual. During KV-cache decoding we need the + # previous token id because the sequence length is one. + if self.bigram_embed is not None: + if kv_cache is None or T > 1: + bigram_idx = self._bigram_hash(idx) + else: + bigram_idx = self._bigram_hash(idx, kv_cache.prev_token) + x0_bigram = self.bigram_embed(bigram_idx).to(x.dtype) + else: + x0_bigram = None + if kv_cache is not None: + kv_cache.prev_token = idx[:, -1:].clone() + # Forward the trunk of the Transformer x0 = x # save initial normalized embedding for x0 residual n_layer = self.config.n_layer @@ -455,6 +520,8 @@ class GPT(nn.Module): x_backout = None for i, block in enumerate(self.transformer.h): x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 + if x0_bigram is not None: + x = x + self.bigram_lambdas[i].to(x.dtype) * x0_bigram ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) if i == backout_layer: diff --git a/nanochat/optim.py b/nanochat/optim.py index 56e85e14..cd10688a 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -100,6 +100,8 @@ def muon_step_fused( beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations red_dim: int, # -1 or -2 - reduction dimension for variance + muon_plus: bool, # add one Frobenius renormalization after orthogonalization + muon_eq_axis: int, # 0 none, 1 row, 2 column equilibration before orthogonalization ) -> None: """ Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update @@ -115,6 +117,14 @@ def muon_step_fused( # Polar express # Cast to bf16 for speed when available; skip cast otherwise (fp16 is unstable here due to limited exponent range) X = g.bfloat16() if COMPUTE_DTYPE == torch.bfloat16 else g + if muon_eq_axis == 1: + target = X.float().norm(dim=(-2, -1), keepdim=True) / (X.size(-2) ** 0.5) + row_norm = X.float().norm(dim=-1, keepdim=True).clamp_min(1e-6) + X = X * (target / row_norm).to(X.dtype) + elif muon_eq_axis == 2: + target = X.float().norm(dim=(-2, -1), keepdim=True) / (X.size(-1) ** 0.5) + col_norm = X.float().norm(dim=-2, keepdim=True).clamp_min(1e-6) + X = X * (target / col_norm).to(X.dtype) X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6) if g.size(-2) > g.size(-1): # Tall matrix for a, b, c in polar_express_coeffs[:ns_steps]: @@ -127,6 +137,10 @@ def muon_step_fused( B = b * A + c * (A @ A) X = a * X + B @ X g = X + if muon_plus: + target_norm = min(g.size(-2), g.size(-1)) ** 0.5 + current_norm = g.float().norm(dim=(-2, -1), keepdim=True).clamp_min(1e-6) + g = g * (target_norm / current_norm).to(g.dtype) # Variance reduction beta2 = beta2_t.to(g.dtype) @@ -277,6 +291,8 @@ class MuonAdamW(torch.optim.Optimizer): self._muon_beta2_t, group["ns_steps"], red_dim, + group.get("muon_plus", False), + group.get("muon_eq_axis", 0), ) # Copy back to original params @@ -486,7 +502,7 @@ class DistMuonAdamW(torch.optim.Optimizer): grad_chunk[:num_owned], stacked_owned, state["momentum_buffer"][:num_owned], state["second_momentum_buffer"][:num_owned], self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, self._muon_beta2_t, - group["ns_steps"], red_dim, + group["ns_steps"], red_dim, group.get("muon_plus", False), group.get("muon_eq_axis", 0), ) updated_params[:num_owned].copy_(stacked_owned) diff --git a/runs/speedrun.sh b/runs/speedrun.sh index 48fcc68a..9f780faf 100644 --- a/runs/speedrun.sh +++ b/runs/speedrun.sh @@ -70,7 +70,19 @@ echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID # d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 8) -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=8 --device-batch-size=16 --fp8 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ + --depth=24 \ + --target-param-data-ratio=8 \ + --device-batch-size=16 \ + --total-batch-size=1048576 \ + --fp8 \ + --compile-mode=max-autotune-no-cudagraphs \ + --muon-plus \ + --muon-eq=row \ + --bigram-embed-factor=5 \ + --scalar-lr=0.3 \ + --train-log-every=50 \ + --run=$WANDB_RUN # evaluate the model: CORE metric, BPB on train/val, and draw samples torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16 diff --git a/scripts/base_train.py b/scripts/base_train.py index a161c477..56fbae2b 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -41,17 +41,23 @@ print_banner() parser = argparse.ArgumentParser(description="Pretrain base model") # Logging parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") +parser.add_argument("--train-log-every", type=int, default=1, help="print training metrics every N steps; values >1 avoid per-step CPU/GPU sync") # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") # FP8 training parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)") parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") +parser.add_argument("--compile-mode", type=str, default="", choices=["", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], help="optional torch.compile mode") # Model architecture parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio") parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") +parser.add_argument("--bigram-embed-factor", type=int, default=0, help="if >0, add a hashed bigram embedding residual") +parser.add_argument("--bigram-lambda-init", type=float, default=0.05, help="initial layer residual scale for --bigram-embed-factor") +parser.add_argument("--bigram-embedding-lr-mult", type=float, default=1.0, help="bigram embedding LR multiplier relative to --embedding-lr") +parser.add_argument("--bigram-lambda-lr", type=float, default=0.004, help="AdamW LR for bigram layer lambdas before dmodel schedule scaling") # Training horizon (only one used, in order of precedence) parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") @@ -64,6 +70,8 @@ parser.add_argument("--unembedding-lr", type=float, default=0.008, help="learnin parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious weight decay for the Muon optimizer (for weights)") parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") +parser.add_argument("--muon-plus", action="store_true", help="apply Muon+ style post-orthogonalization Frobenius renormalization") +parser.add_argument("--muon-eq", type=str, default="none", choices=["none", "row", "col"], help="apply MuonEq-style row/column equilibration before orthogonalization") parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup") parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown") parser.add_argument("--final-lr-frac", type=float, default=0.05, help="final LR as fraction of initial LR") @@ -71,6 +79,7 @@ parser.add_argument("--resume-from-step", type=int, default=-1, help="resume tra # Evaluation parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") parser.add_argument("--eval-tokens", type=int, default=80*524288, help="number of tokens to evaluate val loss on") +parser.add_argument("--skip-initial-eval", action="store_true", help="skip the step 0 validation pass; final validation still runs") parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)") parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric") parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") @@ -79,6 +88,14 @@ parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") args = parser.parse_args() user_config = vars(args).copy() # for logging +if args.train_log_every <= 0: + parser.error("--train-log-every must be positive") +if args.bigram_embed_factor < 0: + parser.error("--bigram-embed-factor must be non-negative") +if args.bigram_lambda_lr < 0: + parser.error("--bigram-lambda-lr must be non-negative") +if args.bigram_embedding_lr_mult <= 0: + parser.error("--bigram-embedding-lr-mult must be positive") # ----------------------------------------------------------------------------- # Compute init and wandb logging @@ -137,6 +154,8 @@ def build_model_meta(depth): sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, window_pattern=args.window_pattern, + bigram_embed_factor=args.bigram_embed_factor, + bigram_lambda_init=args.bigram_lambda_init, ) with torch.device("meta"): model_meta = GPT(config) @@ -243,7 +262,10 @@ def disable_fp8(model): # Compile the model orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) -model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe +compile_kwargs = {"dynamic": False} +if args.compile_mode: + compile_kwargs["mode"] = args.compile_mode +model = torch.compile(model, **compile_kwargs) # the inputs to model will never change shape so dynamic=False is safe # ----------------------------------------------------------------------------- # Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay. @@ -305,14 +327,20 @@ if weight_decay_scaled != args.weight_decay: # ----------------------------------------------------------------------------- # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) +muon_eq_axis = {"none": 0, "row": 1, "col": 2}[args.muon_eq] +print0(f"Muon options: muon_plus={args.muon_plus}, muon_eq={args.muon_eq}") optimizer = model.setup_optimizer( # AdamW hyperparameters unembedding_lr=args.unembedding_lr * batch_lr_scale, embedding_lr=args.embedding_lr * batch_lr_scale, + bigram_embedding_lr_mult=args.bigram_embedding_lr_mult, + bigram_lambda_lr=args.bigram_lambda_lr * batch_lr_scale, scalar_lr=args.scalar_lr * batch_lr_scale, # Muon hyperparameters matrix_lr=args.matrix_lr * batch_lr_scale, weight_decay=weight_decay_scaled, + muon_plus=args.muon_plus, + muon_eq_axis=muon_eq_axis, ) if resuming: @@ -411,6 +439,11 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") +train_log_every = args.train_log_every +batched_train_timing = train_log_every > 1 +train_timing_interval_start = None +train_timing_interval_first_step = step +train_log_count = 0 # Go! while True: @@ -418,7 +451,7 @@ while True: flops_so_far = num_flops_per_token * total_batch_size * step # once in a while: evaluate the val bpb (all ranks participate) - if args.eval_every > 0 and (last_step or step % args.eval_every == 0): + if args.eval_every > 0 and (last_step or (step % args.eval_every == 0 and (step > 0 or not args.skip_initial_eval))): model.eval() val_loader = build_val_loader() eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) @@ -505,8 +538,14 @@ while True: # ------------------------------------------------------------------------- # single training step # evaluate the gradient - synchronize() - t0 = time.time() + if batched_train_timing: + if train_timing_interval_start is None: + synchronize() + train_timing_interval_start = time.time() + train_timing_interval_first_step = step + else: + synchronize() + t0 = time.time() for micro_step in range(grad_accum_steps): loss = model(x, y) train_loss = loss.detach() # for logging @@ -538,46 +577,66 @@ while True: else: optimizer.step() model.zero_grad(set_to_none=True) - train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point - synchronize() - t1 = time.time() - dt = t1 - t0 + should_log_train = step == 0 or (step + 1) % train_log_every == 0 or (step + 1) == num_iterations + if batched_train_timing: + if should_log_train: + synchronize() + t1 = time.time() + interval_steps = step - train_timing_interval_first_step + 1 + interval_dt = t1 - train_timing_interval_start + dt = interval_dt / interval_steps + counted_start = max(train_timing_interval_first_step, 11) + counted_steps = max(0, step - counted_start + 1) + if counted_steps > 0: + total_training_time += interval_dt * counted_steps / interval_steps + train_loss_f = train_loss.item() + train_timing_interval_start = None + else: + dt = None + train_loss_f = None + else: + train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point + synchronize() + t1 = time.time() + dt = t1 - t0 + if step > 10: + total_training_time += dt # only count the time after the first 10 steps # ------------------------------------------------------------------------- # logging (CPU action only) - ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging - smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss - debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA - pct_done = 100 * step / num_iterations - tok_per_sec = int(total_batch_size / dt) - flops_per_sec = num_flops_per_token * total_batch_size / dt - mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) - if step > 10: - total_training_time += dt # only count the time after the first 10 steps - # Calculate ETA based on average time per step (excluding first 10 steps) - steps_done = step - 10 - if steps_done > 0: - avg_time_per_step = total_training_time / steps_done - remaining_steps = num_iterations - step - eta_seconds = remaining_steps * avg_time_per_step - eta_str = f" | eta: {eta_seconds/60:.1f}m" - else: - eta_str = "" - epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}" - print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") - if step % 100 == 0: - log_data = { - "step": step, - "total_training_flops": flops_so_far, - "total_training_time": total_training_time, - "train/loss": debiased_smooth_loss, - "train/lrm": lrm, - "train/dt": dt, - "train/tok_per_sec": tok_per_sec, - "train/mfu": mfu, - "train/epoch": epoch, - } - wandb_run.log(log_data) + if should_log_train: + ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging + smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss + train_log_count += 1 + debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**train_log_count) # debias the EMA + pct_done = 100 * step / num_iterations + tok_per_sec = int(total_batch_size / dt) + flops_per_sec = num_flops_per_token * total_batch_size / dt + mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) + # Calculate ETA based on average time per step (excluding first 10 steps) + steps_done = step - 10 + if steps_done > 0: + avg_time_per_step = total_training_time / steps_done + remaining_steps = num_iterations - step + eta_seconds = remaining_steps * avg_time_per_step + eta_str = f" | eta: {eta_seconds/60:.1f}m" + else: + eta_str = "" + epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}" + print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") + if step % 100 == 0 or (step + 1) % 100 == 0: + log_data = { + "step": step, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "train/loss": debiased_smooth_loss, + "train/lrm": lrm, + "train/dt": dt, + "train/tok_per_sec": tok_per_sec, + "train/mfu": mfu, + "train/epoch": epoch, + } + wandb_run.log(log_data) # state update first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step) diff --git a/tests/test_engine.py b/tests/test_engine.py index 784ffcb9..52d5b785 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -47,6 +47,25 @@ class MockModel: return logits +class BigramStateModel(MockModel): + """Mock model whose greedy next token depends on current and previous token ids.""" + def forward(self, ids, kv_cache=None): + B, T = ids.shape + if kv_cache is None: + prev = torch.cat([torch.zeros(B, 1, dtype=ids.dtype), ids[:, :-1]], dim=1) + else: + if T > 1 or kv_cache.prev_token is None: + prev = torch.cat([torch.zeros(B, 1, dtype=ids.dtype), ids[:, :-1]], dim=1) + else: + prev = kv_cache.prev_token + kv_cache.prev_token = ids[:, -1:].clone() + kv_cache.advance(T) + next_token = ((ids + prev + 1) % 256).long() + logits = torch.full((B, T, self.vocab_size), -1000.0) + logits.scatter_(2, next_token.unsqueeze(-1), 1000.0) + return logits + + class ByteTokenizer: """ Simple byte-level tokenizer for testing. @@ -114,6 +133,7 @@ def test_kv_cache_basic(): # Test reset kv_cache.reset() assert kv_cache.get_pos() == 0 + assert kv_cache.prev_token is None # Test get_layer_cache returns correct views k_layer0, v_layer0 = kv_cache.get_layer_cache(0) @@ -136,6 +156,7 @@ def test_kv_cache_prefill(): # Write some data to source cache src_cache.k_cache[0, 0, :16, :, :] = 1.0 src_cache.v_cache[0, 0, :16, :, :] = 2.0 + src_cache.prev_token = torch.tensor([[123]]) src_cache.advance(16) # Create destination cache with larger seq_len @@ -153,6 +174,29 @@ def test_kv_cache_prefill(): # Check data was copied assert (dst_cache.k_cache[0, 0, :16, :, :] == 1.0).all() assert (dst_cache.v_cache[0, 0, :16, :, :] == 2.0).all() + assert dst_cache.prev_token.tolist() == [[123]] + + +def test_engine_preserves_bigram_prev_token_state(): + """Engine KV-cache generation should match naive generation for previous-token state.""" + model = BigramStateModel() + tokenizer = ByteTokenizer() + engine = Engine(model, tokenizer) + prompt = [261, 17, 23, 42] + max_tokens = 8 + + def naive_generate(tokens): + ids = torch.tensor([tokens], dtype=torch.long) + out = [] + for _ in range(max_tokens): + logits = model.forward(ids) + next_id = int(logits[:, -1, :].argmax(dim=-1).item()) + out.append(next_id) + ids = torch.cat([ids, torch.tensor([[next_id]], dtype=torch.long)], dim=1) + return tokens + out + + results, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=max_tokens) + assert results[0] == naive_generate(prompt) def test_multi_sample_first_token_diversity(): From 9118a3d15ef81cf7da9fb2804832ab171befa86b Mon Sep 17 00:00:00 2001 From: Codex Date: Wed, 6 May 2026 12:37:44 +0000 Subject: [PATCH 2/8] Document controlled bigram throughput --- dev/bigram_speedrun_results.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/dev/bigram_speedrun_results.md b/dev/bigram_speedrun_results.md index 61f6ec75..436da647 100644 --- a/dev/bigram_speedrun_results.md +++ b/dev/bigram_speedrun_results.md @@ -47,6 +47,23 @@ comparison. The candidate also uses `--train-log-every=50` and `--compile-mode=max-autotune-no-cudagraphs`, while upstream master logs every step and uses the default compile mode. +## Controlled d16 Throughput + +A denser control run with the same log50/compile-control style is the better +way to estimate the per-step overhead of the bigram path. + +| Run | Final BPB | Train time | Avg logged tok/s, excluding first | Avg logged step time, excluding first | +| --- | ---: | ---: | ---: | ---: | +| Dense log50 compile control | `0.800604` | `92.85m` | `336,247` | `1559.258ms` | +| Bigram/Muon+ candidate, full 3584 | `0.798000` | `93.61m` | `333,507` | `1572.058ms` | + +Against this controlled dense run, the bigram candidate is about `0.81%` slower +per step, but `0.002604` BPB better at the same horizon. + +A shortened bigram run at 3400 steps landed at `0.800232` BPB in `88.92m`, +which is `0.000372` BPB better than the dense log50 compile control while using +about `4.23%` less training time. + ## Compile Mode Probe Short d16/40 throughput probes on the minimal branch: From e014abacc6c8c6f766e9ce9a2089edcaa5358e22 Mon Sep 17 00:00:00 2001 From: Codex Date: Thu, 7 May 2026 05:54:12 +0000 Subject: [PATCH 3/8] Document minimal PR changes --- dev/bigram_minimal_pr_changes.md | 208 +++++++++++++++++++++++++++++++ 1 file changed, 208 insertions(+) create mode 100644 dev/bigram_minimal_pr_changes.md diff --git a/dev/bigram_minimal_pr_changes.md b/dev/bigram_minimal_pr_changes.md new file mode 100644 index 00000000..24d8ecf1 --- /dev/null +++ b/dev/bigram_minimal_pr_changes.md @@ -0,0 +1,208 @@ +# Minimal Bigram Speedrun PR Changes + +This branch is based on upstream nanochat master at `dc54a1a`. The goal is to +keep the submission patch limited to the changes needed to reproduce the +best-performing speedrun recipe: + +```bash +--fp8 +--bigram-embed-factor=5 +--muon-plus +--muon-eq=row +--scalar-lr=0.3 +--train-log-every=50 +--compile-mode=max-autotune-no-cudagraphs +``` + +It does not include the experimental branches that were tested and rejected: +sparse architecture changes, MoE/TOP auxiliary losses, train-time logit-bias +losses, post-hoc calibration, NorMuon variants, checkpoint merging, or d22/d24 +run-management scripts. + +## `nanochat/gpt.py` + +### Hashed Bigram Residual Embedding + +Adds two config fields: + +- `bigram_embed_factor`, default `0` +- `bigram_lambda_init`, default `0.05` + +When `bigram_embed_factor > 0`, the model creates a separate bigram embedding +table with `vocab_size * bigram_embed_factor` entries. For each token position, +the current token id and previous token id are hashed into that table. The +resulting embedding is added as a residual input before every transformer block: + +```python +x = x + bigram_lambdas[i] * x0_bigram +``` + +The first token in each sequence uses a sentinel bucket because it has no +previous token. During KV-cache decoding, the previous token is read from the +cache so generation matches the training-time bigram definition. + +Why this helps: it gives the model a cheap, direct representation of adjacent +token pairs without adding attention or MLP compute. The bigram table is +zero-initialized, so the model starts from the original network function, while +the per-layer `bigram_lambdas` start at `0.05` to let the residual learn quickly. + +### Parameter Counting and FLOP Accounting + +The bigram embedding table and bigram lambdas are excluded from the main matmul +FLOP/scaling parameter count. They are not transformer matrix weights, and +including them would distort the target param/data ratio logic. + +### Optimizer Groups + +Adds dedicated optimizer groups for: + +- `bigram_embed` +- `bigram_lambdas` + +The bigram embedding uses AdamW with a configurable multiplier relative to the +main embedding LR. The layer lambdas use a small AdamW LR. This keeps the bigram +residual trainable without mixing it into the Muon-managed transformer matrices. + +### Muon Options Plumbed Through + +`setup_optimizer()` accepts: + +- `muon_plus` +- `muon_eq_axis` + +These are forwarded into the Muon parameter groups so the optimizer can apply +the selected Muon variants to matrix weights. + +## `nanochat/optim.py` + +### Muon+ Renormalization + +After Newton-Schulz orthogonalization, Muon+ rescales the update by its +Frobenius norm. This is a small post-processing step on the Muon update and was +the strongest optimizer-side change in the experiments. + +Why this helps: it stabilizes update scale after orthogonalization without +changing the model architecture or adding optimizer state. + +### Row/Column Equilibration + +Adds optional row or column norm equilibration before orthogonalization: + +- `muon_eq_axis=1`: row equilibration +- `muon_eq_axis=2`: column equilibration +- `muon_eq_axis=0`: disabled + +The speedrun recipe uses row equilibration. It normalizes rows toward a common +target norm before the polar/Newton-Schulz step, then continues through the +existing Muon update path. + +Why this helps: row equilibration was a small but positive companion to Muon+ in +the winning recipe, with minimal extra code and no extra persistent optimizer +state. + +## `nanochat/engine.py` + +### Previous Token in KV Cache + +Adds `prev_token` to `KVCache`, resets it with the rest of the cache, and copies +it during prefill expansion. + +Why this is needed: full-sequence training can compute bigram hashes from +`idx[:, :-1]`, but one-token decode does not have the previous token in the +current input tensor. Keeping `prev_token` in the cache makes generation use the +same bigram feature as training. + +## `scripts/base_train.py` + +### Bigram CLI Flags + +Adds: + +- `--bigram-embed-factor` +- `--bigram-lambda-init` +- `--bigram-embedding-lr-mult` +- `--bigram-lambda-lr` + +These configure the bigram residual and its optimizer treatment from the +training script without changing defaults. With default values, upstream +behavior is unchanged because `--bigram-embed-factor` defaults to `0`. + +### Muon Variant Flags + +Adds: + +- `--muon-plus` +- `--muon-eq` + +These expose the optimizer variants used in the recipe. Defaults preserve the +original optimizer behavior. + +### Train Logging Cadence + +Adds `--train-log-every`. Values greater than 1 avoid converting the loss tensor +to a Python scalar every step. + +Why this helps: per-step logging creates extra synchronization overhead. The +speedrun uses `--train-log-every=50`, which keeps useful progress reporting +while reducing logging overhead. + +### Compile Mode + +Adds `--compile-mode` so the speedrun can request: + +```bash +--compile-mode=max-autotune-no-cudagraphs +``` + +Why this helps: on the d16 probe, this compile mode was about 2.5% faster than +default `torch.compile` for the candidate recipe. + +### Skip Initial Eval + +Adds `--skip-initial-eval`. This avoids spending benchmark wall time on the +step-0 validation pass when it is not needed for a speedrun submission. + +## `runs/speedrun.sh` + +Updates the default speedrun command to use the winning recipe flags: + +- FP8 +- total batch size `1048576` +- Muon+ +- row equilibration +- bigram factor 5 +- scalar LR `0.3` +- log every 50 training steps +- `max-autotune-no-cudagraphs` compile mode + +This script is the intended entry point for reproducing the submitted run. + +## `tests/test_engine.py` + +Adds coverage for preserving `prev_token` through KV-cache prefill/expansion. + +Why this matters: the bigram feature must behave consistently during generation. +The test guards the cache state required for single-token decode. + +## `dev/bigram_speedrun_results.md` + +Records the validation and throughput evidence used to justify the recipe: + +- minimal branch sanity check against the prior candidate branch +- full d16 comparison against upstream dense +- controlled d16 throughput comparison +- compile-mode probe +- test status + +This is supporting documentation for the PR, not code required at runtime. + +## Submission Readiness + +Completed checks: + +- `python -m pytest tests/test_engine.py -q` +- `python -m py_compile nanochat/gpt.py nanochat/optim.py scripts/base_train.py nanochat/engine.py` +- `git diff --check` + +The remaining work is operational: run the final benchmark on the 8xH100 system +from this branch and include the measured result in the submission PR. From 0de3a399108c49e4c7a4a7bebebc5a9ede5c18ab Mon Sep 17 00:00:00 2001 From: Codex Date: Thu, 7 May 2026 09:02:55 +0000 Subject: [PATCH 4/8] Set speedrun default to d22 bigram recipe --- dev/bigram_minimal_pr_changes.md | 11 ++++++++++- runs/speedrun.sh | 15 ++++++++++----- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/dev/bigram_minimal_pr_changes.md b/dev/bigram_minimal_pr_changes.md index 24d8ecf1..bdc11ff2 100644 --- a/dev/bigram_minimal_pr_changes.md +++ b/dev/bigram_minimal_pr_changes.md @@ -6,12 +6,17 @@ best-performing speedrun recipe: ```bash --fp8 +--depth=22 +--num-iterations=11600 +--total-batch-size=524288 --bigram-embed-factor=5 --muon-plus --muon-eq=row --scalar-lr=0.3 --train-log-every=50 --compile-mode=max-autotune-no-cudagraphs +--eval-every=250 +--core-metric-every=5800 ``` It does not include the experimental branches that were tested and rejected: @@ -167,13 +172,17 @@ step-0 validation pass when it is not needed for a speedrun submission. Updates the default speedrun command to use the winning recipe flags: - FP8 -- total batch size `1048576` +- depth `22` +- fixed `11600` optimizer steps +- total batch size `524288` - Muon+ - row equilibration - bigram factor 5 - scalar LR `0.3` - log every 50 training steps - `max-autotune-no-cudagraphs` compile mode +- validation every 250 steps +- one CORE metric pass halfway through at step 5800 This script is the intended entry point for reproducing the submitted run. diff --git a/runs/speedrun.sh b/runs/speedrun.sh index 9f780faf..8dab8cf0 100644 --- a/runs/speedrun.sh +++ b/runs/speedrun.sh @@ -69,12 +69,15 @@ python -m scripts.tok_eval echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID -# d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 8) +# d22 Muon+/row-eq + hashed bigram recipe. +# This is the submission default: fixed 11,600 optimizer steps, eval every 250, +# and one in-training CORE pass halfway through. torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ - --depth=24 \ - --target-param-data-ratio=8 \ - --device-batch-size=16 \ - --total-batch-size=1048576 \ + --depth=22 \ + --num-iterations=11600 \ + --target-param-data-ratio=11 \ + --device-batch-size=32 \ + --total-batch-size=524288 \ --fp8 \ --compile-mode=max-autotune-no-cudagraphs \ --muon-plus \ @@ -82,6 +85,8 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ --bigram-embed-factor=5 \ --scalar-lr=0.3 \ --train-log-every=50 \ + --eval-every=250 \ + --core-metric-every=5800 \ --run=$WANDB_RUN # evaluate the model: CORE metric, BPB on train/val, and draw samples torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16 From 0393a2c13f0bcd93bac193b34d556aa9566f08c5 Mon Sep 17 00:00:00 2001 From: Codex Date: Thu, 7 May 2026 09:15:47 +0000 Subject: [PATCH 5/8] Make d22 bigram recipe the training default --- dev/bigram_minimal_pr_changes.md | 22 +++++++++++---------- runs/speedrun.sh | 21 +++----------------- scripts/base_train.py | 33 +++++++++++++++++--------------- 3 files changed, 33 insertions(+), 43 deletions(-) diff --git a/dev/bigram_minimal_pr_changes.md b/dev/bigram_minimal_pr_changes.md index bdc11ff2..59c3a2d1 100644 --- a/dev/bigram_minimal_pr_changes.md +++ b/dev/bigram_minimal_pr_changes.md @@ -2,7 +2,7 @@ This branch is based on upstream nanochat master at `dc54a1a`. The goal is to keep the submission patch limited to the changes needed to reproduce the -best-performing speedrun recipe: +best-performing speedrun recipe. These are the `scripts/base_train.py` defaults: ```bash --fp8 @@ -129,18 +129,19 @@ Adds: - `--bigram-lambda-lr` These configure the bigram residual and its optimizer treatment from the -training script without changing defaults. With default values, upstream -behavior is unchanged because `--bigram-embed-factor` defaults to `0`. +training script. The submission default is `--bigram-embed-factor=5`. ### Muon Variant Flags Adds: - `--muon-plus` +- `--no-muon-plus` - `--muon-eq` -These expose the optimizer variants used in the recipe. Defaults preserve the -original optimizer behavior. +These expose the optimizer variants used in the recipe. The submission defaults +are Muon+ enabled and `--muon-eq=row`. `--no-muon-plus --muon-eq=none` restores +the original Muon path. ### Train Logging Cadence @@ -148,7 +149,7 @@ Adds `--train-log-every`. Values greater than 1 avoid converting the loss tensor to a Python scalar every step. Why this helps: per-step logging creates extra synchronization overhead. The -speedrun uses `--train-log-every=50`, which keeps useful progress reporting +submission default is `--train-log-every=50`, which keeps useful progress reporting while reducing logging overhead. ### Compile Mode @@ -160,16 +161,17 @@ Adds `--compile-mode` so the speedrun can request: ``` Why this helps: on the d16 probe, this compile mode was about 2.5% faster than -default `torch.compile` for the candidate recipe. +default `torch.compile` for the candidate recipe. It is now the submission +default. ### Skip Initial Eval -Adds `--skip-initial-eval`. This avoids spending benchmark wall time on the -step-0 validation pass when it is not needed for a speedrun submission. +Adds `--skip-initial-eval` and `--initial-eval`. The submission default skips +the step-0 validation pass; `--initial-eval` restores the original behavior. ## `runs/speedrun.sh` -Updates the default speedrun command to use the winning recipe flags: +Uses the `scripts/base_train.py` submission defaults: - FP8 - depth `22` diff --git a/runs/speedrun.sh b/runs/speedrun.sh index 8dab8cf0..9a4c3977 100644 --- a/runs/speedrun.sh +++ b/runs/speedrun.sh @@ -70,24 +70,9 @@ echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID # d22 Muon+/row-eq + hashed bigram recipe. -# This is the submission default: fixed 11,600 optimizer steps, eval every 250, -# and one in-training CORE pass halfway through. -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ - --depth=22 \ - --num-iterations=11600 \ - --target-param-data-ratio=11 \ - --device-batch-size=32 \ - --total-batch-size=524288 \ - --fp8 \ - --compile-mode=max-autotune-no-cudagraphs \ - --muon-plus \ - --muon-eq=row \ - --bigram-embed-factor=5 \ - --scalar-lr=0.3 \ - --train-log-every=50 \ - --eval-every=250 \ - --core-metric-every=5800 \ - --run=$WANDB_RUN +# scripts/base_train defaults are the submission defaults: fixed 11,600 +# optimizer steps, eval every 250, and one in-training CORE pass halfway through. +torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --run=$WANDB_RUN # evaluate the model: CORE metric, BPB on train/val, and draw samples torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16 diff --git a/scripts/base_train.py b/scripts/base_train.py index 56fbae2b..d0e3780c 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -8,7 +8,7 @@ or distributed as: torchrun --nproc_per_node=8 -m scripts.base_train If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example: -python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 +python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 --no-fp8 --no-muon-plus --muon-eq=none --bigram-embed-factor=0 """ import os @@ -41,37 +41,39 @@ print_banner() parser = argparse.ArgumentParser(description="Pretrain base model") # Logging parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") -parser.add_argument("--train-log-every", type=int, default=1, help="print training metrics every N steps; values >1 avoid per-step CPU/GPU sync") +parser.add_argument("--train-log-every", type=int, default=50, help="print training metrics every N steps; values >1 avoid per-step CPU/GPU sync") # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") # FP8 training -parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)") +parser.add_argument("--fp8", dest="fp8", action="store_true", default=True, help="enable FP8 training (requires H100+ GPU and torchao)") +parser.add_argument("--no-fp8", dest="fp8", action="store_false", help="disable FP8 training") parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") -parser.add_argument("--compile-mode", type=str, default="", choices=["", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], help="optional torch.compile mode") +parser.add_argument("--compile-mode", type=str, default="max-autotune-no-cudagraphs", choices=["", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], help="optional torch.compile mode") # Model architecture -parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") +parser.add_argument("--depth", type=int, default=22, help="depth of the Transformer model") parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio") parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") -parser.add_argument("--bigram-embed-factor", type=int, default=0, help="if >0, add a hashed bigram embedding residual") +parser.add_argument("--bigram-embed-factor", type=int, default=5, help="if >0, add a hashed bigram embedding residual") parser.add_argument("--bigram-lambda-init", type=float, default=0.05, help="initial layer residual scale for --bigram-embed-factor") parser.add_argument("--bigram-embedding-lr-mult", type=float, default=1.0, help="bigram embedding LR multiplier relative to --embedding-lr") parser.add_argument("--bigram-lambda-lr", type=float, default=0.004, help="AdamW LR for bigram layer lambdas before dmodel schedule scaling") # Training horizon (only one used, in order of precedence) -parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") +parser.add_argument("--num-iterations", type=int, default=11600, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") -parser.add_argument("--target-param-data-ratio", type=float, default=12, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") +parser.add_argument("--target-param-data-ratio", type=float, default=11, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") # Optimization parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.") -parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)") +parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)") parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") parser.add_argument("--unembedding-lr", type=float, default=0.008, help="learning rate for unembedding parameters (Adam)") parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious weight decay for the Muon optimizer (for weights)") parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") -parser.add_argument("--muon-plus", action="store_true", help="apply Muon+ style post-orthogonalization Frobenius renormalization") -parser.add_argument("--muon-eq", type=str, default="none", choices=["none", "row", "col"], help="apply MuonEq-style row/column equilibration before orthogonalization") +parser.add_argument("--scalar-lr", type=float, default=0.3, help="learning rate for scalars (resid_lambdas, x0_lambdas)") +parser.add_argument("--muon-plus", dest="muon_plus", action="store_true", default=True, help="apply Muon+ style post-orthogonalization Frobenius renormalization") +parser.add_argument("--no-muon-plus", dest="muon_plus", action="store_false", help="disable Muon+ post-orthogonalization renormalization") +parser.add_argument("--muon-eq", type=str, default="row", choices=["none", "row", "col"], help="apply MuonEq-style row/column equilibration before orthogonalization") parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup") parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown") parser.add_argument("--final-lr-frac", type=float, default=0.05, help="final LR as fraction of initial LR") @@ -79,10 +81,11 @@ parser.add_argument("--resume-from-step", type=int, default=-1, help="resume tra # Evaluation parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") parser.add_argument("--eval-tokens", type=int, default=80*524288, help="number of tokens to evaluate val loss on") -parser.add_argument("--skip-initial-eval", action="store_true", help="skip the step 0 validation pass; final validation still runs") -parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)") +parser.add_argument("--skip-initial-eval", dest="skip_initial_eval", action="store_true", default=True, help="skip the step 0 validation pass; final validation still runs") +parser.add_argument("--initial-eval", dest="skip_initial_eval", action="store_false", help="run validation at step 0") +parser.add_argument("--core-metric-every", type=int, default=5800, help="evaluate CORE metric every N steps (-1 = disable)") parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric") -parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") +parser.add_argument("--sample-every", type=int, default=-1, help="sample from model every N steps (-1 = disable)") parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") # Output parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") From d6a169b3290cd60a7bd84a254d98d96676180dc8 Mon Sep 17 00:00:00 2001 From: Codex Date: Thu, 7 May 2026 12:15:53 +0000 Subject: [PATCH 6/8] Remove dev PR notes --- dev/bigram_minimal_pr_changes.md | 219 ------------------------------- dev/bigram_speedrun_results.md | 83 ------------ 2 files changed, 302 deletions(-) delete mode 100644 dev/bigram_minimal_pr_changes.md delete mode 100644 dev/bigram_speedrun_results.md diff --git a/dev/bigram_minimal_pr_changes.md b/dev/bigram_minimal_pr_changes.md deleted file mode 100644 index 59c3a2d1..00000000 --- a/dev/bigram_minimal_pr_changes.md +++ /dev/null @@ -1,219 +0,0 @@ -# Minimal Bigram Speedrun PR Changes - -This branch is based on upstream nanochat master at `dc54a1a`. The goal is to -keep the submission patch limited to the changes needed to reproduce the -best-performing speedrun recipe. These are the `scripts/base_train.py` defaults: - -```bash ---fp8 ---depth=22 ---num-iterations=11600 ---total-batch-size=524288 ---bigram-embed-factor=5 ---muon-plus ---muon-eq=row ---scalar-lr=0.3 ---train-log-every=50 ---compile-mode=max-autotune-no-cudagraphs ---eval-every=250 ---core-metric-every=5800 -``` - -It does not include the experimental branches that were tested and rejected: -sparse architecture changes, MoE/TOP auxiliary losses, train-time logit-bias -losses, post-hoc calibration, NorMuon variants, checkpoint merging, or d22/d24 -run-management scripts. - -## `nanochat/gpt.py` - -### Hashed Bigram Residual Embedding - -Adds two config fields: - -- `bigram_embed_factor`, default `0` -- `bigram_lambda_init`, default `0.05` - -When `bigram_embed_factor > 0`, the model creates a separate bigram embedding -table with `vocab_size * bigram_embed_factor` entries. For each token position, -the current token id and previous token id are hashed into that table. The -resulting embedding is added as a residual input before every transformer block: - -```python -x = x + bigram_lambdas[i] * x0_bigram -``` - -The first token in each sequence uses a sentinel bucket because it has no -previous token. During KV-cache decoding, the previous token is read from the -cache so generation matches the training-time bigram definition. - -Why this helps: it gives the model a cheap, direct representation of adjacent -token pairs without adding attention or MLP compute. The bigram table is -zero-initialized, so the model starts from the original network function, while -the per-layer `bigram_lambdas` start at `0.05` to let the residual learn quickly. - -### Parameter Counting and FLOP Accounting - -The bigram embedding table and bigram lambdas are excluded from the main matmul -FLOP/scaling parameter count. They are not transformer matrix weights, and -including them would distort the target param/data ratio logic. - -### Optimizer Groups - -Adds dedicated optimizer groups for: - -- `bigram_embed` -- `bigram_lambdas` - -The bigram embedding uses AdamW with a configurable multiplier relative to the -main embedding LR. The layer lambdas use a small AdamW LR. This keeps the bigram -residual trainable without mixing it into the Muon-managed transformer matrices. - -### Muon Options Plumbed Through - -`setup_optimizer()` accepts: - -- `muon_plus` -- `muon_eq_axis` - -These are forwarded into the Muon parameter groups so the optimizer can apply -the selected Muon variants to matrix weights. - -## `nanochat/optim.py` - -### Muon+ Renormalization - -After Newton-Schulz orthogonalization, Muon+ rescales the update by its -Frobenius norm. This is a small post-processing step on the Muon update and was -the strongest optimizer-side change in the experiments. - -Why this helps: it stabilizes update scale after orthogonalization without -changing the model architecture or adding optimizer state. - -### Row/Column Equilibration - -Adds optional row or column norm equilibration before orthogonalization: - -- `muon_eq_axis=1`: row equilibration -- `muon_eq_axis=2`: column equilibration -- `muon_eq_axis=0`: disabled - -The speedrun recipe uses row equilibration. It normalizes rows toward a common -target norm before the polar/Newton-Schulz step, then continues through the -existing Muon update path. - -Why this helps: row equilibration was a small but positive companion to Muon+ in -the winning recipe, with minimal extra code and no extra persistent optimizer -state. - -## `nanochat/engine.py` - -### Previous Token in KV Cache - -Adds `prev_token` to `KVCache`, resets it with the rest of the cache, and copies -it during prefill expansion. - -Why this is needed: full-sequence training can compute bigram hashes from -`idx[:, :-1]`, but one-token decode does not have the previous token in the -current input tensor. Keeping `prev_token` in the cache makes generation use the -same bigram feature as training. - -## `scripts/base_train.py` - -### Bigram CLI Flags - -Adds: - -- `--bigram-embed-factor` -- `--bigram-lambda-init` -- `--bigram-embedding-lr-mult` -- `--bigram-lambda-lr` - -These configure the bigram residual and its optimizer treatment from the -training script. The submission default is `--bigram-embed-factor=5`. - -### Muon Variant Flags - -Adds: - -- `--muon-plus` -- `--no-muon-plus` -- `--muon-eq` - -These expose the optimizer variants used in the recipe. The submission defaults -are Muon+ enabled and `--muon-eq=row`. `--no-muon-plus --muon-eq=none` restores -the original Muon path. - -### Train Logging Cadence - -Adds `--train-log-every`. Values greater than 1 avoid converting the loss tensor -to a Python scalar every step. - -Why this helps: per-step logging creates extra synchronization overhead. The -submission default is `--train-log-every=50`, which keeps useful progress reporting -while reducing logging overhead. - -### Compile Mode - -Adds `--compile-mode` so the speedrun can request: - -```bash ---compile-mode=max-autotune-no-cudagraphs -``` - -Why this helps: on the d16 probe, this compile mode was about 2.5% faster than -default `torch.compile` for the candidate recipe. It is now the submission -default. - -### Skip Initial Eval - -Adds `--skip-initial-eval` and `--initial-eval`. The submission default skips -the step-0 validation pass; `--initial-eval` restores the original behavior. - -## `runs/speedrun.sh` - -Uses the `scripts/base_train.py` submission defaults: - -- FP8 -- depth `22` -- fixed `11600` optimizer steps -- total batch size `524288` -- Muon+ -- row equilibration -- bigram factor 5 -- scalar LR `0.3` -- log every 50 training steps -- `max-autotune-no-cudagraphs` compile mode -- validation every 250 steps -- one CORE metric pass halfway through at step 5800 - -This script is the intended entry point for reproducing the submitted run. - -## `tests/test_engine.py` - -Adds coverage for preserving `prev_token` through KV-cache prefill/expansion. - -Why this matters: the bigram feature must behave consistently during generation. -The test guards the cache state required for single-token decode. - -## `dev/bigram_speedrun_results.md` - -Records the validation and throughput evidence used to justify the recipe: - -- minimal branch sanity check against the prior candidate branch -- full d16 comparison against upstream dense -- controlled d16 throughput comparison -- compile-mode probe -- test status - -This is supporting documentation for the PR, not code required at runtime. - -## Submission Readiness - -Completed checks: - -- `python -m pytest tests/test_engine.py -q` -- `python -m py_compile nanochat/gpt.py nanochat/optim.py scripts/base_train.py nanochat/engine.py` -- `git diff --check` - -The remaining work is operational: run the final benchmark on the 8xH100 system -from this branch and include the measured result in the submission PR. diff --git a/dev/bigram_speedrun_results.md b/dev/bigram_speedrun_results.md deleted file mode 100644 index 436da647..00000000 --- a/dev/bigram_speedrun_results.md +++ /dev/null @@ -1,83 +0,0 @@ -# Bigram Speedrun Verification Notes - -This branch is based on upstream nanochat master at `dc54a1a` and keeps the -submission implementation focused on the winning recipe: - -- per-layer hashed bigram residual embeddings -- Muon+ post-orthogonalization normalization -- row equilibration before Muon orthogonalization -- lower scalar LR (`--scalar-lr=0.3`) -- batched training logging (`--train-log-every=50`) -- `torch.compile(..., mode="max-autotune-no-cudagraphs")` for the speedrun script - -It intentionally excludes the experimental branches that were not part of the -final candidate: sparse layers, MoE/TOP losses, train-time logit bias losses, -post-hoc fitting, NorMuon, and checkpoint merging. - -## Reproduction Sanity Check - -Minimal branch d4/20 matched the prior experimental branch: - -| Run | Step 0 BPB | Step 10 BPB | Final BPB | -| --- | ---: | ---: | ---: | -| Prior candidate branch | `3.237224` | `3.234722` | `3.223259` | -| Minimal PR branch | `3.237224` | `3.234722` | `3.223286` | - -The final difference is `0.000027` BPB on a tiny run, consistent with small -compile/graph differences after removing unused experimental code. - -## Full d16 Verification - -Both runs used d16, FP8, target param/data ratio 8, total batch `524288`, and -device batch `32` on the same machine. - -| Run | Final BPB | Train time | Avg logged tok/s, excluding first | Avg logged step time, excluding first | -| --- | ---: | ---: | ---: | ---: | -| Upstream master dense | `0.800673` | `94.64m` | `329,904` | `1589.232ms` | -| Bigram/Muon+ candidate | `0.798000` | `93.61m` | `333,507` | `1572.058ms` | - -Candidate delta versus upstream master dense: - -- BPB: `-0.002673` -- train time: `-1.03m` (`1.09%` faster) -- logged throughput: `+3,603 tok/s` (`1.09%` higher) - -Important caveat: this is a full recipe comparison, not an architecture-only -comparison. The candidate also uses `--train-log-every=50` and -`--compile-mode=max-autotune-no-cudagraphs`, while upstream master logs every -step and uses the default compile mode. - -## Controlled d16 Throughput - -A denser control run with the same log50/compile-control style is the better -way to estimate the per-step overhead of the bigram path. - -| Run | Final BPB | Train time | Avg logged tok/s, excluding first | Avg logged step time, excluding first | -| --- | ---: | ---: | ---: | ---: | -| Dense log50 compile control | `0.800604` | `92.85m` | `336,247` | `1559.258ms` | -| Bigram/Muon+ candidate, full 3584 | `0.798000` | `93.61m` | `333,507` | `1572.058ms` | - -Against this controlled dense run, the bigram candidate is about `0.81%` slower -per step, but `0.002604` BPB better at the same horizon. - -A shortened bigram run at 3400 steps landed at `0.800232` BPB in `88.92m`, -which is `0.000372` BPB better than the dense log50 compile control while using -about `4.23%` less training time. - -## Compile Mode Probe - -Short d16/40 throughput probes on the minimal branch: - -| Compile mode | Avg logged tok/s, excluding first | Avg logged step time, excluding first | Total time | -| --- | ---: | ---: | ---: | -| default `torch.compile` | `324,995` | `1613.250ms` | `0.78m` | -| `max-autotune-no-cudagraphs` | `333,261` | `1573.250ms` | `0.76m` | - -On this d16 probe, `max-autotune-no-cudagraphs` was about `2.5%` faster than -the default compile mode. The speedrun script keeps this compile mode for that -reason. - -## Test Status - -- `python -m pytest tests/test_engine.py -q`: `9 passed` -- `python -m py_compile nanochat/gpt.py nanochat/optim.py scripts/base_train.py nanochat/engine.py`: passed From 4680e799fbccdb861e0d2ef11f8e2d7e99805e3e Mon Sep 17 00:00:00 2001 From: Codex Date: Thu, 7 May 2026 12:24:51 +0000 Subject: [PATCH 7/8] Keep speedrun changes functional --- runs/speedrun.sh | 16 ++++- scripts/base_train.py | 159 ++++++++++++++---------------------------- 2 files changed, 67 insertions(+), 108 deletions(-) diff --git a/runs/speedrun.sh b/runs/speedrun.sh index 9a4c3977..20e62488 100644 --- a/runs/speedrun.sh +++ b/runs/speedrun.sh @@ -70,9 +70,19 @@ echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID # d22 Muon+/row-eq + hashed bigram recipe. -# scripts/base_train defaults are the submission defaults: fixed 11,600 -# optimizer steps, eval every 250, and one in-training CORE pass halfway through. -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ + --run=$WANDB_RUN \ + --fp8 \ + --depth=22 \ + --num-iterations=11600 \ + --target-param-data-ratio=11 \ + --total-batch-size=524288 \ + --scalar-lr=0.3 \ + --bigram-embed-factor=5 \ + --muon-plus \ + --muon-eq=row \ + --core-metric-every=5800 \ + --sample-every=-1 # evaluate the model: CORE metric, BPB on train/val, and draw samples torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16 diff --git a/scripts/base_train.py b/scripts/base_train.py index d0e3780c..b415005c 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -8,7 +8,7 @@ or distributed as: torchrun --nproc_per_node=8 -m scripts.base_train If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example: -python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 --no-fp8 --no-muon-plus --muon-eq=none --bigram-embed-factor=0 +python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 """ import os @@ -41,39 +41,32 @@ print_banner() parser = argparse.ArgumentParser(description="Pretrain base model") # Logging parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") -parser.add_argument("--train-log-every", type=int, default=50, help="print training metrics every N steps; values >1 avoid per-step CPU/GPU sync") # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") # FP8 training -parser.add_argument("--fp8", dest="fp8", action="store_true", default=True, help="enable FP8 training (requires H100+ GPU and torchao)") -parser.add_argument("--no-fp8", dest="fp8", action="store_false", help="disable FP8 training") +parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)") parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") -parser.add_argument("--compile-mode", type=str, default="max-autotune-no-cudagraphs", choices=["", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], help="optional torch.compile mode") # Model architecture -parser.add_argument("--depth", type=int, default=22, help="depth of the Transformer model") +parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio") parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") -parser.add_argument("--bigram-embed-factor", type=int, default=5, help="if >0, add a hashed bigram embedding residual") -parser.add_argument("--bigram-lambda-init", type=float, default=0.05, help="initial layer residual scale for --bigram-embed-factor") -parser.add_argument("--bigram-embedding-lr-mult", type=float, default=1.0, help="bigram embedding LR multiplier relative to --embedding-lr") -parser.add_argument("--bigram-lambda-lr", type=float, default=0.004, help="AdamW LR for bigram layer lambdas before dmodel schedule scaling") +parser.add_argument("--bigram-embed-factor", type=int, default=0, help="if >0, add a hashed bigram embedding residual") # Training horizon (only one used, in order of precedence) -parser.add_argument("--num-iterations", type=int, default=11600, help="explicit number of optimization steps (-1 = disable)") +parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") -parser.add_argument("--target-param-data-ratio", type=float, default=11, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") +parser.add_argument("--target-param-data-ratio", type=float, default=12, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") # Optimization parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.") -parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)") +parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)") parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") parser.add_argument("--unembedding-lr", type=float, default=0.008, help="learning rate for unembedding parameters (Adam)") parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious weight decay for the Muon optimizer (for weights)") parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--scalar-lr", type=float, default=0.3, help="learning rate for scalars (resid_lambdas, x0_lambdas)") -parser.add_argument("--muon-plus", dest="muon_plus", action="store_true", default=True, help="apply Muon+ style post-orthogonalization Frobenius renormalization") -parser.add_argument("--no-muon-plus", dest="muon_plus", action="store_false", help="disable Muon+ post-orthogonalization renormalization") -parser.add_argument("--muon-eq", type=str, default="row", choices=["none", "row", "col"], help="apply MuonEq-style row/column equilibration before orthogonalization") +parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") +parser.add_argument("--muon-plus", action="store_true", help="apply Muon+ style post-orthogonalization Frobenius renormalization") +parser.add_argument("--muon-eq", type=str, default="none", choices=["none", "row", "col"], help="apply MuonEq-style row/column equilibration before orthogonalization") parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup") parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown") parser.add_argument("--final-lr-frac", type=float, default=0.05, help="final LR as fraction of initial LR") @@ -81,24 +74,16 @@ parser.add_argument("--resume-from-step", type=int, default=-1, help="resume tra # Evaluation parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") parser.add_argument("--eval-tokens", type=int, default=80*524288, help="number of tokens to evaluate val loss on") -parser.add_argument("--skip-initial-eval", dest="skip_initial_eval", action="store_true", default=True, help="skip the step 0 validation pass; final validation still runs") -parser.add_argument("--initial-eval", dest="skip_initial_eval", action="store_false", help="run validation at step 0") -parser.add_argument("--core-metric-every", type=int, default=5800, help="evaluate CORE metric every N steps (-1 = disable)") +parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)") parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric") -parser.add_argument("--sample-every", type=int, default=-1, help="sample from model every N steps (-1 = disable)") +parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") # Output parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") args = parser.parse_args() user_config = vars(args).copy() # for logging -if args.train_log_every <= 0: - parser.error("--train-log-every must be positive") if args.bigram_embed_factor < 0: parser.error("--bigram-embed-factor must be non-negative") -if args.bigram_lambda_lr < 0: - parser.error("--bigram-lambda-lr must be non-negative") -if args.bigram_embedding_lr_mult <= 0: - parser.error("--bigram-embedding-lr-mult must be positive") # ----------------------------------------------------------------------------- # Compute init and wandb logging @@ -158,7 +143,6 @@ def build_model_meta(depth): n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, window_pattern=args.window_pattern, bigram_embed_factor=args.bigram_embed_factor, - bigram_lambda_init=args.bigram_lambda_init, ) with torch.device("meta"): model_meta = GPT(config) @@ -265,10 +249,7 @@ def disable_fp8(model): # Compile the model orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) -compile_kwargs = {"dynamic": False} -if args.compile_mode: - compile_kwargs["mode"] = args.compile_mode -model = torch.compile(model, **compile_kwargs) # the inputs to model will never change shape so dynamic=False is safe +model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe # ----------------------------------------------------------------------------- # Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay. @@ -336,8 +317,7 @@ optimizer = model.setup_optimizer( # AdamW hyperparameters unembedding_lr=args.unembedding_lr * batch_lr_scale, embedding_lr=args.embedding_lr * batch_lr_scale, - bigram_embedding_lr_mult=args.bigram_embedding_lr_mult, - bigram_lambda_lr=args.bigram_lambda_lr * batch_lr_scale, + bigram_lambda_lr=0.004 * batch_lr_scale, scalar_lr=args.scalar_lr * batch_lr_scale, # Muon hyperparameters matrix_lr=args.matrix_lr * batch_lr_scale, @@ -442,11 +422,6 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") -train_log_every = args.train_log_every -batched_train_timing = train_log_every > 1 -train_timing_interval_start = None -train_timing_interval_first_step = step -train_log_count = 0 # Go! while True: @@ -454,7 +429,7 @@ while True: flops_so_far = num_flops_per_token * total_batch_size * step # once in a while: evaluate the val bpb (all ranks participate) - if args.eval_every > 0 and (last_step or (step % args.eval_every == 0 and (step > 0 or not args.skip_initial_eval))): + if args.eval_every > 0 and (last_step or step % args.eval_every == 0): model.eval() val_loader = build_val_loader() eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) @@ -541,14 +516,8 @@ while True: # ------------------------------------------------------------------------- # single training step # evaluate the gradient - if batched_train_timing: - if train_timing_interval_start is None: - synchronize() - train_timing_interval_start = time.time() - train_timing_interval_first_step = step - else: - synchronize() - t0 = time.time() + synchronize() + t0 = time.time() for micro_step in range(grad_accum_steps): loss = model(x, y) train_loss = loss.detach() # for logging @@ -580,66 +549,46 @@ while True: else: optimizer.step() model.zero_grad(set_to_none=True) - should_log_train = step == 0 or (step + 1) % train_log_every == 0 or (step + 1) == num_iterations - if batched_train_timing: - if should_log_train: - synchronize() - t1 = time.time() - interval_steps = step - train_timing_interval_first_step + 1 - interval_dt = t1 - train_timing_interval_start - dt = interval_dt / interval_steps - counted_start = max(train_timing_interval_first_step, 11) - counted_steps = max(0, step - counted_start + 1) - if counted_steps > 0: - total_training_time += interval_dt * counted_steps / interval_steps - train_loss_f = train_loss.item() - train_timing_interval_start = None - else: - dt = None - train_loss_f = None - else: - train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point - synchronize() - t1 = time.time() - dt = t1 - t0 - if step > 10: - total_training_time += dt # only count the time after the first 10 steps + train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point + synchronize() + t1 = time.time() + dt = t1 - t0 # ------------------------------------------------------------------------- # logging (CPU action only) - if should_log_train: - ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging - smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss - train_log_count += 1 - debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**train_log_count) # debias the EMA - pct_done = 100 * step / num_iterations - tok_per_sec = int(total_batch_size / dt) - flops_per_sec = num_flops_per_token * total_batch_size / dt - mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) - # Calculate ETA based on average time per step (excluding first 10 steps) - steps_done = step - 10 - if steps_done > 0: - avg_time_per_step = total_training_time / steps_done - remaining_steps = num_iterations - step - eta_seconds = remaining_steps * avg_time_per_step - eta_str = f" | eta: {eta_seconds/60:.1f}m" - else: - eta_str = "" - epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}" - print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") - if step % 100 == 0 or (step + 1) % 100 == 0: - log_data = { - "step": step, - "total_training_flops": flops_so_far, - "total_training_time": total_training_time, - "train/loss": debiased_smooth_loss, - "train/lrm": lrm, - "train/dt": dt, - "train/tok_per_sec": tok_per_sec, - "train/mfu": mfu, - "train/epoch": epoch, - } - wandb_run.log(log_data) + ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging + smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss + debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA + pct_done = 100 * step / num_iterations + tok_per_sec = int(total_batch_size / dt) + flops_per_sec = num_flops_per_token * total_batch_size / dt + mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) + if step > 10: + total_training_time += dt # only count the time after the first 10 steps + # Calculate ETA based on average time per step (excluding first 10 steps) + steps_done = step - 10 + if steps_done > 0: + avg_time_per_step = total_training_time / steps_done + remaining_steps = num_iterations - step + eta_seconds = remaining_steps * avg_time_per_step + eta_str = f" | eta: {eta_seconds/60:.1f}m" + else: + eta_str = "" + epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}" + print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") + if step % 100 == 0: + log_data = { + "step": step, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "train/loss": debiased_smooth_loss, + "train/lrm": lrm, + "train/dt": dt, + "train/tok_per_sec": tok_per_sec, + "train/mfu": mfu, + "train/epoch": epoch, + } + wandb_run.log(log_data) # state update first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step) From 412e9a1cbc0195ae62bda6ae8969fce8e58577b2 Mon Sep 17 00:00:00 2001 From: Codex Date: Thu, 7 May 2026 12:26:57 +0000 Subject: [PATCH 8/8] Use speedrun compile mode --- scripts/base_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index b415005c..adf28d35 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -249,7 +249,7 @@ def disable_fp8(model): # Compile the model orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) -model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe +model = torch.compile(model, dynamic=False, mode="max-autotune-no-cudagraphs") # the inputs to model will never change shape so dynamic=False is safe # ----------------------------------------------------------------------------- # Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.