mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-13 16:33:41 +00:00
Add scripts for running evaluations and training with W&B integration
- Added `dev/runmps_evals.sh` for evaluating checkpoints and logging results to W&B. - Introduced `dev/runmps.sh` for orchestrating training stages with W&B support. - Updated `.gitignore` to include `wandb/` and `.runmps_wandb_ids`. - Changed permissions for `dev/runcpu.sh` and added executable flag. - Enhanced existing scripts to log metrics to W&B during training and evaluation processes.
This commit is contained in:
parent
c75fe54aa7
commit
b1d49aade5
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -4,4 +4,6 @@ __pycache__/
|
|||
rustbpe/target/
|
||||
dev-ignore/
|
||||
report.md
|
||||
eval_bundle/
|
||||
eval_bundle/
|
||||
wandb/
|
||||
.runmps_wandb_ids
|
||||
|
|
|
|||
0
dev/runcpu.sh
Normal file → Executable file
0
dev/runcpu.sh
Normal file → Executable file
336
dev/runmps.sh
Executable file
336
dev/runmps.sh
Executable file
|
|
@ -0,0 +1,336 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
|
||||
# Run as:
|
||||
# bash dev/cpu_demo_run.sh
|
||||
|
||||
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
|
||||
# Think of this run as educational/fun demo, not something you should expect to work well.
|
||||
# This is also why I hide this script away in dev/
|
||||
|
||||
# Stage selection (allow running only a subset, e.g. --stage=sft or --from=mid)
|
||||
RUN_BASE=1
|
||||
RUN_MID=1
|
||||
RUN_SFT=1
|
||||
RUN_REPORT=1
|
||||
STAGE_ONLY=""
|
||||
FROM_STAGE=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--stage=*)
|
||||
STAGE_ONLY="${1#*=}"
|
||||
;;
|
||||
--base|--mid|--sft|--report)
|
||||
STAGE_ONLY="${1#--}"
|
||||
;;
|
||||
--from=*)
|
||||
FROM_STAGE="${1#*=}"
|
||||
;;
|
||||
--from-base|--from-mid|--from-sft)
|
||||
FROM_STAGE="${1#--from-}"
|
||||
;;
|
||||
--help|-h)
|
||||
cat <<'EOF'
|
||||
Usage: bash dev/runmps.sh [options]
|
||||
|
||||
Options:
|
||||
--stage=<base|mid|sft|report> Run only the specified stage.
|
||||
--from=<base|mid|sft> Run from the specified stage through the end.
|
||||
--help Show this help message.
|
||||
|
||||
Environment variables (same as before) control batch sizes, WANDB run names, etc.
|
||||
EOF
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown argument: $1" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
|
||||
if [[ -n "$FROM_STAGE" ]]; then
|
||||
RUN_BASE=0
|
||||
RUN_MID=0
|
||||
RUN_SFT=0
|
||||
RUN_REPORT=0
|
||||
case "$FROM_STAGE" in
|
||||
base)
|
||||
RUN_BASE=1
|
||||
RUN_MID=1
|
||||
RUN_SFT=1
|
||||
RUN_REPORT=1
|
||||
;;
|
||||
mid)
|
||||
RUN_MID=1
|
||||
RUN_SFT=1
|
||||
;;
|
||||
sft)
|
||||
RUN_SFT=1
|
||||
;;
|
||||
*)
|
||||
echo "Unknown --from stage: $FROM_STAGE" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
if [[ -n "$STAGE_ONLY" ]]; then
|
||||
RUN_BASE=0
|
||||
RUN_MID=0
|
||||
RUN_SFT=0
|
||||
RUN_REPORT=0
|
||||
case "$STAGE_ONLY" in
|
||||
base)
|
||||
RUN_BASE=1
|
||||
;;
|
||||
mid)
|
||||
RUN_MID=1
|
||||
;;
|
||||
sft)
|
||||
RUN_SFT=1
|
||||
;;
|
||||
report)
|
||||
RUN_REPORT=1
|
||||
;;
|
||||
*)
|
||||
echo "Unknown --stage value: $STAGE_ONLY" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
if [[ -n "$STAGE_ONLY" || -n "$FROM_STAGE" ]]; then
|
||||
# avoid regenerating reports when running a subset unless specifically requested
|
||||
if [[ "$STAGE_ONLY" != "report" && "$FROM_STAGE" != "base" ]]; then
|
||||
RUN_REPORT=0
|
||||
fi
|
||||
fi
|
||||
|
||||
# all the setup stuff
|
||||
export OMP_NUM_THREADS=1
|
||||
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra cpu
|
||||
source .venv/bin/activate
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
if [ "$WANDB_RUN" != "dummy" ] && [ -z "$WANDB_MODE" ]; then
|
||||
export WANDB_MODE=online
|
||||
fi
|
||||
|
||||
# Batch/sequence configuration
|
||||
BASE_DEPTH=${BASE_DEPTH:-4}
|
||||
SEQ_LEN=${SEQ_LEN:-1024}
|
||||
DEVICE_BATCH=${DEVICE_BATCH:-16}
|
||||
TOTAL_BATCH=${TOTAL_BATCH:-$((DEVICE_BATCH * SEQ_LEN))} # tokens per optimizer step
|
||||
EVAL_SEQUENCES=10000
|
||||
EVAL_STEPS=$(((EVAL_SEQUENCES + DEVICE_BATCH - 1) / DEVICE_BATCH))
|
||||
EVAL_BATCH_MULT=4 # evaluate on 4 full batches
|
||||
EVAL_TOKENS=$((TOTAL_BATCH * EVAL_BATCH_MULT))
|
||||
MID_NUM_STEPS=6144
|
||||
SFT_NUM_STEPS=${SFT_NUM_STEPS:-3072}
|
||||
CHECKPOINT_EVERY_SEQ=${CHECKPOINT_EVERY_SEQ:-10000}
|
||||
RUN_STAGE_EVALS=${RUN_STAGE_EVALS:-0}
|
||||
WANDB_PROJECT=${WANDB_PROJECT:-nanochat}
|
||||
TARGET_PARAM_DATA_RATIO=${TARGET_PARAM_DATA_RATIO:-20}
|
||||
SFT_DEVICE_BATCH=${SFT_DEVICE_BATCH:-$DEVICE_BATCH}
|
||||
SFT_TARGET_EXAMPLES=${SFT_TARGET_EXAMPLES:-$DEVICE_BATCH}
|
||||
SFT_EVAL_EVERY=${SFT_EVAL_EVERY:-0}
|
||||
SFT_EVAL_STEPS=${SFT_EVAL_STEPS:-0}
|
||||
SFT_EVAL_METRICS_EVERY=${SFT_EVAL_METRICS_EVERY:-0}
|
||||
SFT_EVAL_METRICS_MAX=${SFT_EVAL_METRICS_MAX:-0}
|
||||
|
||||
STATE_FILE=".runmps_wandb_ids"
|
||||
printf '' > "$STATE_FILE"
|
||||
echo "WANDB_PROJECT=$WANDB_PROJECT" >> "$STATE_FILE"
|
||||
|
||||
generate_wandb_id() {
|
||||
python - <<'PY'
|
||||
import uuid
|
||||
print(uuid.uuid4().hex[:8])
|
||||
PY
|
||||
}
|
||||
|
||||
BASE_SEQS_PER_STEP=$((TOTAL_BATCH / SEQ_LEN))
|
||||
if [ $BASE_SEQS_PER_STEP -le 0 ]; then BASE_SEQS_PER_STEP=1; fi
|
||||
if [ "$CHECKPOINT_EVERY_SEQ" -le 0 ]; then
|
||||
BASE_CHECKPOINT_STEPS=0
|
||||
else
|
||||
BASE_CHECKPOINT_STEPS=$(((CHECKPOINT_EVERY_SEQ + BASE_SEQS_PER_STEP - 1) / BASE_SEQS_PER_STEP))
|
||||
fi
|
||||
|
||||
MID_SEQS_PER_STEP=$((TOTAL_BATCH / SEQ_LEN))
|
||||
if [ $MID_SEQS_PER_STEP -le 0 ]; then MID_SEQS_PER_STEP=1; fi
|
||||
if [ "$CHECKPOINT_EVERY_SEQ" -le 0 ]; then
|
||||
MID_CHECKPOINT_STEPS=0
|
||||
else
|
||||
MID_CHECKPOINT_STEPS=$(((CHECKPOINT_EVERY_SEQ + MID_SEQS_PER_STEP - 1) / MID_SEQS_PER_STEP))
|
||||
fi
|
||||
|
||||
SFT_SEQS_PER_STEP=$SFT_DEVICE_BATCH
|
||||
if [ $SFT_SEQS_PER_STEP -le 0 ]; then SFT_SEQS_PER_STEP=1; fi
|
||||
if [ "$CHECKPOINT_EVERY_SEQ" -le 0 ]; then
|
||||
SFT_CHECKPOINT_STEPS=0
|
||||
else
|
||||
SFT_CHECKPOINT_STEPS=$(((CHECKPOINT_EVERY_SEQ + SFT_SEQS_PER_STEP - 1) / SFT_SEQS_PER_STEP))
|
||||
fi
|
||||
|
||||
# Auto-populate WANDB_API_KEY from ~/.netrc when talking to a local W&B server.
|
||||
# Mirrors the helper used in TinyRecursiveModels/pretrain_text.py so we can log
|
||||
# to a self-hosted instance without manual export each time.
|
||||
if [ -z "$WANDB_API_KEY" ] && [ -f "$HOME/.netrc" ]; then
|
||||
# Allow custom WANDB_BASE_URL; default to localhost if user sets WANDB
|
||||
WANDB_BASE_URL_DEFAULT=${WANDB_BASE_URL:-http://localhost:8080}
|
||||
if printf '%s' "$WANDB_BASE_URL_DEFAULT" | grep -q "localhost"; then
|
||||
HOST=$(printf '%s\n' "$WANDB_BASE_URL_DEFAULT" | sed -E 's#https?://##' | cut -d/ -f1)
|
||||
API_KEY=$(python - "$HOST" <<'PY'
|
||||
import sys
|
||||
from netrc import netrc
|
||||
|
||||
host = sys.argv[1]
|
||||
auth = netrc().authenticators(host)
|
||||
if auth and auth[2]:
|
||||
print(auth[2], end="")
|
||||
PY
|
||||
)
|
||||
if [ -n "$API_KEY" ]; then
|
||||
export WANDB_BASE_URL="$WANDB_BASE_URL_DEFAULT"
|
||||
export WANDB_API_KEY="$API_KEY"
|
||||
echo "[runmps] Loaded WANDB_API_KEY for $WANDB_BASE_URL from ~/.netrc"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
||||
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
||||
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
|
||||
unzip -q eval_bundle.zip
|
||||
rm eval_bundle.zip
|
||||
mv eval_bundle $NANOCHAT_BASE_DIR
|
||||
fi
|
||||
|
||||
if (( RUN_BASE )); then
|
||||
# wipe the report
|
||||
python -m nanochat.report reset
|
||||
|
||||
# train tokenizer on ~2B characters (download full shard set for extended training)
|
||||
python -m nanochat.dataset -n 240
|
||||
python -m scripts.tok_train --max_chars=2000000000
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# train a very small 4 layer model on the CPU
|
||||
# each optimization step processes a single sequence of 1024 tokens
|
||||
# we only run 50 steps of optimization (bump this to get better results)
|
||||
if [ "$WANDB_RUN" != "dummy" ]; then
|
||||
BASE_WANDB_ID=${BASE_WANDB_ID:-$(generate_wandb_id)}
|
||||
echo "BASE_WANDB_ID=$BASE_WANDB_ID" >> "$STATE_FILE"
|
||||
echo "BASE_WANDB_NAME=$WANDB_RUN" >> "$STATE_FILE"
|
||||
export WANDB_RUN_ID=$BASE_WANDB_ID
|
||||
export WANDB_EVAL_RUN=$WANDB_RUN
|
||||
export WANDB_PROJECT
|
||||
fi
|
||||
|
||||
python -m scripts.base_train \
|
||||
--depth=$BASE_DEPTH \
|
||||
--max_seq_len=$SEQ_LEN \
|
||||
--device_batch_size=$DEVICE_BATCH \
|
||||
--total_batch_size=$TOTAL_BATCH \
|
||||
--target_param_data_ratio=$TARGET_PARAM_DATA_RATIO \
|
||||
--run="$WANDB_RUN" \
|
||||
--eval_every=$EVAL_STEPS \
|
||||
--eval_tokens=$EVAL_TOKENS \
|
||||
--core_metric_every=-1 \
|
||||
--sample_every=-1 \
|
||||
--checkpoint_every_steps=$BASE_CHECKPOINT_STEPS
|
||||
|
||||
if [ "$WANDB_RUN" != "dummy" ]; then
|
||||
unset WANDB_RUN_ID
|
||||
unset WANDB_EVAL_RUN
|
||||
fi
|
||||
|
||||
if [ "$RUN_STAGE_EVALS" = "1" ]; then
|
||||
python -m scripts.base_loss --device_batch_size=$DEVICE_BATCH --split_tokens=$EVAL_TOKENS
|
||||
python -m scripts.base_eval --max-per-task=16
|
||||
fi
|
||||
fi
|
||||
|
||||
if (( RUN_MID )); then
|
||||
# midtraining
|
||||
if [ "$WANDB_RUN" != "dummy" ]; then
|
||||
MID_WANDB_ID=${MID_WANDB_ID:-$(generate_wandb_id)}
|
||||
echo "MID_WANDB_ID=$MID_WANDB_ID" >> "$STATE_FILE"
|
||||
echo "MID_WANDB_NAME=${WANDB_RUN}-mid" >> "$STATE_FILE"
|
||||
export WANDB_RUN_ID=$MID_WANDB_ID
|
||||
export WANDB_EVAL_RUN="${WANDB_RUN}-mid"
|
||||
export WANDB_PROJECT
|
||||
fi
|
||||
|
||||
python -m scripts.mid_train \
|
||||
--max_seq_len=$SEQ_LEN \
|
||||
--device_batch_size=$DEVICE_BATCH \
|
||||
--total_batch_size=$TOTAL_BATCH \
|
||||
--run="${WANDB_RUN}-mid" \
|
||||
--eval_every=$EVAL_STEPS \
|
||||
--eval_tokens=$EVAL_TOKENS \
|
||||
--checkpoint_every_steps=$MID_CHECKPOINT_STEPS \
|
||||
--num_iterations=$MID_NUM_STEPS
|
||||
if [ "$WANDB_RUN" != "dummy" ]; then
|
||||
unset WANDB_RUN_ID
|
||||
unset WANDB_EVAL_RUN
|
||||
fi
|
||||
if [ "$RUN_STAGE_EVALS" = "1" ]; then
|
||||
# eval results will be terrible, this is just to execute the code paths.
|
||||
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
|
||||
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
|
||||
fi
|
||||
fi
|
||||
|
||||
if (( RUN_SFT )); then
|
||||
# SFT
|
||||
if [ "$WANDB_RUN" != "dummy" ]; then
|
||||
SFT_WANDB_ID=${SFT_WANDB_ID:-$(generate_wandb_id)}
|
||||
echo "SFT_WANDB_ID=$SFT_WANDB_ID" >> "$STATE_FILE"
|
||||
echo "SFT_WANDB_NAME=${WANDB_RUN}-sft" >> "$STATE_FILE"
|
||||
export WANDB_RUN_ID=$SFT_WANDB_ID
|
||||
export WANDB_EVAL_RUN="${WANDB_RUN}-sft"
|
||||
export WANDB_PROJECT
|
||||
fi
|
||||
|
||||
python -m scripts.chat_sft \
|
||||
--device_batch_size=$SFT_DEVICE_BATCH \
|
||||
--target_examples_per_step=$SFT_TARGET_EXAMPLES \
|
||||
--run="${WANDB_RUN}-sft" \
|
||||
--num_iterations=$SFT_NUM_STEPS \
|
||||
--eval_every=$SFT_EVAL_EVERY \
|
||||
--eval_steps=$SFT_EVAL_STEPS \
|
||||
--eval_metrics_every=$SFT_EVAL_METRICS_EVERY \
|
||||
--eval_metrics_max_problems=$SFT_EVAL_METRICS_MAX \
|
||||
--checkpoint_every_steps=$SFT_CHECKPOINT_STEPS
|
||||
|
||||
if [ "$WANDB_RUN" != "dummy" ]; then
|
||||
unset WANDB_RUN_ID
|
||||
unset WANDB_EVAL_RUN
|
||||
fi
|
||||
fi
|
||||
|
||||
# Chat CLI
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
||||
# Chat Web
|
||||
# python -m scripts.chat_web
|
||||
|
||||
if (( RUN_REPORT )); then
|
||||
python -m nanochat.report generate
|
||||
fi
|
||||
|
||||
if [ "$RUN_STAGE_EVALS" != "1" ] && (( RUN_BASE || RUN_MID || RUN_SFT )); then
|
||||
echo "[runmps] Inline evals disabled. Run 'bash dev/runmps_evals.sh' to compute metrics from saved checkpoints."
|
||||
fi
|
||||
223
dev/runmps_evals.sh
Executable file
223
dev/runmps_evals.sh
Executable file
|
|
@ -0,0 +1,223 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -euo pipefail
|
||||
shopt -s nullglob
|
||||
|
||||
# Evaluate every checkpoint generated by dev/runmps.sh and log back to the
|
||||
# original W&B runs.
|
||||
# Usage:
|
||||
# bash dev/runmps_evals.sh
|
||||
|
||||
STATE_FILE=".runmps_wandb_ids"
|
||||
|
||||
if [ ! -d ".venv" ]; then
|
||||
echo "[runmps_evals] .venv not found. Run dev/runmps.sh first." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
source .venv/bin/activate
|
||||
|
||||
if [ -f "$STATE_FILE" ]; then
|
||||
. "$STATE_FILE"
|
||||
fi
|
||||
|
||||
export OMP_NUM_THREADS=${OMP_NUM_THREADS:-1}
|
||||
export NANOCHAT_BASE_DIR=${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat}
|
||||
export WANDB_PROJECT=${WANDB_PROJECT:-nanochat}
|
||||
WANDB_RUN=${WANDB_RUN:-dummy}
|
||||
if [ "$WANDB_RUN" != "dummy" ] && [ -z "${WANDB_MODE:-}" ]; then
|
||||
export WANDB_MODE=online
|
||||
fi
|
||||
|
||||
# Auto-populate WANDB_API_KEY from ~/.netrc when talking to a local W&B server,
|
||||
# mirroring the training script behaviour so evals can log without manual export.
|
||||
if [ -z "${WANDB_API_KEY:-}" ] && [ -f "$HOME/.netrc" ]; then
|
||||
WANDB_BASE_URL_DEFAULT=${WANDB_BASE_URL:-http://localhost:8080}
|
||||
if printf '%s' "$WANDB_BASE_URL_DEFAULT" | grep -q "localhost"; then
|
||||
HOST=$(printf '%s\n' "$WANDB_BASE_URL_DEFAULT" | sed -E 's#https?://##' | cut -d/ -f1)
|
||||
API_KEY=$(python - "$HOST" <<'PY'
|
||||
import sys
|
||||
from netrc import netrc
|
||||
|
||||
host = sys.argv[1]
|
||||
auth = netrc().authenticators(host)
|
||||
if auth and auth[2]:
|
||||
print(auth[2], end="")
|
||||
PY
|
||||
)
|
||||
if [ -n "$API_KEY" ]; then
|
||||
export WANDB_BASE_URL="$WANDB_BASE_URL_DEFAULT"
|
||||
export WANDB_API_KEY="$API_KEY"
|
||||
echo "[runmps_evals] Loaded WANDB_API_KEY for $WANDB_BASE_URL from ~/.netrc"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
generate_wandb_id() {
|
||||
python - <<'PY'
|
||||
import uuid
|
||||
print(uuid.uuid4().hex[:8])
|
||||
PY
|
||||
}
|
||||
|
||||
set_eval_run() {
|
||||
local prefix="$1"
|
||||
local default_name="$2"
|
||||
local id_var="${prefix}_EVAL_WANDB_ID"
|
||||
local name_var="${prefix}_EVAL_WANDB_NAME"
|
||||
local current_id=$(generate_wandb_id)
|
||||
local current_name="$default_name"
|
||||
eval "$id_var=$current_id"
|
||||
eval "$name_var='$current_name'"
|
||||
}
|
||||
|
||||
SEQ_LEN=${SEQ_LEN:-1024}
|
||||
DEVICE_BATCH=${DEVICE_BATCH:-16}
|
||||
TOTAL_BATCH=$((DEVICE_BATCH * SEQ_LEN))
|
||||
EVAL_SEQUENCES=10000
|
||||
EVAL_BATCH_MULT=4
|
||||
EVAL_TOKENS=$((TOTAL_BATCH * EVAL_BATCH_MULT))
|
||||
|
||||
checkpoint_steps() {
|
||||
local dir="$1"
|
||||
python - "$dir" <<'PY'
|
||||
import os, sys
|
||||
dir_path = sys.argv[1]
|
||||
if not os.path.isdir(dir_path):
|
||||
sys.exit(0)
|
||||
steps = []
|
||||
for fname in os.listdir(dir_path):
|
||||
if fname.startswith("model_") and fname.endswith(".pt"):
|
||||
try:
|
||||
step = int(fname.split("_")[1].split(".")[0])
|
||||
except ValueError:
|
||||
continue
|
||||
steps.append(step)
|
||||
for step in sorted(steps):
|
||||
print(step)
|
||||
PY
|
||||
}
|
||||
|
||||
checkpoint_tag() {
|
||||
local dir="$1"
|
||||
python - "$dir" <<'PY'
|
||||
import os, sys
|
||||
from nanochat.checkpoint_manager import find_largest_model
|
||||
base_dir = sys.argv[1]
|
||||
if not os.path.isdir(base_dir):
|
||||
sys.exit(0)
|
||||
try:
|
||||
tag = find_largest_model(base_dir)
|
||||
except FileNotFoundError:
|
||||
sys.exit(0)
|
||||
print(tag)
|
||||
PY
|
||||
}
|
||||
|
||||
run_base_evals() {
|
||||
local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/base_checkpoints")
|
||||
if [ -z "$tag" ]; then
|
||||
return
|
||||
fi
|
||||
local dir="$NANOCHAT_BASE_DIR/base_checkpoints/$tag"
|
||||
local steps="$(checkpoint_steps "$dir")"
|
||||
if [ -z "$steps" ]; then
|
||||
return
|
||||
fi
|
||||
echo "[runmps_evals] Running base eval on checkpoints in $dir..."
|
||||
set_eval_run BASE "${WANDB_RUN:-base}-eval"
|
||||
export WANDB_RUN_ID=$BASE_EVAL_WANDB_ID
|
||||
export WANDB_EVAL_RUN=$BASE_EVAL_WANDB_NAME
|
||||
export WANDB_PROJECT
|
||||
while read -r step; do
|
||||
[ -z "$step" ] && continue
|
||||
echo " • base checkpoint step $step"
|
||||
# run base_loss to print sample completions, but keep it off W&B
|
||||
saved_wandb_run_id=${WANDB_RUN_ID:-}
|
||||
saved_wandb_eval_run=${WANDB_EVAL_RUN:-}
|
||||
unset WANDB_RUN_ID
|
||||
unset WANDB_EVAL_RUN
|
||||
python -m scripts.base_loss \
|
||||
--device_batch_size=$DEVICE_BATCH \
|
||||
--split_tokens=$EVAL_TOKENS \
|
||||
--model_step=$step
|
||||
if [ -n "$saved_wandb_run_id" ]; then
|
||||
export WANDB_RUN_ID=$saved_wandb_run_id
|
||||
fi
|
||||
if [ -n "$saved_wandb_eval_run" ]; then
|
||||
export WANDB_EVAL_RUN=$saved_wandb_eval_run
|
||||
fi
|
||||
python -m scripts.base_eval \
|
||||
--max-per-task=16 \
|
||||
--model-step=$step
|
||||
done <<< "$steps"
|
||||
unset WANDB_RUN_ID
|
||||
unset WANDB_EVAL_RUN
|
||||
}
|
||||
|
||||
run_mid_evals() {
|
||||
local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/mid_checkpoints")
|
||||
if [ -z "$tag" ]; then
|
||||
return
|
||||
fi
|
||||
local dir="$NANOCHAT_BASE_DIR/mid_checkpoints/$tag"
|
||||
local steps="$(checkpoint_steps "$dir")"
|
||||
if [ -z "$steps" ]; then
|
||||
return
|
||||
fi
|
||||
echo "[runmps_evals] Running mid-stage chat eval on checkpoints in $dir..."
|
||||
set_eval_run MID "${WANDB_RUN:-mid}-mid-eval"
|
||||
export WANDB_RUN_ID=$MID_EVAL_WANDB_ID
|
||||
export WANDB_EVAL_RUN=$MID_EVAL_WANDB_NAME
|
||||
export WANDB_PROJECT
|
||||
while read -r step; do
|
||||
[ -z "$step" ] && continue
|
||||
echo " • mid checkpoint step $step"
|
||||
python -m scripts.chat_eval \
|
||||
--source=mid \
|
||||
--step=$step \
|
||||
--batch-size=$DEVICE_BATCH \
|
||||
--max-new-tokens=128 \
|
||||
--max-problems=20
|
||||
done <<< "$steps"
|
||||
unset WANDB_RUN_ID
|
||||
unset WANDB_EVAL_RUN
|
||||
}
|
||||
|
||||
run_sft_evals() {
|
||||
local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/chatsft_checkpoints")
|
||||
if [ -z "$tag" ]; then
|
||||
return
|
||||
fi
|
||||
local dir="$NANOCHAT_BASE_DIR/chatsft_checkpoints/$tag"
|
||||
local steps="$(checkpoint_steps "$dir")"
|
||||
if [ -z "$steps" ]; then
|
||||
return
|
||||
fi
|
||||
echo "[runmps_evals] Running SFT chat eval on checkpoints in $dir..."
|
||||
set_eval_run SFT "${WANDB_RUN:-sft}-sft-eval"
|
||||
export WANDB_RUN_ID=$SFT_EVAL_WANDB_ID
|
||||
export WANDB_EVAL_RUN=$SFT_EVAL_WANDB_NAME
|
||||
export WANDB_PROJECT
|
||||
while read -r step; do
|
||||
[ -z "$step" ] && continue
|
||||
echo " • SFT checkpoint step $step"
|
||||
python -m scripts.chat_eval \
|
||||
--source=sft \
|
||||
--step=$step \
|
||||
--batch-size=$DEVICE_BATCH \
|
||||
--max-new-tokens=128 \
|
||||
--max-problems=20
|
||||
done <<< "$steps"
|
||||
unset WANDB_RUN_ID
|
||||
unset WANDB_EVAL_RUN
|
||||
}
|
||||
|
||||
run_base_evals
|
||||
run_mid_evals
|
||||
run_sft_evals
|
||||
|
||||
echo "[runmps_evals] Regenerating report with fresh metrics..."
|
||||
python -m nanochat.report generate
|
||||
|
||||
echo "[runmps_evals] Done."
|
||||
|
|
@ -19,8 +19,9 @@ from contextlib import nullcontext
|
|||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import wandb
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, DummyWandb
|
||||
from nanochat.tokenizer import HuggingFaceTokenizer
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.core_eval import evaluate_task
|
||||
|
|
@ -123,12 +124,15 @@ def main():
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
|
||||
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
|
||||
parser.add_argument('--model-step', type=int, default=None, help='Checkpoint step to evaluate when using local models')
|
||||
args = parser.parse_args()
|
||||
|
||||
# distributed / precision setup
|
||||
device_type = autodetect_device_type()
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
wandb_run = DummyWandb()
|
||||
use_wandb = bool(os.environ.get("WANDB_RUN_ID")) and args.hf_path is None and ddp_rank == 0
|
||||
|
||||
# Load model and tokenizer from command line or from file system
|
||||
if args.hf_path is not None:
|
||||
|
|
@ -140,24 +144,32 @@ def main():
|
|||
model_slug = hf_path.replace("/", "-") # for the output csv file
|
||||
else:
|
||||
# load a local model from the file system
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval")
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", step=args.model_step)
|
||||
model_name = f"base_model (step {meta['step']})" # just for logging
|
||||
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
|
||||
if use_wandb:
|
||||
wandb_kwargs = {
|
||||
"project": os.environ.get("WANDB_PROJECT", "nanochat"),
|
||||
"name": os.environ.get("WANDB_EVAL_RUN", model_name),
|
||||
"id": os.environ.get("WANDB_RUN_ID"),
|
||||
"resume": "allow",
|
||||
"reinit": True,
|
||||
}
|
||||
wandb_kwargs = {k: v for k, v in wandb_kwargs.items() if v is not None}
|
||||
wandb_run = wandb.init(**wandb_kwargs)
|
||||
|
||||
# Evaluate the model
|
||||
with autocast_ctx:
|
||||
out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
|
||||
|
||||
# Write out the results to a csv file
|
||||
core_metric = None
|
||||
centered_results = {}
|
||||
results = out["results"]
|
||||
centered_results = out["centered_results"]
|
||||
core_metric = out["core_metric"]
|
||||
if ddp_rank == 0:
|
||||
base_dir = get_base_dir()
|
||||
output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv")
|
||||
os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
|
||||
results = out["results"]
|
||||
centered_results = out["centered_results"]
|
||||
core_metric = out["core_metric"]
|
||||
with open(output_csv_path, 'w') as f:
|
||||
f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n")
|
||||
for label in results:
|
||||
|
|
@ -180,6 +192,15 @@ def main():
|
|||
centered_results, # the full table
|
||||
])
|
||||
|
||||
if use_wandb:
|
||||
wandb_payload = {"core/metric": core_metric}
|
||||
wandb_payload.update({f"core/{k}": v for k, v in centered_results.items()})
|
||||
wandb_payload.update({f"core_raw/{k}": v for k, v in results.items()})
|
||||
for key, value in wandb_payload.items():
|
||||
wandb_run.summary[key] = value
|
||||
wandb_run.log(wandb_payload, step=meta["step"])
|
||||
wandb_run.finish()
|
||||
|
||||
compute_cleanup()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
|||
import os
|
||||
from contextlib import nullcontext
|
||||
import torch
|
||||
import wandb
|
||||
|
||||
from nanochat.common import DummyWandb
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
|
|
@ -36,6 +39,19 @@ tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
|||
assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
|
||||
steps = split_tokens // tokens_per_step
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
use_wandb = bool(os.environ.get("WANDB_RUN_ID"))
|
||||
wandb_run = DummyWandb()
|
||||
if use_wandb:
|
||||
wandb_kwargs = {
|
||||
"project": os.environ.get("WANDB_PROJECT", "nanochat"),
|
||||
"name": os.environ.get("WANDB_EVAL_RUN", "base-eval"),
|
||||
"id": os.environ.get("WANDB_RUN_ID"),
|
||||
"resume": "allow",
|
||||
"reinit": True,
|
||||
}
|
||||
wandb_kwargs = {k: v for k, v in wandb_kwargs.items() if v is not None}
|
||||
wandb_run = wandb.init(**wandb_kwargs)
|
||||
|
||||
bpb_results = {}
|
||||
for split_name in ["train", "val"]:
|
||||
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
|
||||
|
|
@ -63,7 +79,7 @@ if ddp_rank == 0:
|
|||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
sample_str = tokenizer.decode(sample[0])
|
||||
print0(sample_str)
|
||||
samples.append(sample_str)
|
||||
samples.append(sample_str)
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
|
|
@ -75,5 +91,12 @@ get_report().log(section="Base model loss", data=[
|
|||
{f"sample {i}": sample for i, sample in enumerate(samples)},
|
||||
])
|
||||
|
||||
if use_wandb:
|
||||
wandb_run.log({
|
||||
"base_loss/train_bpb": bpb_results["train"],
|
||||
"base_loss/val_bpb": bpb_results["val"],
|
||||
}, step=meta.get("step"))
|
||||
wandb_run.finish()
|
||||
|
||||
# Cleanup
|
||||
compute_cleanup()
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
|||
sample_every = 2000 # every how many steps to sample from the model
|
||||
# Output
|
||||
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
||||
checkpoint_every_steps = 0 # save intermediate checkpoints every N optimization steps (0 = disable)
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
|
|
@ -76,7 +77,19 @@ get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else l
|
|||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config)
|
||||
if use_dummy_wandb:
|
||||
wandb_run = DummyWandb()
|
||||
else:
|
||||
wandb_kwargs = {
|
||||
"project": os.environ.get("WANDB_PROJECT", "nanochat"),
|
||||
"name": run,
|
||||
"config": user_config,
|
||||
"reinit": True,
|
||||
}
|
||||
wandb_id = os.environ.get("WANDB_RUN_ID")
|
||||
if wandb_id:
|
||||
wandb_kwargs.update({"id": wandb_id, "resume": "allow"})
|
||||
wandb_run = wandb.init(**wandb_kwargs)
|
||||
|
||||
# Tokenizer will be useful for evaluation, also we need the vocab size
|
||||
tokenizer = get_tokenizer()
|
||||
|
|
@ -138,6 +151,10 @@ print0(f"Total number of training tokens: {total_tokens:,}")
|
|||
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
sequences_per_step = max(1, total_batch_size // max_seq_len)
|
||||
checkpoint_every_steps = int(checkpoint_every_steps)
|
||||
checkpoint_enabled = checkpoint_every_steps > 0
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
|
|
@ -150,6 +167,10 @@ train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len
|
|||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
|
||||
x, y = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# Checkpoint output location
|
||||
checkpoint_dirname = model_tag if model_tag else f"d{depth}"
|
||||
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", checkpoint_dirname)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Set up hyperparameter schedulers
|
||||
|
||||
|
|
@ -177,6 +198,33 @@ min_val_bpb = float("inf")
|
|||
smooth_train_loss = 0 # EMA of training loss
|
||||
ema_beta = 0.9 # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
# Keep track of total tokens/sequences processed for logging
|
||||
tokens_per_step = total_batch_size
|
||||
total_tokens_seen = 0
|
||||
total_sequences_seen = 0
|
||||
last_val_bpb = None
|
||||
|
||||
def save_base_checkpoint(step_idx):
|
||||
output_dirname = model_tag if model_tag else f"d{depth}"
|
||||
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
||||
optimizer_state = [opt.state_dict() for opt in optimizers]
|
||||
meta = {
|
||||
"step": step_idx,
|
||||
"val_bpb": last_val_bpb,
|
||||
"model_config": model_config_kwargs,
|
||||
"user_config": user_config,
|
||||
"device_batch_size": device_batch_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
"total_tokens_seen": total_tokens_seen,
|
||||
"total_sequences_seen": total_sequences_seen,
|
||||
}
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step_idx,
|
||||
orig_model.state_dict(),
|
||||
optimizer_state,
|
||||
meta,
|
||||
)
|
||||
# note that we run +1 steps only so that we can eval and save at the end
|
||||
for step in range(num_iterations + 1):
|
||||
last_step = step == num_iterations
|
||||
|
|
@ -192,11 +240,14 @@ for step in range(num_iterations + 1):
|
|||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
min_val_bpb = val_bpb
|
||||
last_val_bpb = val_bpb
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"val/bpb": val_bpb,
|
||||
"train/total_tokens": total_tokens_seen,
|
||||
"train/total_sequences": total_sequences_seen,
|
||||
})
|
||||
model.train()
|
||||
|
||||
|
|
@ -213,12 +264,14 @@ for step in range(num_iterations + 1):
|
|||
"total_training_flops": flops_so_far,
|
||||
"core_metric": results["core_metric"],
|
||||
"centered_results": results["centered_results"],
|
||||
"train/total_tokens": total_tokens_seen,
|
||||
"train/total_sequences": total_sequences_seen,
|
||||
})
|
||||
model.train()
|
||||
|
||||
# once in a while: sample from the model (only on master process)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
if master_process and (last_step or (step > 0 and step % sample_every == 0)):
|
||||
if master_process and sample_every > 0 and (last_step or (step > 0 and step % sample_every == 0)):
|
||||
model.eval()
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
|
|
@ -239,22 +292,7 @@ for step in range(num_iterations + 1):
|
|||
|
||||
# save checkpoint at the end of the run (only on master process)
|
||||
if master_process and last_step:
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(),
|
||||
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
|
||||
{
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": model_config_kwargs,
|
||||
"user_config": user_config, # inputs to the training script
|
||||
"device_batch_size": device_batch_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
)
|
||||
save_base_checkpoint(step)
|
||||
|
||||
if last_step:
|
||||
break
|
||||
|
|
@ -287,6 +325,11 @@ for step in range(num_iterations + 1):
|
|||
model.zero_grad(set_to_none=True)
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
total_tokens_seen += tokens_per_step
|
||||
total_sequences_seen += sequences_per_step
|
||||
current_step = step + 1
|
||||
if master_process and checkpoint_enabled and not last_step and checkpoint_every_steps > 0 and current_step % checkpoint_every_steps == 0:
|
||||
save_base_checkpoint(current_step)
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
|
@ -311,6 +354,8 @@ for step in range(num_iterations + 1):
|
|||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
"train/total_tokens": total_tokens_seen,
|
||||
"train/total_sequences": total_sequences_seen,
|
||||
})
|
||||
|
||||
# print a few more stats
|
||||
|
|
@ -335,12 +380,14 @@ get_report().log(section="Base model training", data=[
|
|||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
"Final validation bpb": val_bpb,
|
||||
"Final validation bpb": last_val_bpb,
|
||||
"CORE metric estimate": results.get("core_metric", None),
|
||||
"MFU %": f"{mfu:.2f}%",
|
||||
"Total training flops": f"{flops_so_far:e}",
|
||||
"Total training time": f"{total_training_time/60:.2f}m",
|
||||
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
|
||||
"Total tokens processed": total_tokens_seen,
|
||||
"Total sequences processed": total_sequences_seen,
|
||||
}
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Evaluate the Chat model.
|
||||
All the generic code lives here, and all the evlauation-specific
|
||||
All the generic code lives here, and all the evaluation-specific
|
||||
code lives in nanochat directory and is imported from here.
|
||||
|
||||
Example runs:
|
||||
|
|
@ -9,13 +9,15 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from functools import partial
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type
|
||||
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type, DummyWandb
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
|
||||
|
|
@ -201,9 +203,21 @@ if __name__ == "__main__":
|
|||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
wandb_run = DummyWandb()
|
||||
use_wandb = bool(os.environ.get("WANDB_RUN_ID")) and ddp_rank == 0
|
||||
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
engine = Engine(model, tokenizer)
|
||||
if use_wandb:
|
||||
wandb_kwargs = {
|
||||
"project": os.environ.get("WANDB_PROJECT", "nanochat"),
|
||||
"name": os.environ.get("WANDB_EVAL_RUN", f"{args.source}-eval"),
|
||||
"id": os.environ.get("WANDB_RUN_ID"),
|
||||
"resume": "allow",
|
||||
"reinit": True,
|
||||
}
|
||||
wandb_kwargs = {k: v for k, v in wandb_kwargs.items() if v is not None}
|
||||
wandb_run = wandb.init(**wandb_kwargs)
|
||||
|
||||
# Get the tasks to evaluate on
|
||||
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
|
||||
|
|
@ -254,4 +268,11 @@ if __name__ == "__main__":
|
|||
chatcore_metric_dict,
|
||||
])
|
||||
|
||||
if use_wandb:
|
||||
wandb_payload = {f"chat_eval/{task}": acc for task, acc in results.items()}
|
||||
if chatcore_metric_dict:
|
||||
wandb_payload.update({"chat_eval/chatcore": chatcore_metric_dict["ChatCORE metric"]})
|
||||
wandb_run.log(wandb_payload, step=meta.get("step"))
|
||||
wandb_run.finish()
|
||||
|
||||
compute_cleanup()
|
||||
|
|
|
|||
|
|
@ -55,10 +55,21 @@ eval_every = 100
|
|||
eval_steps = 100
|
||||
eval_metrics_every = 200
|
||||
eval_metrics_max_problems = 1024
|
||||
checkpoint_every_steps = 0 # save intermediate checkpoints every N optimization steps (0 = disable)
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
|
||||
# Normalize evaluation knobs: <=0 disables inline evaluation
|
||||
if eval_every <= 0:
|
||||
eval_every = None
|
||||
if eval_steps <= 0:
|
||||
eval_steps = 0
|
||||
if eval_metrics_every <= 0:
|
||||
eval_metrics_every = None
|
||||
if eval_metrics_max_problems <= 0:
|
||||
eval_metrics_max_problems = 0
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
|
|
@ -70,13 +81,27 @@ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if dev
|
|||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
|
||||
if use_dummy_wandb:
|
||||
wandb_run = DummyWandb()
|
||||
else:
|
||||
wandb_kwargs = {
|
||||
"project": os.environ.get("WANDB_PROJECT", "nanochat-sft"),
|
||||
"name": run,
|
||||
"config": user_config,
|
||||
"save_code": True,
|
||||
"reinit": True,
|
||||
}
|
||||
wandb_id = os.environ.get("WANDB_RUN_ID")
|
||||
if wandb_id:
|
||||
wandb_kwargs.update({"id": wandb_id, "resume": "allow"})
|
||||
wandb_run = wandb.init(**wandb_kwargs)
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
||||
orig_model = model # original, uncompiled model
|
||||
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
||||
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Task data mixture we'll train on
|
||||
|
|
@ -142,6 +167,14 @@ if num_iterations == -1:
|
|||
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
|
||||
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
|
||||
|
||||
sequences_per_step = target_examples_per_step
|
||||
checkpoint_every_steps = int(checkpoint_every_steps)
|
||||
checkpoint_enabled = checkpoint_every_steps > 0
|
||||
|
||||
base_dir = get_base_dir()
|
||||
checkpoint_dirname = model_tag if model_tag else f"d{model.config.n_layer}"
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", checkpoint_dirname)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer
|
||||
|
||||
|
|
@ -168,11 +201,34 @@ def get_lr_multiplier(it):
|
|||
# Go!
|
||||
step = 0
|
||||
train_iter = iter(train_loader)
|
||||
total_tokens_seen = 0
|
||||
total_sequences_seen = 0
|
||||
last_val_loss = None
|
||||
last_eval_metrics = {}
|
||||
|
||||
def save_sft_checkpoint(step_idx):
|
||||
meta = {
|
||||
"step": step_idx,
|
||||
"val_loss": last_val_loss,
|
||||
**last_eval_metrics,
|
||||
"model_config": model_config_kwargs,
|
||||
"total_tokens_seen": total_tokens_seen,
|
||||
"total_sequences_seen": total_sequences_seen,
|
||||
}
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step_idx,
|
||||
model.state_dict(),
|
||||
None,
|
||||
meta,
|
||||
)
|
||||
|
||||
for step in range(num_iterations):
|
||||
last_step = step == num_iterations - 1
|
||||
|
||||
# evaluate the validation loss
|
||||
if last_step or step % eval_every == 0:
|
||||
# evaluate the validation loss (if enabled)
|
||||
run_val = eval_every is not None and eval_steps > 0 and (last_step or step % eval_every == 0)
|
||||
if run_val:
|
||||
model.eval()
|
||||
val_iter = iter(build_val_loader())
|
||||
losses = []
|
||||
|
|
@ -186,14 +242,20 @@ for step in range(num_iterations):
|
|||
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks
|
||||
val_loss = val_loss.item()
|
||||
print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}")
|
||||
last_val_loss = val_loss
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"val_loss": val_loss,
|
||||
"train/total_tokens": total_tokens_seen,
|
||||
"train/total_sequences": total_sequences_seen,
|
||||
})
|
||||
model.train()
|
||||
|
||||
# evlauate accuracy of the multiple choice tasks (which are quick to run)
|
||||
if last_step or (step > 0 and step % eval_metrics_every == 0):
|
||||
# evaluate accuracy of the multiple choice tasks (if enabled)
|
||||
run_metrics = eval_metrics_every is not None and eval_metrics_max_problems > 0 and (
|
||||
last_step or (step > 0 and step % eval_metrics_every == 0)
|
||||
)
|
||||
if run_metrics:
|
||||
model.eval()
|
||||
metrics = {}
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
|
|
@ -202,9 +264,12 @@ for step in range(num_iterations):
|
|||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
||||
print0(f"Step {step:05d} | {metrics_str}")
|
||||
last_eval_metrics = metrics.copy()
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
**metrics,
|
||||
"train/total_tokens": total_tokens_seen,
|
||||
"train/total_sequences": total_sequences_seen,
|
||||
})
|
||||
model.train()
|
||||
|
||||
|
|
@ -238,34 +303,25 @@ for step in range(num_iterations):
|
|||
# logging
|
||||
train_loss_item = train_loss.item()
|
||||
num_tokens_item = num_tokens.item()
|
||||
total_tokens_seen += num_tokens_item
|
||||
total_sequences_seen += sequences_per_step
|
||||
current_step = step + 1
|
||||
if master_process and checkpoint_enabled and not last_step and current_step % checkpoint_every_steps == 0:
|
||||
save_sft_checkpoint(current_step)
|
||||
print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"lrm": lrm,
|
||||
"train_loss": train_loss_item,
|
||||
"num_tokens": num_tokens_item,
|
||||
"train/total_tokens": total_tokens_seen,
|
||||
"train/total_sequences": total_sequences_seen,
|
||||
})
|
||||
step += 1
|
||||
|
||||
# Save the model at the end of the run
|
||||
if master_process:
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
model.state_dict(),
|
||||
None, # note: we don't bother to save the optimizer state
|
||||
{
|
||||
"step": step,
|
||||
"val_loss": val_loss,
|
||||
**metrics,
|
||||
"model_config": model_config_kwargs,
|
||||
}
|
||||
)
|
||||
save_sft_checkpoint(step)
|
||||
print(f"✅ Saved model checkpoint to {checkpoint_dir}")
|
||||
|
||||
# Log to report
|
||||
|
|
@ -276,7 +332,9 @@ get_report().log(section="Chat SFT", data=[
|
|||
"Training rows": len(train_ds),
|
||||
"Number of iterations": num_iterations,
|
||||
"Training loss": train_loss_item,
|
||||
"Validation loss": val_loss,
|
||||
"Validation loss": last_val_loss,
|
||||
"Total tokens processed": total_tokens_seen,
|
||||
"Total sequences processed": total_sequences_seen,
|
||||
},
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ eval_every = 150 # -1 = disable
|
|||
eval_tokens = 20*524288
|
||||
total_batch_size = 524288
|
||||
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
|
||||
checkpoint_every_steps = 0 # save intermediate checkpoints every N optimization steps (0 = disable)
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
|
|
@ -63,7 +64,19 @@ get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else l
|
|||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config)
|
||||
if use_dummy_wandb:
|
||||
wandb_run = DummyWandb()
|
||||
else:
|
||||
wandb_kwargs = {
|
||||
"project": os.environ.get("WANDB_PROJECT", "nanochat-mid"),
|
||||
"name": run,
|
||||
"config": user_config,
|
||||
"reinit": True,
|
||||
}
|
||||
wandb_id = os.environ.get("WANDB_RUN_ID")
|
||||
if wandb_id:
|
||||
wandb_kwargs.update({"id": wandb_id, "resume": "allow"})
|
||||
wandb_run = wandb.init(**wandb_kwargs)
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
|
||||
|
|
@ -83,6 +96,10 @@ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
|||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
|
||||
sequences_per_step = max(1, total_batch_size // max_seq_len)
|
||||
checkpoint_every_steps = int(checkpoint_every_steps)
|
||||
checkpoint_enabled = checkpoint_every_steps > 0
|
||||
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
|
|
@ -159,6 +176,9 @@ train_loader = mid_data_generator("train")
|
|||
build_val_loader = lambda: mid_data_generator("val")
|
||||
progress = 0 # will go from 0 to 1 over the course of the epoch
|
||||
|
||||
checkpoint_dirname = f"d{depth}"
|
||||
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", checkpoint_dirname)
|
||||
|
||||
# Learning rate scheduler
|
||||
def get_lr_multiplier(progress):
|
||||
# first 80% of training: no decay, then linearly ramp down to 0.
|
||||
|
|
@ -177,6 +197,37 @@ min_val_bpb = float("inf")
|
|||
smooth_train_loss = 0 # EMA of training loss
|
||||
ema_beta = 0.9 # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
tokens_per_step = total_batch_size
|
||||
total_tokens_seen = 0
|
||||
total_sequences_seen = 0
|
||||
last_val_bpb = None
|
||||
|
||||
def save_mid_checkpoint(step_idx):
|
||||
if dry_run:
|
||||
return
|
||||
meta = {
|
||||
"step": step_idx,
|
||||
"val_bpb": last_val_bpb,
|
||||
"model_config": {
|
||||
"sequence_len": max_seq_len,
|
||||
"vocab_size": tokenizer.get_vocab_size(),
|
||||
"n_layer": depth,
|
||||
"n_head": model.config.n_head,
|
||||
"n_kv_head": model.config.n_kv_head,
|
||||
"n_embd": model.config.n_embd,
|
||||
},
|
||||
"user_config": user_config,
|
||||
"total_tokens_seen": total_tokens_seen,
|
||||
"total_sequences_seen": total_sequences_seen,
|
||||
}
|
||||
optimizer_state = [opt.state_dict() for opt in optimizers]
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step_idx,
|
||||
orig_model.state_dict(),
|
||||
optimizer_state,
|
||||
meta,
|
||||
)
|
||||
step = 0
|
||||
while True:
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
|
|
@ -197,37 +248,20 @@ while True:
|
|||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
min_val_bpb = val_bpb
|
||||
last_val_bpb = val_bpb
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"val/bpb": val_bpb,
|
||||
"train/total_tokens": total_tokens_seen,
|
||||
"train/total_sequences": total_sequences_seen,
|
||||
})
|
||||
model.train()
|
||||
|
||||
# save checkpoint at the end of the run (only on master process)
|
||||
if master_process and last_step and not dry_run:
|
||||
output_dirname = f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(),
|
||||
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
|
||||
{
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": {
|
||||
"sequence_len": max_seq_len,
|
||||
"vocab_size": tokenizer.get_vocab_size(),
|
||||
"n_layer": depth,
|
||||
"n_head": model.config.n_head,
|
||||
"n_kv_head": model.config.n_kv_head,
|
||||
"n_embd": model.config.n_embd,
|
||||
},
|
||||
"user_config": user_config, # inputs to the training script
|
||||
}
|
||||
)
|
||||
save_mid_checkpoint(step)
|
||||
|
||||
if last_step:
|
||||
break
|
||||
|
|
@ -258,12 +292,17 @@ while True:
|
|||
model.zero_grad(set_to_none=True)
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
total_tokens_seen += tokens_per_step
|
||||
total_sequences_seen += sequences_per_step
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# State
|
||||
step += 1
|
||||
|
||||
if master_process and checkpoint_enabled and not last_step and checkpoint_every_steps > 0 and step % checkpoint_every_steps == 0:
|
||||
save_mid_checkpoint(step)
|
||||
|
||||
# logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
|
|
@ -285,6 +324,8 @@ while True:
|
|||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
"train/total_tokens": total_tokens_seen,
|
||||
"train/total_sequences": total_sequences_seen,
|
||||
})
|
||||
|
||||
# print a few more stats
|
||||
|
|
@ -303,6 +344,9 @@ if not dry_run:
|
|||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
"Final validation bpb": last_val_bpb,
|
||||
"Total tokens processed": total_tokens_seen,
|
||||
"Total sequences processed": total_sequences_seen,
|
||||
}
|
||||
])
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user