mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-07 21:02:15 +00:00
Update speedrun.sh
This commit is contained in:
parent
b6da6982f6
commit
83ce1af08e
36
speedrun.sh
36
speedrun.sh
|
|
@ -15,6 +15,9 @@ export OMP_NUM_THREADS=1
|
||||||
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||||
mkdir -p $NANOCHAT_BASE_DIR
|
mkdir -p $NANOCHAT_BASE_DIR
|
||||||
|
|
||||||
|
# Number of processes per node for distributed training
|
||||||
|
NPROC_PER_NODE=4
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Python venv setup with uv
|
# Python venv setup with uv
|
||||||
|
|
||||||
|
|
@ -23,7 +26,7 @@ command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
# create a .venv local virtual environment (if it doesn't exist)
|
# create a .venv local virtual environment (if it doesn't exist)
|
||||||
[ -d ".venv" ] || uv venv
|
[ -d ".venv" ] || uv venv
|
||||||
# install the repo dependencies
|
# install the repo dependencies
|
||||||
uv sync --extra gpu
|
uv sync
|
||||||
# activate venv so that `python` uses the project's venv instead of system python
|
# activate venv so that `python` uses the project's venv instead of system python
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
|
@ -73,6 +76,15 @@ python -m scripts.tok_eval
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Base model (pretraining)
|
# Base model (pretraining)
|
||||||
|
|
||||||
|
# Download the eval_bundle from s3 to evaluate CORE metric during training (~162MB)
|
||||||
|
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
||||||
|
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
||||||
|
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
|
||||||
|
unzip -q eval_bundle.zip
|
||||||
|
rm eval_bundle.zip
|
||||||
|
mv eval_bundle $NANOCHAT_BASE_DIR
|
||||||
|
fi
|
||||||
|
|
||||||
# The d20 model is 561M parameters.
|
# The d20 model is 561M parameters.
|
||||||
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
|
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
|
||||||
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
|
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
|
||||||
|
|
@ -83,29 +95,25 @@ echo "Waiting for dataset download to complete..."
|
||||||
wait $DATASET_DOWNLOAD_PID
|
wait $DATASET_DOWNLOAD_PID
|
||||||
|
|
||||||
# pretrain the d20 model
|
# pretrain the d20 model
|
||||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
|
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
|
||||||
# evaluate the model on a larger chunk of train/val data and draw some samples
|
# evaluate the model on a larger chunk of train/val data and draw some samples
|
||||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
|
||||||
# evaluate the model on CORE tasks
|
# evaluate the model on CORE tasks
|
||||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
|
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
|
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
|
||||||
|
|
||||||
# download 2.3MB of synthetic identity conversations to impart a personality to nanochat
|
|
||||||
# see dev/gen_sft_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
|
|
||||||
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
|
||||||
|
|
||||||
# run midtraining and eval the model
|
# run midtraining and eval the model
|
||||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN
|
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN
|
||||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid
|
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
|
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
|
||||||
|
|
||||||
# train sft and re-eval right away (should see a small bump)
|
# train sft and re-eval right away (should see a small bump)
|
||||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN
|
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
|
||||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
|
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
|
||||||
|
|
||||||
# chat with the model over CLI! Leave out the -p to chat interactively
|
# chat with the model over CLI! Leave out the -p to chat interactively
|
||||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||||
|
|
@ -118,9 +126,9 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
|
||||||
# (optional)
|
# (optional)
|
||||||
|
|
||||||
# run reinforcement learning
|
# run reinforcement learning
|
||||||
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN
|
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN
|
||||||
# eval the RL model only on GSM8K
|
# eval the RL model only on GSM8K
|
||||||
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K
|
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Generate the full report by putting together all the sections
|
# Generate the full report by putting together all the sections
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user