#!/bin/bash set -euo pipefail shopt -s nullglob # Evaluate every checkpoint generated by dev/runmps.sh and log back to the # original W&B runs. # Usage: # bash dev/runmps_evals.sh STATE_FILE=".runmps_wandb_ids" if [ ! -d ".venv" ]; then echo "[runmps_evals] .venv not found. Run dev/runmps.sh first." >&2 exit 1 fi source .venv/bin/activate if [ -f "$STATE_FILE" ]; then . "$STATE_FILE" fi export OMP_NUM_THREADS=${OMP_NUM_THREADS:-1} export NANOCHAT_BASE_DIR=${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat} export WANDB_PROJECT=${WANDB_PROJECT:-nanochat} WANDB_RUN=${WANDB_RUN:-dummy} if [ "$WANDB_RUN" != "dummy" ] && [ -z "${WANDB_MODE:-}" ]; then export WANDB_MODE=online fi # Auto-populate WANDB_API_KEY from ~/.netrc when talking to a local W&B server, # mirroring the training script behaviour so evals can log without manual export. if [ -z "${WANDB_API_KEY:-}" ] && [ -f "$HOME/.netrc" ]; then WANDB_BASE_URL_DEFAULT=${WANDB_BASE_URL:-http://localhost:8080} if printf '%s' "$WANDB_BASE_URL_DEFAULT" | grep -q "localhost"; then HOST=$(printf '%s\n' "$WANDB_BASE_URL_DEFAULT" | sed -E 's#https?://##' | cut -d/ -f1) API_KEY=$(python - "$HOST" <<'PY' import sys from netrc import netrc host = sys.argv[1] auth = netrc().authenticators(host) if auth and auth[2]: print(auth[2], end="") PY ) if [ -n "$API_KEY" ]; then export WANDB_BASE_URL="$WANDB_BASE_URL_DEFAULT" export WANDB_API_KEY="$API_KEY" echo "[runmps_evals] Loaded WANDB_API_KEY for $WANDB_BASE_URL from ~/.netrc" fi fi fi generate_wandb_id() { python - <<'PY' import uuid print(uuid.uuid4().hex[:8]) PY } set_eval_run() { local prefix="$1" local default_name="$2" local id_var="${prefix}_EVAL_WANDB_ID" local name_var="${prefix}_EVAL_WANDB_NAME" local current_id=$(generate_wandb_id) local current_name="$default_name" eval "$id_var=$current_id" eval "$name_var='$current_name'" } SEQ_LEN=${SEQ_LEN:-1024} DEVICE_BATCH=${DEVICE_BATCH:-16} TOTAL_BATCH=$((DEVICE_BATCH * SEQ_LEN)) EVAL_SEQUENCES=10000 EVAL_BATCH_MULT=4 EVAL_TOKENS=$((TOTAL_BATCH * EVAL_BATCH_MULT)) checkpoint_steps() { local dir="$1" python - "$dir" <<'PY' import os, sys dir_path = sys.argv[1] if not os.path.isdir(dir_path): sys.exit(0) steps = [] for fname in os.listdir(dir_path): if fname.startswith("model_") and fname.endswith(".pt"): try: step = int(fname.split("_")[1].split(".")[0]) except ValueError: continue steps.append(step) for step in sorted(steps): print(step) PY } checkpoint_tag() { local dir="$1" python - "$dir" <<'PY' import os, sys from nanochat.checkpoint_manager import find_largest_model base_dir = sys.argv[1] if not os.path.isdir(base_dir): sys.exit(0) try: tag = find_largest_model(base_dir) except FileNotFoundError: sys.exit(0) print(tag) PY } run_base_evals() { local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/base_checkpoints") if [ -z "$tag" ]; then return fi local dir="$NANOCHAT_BASE_DIR/base_checkpoints/$tag" local steps="$(checkpoint_steps "$dir")" if [ -z "$steps" ]; then return fi echo "[runmps_evals] Running base eval on checkpoints in $dir..." set_eval_run BASE "${WANDB_RUN:-base}-eval" export WANDB_RUN_ID=$BASE_EVAL_WANDB_ID export WANDB_EVAL_RUN=$BASE_EVAL_WANDB_NAME export WANDB_PROJECT while read -r step; do [ -z "$step" ] && continue echo " • base checkpoint step $step" # run base_loss to print sample completions, but keep it off W&B saved_wandb_run_id=${WANDB_RUN_ID:-} saved_wandb_eval_run=${WANDB_EVAL_RUN:-} unset WANDB_RUN_ID unset WANDB_EVAL_RUN python -m scripts.base_loss \ --device_batch_size=$DEVICE_BATCH \ --split_tokens=$EVAL_TOKENS \ --model_step=$step if [ -n "$saved_wandb_run_id" ]; then export WANDB_RUN_ID=$saved_wandb_run_id fi if [ -n "$saved_wandb_eval_run" ]; then export WANDB_EVAL_RUN=$saved_wandb_eval_run fi python -m scripts.base_eval \ --max-per-task=16 \ --model-step=$step done <<< "$steps" unset WANDB_RUN_ID unset WANDB_EVAL_RUN } run_mid_evals() { local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/mid_checkpoints") if [ -z "$tag" ]; then return fi local dir="$NANOCHAT_BASE_DIR/mid_checkpoints/$tag" local steps="$(checkpoint_steps "$dir")" if [ -z "$steps" ]; then return fi echo "[runmps_evals] Running mid-stage chat eval on checkpoints in $dir..." set_eval_run MID "${WANDB_RUN:-mid}-mid-eval" export WANDB_RUN_ID=$MID_EVAL_WANDB_ID export WANDB_EVAL_RUN=$MID_EVAL_WANDB_NAME export WANDB_PROJECT while read -r step; do [ -z "$step" ] && continue echo " • mid checkpoint step $step" python -m scripts.chat_eval \ --source=mid \ --step=$step \ --batch-size=$DEVICE_BATCH \ --max-new-tokens=128 \ --max-problems=20 done <<< "$steps" unset WANDB_RUN_ID unset WANDB_EVAL_RUN } run_sft_evals() { local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/chatsft_checkpoints") if [ -z "$tag" ]; then return fi local dir="$NANOCHAT_BASE_DIR/chatsft_checkpoints/$tag" local steps="$(checkpoint_steps "$dir")" if [ -z "$steps" ]; then return fi echo "[runmps_evals] Running SFT chat eval on checkpoints in $dir..." set_eval_run SFT "${WANDB_RUN:-sft}-sft-eval" export WANDB_RUN_ID=$SFT_EVAL_WANDB_ID export WANDB_EVAL_RUN=$SFT_EVAL_WANDB_NAME export WANDB_PROJECT while read -r step; do [ -z "$step" ] && continue echo " • SFT checkpoint step $step" python -m scripts.chat_eval \ --source=sft \ --step=$step \ --batch-size=$DEVICE_BATCH \ --max-new-tokens=128 \ --max-problems=20 done <<< "$steps" unset WANDB_RUN_ID unset WANDB_EVAL_RUN } run_base_evals run_mid_evals run_sft_evals echo "[runmps_evals] Regenerating report with fresh metrics..." python -m nanochat.report generate echo "[runmps_evals] Done."