nanochat/dev/runmps_evals.sh
William Thurston b1d49aade5 Add scripts for running evaluations and training with W&B integration
- Added `dev/runmps_evals.sh` for evaluating checkpoints and logging results to W&B.
- Introduced `dev/runmps.sh` for orchestrating training stages with W&B support.
- Updated `.gitignore` to include `wandb/` and `.runmps_wandb_ids`.
- Changed permissions for `dev/runcpu.sh` and added executable flag.
- Enhanced existing scripts to log metrics to W&B during training and evaluation processes.
2025-11-05 11:49:50 -08:00

224 lines
6.2 KiB
Bash
Executable File

#!/bin/bash
set -euo pipefail
shopt -s nullglob
# Evaluate every checkpoint generated by dev/runmps.sh and log back to the
# original W&B runs.
# Usage:
# bash dev/runmps_evals.sh
STATE_FILE=".runmps_wandb_ids"
if [ ! -d ".venv" ]; then
echo "[runmps_evals] .venv not found. Run dev/runmps.sh first." >&2
exit 1
fi
source .venv/bin/activate
if [ -f "$STATE_FILE" ]; then
. "$STATE_FILE"
fi
export OMP_NUM_THREADS=${OMP_NUM_THREADS:-1}
export NANOCHAT_BASE_DIR=${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat}
export WANDB_PROJECT=${WANDB_PROJECT:-nanochat}
WANDB_RUN=${WANDB_RUN:-dummy}
if [ "$WANDB_RUN" != "dummy" ] && [ -z "${WANDB_MODE:-}" ]; then
export WANDB_MODE=online
fi
# Auto-populate WANDB_API_KEY from ~/.netrc when talking to a local W&B server,
# mirroring the training script behaviour so evals can log without manual export.
if [ -z "${WANDB_API_KEY:-}" ] && [ -f "$HOME/.netrc" ]; then
WANDB_BASE_URL_DEFAULT=${WANDB_BASE_URL:-http://localhost:8080}
if printf '%s' "$WANDB_BASE_URL_DEFAULT" | grep -q "localhost"; then
HOST=$(printf '%s\n' "$WANDB_BASE_URL_DEFAULT" | sed -E 's#https?://##' | cut -d/ -f1)
API_KEY=$(python - "$HOST" <<'PY'
import sys
from netrc import netrc
host = sys.argv[1]
auth = netrc().authenticators(host)
if auth and auth[2]:
print(auth[2], end="")
PY
)
if [ -n "$API_KEY" ]; then
export WANDB_BASE_URL="$WANDB_BASE_URL_DEFAULT"
export WANDB_API_KEY="$API_KEY"
echo "[runmps_evals] Loaded WANDB_API_KEY for $WANDB_BASE_URL from ~/.netrc"
fi
fi
fi
generate_wandb_id() {
python - <<'PY'
import uuid
print(uuid.uuid4().hex[:8])
PY
}
set_eval_run() {
local prefix="$1"
local default_name="$2"
local id_var="${prefix}_EVAL_WANDB_ID"
local name_var="${prefix}_EVAL_WANDB_NAME"
local current_id=$(generate_wandb_id)
local current_name="$default_name"
eval "$id_var=$current_id"
eval "$name_var='$current_name'"
}
SEQ_LEN=${SEQ_LEN:-1024}
DEVICE_BATCH=${DEVICE_BATCH:-16}
TOTAL_BATCH=$((DEVICE_BATCH * SEQ_LEN))
EVAL_SEQUENCES=10000
EVAL_BATCH_MULT=4
EVAL_TOKENS=$((TOTAL_BATCH * EVAL_BATCH_MULT))
checkpoint_steps() {
local dir="$1"
python - "$dir" <<'PY'
import os, sys
dir_path = sys.argv[1]
if not os.path.isdir(dir_path):
sys.exit(0)
steps = []
for fname in os.listdir(dir_path):
if fname.startswith("model_") and fname.endswith(".pt"):
try:
step = int(fname.split("_")[1].split(".")[0])
except ValueError:
continue
steps.append(step)
for step in sorted(steps):
print(step)
PY
}
checkpoint_tag() {
local dir="$1"
python - "$dir" <<'PY'
import os, sys
from nanochat.checkpoint_manager import find_largest_model
base_dir = sys.argv[1]
if not os.path.isdir(base_dir):
sys.exit(0)
try:
tag = find_largest_model(base_dir)
except FileNotFoundError:
sys.exit(0)
print(tag)
PY
}
run_base_evals() {
local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/base_checkpoints")
if [ -z "$tag" ]; then
return
fi
local dir="$NANOCHAT_BASE_DIR/base_checkpoints/$tag"
local steps="$(checkpoint_steps "$dir")"
if [ -z "$steps" ]; then
return
fi
echo "[runmps_evals] Running base eval on checkpoints in $dir..."
set_eval_run BASE "${WANDB_RUN:-base}-eval"
export WANDB_RUN_ID=$BASE_EVAL_WANDB_ID
export WANDB_EVAL_RUN=$BASE_EVAL_WANDB_NAME
export WANDB_PROJECT
while read -r step; do
[ -z "$step" ] && continue
echo " • base checkpoint step $step"
# run base_loss to print sample completions, but keep it off W&B
saved_wandb_run_id=${WANDB_RUN_ID:-}
saved_wandb_eval_run=${WANDB_EVAL_RUN:-}
unset WANDB_RUN_ID
unset WANDB_EVAL_RUN
python -m scripts.base_loss \
--device_batch_size=$DEVICE_BATCH \
--split_tokens=$EVAL_TOKENS \
--model_step=$step
if [ -n "$saved_wandb_run_id" ]; then
export WANDB_RUN_ID=$saved_wandb_run_id
fi
if [ -n "$saved_wandb_eval_run" ]; then
export WANDB_EVAL_RUN=$saved_wandb_eval_run
fi
python -m scripts.base_eval \
--max-per-task=16 \
--model-step=$step
done <<< "$steps"
unset WANDB_RUN_ID
unset WANDB_EVAL_RUN
}
run_mid_evals() {
local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/mid_checkpoints")
if [ -z "$tag" ]; then
return
fi
local dir="$NANOCHAT_BASE_DIR/mid_checkpoints/$tag"
local steps="$(checkpoint_steps "$dir")"
if [ -z "$steps" ]; then
return
fi
echo "[runmps_evals] Running mid-stage chat eval on checkpoints in $dir..."
set_eval_run MID "${WANDB_RUN:-mid}-mid-eval"
export WANDB_RUN_ID=$MID_EVAL_WANDB_ID
export WANDB_EVAL_RUN=$MID_EVAL_WANDB_NAME
export WANDB_PROJECT
while read -r step; do
[ -z "$step" ] && continue
echo " • mid checkpoint step $step"
python -m scripts.chat_eval \
--source=mid \
--step=$step \
--batch-size=$DEVICE_BATCH \
--max-new-tokens=128 \
--max-problems=20
done <<< "$steps"
unset WANDB_RUN_ID
unset WANDB_EVAL_RUN
}
run_sft_evals() {
local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/chatsft_checkpoints")
if [ -z "$tag" ]; then
return
fi
local dir="$NANOCHAT_BASE_DIR/chatsft_checkpoints/$tag"
local steps="$(checkpoint_steps "$dir")"
if [ -z "$steps" ]; then
return
fi
echo "[runmps_evals] Running SFT chat eval on checkpoints in $dir..."
set_eval_run SFT "${WANDB_RUN:-sft}-sft-eval"
export WANDB_RUN_ID=$SFT_EVAL_WANDB_ID
export WANDB_EVAL_RUN=$SFT_EVAL_WANDB_NAME
export WANDB_PROJECT
while read -r step; do
[ -z "$step" ] && continue
echo " • SFT checkpoint step $step"
python -m scripts.chat_eval \
--source=sft \
--step=$step \
--batch-size=$DEVICE_BATCH \
--max-new-tokens=128 \
--max-problems=20
done <<< "$steps"
unset WANDB_RUN_ID
unset WANDB_EVAL_RUN
}
run_base_evals
run_mid_evals
run_sft_evals
echo "[runmps_evals] Regenerating report with fresh metrics..."
python -m nanochat.report generate
echo "[runmps_evals] Done."