fix buggy midtrain and update all kwargs to be idiomatic. that is, argparse uses dashes variables use underscores. the underscores are just a remnant of the previous Configurator object. This is the right way

This commit is contained in:
Andrej Karpathy 2026-01-13 22:45:27 +00:00
parent 3b50b77ed3
commit 7312ec9898
11 changed files with 144 additions and 139 deletions

View File

@ -25,7 +25,7 @@ python -m nanochat.report reset
# train tokenizer on ~1B characters
python -m nanochat.dataset -n 4
python -m scripts.tok_train --max_chars=1000000000
python -m scripts.tok_train --max-chars=1000000000
python -m scripts.tok_eval
# train a very small 4 layer model on the CPU
@ -33,37 +33,37 @@ python -m scripts.tok_eval
# we only run 50 steps of optimization (bump this to get better results)
python -m scripts.base_train \
--depth=4 \
--max_seq_len=1024 \
--device_batch_size=1 \
--total_batch_size=1024 \
--eval_every=50 \
--eval_tokens=4096 \
--core_metric_every=50 \
--core_metric_max_per_task=12 \
--sample_every=50 \
--num_iterations=50
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
--max-seq-len=1024 \
--device-batch-size=1 \
--total-batch-size=1024 \
--eval-every=50 \
--eval-tokens=4096 \
--core-metric-every=50 \
--core-metric-max-per-task=12 \
--sample-every=50 \
--num-iterations=50
python -m scripts.base_loss --device-batch-size=1 --split-tokens=4096
python -m scripts.base_eval --max-per-task=16
# midtraining
python -m scripts.mid_train \
--max_seq_len=1024 \
--device_batch_size=1 \
--eval_every=50 \
--eval_tokens=4096 \
--total_batch_size=1024 \
--num_iterations=100
--max-seq-len=1024 \
--device-batch-size=1 \
--eval-every=50 \
--eval-tokens=4096 \
--total-batch-size=1024 \
--num-iterations=100
# eval results will be terrible, this is just to execute the code paths.
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
# SFT
python -m scripts.chat_sft \
--device_batch_size=1 \
--target_examples_per_step=4 \
--num_iterations=100 \
--eval_steps=4 \
--eval_metrics_max_problems=16
--device-batch-size=1 \
--target-examples-per-step=4 \
--num-iterations=100 \
--eval-steps=4 \
--eval-metrics-max-problems=16
# Chat CLI
# python -m scripts.chat_cli -p "Why is the sky blue?"

View File

@ -20,7 +20,7 @@ if [ -z "$SKIP_SETUP" ]; then
# Tokenizer, download 1000 shards for pretraining
# (probably this can be reduced but it's tricky to determine the exact right number, TODO).
python -m nanochat.dataset -n 1000
python -m scripts.tok_train --max_chars=2000000000 --vocab_size=32768
python -m scripts.tok_train --max-chars=2000000000 --vocab-size=32768
else
source .venv/bin/activate
fi
@ -58,16 +58,16 @@ for d in "${DEPTHS[@]}"; do
START_TIME=$(date +%s)
# Train the model with natural horizon (target_param_data_ratio default)
# No --target_flops, let it use the default ratio from base_train
# No --target-flops, let it use the default ratio from base_train
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
--depth=$d \
--target_param_data_ratio=8 \
--target-param-data-ratio=8 \
--run="${WANDB_RUN}_d${d}" \
--model_tag="${TAG}" \
--core_metric_every=999999 \
--core_metric_max_per_task=-1 \
--sample_every=-1 \
--save_every=-1 \
--model-tag="${TAG}" \
--core-metric-every=999999 \
--core-metric-max-per-task=-1 \
--sample-every=-1 \
--save-every=-1 \
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
END_TIME=$(date +%s)

View File

@ -23,15 +23,15 @@ python -m nanochat.dataset -n 16
# start downloading the rest of the shards for a total of 1200 (see below why 1200)
python -m nanochat.dataset -n 1200 &
# todo: download the rest of it
python -m scripts.tok_train --max_chars=4000000000 --vocab_size=65536
python -m scripts.tok_train --max-chars=4000000000 --vocab-size=65536
python -m scripts.tok_eval
# Documenting my process for determining the hyperparameters for this run1000.sh script:
# We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute
# 1) I guessed the model size for this to be about depth=32
# 2) Determine the device_batch_size that fits:
# Running the base_train.py script with --depth=32, I saw that --device_batch_size=16
# runs out of memory, but --device_batch_size=8 fits. Inspecting `nvidia-smi` during training,
# Running the base_train.py script with --depth=32, I saw that --device-batch-size=16
# runs out of memory, but --device-batch-size=8 fits. Inspecting `nvidia-smi` during training,
# I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%.
# So the training script was running ok and showed:
# Vocab size: 65,536
@ -73,13 +73,13 @@ python -m scripts.tok_eval
# Number of processes/GPUs to use
NPROC_PER_NODE=8
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target_param_data_ratio=20 --device_batch_size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target-param-data-ratio=20 --device-batch-size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
# midtrain
# NOTE: ensure that we use the same device_batch_size here as the base training script.
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device-batch-size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
# sft

View File

@ -64,15 +64,15 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
# CORE eval happens once at the end (999999 ensures only final step)
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
--depth=$d \
--target_flops=$flops \
--target_param_data_ratio=-1 \
--target-flops=$flops \
--target-param-data-ratio=-1 \
--run="${WANDB_RUN}_${TAG}" \
--model_tag="${TAG}" \
--eval_tokens=$EVAL_TOKENS \
--core_metric_every=999999 \
--core_metric_max_per_task=-1 \
--sample_every=-1 \
--save_every=-1 \
--model-tag="${TAG}" \
--eval-tokens=$EVAL_TOKENS \
--core-metric-every=999999 \
--core-metric-max-per-task=-1 \
--sample-every=-1 \
--save-every=-1 \
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
END_TIME=$(date +%s)

View File

@ -7,7 +7,7 @@ Example run as:
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
To evaluate a HuggingFace model:
python -m scripts.base_loss --hf_path openai-community/gpt2
python -m scripts.base_loss --hf-path openai-community/gpt2
"""
import argparse
from contextlib import nullcontext
@ -61,12 +61,12 @@ def get_hf_token_bytes(tokenizer, device="cpu"):
# CLI arguments
parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model")
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
parser.add_argument("--split_tokens", type=int, default=40*524288, help="number of tokens to evaluate per split")
parser.add_argument("--model_tag", type=str, default=None, help="model tag for checkpoint directory")
parser.add_argument("--model_step", type=int, default=None, help="model step to load")
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--hf_path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)")
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--split-tokens", type=int, default=40*524288, help="number of tokens to evaluate per split")
parser.add_argument("--model-tag", type=str, default=None, help="model tag for checkpoint directory")
parser.add_argument("--model-step", type=int, default=None, help="model step to load")
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--hf-path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)")
args = parser.parse_args()
# Load the base model and the tokenizer

View File

@ -8,7 +8,7 @@ or distributed as:
torchrun --nproc_per_node=8 -m scripts.base_train.py
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
"""
import os
@ -36,40 +36,40 @@ parser = argparse.ArgumentParser(description="Pretrain base model")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
# Model architecture
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
parser.add_argument("--aspect_ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
parser.add_argument("--head_dim", type=int, default=128, help="target head dimension for attention")
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
parser.add_argument("--window_pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
# Training horizon (only one used, in order of precedence)
parser.add_argument("--num_iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target_flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
parser.add_argument("--target_param_data_ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
parser.add_argument("--target-param-data-ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
# Optimization
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens")
parser.add_argument("--embedding_lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--weight_decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--scalar_lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
parser.add_argument("--adam_beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
parser.add_argument("--adam_beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
parser.add_argument("--warmup_ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
parser.add_argument("--warmdown_ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
parser.add_argument("--final_lr_frac", type=float, default=0.0, help="final LR as fraction of initial LR")
parser.add_argument("--resume_from_step", type=int, default=-1, help="resume training from this step (-1 = disable)")
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
parser.add_argument("--warmdown-ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
# Evaluation
parser.add_argument("--eval_every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
parser.add_argument("--core_metric_every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
parser.add_argument("--core_metric_max_per_task", type=int, default=500, help="examples per task for CORE metric")
parser.add_argument("--sample_every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
parser.add_argument("--save_every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
# Output
parser.add_argument("--model_tag", type=str, default=None, help="override model tag for checkpoint directory name")
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
args = parser.parse_args()
user_config = vars(args).copy() # for logging
# -----------------------------------------------------------------------------

View File

@ -35,32 +35,32 @@ parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from")
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
# Training horizon
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs over GSM8K")
parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs over GSM8K")
# Batch sizes / sampling
parser.add_argument("--device_batch_size", type=int, default=8, help="max batch size per forward pass")
parser.add_argument("--examples_per_step", type=int, default=16, help="total examples per optimization step across all ranks")
parser.add_argument("--num_samples", type=int, default=16, help="number of samples per example/question")
parser.add_argument("--device-batch-size", type=int, default=8, help="max batch size per forward pass")
parser.add_argument("--examples-per-step", type=int, default=16, help="total examples per optimization step across all ranks")
parser.add_argument("--num-samples", type=int, default=16, help="number of samples per example/question")
# Generation
parser.add_argument("--max_new_tokens", type=int, default=256, help="max tokens to generate per sample")
parser.add_argument("--max-new-tokens", type=int, default=256, help="max tokens to generate per sample")
parser.add_argument("--temperature", type=float, default=1.0, help="sampling temperature")
parser.add_argument("--top_k", type=int, default=50, help="top-k sampling (0 = disabled)")
parser.add_argument("--top-k", type=int, default=50, help="top-k sampling (0 = disabled)")
# Optimization
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init_lr_frac", type=float, default=0.05, help="initial LR as fraction of base LR")
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init-lr-frac", type=float, default=0.05, help="initial LR as fraction of base LR")
# Evaluation / checkpointing
parser.add_argument("--eval_every", type=int, default=60, help="evaluate pass@k every N steps")
parser.add_argument("--eval_examples", type=int, default=400, help="number of examples for pass@k evaluation")
parser.add_argument("--save_every", type=int, default=60, help="save checkpoint every N steps")
parser.add_argument("--eval-every", type=int, default=60, help="evaluate pass@k every N steps")
parser.add_argument("--eval-examples", type=int, default=400, help="number of examples for pass@k evaluation")
parser.add_argument("--save-every", type=int, default=60, help="save checkpoint every N steps")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------

View File

@ -37,29 +37,29 @@ parser = argparse.ArgumentParser(description="Supervised finetuning for chat")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--source", type=str, default="mid", help="base|mid - which checkpoint to load from")
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
# Training horizon
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs")
parser.add_argument("--num_iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)")
parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs")
parser.add_argument("--num-iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)")
# Batch sizes
parser.add_argument("--device_batch_size", type=int, default=4, help="per-device batch size")
parser.add_argument("--target_examples_per_step", type=int, default=32, help="target examples per optimization step")
parser.add_argument("--device-batch-size", type=int, default=4, help="per-device batch size")
parser.add_argument("--target-examples-per-step", type=int, default=32, help="target examples per optimization step")
# Optimization
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init_lr_frac", type=float, default=0.02, help="initial LR as fraction of base LR")
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init-lr-frac", type=float, default=0.02, help="initial LR as fraction of base LR")
# Evaluation
parser.add_argument("--eval_every", type=int, default=100, help="evaluate val loss every N steps")
parser.add_argument("--eval_steps", type=int, default=100, help="number of batches for val loss evaluation")
parser.add_argument("--eval_metrics_every", type=int, default=200, help="evaluate accuracy metrics every N steps")
parser.add_argument("--eval_metrics_max_problems", type=int, default=1024, help="max problems per metric evaluation")
parser.add_argument("--eval-every", type=int, default=100, help="evaluate val loss every N steps")
parser.add_argument("--eval-steps", type=int, default=100, help="number of batches for val loss evaluation")
parser.add_argument("--eval-metrics-every", type=int, default=200, help="evaluate accuracy metrics every N steps")
parser.add_argument("--eval-metrics-max-problems", type=int, default=1024, help="max problems per metric evaluation")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------

View File

@ -6,7 +6,7 @@ python -m scripts.mid_train
Or torchrun for training:
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16
"""
import argparse
@ -36,28 +36,28 @@ parser = argparse.ArgumentParser(description="Midtrain the model")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
# Training horizon
parser.add_argument("--num_iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
# Batch sizes
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens")
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
# Optimization
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init_lr_frac", type=float, default=1.0, help="initial LR as fraction of base LR")
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR")
# Evaluation
parser.add_argument("--eval_every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
# Output
parser.add_argument("--dry_run", action="store_true", help="log to wandb but skip checkpoints/report")
parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
@ -79,7 +79,7 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mi
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
orig_model = model
model = torch.compile(model, dynamic=False)
depth = model.config.n_layer
@ -142,7 +142,8 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
# Conversation buffer: list of token lists
conv_buffer = []
cursor = ddp_rank # Each rank processes different conversations
cursor = ddp_rank # Each rank processes different conversations (for fetching)
consumed = ddp_rank # Track actual consumption separately from buffering
epoch = 1
it = 0 # iteration counter
@ -156,8 +157,7 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
if cursor >= dataset_size:
cursor = cursor % dataset_size
epoch += 1
if split == "train":
last_step = True # toggle last_step to True, which will terminate the training loop
# Note: last_step is now triggered based on consumption, not fetching
while True:
rows = []
@ -183,10 +183,12 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
# Found a conversation that fits - use it entirely
conv = conv_buffer.pop(best_idx)
row.extend(conv)
consumed += ddp_world_size # Track actual consumption
else:
# No conversation fits - crop first conversation to fill remaining
conv = conv_buffer.pop(0)
row.extend(conv[:remaining])
consumed += ddp_world_size # Track actual consumption
rows.append(row[:row_capacity])
@ -195,13 +197,16 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
if 0 < args.num_iterations <= it and split == "train":
last_step = True
# Update progress tracking
# Update progress tracking (based on consumed, not cursor, to account for buffering)
if split == "train":
current_epoch = epoch
if args.num_iterations > 0:
approx_progress = it / args.num_iterations
else:
approx_progress = cursor / dataset_size
approx_progress = consumed / dataset_size
# Trigger last_step when we've consumed enough (instead of when cursor wraps)
if consumed >= dataset_size:
last_step = True
# Build tensors
use_cuda = device_type == "cuda"

View File

@ -14,9 +14,9 @@ from nanochat.dataset import parquets_iter_batched
# Parse command line arguments
parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
parser.add_argument('--vocab_size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)')
parser.add_argument('--max-chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
parser.add_argument('--doc-cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
parser.add_argument('--vocab-size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)')
args = parser.parse_args()
print(f"max_chars: {args.max_chars:,}")
print(f"doc_cap: {args.doc_cap:,}")

View File

@ -59,7 +59,7 @@ python -m nanochat.dataset -n 8
python -m nanochat.dataset -n 370 &
DATASET_DOWNLOAD_PID=$!
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
python -m scripts.tok_train --max_chars=2000000000 --vocab_size=65536
python -m scripts.tok_train --max-chars=2000000000 --vocab-size=65536
# evaluate the tokenizer (report compression ratio etc.)
python -m scripts.tok_eval
@ -81,7 +81,7 @@ wait $DATASET_DOWNLOAD_PID
NPROC_PER_NODE=8
# pretrain the d20 model
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target_param_data_ratio=20 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target-param-data-ratio=20 --run=$WANDB_RUN
# evaluate the model on a larger chunk of train/val data and draw some samples
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
# evaluate the model on CORE tasks