From 0393a2c13f0bcd93bac193b34d556aa9566f08c5 Mon Sep 17 00:00:00 2001 From: Codex Date: Thu, 7 May 2026 09:15:47 +0000 Subject: [PATCH] Make d22 bigram recipe the training default --- dev/bigram_minimal_pr_changes.md | 22 +++++++++++---------- runs/speedrun.sh | 21 +++----------------- scripts/base_train.py | 33 +++++++++++++++++--------------- 3 files changed, 33 insertions(+), 43 deletions(-) diff --git a/dev/bigram_minimal_pr_changes.md b/dev/bigram_minimal_pr_changes.md index bdc11ff2..59c3a2d1 100644 --- a/dev/bigram_minimal_pr_changes.md +++ b/dev/bigram_minimal_pr_changes.md @@ -2,7 +2,7 @@ This branch is based on upstream nanochat master at `dc54a1a`. The goal is to keep the submission patch limited to the changes needed to reproduce the -best-performing speedrun recipe: +best-performing speedrun recipe. These are the `scripts/base_train.py` defaults: ```bash --fp8 @@ -129,18 +129,19 @@ Adds: - `--bigram-lambda-lr` These configure the bigram residual and its optimizer treatment from the -training script without changing defaults. With default values, upstream -behavior is unchanged because `--bigram-embed-factor` defaults to `0`. +training script. The submission default is `--bigram-embed-factor=5`. ### Muon Variant Flags Adds: - `--muon-plus` +- `--no-muon-plus` - `--muon-eq` -These expose the optimizer variants used in the recipe. Defaults preserve the -original optimizer behavior. +These expose the optimizer variants used in the recipe. The submission defaults +are Muon+ enabled and `--muon-eq=row`. `--no-muon-plus --muon-eq=none` restores +the original Muon path. ### Train Logging Cadence @@ -148,7 +149,7 @@ Adds `--train-log-every`. Values greater than 1 avoid converting the loss tensor to a Python scalar every step. Why this helps: per-step logging creates extra synchronization overhead. The -speedrun uses `--train-log-every=50`, which keeps useful progress reporting +submission default is `--train-log-every=50`, which keeps useful progress reporting while reducing logging overhead. ### Compile Mode @@ -160,16 +161,17 @@ Adds `--compile-mode` so the speedrun can request: ``` Why this helps: on the d16 probe, this compile mode was about 2.5% faster than -default `torch.compile` for the candidate recipe. +default `torch.compile` for the candidate recipe. It is now the submission +default. ### Skip Initial Eval -Adds `--skip-initial-eval`. This avoids spending benchmark wall time on the -step-0 validation pass when it is not needed for a speedrun submission. +Adds `--skip-initial-eval` and `--initial-eval`. The submission default skips +the step-0 validation pass; `--initial-eval` restores the original behavior. ## `runs/speedrun.sh` -Updates the default speedrun command to use the winning recipe flags: +Uses the `scripts/base_train.py` submission defaults: - FP8 - depth `22` diff --git a/runs/speedrun.sh b/runs/speedrun.sh index 8dab8cf0..9a4c3977 100644 --- a/runs/speedrun.sh +++ b/runs/speedrun.sh @@ -70,24 +70,9 @@ echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID # d22 Muon+/row-eq + hashed bigram recipe. -# This is the submission default: fixed 11,600 optimizer steps, eval every 250, -# and one in-training CORE pass halfway through. -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ - --depth=22 \ - --num-iterations=11600 \ - --target-param-data-ratio=11 \ - --device-batch-size=32 \ - --total-batch-size=524288 \ - --fp8 \ - --compile-mode=max-autotune-no-cudagraphs \ - --muon-plus \ - --muon-eq=row \ - --bigram-embed-factor=5 \ - --scalar-lr=0.3 \ - --train-log-every=50 \ - --eval-every=250 \ - --core-metric-every=5800 \ - --run=$WANDB_RUN +# scripts/base_train defaults are the submission defaults: fixed 11,600 +# optimizer steps, eval every 250, and one in-training CORE pass halfway through. +torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --run=$WANDB_RUN # evaluate the model: CORE metric, BPB on train/val, and draw samples torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16 diff --git a/scripts/base_train.py b/scripts/base_train.py index 56fbae2b..d0e3780c 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -8,7 +8,7 @@ or distributed as: torchrun --nproc_per_node=8 -m scripts.base_train If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example: -python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 +python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 --no-fp8 --no-muon-plus --muon-eq=none --bigram-embed-factor=0 """ import os @@ -41,37 +41,39 @@ print_banner() parser = argparse.ArgumentParser(description="Pretrain base model") # Logging parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") -parser.add_argument("--train-log-every", type=int, default=1, help="print training metrics every N steps; values >1 avoid per-step CPU/GPU sync") +parser.add_argument("--train-log-every", type=int, default=50, help="print training metrics every N steps; values >1 avoid per-step CPU/GPU sync") # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") # FP8 training -parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)") +parser.add_argument("--fp8", dest="fp8", action="store_true", default=True, help="enable FP8 training (requires H100+ GPU and torchao)") +parser.add_argument("--no-fp8", dest="fp8", action="store_false", help="disable FP8 training") parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") -parser.add_argument("--compile-mode", type=str, default="", choices=["", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], help="optional torch.compile mode") +parser.add_argument("--compile-mode", type=str, default="max-autotune-no-cudagraphs", choices=["", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], help="optional torch.compile mode") # Model architecture -parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") +parser.add_argument("--depth", type=int, default=22, help="depth of the Transformer model") parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio") parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") -parser.add_argument("--bigram-embed-factor", type=int, default=0, help="if >0, add a hashed bigram embedding residual") +parser.add_argument("--bigram-embed-factor", type=int, default=5, help="if >0, add a hashed bigram embedding residual") parser.add_argument("--bigram-lambda-init", type=float, default=0.05, help="initial layer residual scale for --bigram-embed-factor") parser.add_argument("--bigram-embedding-lr-mult", type=float, default=1.0, help="bigram embedding LR multiplier relative to --embedding-lr") parser.add_argument("--bigram-lambda-lr", type=float, default=0.004, help="AdamW LR for bigram layer lambdas before dmodel schedule scaling") # Training horizon (only one used, in order of precedence) -parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") +parser.add_argument("--num-iterations", type=int, default=11600, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") -parser.add_argument("--target-param-data-ratio", type=float, default=12, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") +parser.add_argument("--target-param-data-ratio", type=float, default=11, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") # Optimization parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.") -parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)") +parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)") parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") parser.add_argument("--unembedding-lr", type=float, default=0.008, help="learning rate for unembedding parameters (Adam)") parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious weight decay for the Muon optimizer (for weights)") parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") -parser.add_argument("--muon-plus", action="store_true", help="apply Muon+ style post-orthogonalization Frobenius renormalization") -parser.add_argument("--muon-eq", type=str, default="none", choices=["none", "row", "col"], help="apply MuonEq-style row/column equilibration before orthogonalization") +parser.add_argument("--scalar-lr", type=float, default=0.3, help="learning rate for scalars (resid_lambdas, x0_lambdas)") +parser.add_argument("--muon-plus", dest="muon_plus", action="store_true", default=True, help="apply Muon+ style post-orthogonalization Frobenius renormalization") +parser.add_argument("--no-muon-plus", dest="muon_plus", action="store_false", help="disable Muon+ post-orthogonalization renormalization") +parser.add_argument("--muon-eq", type=str, default="row", choices=["none", "row", "col"], help="apply MuonEq-style row/column equilibration before orthogonalization") parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup") parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown") parser.add_argument("--final-lr-frac", type=float, default=0.05, help="final LR as fraction of initial LR") @@ -79,10 +81,11 @@ parser.add_argument("--resume-from-step", type=int, default=-1, help="resume tra # Evaluation parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") parser.add_argument("--eval-tokens", type=int, default=80*524288, help="number of tokens to evaluate val loss on") -parser.add_argument("--skip-initial-eval", action="store_true", help="skip the step 0 validation pass; final validation still runs") -parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)") +parser.add_argument("--skip-initial-eval", dest="skip_initial_eval", action="store_true", default=True, help="skip the step 0 validation pass; final validation still runs") +parser.add_argument("--initial-eval", dest="skip_initial_eval", action="store_false", help="run validation at step 0") +parser.add_argument("--core-metric-every", type=int, default=5800, help="evaluate CORE metric every N steps (-1 = disable)") parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric") -parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") +parser.add_argument("--sample-every", type=int, default=-1, help="sample from model every N steps (-1 = disable)") parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") # Output parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")