mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-24 04:14:27 +00:00
fix buggy midtrain and update all kwargs to be idiomatic. that is, argparse uses dashes variables use underscores. the underscores are just a remnant of the previous Configurator object. This is the right way
This commit is contained in:
parent
3b50b77ed3
commit
7312ec9898
|
|
@ -25,7 +25,7 @@ python -m nanochat.report reset
|
|||
|
||||
# train tokenizer on ~1B characters
|
||||
python -m nanochat.dataset -n 4
|
||||
python -m scripts.tok_train --max_chars=1000000000
|
||||
python -m scripts.tok_train --max-chars=1000000000
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# train a very small 4 layer model on the CPU
|
||||
|
|
@ -33,37 +33,37 @@ python -m scripts.tok_eval
|
|||
# we only run 50 steps of optimization (bump this to get better results)
|
||||
python -m scripts.base_train \
|
||||
--depth=4 \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--total_batch_size=1024 \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=4096 \
|
||||
--core_metric_every=50 \
|
||||
--core_metric_max_per_task=12 \
|
||||
--sample_every=50 \
|
||||
--num_iterations=50
|
||||
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
|
||||
--max-seq-len=1024 \
|
||||
--device-batch-size=1 \
|
||||
--total-batch-size=1024 \
|
||||
--eval-every=50 \
|
||||
--eval-tokens=4096 \
|
||||
--core-metric-every=50 \
|
||||
--core-metric-max-per-task=12 \
|
||||
--sample-every=50 \
|
||||
--num-iterations=50
|
||||
python -m scripts.base_loss --device-batch-size=1 --split-tokens=4096
|
||||
python -m scripts.base_eval --max-per-task=16
|
||||
|
||||
# midtraining
|
||||
python -m scripts.mid_train \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=4096 \
|
||||
--total_batch_size=1024 \
|
||||
--num_iterations=100
|
||||
--max-seq-len=1024 \
|
||||
--device-batch-size=1 \
|
||||
--eval-every=50 \
|
||||
--eval-tokens=4096 \
|
||||
--total-batch-size=1024 \
|
||||
--num-iterations=100
|
||||
# eval results will be terrible, this is just to execute the code paths.
|
||||
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
|
||||
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
|
||||
|
||||
# SFT
|
||||
python -m scripts.chat_sft \
|
||||
--device_batch_size=1 \
|
||||
--target_examples_per_step=4 \
|
||||
--num_iterations=100 \
|
||||
--eval_steps=4 \
|
||||
--eval_metrics_max_problems=16
|
||||
--device-batch-size=1 \
|
||||
--target-examples-per-step=4 \
|
||||
--num-iterations=100 \
|
||||
--eval-steps=4 \
|
||||
--eval-metrics-max-problems=16
|
||||
|
||||
# Chat CLI
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ if [ -z "$SKIP_SETUP" ]; then
|
|||
# Tokenizer, download 1000 shards for pretraining
|
||||
# (probably this can be reduced but it's tricky to determine the exact right number, TODO).
|
||||
python -m nanochat.dataset -n 1000
|
||||
python -m scripts.tok_train --max_chars=2000000000 --vocab_size=32768
|
||||
python -m scripts.tok_train --max-chars=2000000000 --vocab-size=32768
|
||||
else
|
||||
source .venv/bin/activate
|
||||
fi
|
||||
|
|
@ -58,16 +58,16 @@ for d in "${DEPTHS[@]}"; do
|
|||
START_TIME=$(date +%s)
|
||||
|
||||
# Train the model with natural horizon (target_param_data_ratio default)
|
||||
# No --target_flops, let it use the default ratio from base_train
|
||||
# No --target-flops, let it use the default ratio from base_train
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
|
||||
--depth=$d \
|
||||
--target_param_data_ratio=8 \
|
||||
--target-param-data-ratio=8 \
|
||||
--run="${WANDB_RUN}_d${d}" \
|
||||
--model_tag="${TAG}" \
|
||||
--core_metric_every=999999 \
|
||||
--core_metric_max_per_task=-1 \
|
||||
--sample_every=-1 \
|
||||
--save_every=-1 \
|
||||
--model-tag="${TAG}" \
|
||||
--core-metric-every=999999 \
|
||||
--core-metric-max-per-task=-1 \
|
||||
--sample-every=-1 \
|
||||
--save-every=-1 \
|
||||
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
|
||||
|
||||
END_TIME=$(date +%s)
|
||||
|
|
|
|||
10
run1000.sh
10
run1000.sh
|
|
@ -23,15 +23,15 @@ python -m nanochat.dataset -n 16
|
|||
# start downloading the rest of the shards for a total of 1200 (see below why 1200)
|
||||
python -m nanochat.dataset -n 1200 &
|
||||
# todo: download the rest of it
|
||||
python -m scripts.tok_train --max_chars=4000000000 --vocab_size=65536
|
||||
python -m scripts.tok_train --max-chars=4000000000 --vocab-size=65536
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# Documenting my process for determining the hyperparameters for this run1000.sh script:
|
||||
# We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute
|
||||
# 1) I guessed the model size for this to be about depth=32
|
||||
# 2) Determine the device_batch_size that fits:
|
||||
# Running the base_train.py script with --depth=32, I saw that --device_batch_size=16
|
||||
# runs out of memory, but --device_batch_size=8 fits. Inspecting `nvidia-smi` during training,
|
||||
# Running the base_train.py script with --depth=32, I saw that --device-batch-size=16
|
||||
# runs out of memory, but --device-batch-size=8 fits. Inspecting `nvidia-smi` during training,
|
||||
# I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%.
|
||||
# So the training script was running ok and showed:
|
||||
# Vocab size: 65,536
|
||||
|
|
@ -73,13 +73,13 @@ python -m scripts.tok_eval
|
|||
# Number of processes/GPUs to use
|
||||
NPROC_PER_NODE=8
|
||||
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target_param_data_ratio=20 --device_batch_size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target-param-data-ratio=20 --device-batch-size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
|
||||
|
||||
# midtrain
|
||||
# NOTE: ensure that we use the same device_batch_size here as the base training script.
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device-batch-size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
|
||||
|
||||
# sft
|
||||
|
|
|
|||
|
|
@ -64,15 +64,15 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
|
|||
# CORE eval happens once at the end (999999 ensures only final step)
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
|
||||
--depth=$d \
|
||||
--target_flops=$flops \
|
||||
--target_param_data_ratio=-1 \
|
||||
--target-flops=$flops \
|
||||
--target-param-data-ratio=-1 \
|
||||
--run="${WANDB_RUN}_${TAG}" \
|
||||
--model_tag="${TAG}" \
|
||||
--eval_tokens=$EVAL_TOKENS \
|
||||
--core_metric_every=999999 \
|
||||
--core_metric_max_per_task=-1 \
|
||||
--sample_every=-1 \
|
||||
--save_every=-1 \
|
||||
--model-tag="${TAG}" \
|
||||
--eval-tokens=$EVAL_TOKENS \
|
||||
--core-metric-every=999999 \
|
||||
--core-metric-max-per-task=-1 \
|
||||
--sample-every=-1 \
|
||||
--save-every=-1 \
|
||||
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
|
||||
|
||||
END_TIME=$(date +%s)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ Example run as:
|
|||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||
|
||||
To evaluate a HuggingFace model:
|
||||
python -m scripts.base_loss --hf_path openai-community/gpt2
|
||||
python -m scripts.base_loss --hf-path openai-community/gpt2
|
||||
"""
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
|
|
@ -61,12 +61,12 @@ def get_hf_token_bytes(tokenizer, device="cpu"):
|
|||
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model")
|
||||
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--split_tokens", type=int, default=40*524288, help="number of tokens to evaluate per split")
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="model tag for checkpoint directory")
|
||||
parser.add_argument("--model_step", type=int, default=None, help="model step to load")
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--hf_path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)")
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--split-tokens", type=int, default=40*524288, help="number of tokens to evaluate per split")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag for checkpoint directory")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--hf-path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the base model and the tokenizer
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ or distributed as:
|
|||
torchrun --nproc_per_node=8 -m scripts.base_train.py
|
||||
|
||||
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
||||
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
|
||||
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -36,40 +36,40 @@ parser = argparse.ArgumentParser(description="Pretrain base model")
|
|||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
# Model architecture
|
||||
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
|
||||
parser.add_argument("--aspect_ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
|
||||
parser.add_argument("--head_dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--window_pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
||||
parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
|
||||
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num_iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target_flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
parser.add_argument("--target_param_data_ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
parser.add_argument("--target-param-data-ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
# Optimization
|
||||
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens")
|
||||
parser.add_argument("--embedding_lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--scalar_lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
|
||||
parser.add_argument("--warmup_ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
||||
parser.add_argument("--warmdown_ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final_lr_frac", type=float, default=0.0, help="final LR as fraction of initial LR")
|
||||
parser.add_argument("--resume_from_step", type=int, default=-1, help="resume training from this step (-1 = disable)")
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
||||
parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
|
||||
parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
|
||||
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
|
||||
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval_every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--core_metric_every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
|
||||
parser.add_argument("--core_metric_max_per_task", type=int, default=500, help="examples per task for CORE metric")
|
||||
parser.add_argument("--sample_every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
|
||||
parser.add_argument("--save_every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
|
||||
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
|
||||
parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
|
||||
parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
|
||||
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
|
||||
# Output
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="override model tag for checkpoint directory name")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy() # for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -35,32 +35,32 @@ parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
|
|||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from")
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs over GSM8K")
|
||||
parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs over GSM8K")
|
||||
# Batch sizes / sampling
|
||||
parser.add_argument("--device_batch_size", type=int, default=8, help="max batch size per forward pass")
|
||||
parser.add_argument("--examples_per_step", type=int, default=16, help="total examples per optimization step across all ranks")
|
||||
parser.add_argument("--num_samples", type=int, default=16, help="number of samples per example/question")
|
||||
parser.add_argument("--device-batch-size", type=int, default=8, help="max batch size per forward pass")
|
||||
parser.add_argument("--examples-per-step", type=int, default=16, help="total examples per optimization step across all ranks")
|
||||
parser.add_argument("--num-samples", type=int, default=16, help="number of samples per example/question")
|
||||
# Generation
|
||||
parser.add_argument("--max_new_tokens", type=int, default=256, help="max tokens to generate per sample")
|
||||
parser.add_argument("--max-new-tokens", type=int, default=256, help="max tokens to generate per sample")
|
||||
parser.add_argument("--temperature", type=float, default=1.0, help="sampling temperature")
|
||||
parser.add_argument("--top_k", type=int, default=50, help="top-k sampling (0 = disabled)")
|
||||
parser.add_argument("--top-k", type=int, default=50, help="top-k sampling (0 = disabled)")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init_lr_frac", type=float, default=0.05, help="initial LR as fraction of base LR")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init-lr-frac", type=float, default=0.05, help="initial LR as fraction of base LR")
|
||||
# Evaluation / checkpointing
|
||||
parser.add_argument("--eval_every", type=int, default=60, help="evaluate pass@k every N steps")
|
||||
parser.add_argument("--eval_examples", type=int, default=400, help="number of examples for pass@k evaluation")
|
||||
parser.add_argument("--save_every", type=int, default=60, help="save checkpoint every N steps")
|
||||
parser.add_argument("--eval-every", type=int, default=60, help="evaluate pass@k every N steps")
|
||||
parser.add_argument("--eval-examples", type=int, default=400, help="number of examples for pass@k evaluation")
|
||||
parser.add_argument("--save-every", type=int, default=60, help="save checkpoint every N steps")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -37,29 +37,29 @@ parser = argparse.ArgumentParser(description="Supervised finetuning for chat")
|
|||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--source", type=str, default="mid", help="base|mid - which checkpoint to load from")
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs")
|
||||
parser.add_argument("--num_iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)")
|
||||
parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs")
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)")
|
||||
# Batch sizes
|
||||
parser.add_argument("--device_batch_size", type=int, default=4, help="per-device batch size")
|
||||
parser.add_argument("--target_examples_per_step", type=int, default=32, help="target examples per optimization step")
|
||||
parser.add_argument("--device-batch-size", type=int, default=4, help="per-device batch size")
|
||||
parser.add_argument("--target-examples-per-step", type=int, default=32, help="target examples per optimization step")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init_lr_frac", type=float, default=0.02, help="initial LR as fraction of base LR")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init-lr-frac", type=float, default=0.02, help="initial LR as fraction of base LR")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval_every", type=int, default=100, help="evaluate val loss every N steps")
|
||||
parser.add_argument("--eval_steps", type=int, default=100, help="number of batches for val loss evaluation")
|
||||
parser.add_argument("--eval_metrics_every", type=int, default=200, help="evaluate accuracy metrics every N steps")
|
||||
parser.add_argument("--eval_metrics_max_problems", type=int, default=1024, help="max problems per metric evaluation")
|
||||
parser.add_argument("--eval-every", type=int, default=100, help="evaluate val loss every N steps")
|
||||
parser.add_argument("--eval-steps", type=int, default=100, help="number of batches for val loss evaluation")
|
||||
parser.add_argument("--eval-metrics-every", type=int, default=200, help="evaluate accuracy metrics every N steps")
|
||||
parser.add_argument("--eval-metrics-max-problems", type=int, default=1024, help="max problems per metric evaluation")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ python -m scripts.mid_train
|
|||
|
||||
Or torchrun for training:
|
||||
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
|
@ -36,28 +36,28 @@ parser = argparse.ArgumentParser(description="Midtrain the model")
|
|||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num_iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
|
||||
# Batch sizes
|
||||
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens")
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init_lr_frac", type=float, default=1.0, help="initial LR as fraction of base LR")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval_every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
# Output
|
||||
parser.add_argument("--dry_run", action="store_true", help="log to wandb but skip checkpoints/report")
|
||||
parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -79,7 +79,7 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mi
|
|||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
|
||||
pretrain_batch_size = meta.get("device_batch_size", None)
|
||||
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
|
|
@ -142,7 +142,8 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
|
||||
# Conversation buffer: list of token lists
|
||||
conv_buffer = []
|
||||
cursor = ddp_rank # Each rank processes different conversations
|
||||
cursor = ddp_rank # Each rank processes different conversations (for fetching)
|
||||
consumed = ddp_rank # Track actual consumption separately from buffering
|
||||
epoch = 1
|
||||
it = 0 # iteration counter
|
||||
|
||||
|
|
@ -156,8 +157,7 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
if cursor >= dataset_size:
|
||||
cursor = cursor % dataset_size
|
||||
epoch += 1
|
||||
if split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Note: last_step is now triggered based on consumption, not fetching
|
||||
|
||||
while True:
|
||||
rows = []
|
||||
|
|
@ -183,10 +183,12 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
# Found a conversation that fits - use it entirely
|
||||
conv = conv_buffer.pop(best_idx)
|
||||
row.extend(conv)
|
||||
consumed += ddp_world_size # Track actual consumption
|
||||
else:
|
||||
# No conversation fits - crop first conversation to fill remaining
|
||||
conv = conv_buffer.pop(0)
|
||||
row.extend(conv[:remaining])
|
||||
consumed += ddp_world_size # Track actual consumption
|
||||
|
||||
rows.append(row[:row_capacity])
|
||||
|
||||
|
|
@ -195,13 +197,16 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
if 0 < args.num_iterations <= it and split == "train":
|
||||
last_step = True
|
||||
|
||||
# Update progress tracking
|
||||
# Update progress tracking (based on consumed, not cursor, to account for buffering)
|
||||
if split == "train":
|
||||
current_epoch = epoch
|
||||
if args.num_iterations > 0:
|
||||
approx_progress = it / args.num_iterations
|
||||
else:
|
||||
approx_progress = cursor / dataset_size
|
||||
approx_progress = consumed / dataset_size
|
||||
# Trigger last_step when we've consumed enough (instead of when cursor wraps)
|
||||
if consumed >= dataset_size:
|
||||
last_step = True
|
||||
|
||||
# Build tensors
|
||||
use_cuda = device_type == "cuda"
|
||||
|
|
|
|||
|
|
@ -14,9 +14,9 @@ from nanochat.dataset import parquets_iter_batched
|
|||
# Parse command line arguments
|
||||
|
||||
parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
|
||||
parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
|
||||
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
|
||||
parser.add_argument('--vocab_size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)')
|
||||
parser.add_argument('--max-chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
|
||||
parser.add_argument('--doc-cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
|
||||
parser.add_argument('--vocab-size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)')
|
||||
args = parser.parse_args()
|
||||
print(f"max_chars: {args.max_chars:,}")
|
||||
print(f"doc_cap: {args.doc_cap:,}")
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ python -m nanochat.dataset -n 8
|
|||
python -m nanochat.dataset -n 370 &
|
||||
DATASET_DOWNLOAD_PID=$!
|
||||
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
|
||||
python -m scripts.tok_train --max_chars=2000000000 --vocab_size=65536
|
||||
python -m scripts.tok_train --max-chars=2000000000 --vocab-size=65536
|
||||
# evaluate the tokenizer (report compression ratio etc.)
|
||||
python -m scripts.tok_eval
|
||||
|
||||
|
|
@ -81,7 +81,7 @@ wait $DATASET_DOWNLOAD_PID
|
|||
NPROC_PER_NODE=8
|
||||
|
||||
# pretrain the d20 model
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target_param_data_ratio=20 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target-param-data-ratio=20 --run=$WANDB_RUN
|
||||
# evaluate the model on a larger chunk of train/val data and draw some samples
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
|
||||
# evaluate the model on CORE tasks
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user