mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-13 16:33:41 +00:00
- Added `dev/runmps_evals.sh` for evaluating checkpoints and logging results to W&B. - Introduced `dev/runmps.sh` for orchestrating training stages with W&B support. - Updated `.gitignore` to include `wandb/` and `.runmps_wandb_ids`. - Changed permissions for `dev/runcpu.sh` and added executable flag. - Enhanced existing scripts to log metrics to W&B during training and evaluation processes.
224 lines
6.2 KiB
Bash
Executable File
224 lines
6.2 KiB
Bash
Executable File
#!/bin/bash
|
|
|
|
set -euo pipefail
|
|
shopt -s nullglob
|
|
|
|
# Evaluate every checkpoint generated by dev/runmps.sh and log back to the
|
|
# original W&B runs.
|
|
# Usage:
|
|
# bash dev/runmps_evals.sh
|
|
|
|
STATE_FILE=".runmps_wandb_ids"
|
|
|
|
if [ ! -d ".venv" ]; then
|
|
echo "[runmps_evals] .venv not found. Run dev/runmps.sh first." >&2
|
|
exit 1
|
|
fi
|
|
|
|
source .venv/bin/activate
|
|
|
|
if [ -f "$STATE_FILE" ]; then
|
|
. "$STATE_FILE"
|
|
fi
|
|
|
|
export OMP_NUM_THREADS=${OMP_NUM_THREADS:-1}
|
|
export NANOCHAT_BASE_DIR=${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat}
|
|
export WANDB_PROJECT=${WANDB_PROJECT:-nanochat}
|
|
WANDB_RUN=${WANDB_RUN:-dummy}
|
|
if [ "$WANDB_RUN" != "dummy" ] && [ -z "${WANDB_MODE:-}" ]; then
|
|
export WANDB_MODE=online
|
|
fi
|
|
|
|
# Auto-populate WANDB_API_KEY from ~/.netrc when talking to a local W&B server,
|
|
# mirroring the training script behaviour so evals can log without manual export.
|
|
if [ -z "${WANDB_API_KEY:-}" ] && [ -f "$HOME/.netrc" ]; then
|
|
WANDB_BASE_URL_DEFAULT=${WANDB_BASE_URL:-http://localhost:8080}
|
|
if printf '%s' "$WANDB_BASE_URL_DEFAULT" | grep -q "localhost"; then
|
|
HOST=$(printf '%s\n' "$WANDB_BASE_URL_DEFAULT" | sed -E 's#https?://##' | cut -d/ -f1)
|
|
API_KEY=$(python - "$HOST" <<'PY'
|
|
import sys
|
|
from netrc import netrc
|
|
|
|
host = sys.argv[1]
|
|
auth = netrc().authenticators(host)
|
|
if auth and auth[2]:
|
|
print(auth[2], end="")
|
|
PY
|
|
)
|
|
if [ -n "$API_KEY" ]; then
|
|
export WANDB_BASE_URL="$WANDB_BASE_URL_DEFAULT"
|
|
export WANDB_API_KEY="$API_KEY"
|
|
echo "[runmps_evals] Loaded WANDB_API_KEY for $WANDB_BASE_URL from ~/.netrc"
|
|
fi
|
|
fi
|
|
fi
|
|
|
|
generate_wandb_id() {
|
|
python - <<'PY'
|
|
import uuid
|
|
print(uuid.uuid4().hex[:8])
|
|
PY
|
|
}
|
|
|
|
set_eval_run() {
|
|
local prefix="$1"
|
|
local default_name="$2"
|
|
local id_var="${prefix}_EVAL_WANDB_ID"
|
|
local name_var="${prefix}_EVAL_WANDB_NAME"
|
|
local current_id=$(generate_wandb_id)
|
|
local current_name="$default_name"
|
|
eval "$id_var=$current_id"
|
|
eval "$name_var='$current_name'"
|
|
}
|
|
|
|
SEQ_LEN=${SEQ_LEN:-1024}
|
|
DEVICE_BATCH=${DEVICE_BATCH:-16}
|
|
TOTAL_BATCH=$((DEVICE_BATCH * SEQ_LEN))
|
|
EVAL_SEQUENCES=10000
|
|
EVAL_BATCH_MULT=4
|
|
EVAL_TOKENS=$((TOTAL_BATCH * EVAL_BATCH_MULT))
|
|
|
|
checkpoint_steps() {
|
|
local dir="$1"
|
|
python - "$dir" <<'PY'
|
|
import os, sys
|
|
dir_path = sys.argv[1]
|
|
if not os.path.isdir(dir_path):
|
|
sys.exit(0)
|
|
steps = []
|
|
for fname in os.listdir(dir_path):
|
|
if fname.startswith("model_") and fname.endswith(".pt"):
|
|
try:
|
|
step = int(fname.split("_")[1].split(".")[0])
|
|
except ValueError:
|
|
continue
|
|
steps.append(step)
|
|
for step in sorted(steps):
|
|
print(step)
|
|
PY
|
|
}
|
|
|
|
checkpoint_tag() {
|
|
local dir="$1"
|
|
python - "$dir" <<'PY'
|
|
import os, sys
|
|
from nanochat.checkpoint_manager import find_largest_model
|
|
base_dir = sys.argv[1]
|
|
if not os.path.isdir(base_dir):
|
|
sys.exit(0)
|
|
try:
|
|
tag = find_largest_model(base_dir)
|
|
except FileNotFoundError:
|
|
sys.exit(0)
|
|
print(tag)
|
|
PY
|
|
}
|
|
|
|
run_base_evals() {
|
|
local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/base_checkpoints")
|
|
if [ -z "$tag" ]; then
|
|
return
|
|
fi
|
|
local dir="$NANOCHAT_BASE_DIR/base_checkpoints/$tag"
|
|
local steps="$(checkpoint_steps "$dir")"
|
|
if [ -z "$steps" ]; then
|
|
return
|
|
fi
|
|
echo "[runmps_evals] Running base eval on checkpoints in $dir..."
|
|
set_eval_run BASE "${WANDB_RUN:-base}-eval"
|
|
export WANDB_RUN_ID=$BASE_EVAL_WANDB_ID
|
|
export WANDB_EVAL_RUN=$BASE_EVAL_WANDB_NAME
|
|
export WANDB_PROJECT
|
|
while read -r step; do
|
|
[ -z "$step" ] && continue
|
|
echo " • base checkpoint step $step"
|
|
# run base_loss to print sample completions, but keep it off W&B
|
|
saved_wandb_run_id=${WANDB_RUN_ID:-}
|
|
saved_wandb_eval_run=${WANDB_EVAL_RUN:-}
|
|
unset WANDB_RUN_ID
|
|
unset WANDB_EVAL_RUN
|
|
python -m scripts.base_loss \
|
|
--device_batch_size=$DEVICE_BATCH \
|
|
--split_tokens=$EVAL_TOKENS \
|
|
--model_step=$step
|
|
if [ -n "$saved_wandb_run_id" ]; then
|
|
export WANDB_RUN_ID=$saved_wandb_run_id
|
|
fi
|
|
if [ -n "$saved_wandb_eval_run" ]; then
|
|
export WANDB_EVAL_RUN=$saved_wandb_eval_run
|
|
fi
|
|
python -m scripts.base_eval \
|
|
--max-per-task=16 \
|
|
--model-step=$step
|
|
done <<< "$steps"
|
|
unset WANDB_RUN_ID
|
|
unset WANDB_EVAL_RUN
|
|
}
|
|
|
|
run_mid_evals() {
|
|
local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/mid_checkpoints")
|
|
if [ -z "$tag" ]; then
|
|
return
|
|
fi
|
|
local dir="$NANOCHAT_BASE_DIR/mid_checkpoints/$tag"
|
|
local steps="$(checkpoint_steps "$dir")"
|
|
if [ -z "$steps" ]; then
|
|
return
|
|
fi
|
|
echo "[runmps_evals] Running mid-stage chat eval on checkpoints in $dir..."
|
|
set_eval_run MID "${WANDB_RUN:-mid}-mid-eval"
|
|
export WANDB_RUN_ID=$MID_EVAL_WANDB_ID
|
|
export WANDB_EVAL_RUN=$MID_EVAL_WANDB_NAME
|
|
export WANDB_PROJECT
|
|
while read -r step; do
|
|
[ -z "$step" ] && continue
|
|
echo " • mid checkpoint step $step"
|
|
python -m scripts.chat_eval \
|
|
--source=mid \
|
|
--step=$step \
|
|
--batch-size=$DEVICE_BATCH \
|
|
--max-new-tokens=128 \
|
|
--max-problems=20
|
|
done <<< "$steps"
|
|
unset WANDB_RUN_ID
|
|
unset WANDB_EVAL_RUN
|
|
}
|
|
|
|
run_sft_evals() {
|
|
local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/chatsft_checkpoints")
|
|
if [ -z "$tag" ]; then
|
|
return
|
|
fi
|
|
local dir="$NANOCHAT_BASE_DIR/chatsft_checkpoints/$tag"
|
|
local steps="$(checkpoint_steps "$dir")"
|
|
if [ -z "$steps" ]; then
|
|
return
|
|
fi
|
|
echo "[runmps_evals] Running SFT chat eval on checkpoints in $dir..."
|
|
set_eval_run SFT "${WANDB_RUN:-sft}-sft-eval"
|
|
export WANDB_RUN_ID=$SFT_EVAL_WANDB_ID
|
|
export WANDB_EVAL_RUN=$SFT_EVAL_WANDB_NAME
|
|
export WANDB_PROJECT
|
|
while read -r step; do
|
|
[ -z "$step" ] && continue
|
|
echo " • SFT checkpoint step $step"
|
|
python -m scripts.chat_eval \
|
|
--source=sft \
|
|
--step=$step \
|
|
--batch-size=$DEVICE_BATCH \
|
|
--max-new-tokens=128 \
|
|
--max-problems=20
|
|
done <<< "$steps"
|
|
unset WANDB_RUN_ID
|
|
unset WANDB_EVAL_RUN
|
|
}
|
|
|
|
run_base_evals
|
|
run_mid_evals
|
|
run_sft_evals
|
|
|
|
echo "[runmps_evals] Regenerating report with fresh metrics..."
|
|
python -m nanochat.report generate
|
|
|
|
echo "[runmps_evals] Done."
|