nanochat/dev/runmps.sh
William Thurston b1d49aade5 Add scripts for running evaluations and training with W&B integration
- Added `dev/runmps_evals.sh` for evaluating checkpoints and logging results to W&B.
- Introduced `dev/runmps.sh` for orchestrating training stages with W&B support.
- Updated `.gitignore` to include `wandb/` and `.runmps_wandb_ids`.
- Changed permissions for `dev/runcpu.sh` and added executable flag.
- Enhanced existing scripts to log metrics to W&B during training and evaluation processes.
2025-11-05 11:49:50 -08:00

337 lines
10 KiB
Bash
Executable File

#!/bin/bash
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
# Run as:
# bash dev/cpu_demo_run.sh
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
# Think of this run as educational/fun demo, not something you should expect to work well.
# This is also why I hide this script away in dev/
# Stage selection (allow running only a subset, e.g. --stage=sft or --from=mid)
RUN_BASE=1
RUN_MID=1
RUN_SFT=1
RUN_REPORT=1
STAGE_ONLY=""
FROM_STAGE=""
while [[ $# -gt 0 ]]; do
case "$1" in
--stage=*)
STAGE_ONLY="${1#*=}"
;;
--base|--mid|--sft|--report)
STAGE_ONLY="${1#--}"
;;
--from=*)
FROM_STAGE="${1#*=}"
;;
--from-base|--from-mid|--from-sft)
FROM_STAGE="${1#--from-}"
;;
--help|-h)
cat <<'EOF'
Usage: bash dev/runmps.sh [options]
Options:
--stage=<base|mid|sft|report> Run only the specified stage.
--from=<base|mid|sft> Run from the specified stage through the end.
--help Show this help message.
Environment variables (same as before) control batch sizes, WANDB run names, etc.
EOF
exit 0
;;
*)
echo "Unknown argument: $1" >&2
exit 1
;;
esac
shift
done
if [[ -n "$FROM_STAGE" ]]; then
RUN_BASE=0
RUN_MID=0
RUN_SFT=0
RUN_REPORT=0
case "$FROM_STAGE" in
base)
RUN_BASE=1
RUN_MID=1
RUN_SFT=1
RUN_REPORT=1
;;
mid)
RUN_MID=1
RUN_SFT=1
;;
sft)
RUN_SFT=1
;;
*)
echo "Unknown --from stage: $FROM_STAGE" >&2
exit 1
;;
esac
fi
if [[ -n "$STAGE_ONLY" ]]; then
RUN_BASE=0
RUN_MID=0
RUN_SFT=0
RUN_REPORT=0
case "$STAGE_ONLY" in
base)
RUN_BASE=1
;;
mid)
RUN_MID=1
;;
sft)
RUN_SFT=1
;;
report)
RUN_REPORT=1
;;
*)
echo "Unknown --stage value: $STAGE_ONLY" >&2
exit 1
;;
esac
fi
if [[ -n "$STAGE_ONLY" || -n "$FROM_STAGE" ]]; then
# avoid regenerating reports when running a subset unless specifically requested
if [[ "$STAGE_ONLY" != "report" && "$FROM_STAGE" != "base" ]]; then
RUN_REPORT=0
fi
fi
# all the setup stuff
export OMP_NUM_THREADS=1
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
[ -d ".venv" ] || uv venv
uv sync --extra cpu
source .venv/bin/activate
if [ -z "$WANDB_RUN" ]; then
WANDB_RUN=dummy
fi
if [ "$WANDB_RUN" != "dummy" ] && [ -z "$WANDB_MODE" ]; then
export WANDB_MODE=online
fi
# Batch/sequence configuration
BASE_DEPTH=${BASE_DEPTH:-4}
SEQ_LEN=${SEQ_LEN:-1024}
DEVICE_BATCH=${DEVICE_BATCH:-16}
TOTAL_BATCH=${TOTAL_BATCH:-$((DEVICE_BATCH * SEQ_LEN))} # tokens per optimizer step
EVAL_SEQUENCES=10000
EVAL_STEPS=$(((EVAL_SEQUENCES + DEVICE_BATCH - 1) / DEVICE_BATCH))
EVAL_BATCH_MULT=4 # evaluate on 4 full batches
EVAL_TOKENS=$((TOTAL_BATCH * EVAL_BATCH_MULT))
MID_NUM_STEPS=6144
SFT_NUM_STEPS=${SFT_NUM_STEPS:-3072}
CHECKPOINT_EVERY_SEQ=${CHECKPOINT_EVERY_SEQ:-10000}
RUN_STAGE_EVALS=${RUN_STAGE_EVALS:-0}
WANDB_PROJECT=${WANDB_PROJECT:-nanochat}
TARGET_PARAM_DATA_RATIO=${TARGET_PARAM_DATA_RATIO:-20}
SFT_DEVICE_BATCH=${SFT_DEVICE_BATCH:-$DEVICE_BATCH}
SFT_TARGET_EXAMPLES=${SFT_TARGET_EXAMPLES:-$DEVICE_BATCH}
SFT_EVAL_EVERY=${SFT_EVAL_EVERY:-0}
SFT_EVAL_STEPS=${SFT_EVAL_STEPS:-0}
SFT_EVAL_METRICS_EVERY=${SFT_EVAL_METRICS_EVERY:-0}
SFT_EVAL_METRICS_MAX=${SFT_EVAL_METRICS_MAX:-0}
STATE_FILE=".runmps_wandb_ids"
printf '' > "$STATE_FILE"
echo "WANDB_PROJECT=$WANDB_PROJECT" >> "$STATE_FILE"
generate_wandb_id() {
python - <<'PY'
import uuid
print(uuid.uuid4().hex[:8])
PY
}
BASE_SEQS_PER_STEP=$((TOTAL_BATCH / SEQ_LEN))
if [ $BASE_SEQS_PER_STEP -le 0 ]; then BASE_SEQS_PER_STEP=1; fi
if [ "$CHECKPOINT_EVERY_SEQ" -le 0 ]; then
BASE_CHECKPOINT_STEPS=0
else
BASE_CHECKPOINT_STEPS=$(((CHECKPOINT_EVERY_SEQ + BASE_SEQS_PER_STEP - 1) / BASE_SEQS_PER_STEP))
fi
MID_SEQS_PER_STEP=$((TOTAL_BATCH / SEQ_LEN))
if [ $MID_SEQS_PER_STEP -le 0 ]; then MID_SEQS_PER_STEP=1; fi
if [ "$CHECKPOINT_EVERY_SEQ" -le 0 ]; then
MID_CHECKPOINT_STEPS=0
else
MID_CHECKPOINT_STEPS=$(((CHECKPOINT_EVERY_SEQ + MID_SEQS_PER_STEP - 1) / MID_SEQS_PER_STEP))
fi
SFT_SEQS_PER_STEP=$SFT_DEVICE_BATCH
if [ $SFT_SEQS_PER_STEP -le 0 ]; then SFT_SEQS_PER_STEP=1; fi
if [ "$CHECKPOINT_EVERY_SEQ" -le 0 ]; then
SFT_CHECKPOINT_STEPS=0
else
SFT_CHECKPOINT_STEPS=$(((CHECKPOINT_EVERY_SEQ + SFT_SEQS_PER_STEP - 1) / SFT_SEQS_PER_STEP))
fi
# Auto-populate WANDB_API_KEY from ~/.netrc when talking to a local W&B server.
# Mirrors the helper used in TinyRecursiveModels/pretrain_text.py so we can log
# to a self-hosted instance without manual export each time.
if [ -z "$WANDB_API_KEY" ] && [ -f "$HOME/.netrc" ]; then
# Allow custom WANDB_BASE_URL; default to localhost if user sets WANDB
WANDB_BASE_URL_DEFAULT=${WANDB_BASE_URL:-http://localhost:8080}
if printf '%s' "$WANDB_BASE_URL_DEFAULT" | grep -q "localhost"; then
HOST=$(printf '%s\n' "$WANDB_BASE_URL_DEFAULT" | sed -E 's#https?://##' | cut -d/ -f1)
API_KEY=$(python - "$HOST" <<'PY'
import sys
from netrc import netrc
host = sys.argv[1]
auth = netrc().authenticators(host)
if auth and auth[2]:
print(auth[2], end="")
PY
)
if [ -n "$API_KEY" ]; then
export WANDB_BASE_URL="$WANDB_BASE_URL_DEFAULT"
export WANDB_API_KEY="$API_KEY"
echo "[runmps] Loaded WANDB_API_KEY for $WANDB_BASE_URL from ~/.netrc"
fi
fi
fi
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
unzip -q eval_bundle.zip
rm eval_bundle.zip
mv eval_bundle $NANOCHAT_BASE_DIR
fi
if (( RUN_BASE )); then
# wipe the report
python -m nanochat.report reset
# train tokenizer on ~2B characters (download full shard set for extended training)
python -m nanochat.dataset -n 240
python -m scripts.tok_train --max_chars=2000000000
python -m scripts.tok_eval
# train a very small 4 layer model on the CPU
# each optimization step processes a single sequence of 1024 tokens
# we only run 50 steps of optimization (bump this to get better results)
if [ "$WANDB_RUN" != "dummy" ]; then
BASE_WANDB_ID=${BASE_WANDB_ID:-$(generate_wandb_id)}
echo "BASE_WANDB_ID=$BASE_WANDB_ID" >> "$STATE_FILE"
echo "BASE_WANDB_NAME=$WANDB_RUN" >> "$STATE_FILE"
export WANDB_RUN_ID=$BASE_WANDB_ID
export WANDB_EVAL_RUN=$WANDB_RUN
export WANDB_PROJECT
fi
python -m scripts.base_train \
--depth=$BASE_DEPTH \
--max_seq_len=$SEQ_LEN \
--device_batch_size=$DEVICE_BATCH \
--total_batch_size=$TOTAL_BATCH \
--target_param_data_ratio=$TARGET_PARAM_DATA_RATIO \
--run="$WANDB_RUN" \
--eval_every=$EVAL_STEPS \
--eval_tokens=$EVAL_TOKENS \
--core_metric_every=-1 \
--sample_every=-1 \
--checkpoint_every_steps=$BASE_CHECKPOINT_STEPS
if [ "$WANDB_RUN" != "dummy" ]; then
unset WANDB_RUN_ID
unset WANDB_EVAL_RUN
fi
if [ "$RUN_STAGE_EVALS" = "1" ]; then
python -m scripts.base_loss --device_batch_size=$DEVICE_BATCH --split_tokens=$EVAL_TOKENS
python -m scripts.base_eval --max-per-task=16
fi
fi
if (( RUN_MID )); then
# midtraining
if [ "$WANDB_RUN" != "dummy" ]; then
MID_WANDB_ID=${MID_WANDB_ID:-$(generate_wandb_id)}
echo "MID_WANDB_ID=$MID_WANDB_ID" >> "$STATE_FILE"
echo "MID_WANDB_NAME=${WANDB_RUN}-mid" >> "$STATE_FILE"
export WANDB_RUN_ID=$MID_WANDB_ID
export WANDB_EVAL_RUN="${WANDB_RUN}-mid"
export WANDB_PROJECT
fi
python -m scripts.mid_train \
--max_seq_len=$SEQ_LEN \
--device_batch_size=$DEVICE_BATCH \
--total_batch_size=$TOTAL_BATCH \
--run="${WANDB_RUN}-mid" \
--eval_every=$EVAL_STEPS \
--eval_tokens=$EVAL_TOKENS \
--checkpoint_every_steps=$MID_CHECKPOINT_STEPS \
--num_iterations=$MID_NUM_STEPS
if [ "$WANDB_RUN" != "dummy" ]; then
unset WANDB_RUN_ID
unset WANDB_EVAL_RUN
fi
if [ "$RUN_STAGE_EVALS" = "1" ]; then
# eval results will be terrible, this is just to execute the code paths.
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
fi
fi
if (( RUN_SFT )); then
# SFT
if [ "$WANDB_RUN" != "dummy" ]; then
SFT_WANDB_ID=${SFT_WANDB_ID:-$(generate_wandb_id)}
echo "SFT_WANDB_ID=$SFT_WANDB_ID" >> "$STATE_FILE"
echo "SFT_WANDB_NAME=${WANDB_RUN}-sft" >> "$STATE_FILE"
export WANDB_RUN_ID=$SFT_WANDB_ID
export WANDB_EVAL_RUN="${WANDB_RUN}-sft"
export WANDB_PROJECT
fi
python -m scripts.chat_sft \
--device_batch_size=$SFT_DEVICE_BATCH \
--target_examples_per_step=$SFT_TARGET_EXAMPLES \
--run="${WANDB_RUN}-sft" \
--num_iterations=$SFT_NUM_STEPS \
--eval_every=$SFT_EVAL_EVERY \
--eval_steps=$SFT_EVAL_STEPS \
--eval_metrics_every=$SFT_EVAL_METRICS_EVERY \
--eval_metrics_max_problems=$SFT_EVAL_METRICS_MAX \
--checkpoint_every_steps=$SFT_CHECKPOINT_STEPS
if [ "$WANDB_RUN" != "dummy" ]; then
unset WANDB_RUN_ID
unset WANDB_EVAL_RUN
fi
fi
# Chat CLI
# python -m scripts.chat_cli -p "Why is the sky blue?"
# Chat Web
# python -m scripts.chat_web
if (( RUN_REPORT )); then
python -m nanochat.report generate
fi
if [ "$RUN_STAGE_EVALS" != "1" ] && (( RUN_BASE || RUN_MID || RUN_SFT )); then
echo "[runmps] Inline evals disabled. Run 'bash dev/runmps_evals.sh' to compute metrics from saved checkpoints."
fi