diff --git a/dev/runcpu.sh b/dev/runcpu.sh index c4a719e..c0b32a5 100755 --- a/dev/runcpu.sh +++ b/dev/runcpu.sh @@ -25,7 +25,7 @@ python -m nanochat.report reset # train tokenizer on ~1B characters python -m nanochat.dataset -n 4 -python -m scripts.tok_train --max_chars=1000000000 +python -m scripts.tok_train --max-chars=1000000000 python -m scripts.tok_eval # train a very small 4 layer model on the CPU @@ -33,37 +33,37 @@ python -m scripts.tok_eval # we only run 50 steps of optimization (bump this to get better results) python -m scripts.base_train \ --depth=4 \ - --max_seq_len=1024 \ - --device_batch_size=1 \ - --total_batch_size=1024 \ - --eval_every=50 \ - --eval_tokens=4096 \ - --core_metric_every=50 \ - --core_metric_max_per_task=12 \ - --sample_every=50 \ - --num_iterations=50 -python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096 + --max-seq-len=1024 \ + --device-batch-size=1 \ + --total-batch-size=1024 \ + --eval-every=50 \ + --eval-tokens=4096 \ + --core-metric-every=50 \ + --core-metric-max-per-task=12 \ + --sample-every=50 \ + --num-iterations=50 +python -m scripts.base_loss --device-batch-size=1 --split-tokens=4096 python -m scripts.base_eval --max-per-task=16 # midtraining python -m scripts.mid_train \ - --max_seq_len=1024 \ - --device_batch_size=1 \ - --eval_every=50 \ - --eval_tokens=4096 \ - --total_batch_size=1024 \ - --num_iterations=100 + --max-seq-len=1024 \ + --device-batch-size=1 \ + --eval-every=50 \ + --eval-tokens=4096 \ + --total-batch-size=1024 \ + --num-iterations=100 # eval results will be terrible, this is just to execute the code paths. # note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20 # SFT python -m scripts.chat_sft \ - --device_batch_size=1 \ - --target_examples_per_step=4 \ - --num_iterations=100 \ - --eval_steps=4 \ - --eval_metrics_max_problems=16 + --device-batch-size=1 \ + --target-examples-per-step=4 \ + --num-iterations=100 \ + --eval-steps=4 \ + --eval-metrics-max-problems=16 # Chat CLI # python -m scripts.chat_cli -p "Why is the sky blue?" diff --git a/miniseries.sh b/miniseries.sh index 4d6f436..9a4512b 100644 --- a/miniseries.sh +++ b/miniseries.sh @@ -20,7 +20,7 @@ if [ -z "$SKIP_SETUP" ]; then # Tokenizer, download 1000 shards for pretraining # (probably this can be reduced but it's tricky to determine the exact right number, TODO). python -m nanochat.dataset -n 1000 - python -m scripts.tok_train --max_chars=2000000000 --vocab_size=32768 + python -m scripts.tok_train --max-chars=2000000000 --vocab-size=32768 else source .venv/bin/activate fi @@ -58,16 +58,16 @@ for d in "${DEPTHS[@]}"; do START_TIME=$(date +%s) # Train the model with natural horizon (target_param_data_ratio default) - # No --target_flops, let it use the default ratio from base_train + # No --target-flops, let it use the default ratio from base_train torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \ --depth=$d \ - --target_param_data_ratio=8 \ + --target-param-data-ratio=8 \ --run="${WANDB_RUN}_d${d}" \ - --model_tag="${TAG}" \ - --core_metric_every=999999 \ - --core_metric_max_per_task=-1 \ - --sample_every=-1 \ - --save_every=-1 \ + --model-tag="${TAG}" \ + --core-metric-every=999999 \ + --core-metric-max-per-task=-1 \ + --sample-every=-1 \ + --save-every=-1 \ 2>&1 | tee "$RESULTS_DIR/${TAG}_train.log" END_TIME=$(date +%s) diff --git a/run1000.sh b/run1000.sh index fe92edf..5d0b7dc 100644 --- a/run1000.sh +++ b/run1000.sh @@ -23,15 +23,15 @@ python -m nanochat.dataset -n 16 # start downloading the rest of the shards for a total of 1200 (see below why 1200) python -m nanochat.dataset -n 1200 & # todo: download the rest of it -python -m scripts.tok_train --max_chars=4000000000 --vocab_size=65536 +python -m scripts.tok_train --max-chars=4000000000 --vocab-size=65536 python -m scripts.tok_eval # Documenting my process for determining the hyperparameters for this run1000.sh script: # We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute # 1) I guessed the model size for this to be about depth=32 # 2) Determine the device_batch_size that fits: -# Running the base_train.py script with --depth=32, I saw that --device_batch_size=16 -# runs out of memory, but --device_batch_size=8 fits. Inspecting `nvidia-smi` during training, +# Running the base_train.py script with --depth=32, I saw that --device-batch-size=16 +# runs out of memory, but --device-batch-size=8 fits. Inspecting `nvidia-smi` during training, # I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%. # So the training script was running ok and showed: # Vocab size: 65,536 @@ -73,13 +73,13 @@ python -m scripts.tok_eval # Number of processes/GPUs to use NPROC_PER_NODE=8 -torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target_param_data_ratio=20 --device_batch_size=8 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target-param-data-ratio=20 --device-batch-size=8 --run=$WANDB_RUN torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval # midtrain # NOTE: ensure that we use the same device_batch_size here as the base training script. -torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device-batch-size=8 --run=$WANDB_RUN torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid # sft diff --git a/scaling_laws.sh b/scaling_laws.sh index 102ba11..321b286 100644 --- a/scaling_laws.sh +++ b/scaling_laws.sh @@ -64,15 +64,15 @@ for flops in "${FLOPS_BUDGETS[@]}"; do # CORE eval happens once at the end (999999 ensures only final step) torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \ --depth=$d \ - --target_flops=$flops \ - --target_param_data_ratio=-1 \ + --target-flops=$flops \ + --target-param-data-ratio=-1 \ --run="${WANDB_RUN}_${TAG}" \ - --model_tag="${TAG}" \ - --eval_tokens=$EVAL_TOKENS \ - --core_metric_every=999999 \ - --core_metric_max_per_task=-1 \ - --sample_every=-1 \ - --save_every=-1 \ + --model-tag="${TAG}" \ + --eval-tokens=$EVAL_TOKENS \ + --core-metric-every=999999 \ + --core-metric-max-per-task=-1 \ + --sample-every=-1 \ + --save-every=-1 \ 2>&1 | tee "$RESULTS_DIR/${TAG}_train.log" END_TIME=$(date +%s) diff --git a/scripts/base_loss.py b/scripts/base_loss.py index 46544d4..6b44a30 100644 --- a/scripts/base_loss.py +++ b/scripts/base_loss.py @@ -7,7 +7,7 @@ Example run as: torchrun --standalone --nproc_per_node=8 -m scripts.base_loss To evaluate a HuggingFace model: -python -m scripts.base_loss --hf_path openai-community/gpt2 +python -m scripts.base_loss --hf-path openai-community/gpt2 """ import argparse from contextlib import nullcontext @@ -61,12 +61,12 @@ def get_hf_token_bytes(tokenizer, device="cpu"): # CLI arguments parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model") -parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size") -parser.add_argument("--split_tokens", type=int, default=40*524288, help="number of tokens to evaluate per split") -parser.add_argument("--model_tag", type=str, default=None, help="model tag for checkpoint directory") -parser.add_argument("--model_step", type=int, default=None, help="model step to load") -parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") -parser.add_argument("--hf_path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)") +parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") +parser.add_argument("--split-tokens", type=int, default=40*524288, help="number of tokens to evaluate per split") +parser.add_argument("--model-tag", type=str, default=None, help="model tag for checkpoint directory") +parser.add_argument("--model-step", type=int, default=None, help="model step to load") +parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") +parser.add_argument("--hf-path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)") args = parser.parse_args() # Load the base model and the tokenizer diff --git a/scripts/base_train.py b/scripts/base_train.py index a432e7a..bf4b8cf 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -8,7 +8,7 @@ or distributed as: torchrun --nproc_per_node=8 -m scripts.base_train.py If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example: -python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20 +python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 """ import os @@ -36,40 +36,40 @@ parser = argparse.ArgumentParser(description="Pretrain base model") # Logging parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime -parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") +parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") # Model architecture parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") -parser.add_argument("--aspect_ratio", type=int, default=64, help="model_dim = depth * aspect_ratio") -parser.add_argument("--head_dim", type=int, default=128, help="target head dimension for attention") -parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length") -parser.add_argument("--window_pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") +parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio") +parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") +parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") +parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") # Training horizon (only one used, in order of precedence) -parser.add_argument("--num_iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") -parser.add_argument("--target_flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") -parser.add_argument("--target_param_data_ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") +parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") +parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") +parser.add_argument("--target-param-data-ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") # Optimization -parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size") -parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens") -parser.add_argument("--embedding_lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") -parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") -parser.add_argument("--weight_decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)") -parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--scalar_lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") -parser.add_argument("--adam_beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding") -parser.add_argument("--adam_beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding") -parser.add_argument("--warmup_ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") -parser.add_argument("--warmdown_ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown") -parser.add_argument("--final_lr_frac", type=float, default=0.0, help="final LR as fraction of initial LR") -parser.add_argument("--resume_from_step", type=int, default=-1, help="resume training from this step (-1 = disable)") +parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") +parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") +parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") +parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") +parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)") +parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") +parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") +parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding") +parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding") +parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") +parser.add_argument("--warmdown-ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown") +parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") +parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)") # Evaluation -parser.add_argument("--eval_every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") -parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") -parser.add_argument("--core_metric_every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)") -parser.add_argument("--core_metric_max_per_task", type=int, default=500, help="examples per task for CORE metric") -parser.add_argument("--sample_every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") -parser.add_argument("--save_every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") +parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") +parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") +parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)") +parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric") +parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") +parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") # Output -parser.add_argument("--model_tag", type=str, default=None, help="override model tag for checkpoint directory name") +parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") args = parser.parse_args() user_config = vars(args).copy() # for logging # ----------------------------------------------------------------------------- diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index ad557b9..b0697f3 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -35,32 +35,32 @@ parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K") # Logging parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime -parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") +parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") # Model loading parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from") -parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from") -parser.add_argument("--model_step", type=int, default=None, help="model step to load from") +parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") +parser.add_argument("--model-step", type=int, default=None, help="model step to load from") # Training horizon -parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs over GSM8K") +parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs over GSM8K") # Batch sizes / sampling -parser.add_argument("--device_batch_size", type=int, default=8, help="max batch size per forward pass") -parser.add_argument("--examples_per_step", type=int, default=16, help="total examples per optimization step across all ranks") -parser.add_argument("--num_samples", type=int, default=16, help="number of samples per example/question") +parser.add_argument("--device-batch-size", type=int, default=8, help="max batch size per forward pass") +parser.add_argument("--examples-per-step", type=int, default=16, help="total examples per optimization step across all ranks") +parser.add_argument("--num-samples", type=int, default=16, help="number of samples per example/question") # Generation -parser.add_argument("--max_new_tokens", type=int, default=256, help="max tokens to generate per sample") +parser.add_argument("--max-new-tokens", type=int, default=256, help="max tokens to generate per sample") parser.add_argument("--temperature", type=float, default=1.0, help="sampling temperature") -parser.add_argument("--top_k", type=int, default=50, help="top-k sampling (0 = disabled)") +parser.add_argument("--top-k", type=int, default=50, help="top-k sampling (0 = disabled)") # Optimization -parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") -parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") -parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") -parser.add_argument("--init_lr_frac", type=float, default=0.05, help="initial LR as fraction of base LR") +parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") +parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") +parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") +parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") +parser.add_argument("--init-lr-frac", type=float, default=0.05, help="initial LR as fraction of base LR") # Evaluation / checkpointing -parser.add_argument("--eval_every", type=int, default=60, help="evaluate pass@k every N steps") -parser.add_argument("--eval_examples", type=int, default=400, help="number of examples for pass@k evaluation") -parser.add_argument("--save_every", type=int, default=60, help="save checkpoint every N steps") +parser.add_argument("--eval-every", type=int, default=60, help="evaluate pass@k every N steps") +parser.add_argument("--eval-examples", type=int, default=400, help="number of examples for pass@k evaluation") +parser.add_argument("--save-every", type=int, default=60, help="save checkpoint every N steps") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 853a2bf..9277cf9 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -37,29 +37,29 @@ parser = argparse.ArgumentParser(description="Supervised finetuning for chat") # Logging parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime -parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") +parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") # Model loading parser.add_argument("--source", type=str, default="mid", help="base|mid - which checkpoint to load from") -parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from") -parser.add_argument("--model_step", type=int, default=None, help="model step to load from") +parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") +parser.add_argument("--model-step", type=int, default=None, help="model step to load from") # Training horizon -parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs") -parser.add_argument("--num_iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)") +parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs") +parser.add_argument("--num-iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)") # Batch sizes -parser.add_argument("--device_batch_size", type=int, default=4, help="per-device batch size") -parser.add_argument("--target_examples_per_step", type=int, default=32, help="target examples per optimization step") +parser.add_argument("--device-batch-size", type=int, default=4, help="per-device batch size") +parser.add_argument("--target-examples-per-step", type=int, default=32, help="target examples per optimization step") # Optimization -parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") -parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") -parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") -parser.add_argument("--init_lr_frac", type=float, default=0.02, help="initial LR as fraction of base LR") +parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") +parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") +parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") +parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") +parser.add_argument("--init-lr-frac", type=float, default=0.02, help="initial LR as fraction of base LR") # Evaluation -parser.add_argument("--eval_every", type=int, default=100, help="evaluate val loss every N steps") -parser.add_argument("--eval_steps", type=int, default=100, help="number of batches for val loss evaluation") -parser.add_argument("--eval_metrics_every", type=int, default=200, help="evaluate accuracy metrics every N steps") -parser.add_argument("--eval_metrics_max_problems", type=int, default=1024, help="max problems per metric evaluation") +parser.add_argument("--eval-every", type=int, default=100, help="evaluate val loss every N steps") +parser.add_argument("--eval-steps", type=int, default=100, help="number of batches for val loss evaluation") +parser.add_argument("--eval-metrics-every", type=int, default=200, help="evaluate accuracy metrics every N steps") +parser.add_argument("--eval-metrics-max-problems", type=int, default=1024, help="max problems per metric evaluation") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 0742c08..01d9f7d 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -6,7 +6,7 @@ python -m scripts.mid_train Or torchrun for training: -torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16 +torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16 """ import argparse @@ -36,28 +36,28 @@ parser = argparse.ArgumentParser(description="Midtrain the model") # Logging parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime -parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") +parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") # Model loading -parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from") -parser.add_argument("--model_step", type=int, default=None, help="model step to load from") +parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") +parser.add_argument("--model-step", type=int, default=None, help="model step to load from") # Training horizon -parser.add_argument("--num_iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") +parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") # Batch sizes -parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length") -parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size") -parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens") +parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") +parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") +parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") # Optimization -parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") -parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") -parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") -parser.add_argument("--init_lr_frac", type=float, default=1.0, help="initial LR as fraction of base LR") +parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") +parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") +parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") +parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") +parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR") # Evaluation -parser.add_argument("--eval_every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)") -parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") +parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)") +parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") # Output -parser.add_argument("--dry_run", action="store_true", help="log to wandb but skip checkpoints/report") +parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- @@ -79,7 +79,7 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mi model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step) pretrain_batch_size = meta.get("device_batch_size", None) if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size: - print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?") + print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?") orig_model = model model = torch.compile(model, dynamic=False) depth = model.config.n_layer @@ -142,7 +142,8 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100): # Conversation buffer: list of token lists conv_buffer = [] - cursor = ddp_rank # Each rank processes different conversations + cursor = ddp_rank # Each rank processes different conversations (for fetching) + consumed = ddp_rank # Track actual consumption separately from buffering epoch = 1 it = 0 # iteration counter @@ -156,8 +157,7 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100): if cursor >= dataset_size: cursor = cursor % dataset_size epoch += 1 - if split == "train": - last_step = True # toggle last_step to True, which will terminate the training loop + # Note: last_step is now triggered based on consumption, not fetching while True: rows = [] @@ -183,10 +183,12 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100): # Found a conversation that fits - use it entirely conv = conv_buffer.pop(best_idx) row.extend(conv) + consumed += ddp_world_size # Track actual consumption else: # No conversation fits - crop first conversation to fill remaining conv = conv_buffer.pop(0) row.extend(conv[:remaining]) + consumed += ddp_world_size # Track actual consumption rows.append(row[:row_capacity]) @@ -195,13 +197,16 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100): if 0 < args.num_iterations <= it and split == "train": last_step = True - # Update progress tracking + # Update progress tracking (based on consumed, not cursor, to account for buffering) if split == "train": current_epoch = epoch if args.num_iterations > 0: approx_progress = it / args.num_iterations else: - approx_progress = cursor / dataset_size + approx_progress = consumed / dataset_size + # Trigger last_step when we've consumed enough (instead of when cursor wraps) + if consumed >= dataset_size: + last_step = True # Build tensors use_cuda = device_type == "cuda" diff --git a/scripts/tok_train.py b/scripts/tok_train.py index 4ab995c..9c7979d 100644 --- a/scripts/tok_train.py +++ b/scripts/tok_train.py @@ -14,9 +14,9 @@ from nanochat.dataset import parquets_iter_batched # Parse command line arguments parser = argparse.ArgumentParser(description='Train a BPE tokenizer') -parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)') -parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)') -parser.add_argument('--vocab_size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)') +parser.add_argument('--max-chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)') +parser.add_argument('--doc-cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)') +parser.add_argument('--vocab-size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)') args = parser.parse_args() print(f"max_chars: {args.max_chars:,}") print(f"doc_cap: {args.doc_cap:,}") diff --git a/speedrun.sh b/speedrun.sh index 76ccf21..8fff564 100644 --- a/speedrun.sh +++ b/speedrun.sh @@ -59,7 +59,7 @@ python -m nanochat.dataset -n 8 python -m nanochat.dataset -n 370 & DATASET_DOWNLOAD_PID=$! # train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data -python -m scripts.tok_train --max_chars=2000000000 --vocab_size=65536 +python -m scripts.tok_train --max-chars=2000000000 --vocab-size=65536 # evaluate the tokenizer (report compression ratio etc.) python -m scripts.tok_eval @@ -81,7 +81,7 @@ wait $DATASET_DOWNLOAD_PID NPROC_PER_NODE=8 # pretrain the d20 model -torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target_param_data_ratio=20 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target-param-data-ratio=20 --run=$WANDB_RUN # evaluate the model on a larger chunk of train/val data and draw some samples torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss # evaluate the model on CORE tasks