diff --git a/speedrun.sh b/speedrun.sh index 32c8870..6aaf92b 100644 --- a/speedrun.sh +++ b/speedrun.sh @@ -15,6 +15,9 @@ export OMP_NUM_THREADS=1 export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat" mkdir -p $NANOCHAT_BASE_DIR +# Number of processes per node for distributed training +NPROC_PER_NODE=4 + # ----------------------------------------------------------------------------- # Python venv setup with uv @@ -23,7 +26,7 @@ command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh # create a .venv local virtual environment (if it doesn't exist) [ -d ".venv" ] || uv venv # install the repo dependencies -uv sync --extra gpu +uv sync # activate venv so that `python` uses the project's venv instead of system python source .venv/bin/activate @@ -73,6 +76,15 @@ python -m scripts.tok_eval # ----------------------------------------------------------------------------- # Base model (pretraining) +# Download the eval_bundle from s3 to evaluate CORE metric during training (~162MB) +EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip +if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then + curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL + unzip -q eval_bundle.zip + rm eval_bundle.zip + mv eval_bundle $NANOCHAT_BASE_DIR +fi + # The d20 model is 561M parameters. # Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens. # Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars. @@ -83,29 +95,25 @@ echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID # pretrain the d20 model -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --run=$WANDB_RUN # evaluate the model on a larger chunk of train/val data and draw some samples -torchrun --standalone --nproc_per_node=8 -m scripts.base_loss +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss # evaluate the model on CORE tasks -torchrun --standalone --nproc_per_node=8 -m scripts.base_eval +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval # ----------------------------------------------------------------------------- # Midtraining (teach the model conversation special tokens, tool use, multiple choice) -# download 2.3MB of synthetic identity conversations to impart a personality to nanochat -# see dev/gen_sft_data.py for details on how this data was prepared and to get a sense of how you can easily tune it -curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl - # run midtraining and eval the model -torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN -torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid # ----------------------------------------------------------------------------- # Supervised Finetuning (domain adaptation to each sequence all by itself per row) # train sft and re-eval right away (should see a small bump) -torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN -torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft # chat with the model over CLI! Leave out the -p to chat interactively # python -m scripts.chat_cli -p "Why is the sky blue?" @@ -118,9 +126,9 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft # (optional) # run reinforcement learning -# torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN +# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN # eval the RL model only on GSM8K -# torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K +# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K # ----------------------------------------------------------------------------- # Generate the full report by putting together all the sections