diff --git a/run1000.sh b/run1000.sh index 46325d9..58ee3bc 100644 --- a/run1000.sh +++ b/run1000.sh @@ -70,18 +70,22 @@ python -m scripts.tok_eval # which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd # start to overfit hard. # 5) That's it, everything else (e.g. the learning rates) is adjusted automatically by the training script. -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=32 --device_batch_size=8 --run=$WANDB_RUN -torchrun --standalone --nproc_per_node=8 -m scripts.base_loss -torchrun --standalone --nproc_per_node=8 -m scripts.base_eval + +# Number of processes/GPUs to use +NPROC_PER_NODE=8 + +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --device_batch_size=8 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval # midtrain # NOTE: ensure that we use the same device_batch_size here as the base training script. -torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN -torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid # sft -torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN -torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft # generate final report python -m nanochat.report generate diff --git a/speedrun.sh b/speedrun.sh index 32c8870..7955ec5 100644 --- a/speedrun.sh +++ b/speedrun.sh @@ -82,12 +82,15 @@ python -m scripts.tok_eval echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID +# Number of processes/GPUs to use +NPROC_PER_NODE=8 + # pretrain the d20 model -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --run=$WANDB_RUN # evaluate the model on a larger chunk of train/val data and draw some samples -torchrun --standalone --nproc_per_node=8 -m scripts.base_loss +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss # evaluate the model on CORE tasks -torchrun --standalone --nproc_per_node=8 -m scripts.base_eval +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval # ----------------------------------------------------------------------------- # Midtraining (teach the model conversation special tokens, tool use, multiple choice) @@ -97,15 +100,15 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_eval curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl # run midtraining and eval the model -torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN -torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid # ----------------------------------------------------------------------------- # Supervised Finetuning (domain adaptation to each sequence all by itself per row) # train sft and re-eval right away (should see a small bump) -torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN -torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft # chat with the model over CLI! Leave out the -p to chat interactively # python -m scripts.chat_cli -p "Why is the sky blue?" @@ -118,9 +121,9 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft # (optional) # run reinforcement learning -# torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN +# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN # eval the RL model only on GSM8K -# torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K +# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K # ----------------------------------------------------------------------------- # Generate the full report by putting together all the sections