mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-17 05:07:39 +00:00
Make d22 bigram recipe the training default
This commit is contained in:
parent
0de3a39910
commit
0393a2c13f
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
This branch is based on upstream nanochat master at `dc54a1a`. The goal is to
|
||||
keep the submission patch limited to the changes needed to reproduce the
|
||||
best-performing speedrun recipe:
|
||||
best-performing speedrun recipe. These are the `scripts/base_train.py` defaults:
|
||||
|
||||
```bash
|
||||
--fp8
|
||||
|
|
@ -129,18 +129,19 @@ Adds:
|
|||
- `--bigram-lambda-lr`
|
||||
|
||||
These configure the bigram residual and its optimizer treatment from the
|
||||
training script without changing defaults. With default values, upstream
|
||||
behavior is unchanged because `--bigram-embed-factor` defaults to `0`.
|
||||
training script. The submission default is `--bigram-embed-factor=5`.
|
||||
|
||||
### Muon Variant Flags
|
||||
|
||||
Adds:
|
||||
|
||||
- `--muon-plus`
|
||||
- `--no-muon-plus`
|
||||
- `--muon-eq`
|
||||
|
||||
These expose the optimizer variants used in the recipe. Defaults preserve the
|
||||
original optimizer behavior.
|
||||
These expose the optimizer variants used in the recipe. The submission defaults
|
||||
are Muon+ enabled and `--muon-eq=row`. `--no-muon-plus --muon-eq=none` restores
|
||||
the original Muon path.
|
||||
|
||||
### Train Logging Cadence
|
||||
|
||||
|
|
@ -148,7 +149,7 @@ Adds `--train-log-every`. Values greater than 1 avoid converting the loss tensor
|
|||
to a Python scalar every step.
|
||||
|
||||
Why this helps: per-step logging creates extra synchronization overhead. The
|
||||
speedrun uses `--train-log-every=50`, which keeps useful progress reporting
|
||||
submission default is `--train-log-every=50`, which keeps useful progress reporting
|
||||
while reducing logging overhead.
|
||||
|
||||
### Compile Mode
|
||||
|
|
@ -160,16 +161,17 @@ Adds `--compile-mode` so the speedrun can request:
|
|||
```
|
||||
|
||||
Why this helps: on the d16 probe, this compile mode was about 2.5% faster than
|
||||
default `torch.compile` for the candidate recipe.
|
||||
default `torch.compile` for the candidate recipe. It is now the submission
|
||||
default.
|
||||
|
||||
### Skip Initial Eval
|
||||
|
||||
Adds `--skip-initial-eval`. This avoids spending benchmark wall time on the
|
||||
step-0 validation pass when it is not needed for a speedrun submission.
|
||||
Adds `--skip-initial-eval` and `--initial-eval`. The submission default skips
|
||||
the step-0 validation pass; `--initial-eval` restores the original behavior.
|
||||
|
||||
## `runs/speedrun.sh`
|
||||
|
||||
Updates the default speedrun command to use the winning recipe flags:
|
||||
Uses the `scripts/base_train.py` submission defaults:
|
||||
|
||||
- FP8
|
||||
- depth `22`
|
||||
|
|
|
|||
|
|
@ -70,24 +70,9 @@ echo "Waiting for dataset download to complete..."
|
|||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
# d22 Muon+/row-eq + hashed bigram recipe.
|
||||
# This is the submission default: fixed 11,600 optimizer steps, eval every 250,
|
||||
# and one in-training CORE pass halfway through.
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
|
||||
--depth=22 \
|
||||
--num-iterations=11600 \
|
||||
--target-param-data-ratio=11 \
|
||||
--device-batch-size=32 \
|
||||
--total-batch-size=524288 \
|
||||
--fp8 \
|
||||
--compile-mode=max-autotune-no-cudagraphs \
|
||||
--muon-plus \
|
||||
--muon-eq=row \
|
||||
--bigram-embed-factor=5 \
|
||||
--scalar-lr=0.3 \
|
||||
--train-log-every=50 \
|
||||
--eval-every=250 \
|
||||
--core-metric-every=5800 \
|
||||
--run=$WANDB_RUN
|
||||
# scripts/base_train defaults are the submission defaults: fixed 11,600
|
||||
# optimizer steps, eval every 250, and one in-training CORE pass halfway through.
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --run=$WANDB_RUN
|
||||
# evaluate the model: CORE metric, BPB on train/val, and draw samples
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ or distributed as:
|
|||
torchrun --nproc_per_node=8 -m scripts.base_train
|
||||
|
||||
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
||||
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
|
||||
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 --no-fp8 --no-muon-plus --muon-eq=none --bigram-embed-factor=0
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -41,37 +41,39 @@ print_banner()
|
|||
parser = argparse.ArgumentParser(description="Pretrain base model")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
parser.add_argument("--train-log-every", type=int, default=1, help="print training metrics every N steps; values >1 avoid per-step CPU/GPU sync")
|
||||
parser.add_argument("--train-log-every", type=int, default=50, help="print training metrics every N steps; values >1 avoid per-step CPU/GPU sync")
|
||||
# Runtime
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
# FP8 training
|
||||
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)")
|
||||
parser.add_argument("--fp8", dest="fp8", action="store_true", default=True, help="enable FP8 training (requires H100+ GPU and torchao)")
|
||||
parser.add_argument("--no-fp8", dest="fp8", action="store_false", help="disable FP8 training")
|
||||
parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)")
|
||||
parser.add_argument("--compile-mode", type=str, default="", choices=["", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], help="optional torch.compile mode")
|
||||
parser.add_argument("--compile-mode", type=str, default="max-autotune-no-cudagraphs", choices=["", "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], help="optional torch.compile mode")
|
||||
# Model architecture
|
||||
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
|
||||
parser.add_argument("--depth", type=int, default=22, help="depth of the Transformer model")
|
||||
parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
|
||||
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
||||
parser.add_argument("--bigram-embed-factor", type=int, default=0, help="if >0, add a hashed bigram embedding residual")
|
||||
parser.add_argument("--bigram-embed-factor", type=int, default=5, help="if >0, add a hashed bigram embedding residual")
|
||||
parser.add_argument("--bigram-lambda-init", type=float, default=0.05, help="initial layer residual scale for --bigram-embed-factor")
|
||||
parser.add_argument("--bigram-embedding-lr-mult", type=float, default=1.0, help="bigram embedding LR multiplier relative to --embedding-lr")
|
||||
parser.add_argument("--bigram-lambda-lr", type=float, default=0.004, help="AdamW LR for bigram layer lambdas before dmodel schedule scaling")
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--num-iterations", type=int, default=11600, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
parser.add_argument("--target-param-data-ratio", type=float, default=12, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
parser.add_argument("--target-param-data-ratio", type=float, default=11, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
# Optimization
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.")
|
||||
parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)")
|
||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.008, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious weight decay for the Muon optimizer (for weights)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
||||
parser.add_argument("--muon-plus", action="store_true", help="apply Muon+ style post-orthogonalization Frobenius renormalization")
|
||||
parser.add_argument("--muon-eq", type=str, default="none", choices=["none", "row", "col"], help="apply MuonEq-style row/column equilibration before orthogonalization")
|
||||
parser.add_argument("--scalar-lr", type=float, default=0.3, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
||||
parser.add_argument("--muon-plus", dest="muon_plus", action="store_true", default=True, help="apply Muon+ style post-orthogonalization Frobenius renormalization")
|
||||
parser.add_argument("--no-muon-plus", dest="muon_plus", action="store_false", help="disable Muon+ post-orthogonalization renormalization")
|
||||
parser.add_argument("--muon-eq", type=str, default="row", choices=["none", "row", "col"], help="apply MuonEq-style row/column equilibration before orthogonalization")
|
||||
parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final-lr-frac", type=float, default=0.05, help="final LR as fraction of initial LR")
|
||||
|
|
@ -79,10 +81,11 @@ parser.add_argument("--resume-from-step", type=int, default=-1, help="resume tra
|
|||
# Evaluation
|
||||
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval-tokens", type=int, default=80*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--skip-initial-eval", action="store_true", help="skip the step 0 validation pass; final validation still runs")
|
||||
parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
|
||||
parser.add_argument("--skip-initial-eval", dest="skip_initial_eval", action="store_true", default=True, help="skip the step 0 validation pass; final validation still runs")
|
||||
parser.add_argument("--initial-eval", dest="skip_initial_eval", action="store_false", help="run validation at step 0")
|
||||
parser.add_argument("--core-metric-every", type=int, default=5800, help="evaluate CORE metric every N steps (-1 = disable)")
|
||||
parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
|
||||
parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
|
||||
parser.add_argument("--sample-every", type=int, default=-1, help="sample from model every N steps (-1 = disable)")
|
||||
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
|
||||
# Output
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user