mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-21 02:50:25 +00:00
Make speedrun.sh configurable for different GPU setups
Added --nproc-per-node and --device-batch-size arguments so the script can run on smaller hardware. Default is still 8 GPUs for the original speedrun, but now you can do --nproc-per-node=1 --device-batch-size=2 for a single 16GB GPU.
This commit is contained in:
parent
4346536ab2
commit
ed63b87682
62
speedrun.sh
62
speedrun.sh
|
|
@ -9,6 +9,48 @@
|
|||
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
# 3) Example launch with wandb logging, but see below for setting up wandb first:
|
||||
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
# 4) Example launch for single GPU with 16GB VRAM:
|
||||
# bash speedrun.sh --nproc-per-node=1 --device-batch-size=2
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Parse command-line arguments
|
||||
|
||||
NPROC_PER_NODE=8
|
||||
DEVICE_BATCH_SIZE="" # empty means use default from training scripts
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 [OPTIONS]"
|
||||
echo "Options:"
|
||||
echo " --nproc-per-node=N Number of GPUs to use (default: 8)"
|
||||
echo " --device-batch-size=N Device batch size for training (default: script defaults)"
|
||||
echo " -h, --help Show this help message"
|
||||
exit 1
|
||||
}
|
||||
|
||||
for arg in "$@"; do
|
||||
case $arg in
|
||||
--nproc-per-node=*)
|
||||
NPROC_PER_NODE="${arg#*=}"
|
||||
shift
|
||||
;;
|
||||
--device-batch-size=*)
|
||||
DEVICE_BATCH_SIZE="${arg#*=}"
|
||||
shift
|
||||
;;
|
||||
-h|--help)
|
||||
usage
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $arg"
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
echo "Configuration:"
|
||||
echo " NPROC_PER_NODE: $NPROC_PER_NODE"
|
||||
echo " DEVICE_BATCH_SIZE: ${DEVICE_BATCH_SIZE:-default}"
|
||||
echo ""
|
||||
|
||||
# Default intermediate artifacts directory is in ~/.cache/nanochat
|
||||
export OMP_NUM_THREADS=1
|
||||
|
|
@ -92,25 +134,27 @@ echo "Waiting for dataset download to complete..."
|
|||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
# pretrain the d20 model
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
|
||||
BATCH_SIZE_ARG=""
|
||||
[ -n "$DEVICE_BATCH_SIZE" ] && BATCH_SIZE_ARG="--device_batch_size=$DEVICE_BATCH_SIZE"
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 $BATCH_SIZE_ARG --run=$WANDB_RUN
|
||||
# evaluate the model on a larger chunk of train/val data and draw some samples
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
|
||||
# evaluate the model on CORE tasks
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
|
||||
|
||||
# run midtraining and eval the model
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- $BATCH_SIZE_ARG --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
|
||||
|
||||
# train sft and re-eval right away (should see a small bump)
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- $BATCH_SIZE_ARG --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
|
||||
|
||||
# chat with the model over CLI! Leave out the -p to chat interactively
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
|
@ -123,9 +167,9 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
|
|||
# (optional)
|
||||
|
||||
# run reinforcement learning
|
||||
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN
|
||||
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- $BATCH_SIZE_ARG --run=$WANDB_RUN
|
||||
# eval the RL model only on GSM8K
|
||||
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K
|
||||
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Generate the full report by putting together all the sections
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user