Add scripts for running evaluations and training with W&B integration

- Added `dev/runmps_evals.sh` for evaluating checkpoints and logging results to W&B.
- Introduced `dev/runmps.sh` for orchestrating training stages with W&B support.
- Updated `.gitignore` to include `wandb/` and `.runmps_wandb_ids`.
- Changed permissions for `dev/runcpu.sh` and added executable flag.
- Enhanced existing scripts to log metrics to W&B during training and evaluation processes.
This commit is contained in:
William Thurston 2025-11-05 11:49:50 -08:00
parent c75fe54aa7
commit b1d49aade5
10 changed files with 850 additions and 75 deletions

4
.gitignore vendored
View File

@ -4,4 +4,6 @@ __pycache__/
rustbpe/target/
dev-ignore/
report.md
eval_bundle/
eval_bundle/
wandb/
.runmps_wandb_ids

0
dev/runcpu.sh Normal file → Executable file
View File

336
dev/runmps.sh Executable file
View File

@ -0,0 +1,336 @@
#!/bin/bash
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
# Run as:
# bash dev/cpu_demo_run.sh
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
# Think of this run as educational/fun demo, not something you should expect to work well.
# This is also why I hide this script away in dev/
# Stage selection (allow running only a subset, e.g. --stage=sft or --from=mid)
RUN_BASE=1
RUN_MID=1
RUN_SFT=1
RUN_REPORT=1
STAGE_ONLY=""
FROM_STAGE=""
while [[ $# -gt 0 ]]; do
case "$1" in
--stage=*)
STAGE_ONLY="${1#*=}"
;;
--base|--mid|--sft|--report)
STAGE_ONLY="${1#--}"
;;
--from=*)
FROM_STAGE="${1#*=}"
;;
--from-base|--from-mid|--from-sft)
FROM_STAGE="${1#--from-}"
;;
--help|-h)
cat <<'EOF'
Usage: bash dev/runmps.sh [options]
Options:
--stage=<base|mid|sft|report> Run only the specified stage.
--from=<base|mid|sft> Run from the specified stage through the end.
--help Show this help message.
Environment variables (same as before) control batch sizes, WANDB run names, etc.
EOF
exit 0
;;
*)
echo "Unknown argument: $1" >&2
exit 1
;;
esac
shift
done
if [[ -n "$FROM_STAGE" ]]; then
RUN_BASE=0
RUN_MID=0
RUN_SFT=0
RUN_REPORT=0
case "$FROM_STAGE" in
base)
RUN_BASE=1
RUN_MID=1
RUN_SFT=1
RUN_REPORT=1
;;
mid)
RUN_MID=1
RUN_SFT=1
;;
sft)
RUN_SFT=1
;;
*)
echo "Unknown --from stage: $FROM_STAGE" >&2
exit 1
;;
esac
fi
if [[ -n "$STAGE_ONLY" ]]; then
RUN_BASE=0
RUN_MID=0
RUN_SFT=0
RUN_REPORT=0
case "$STAGE_ONLY" in
base)
RUN_BASE=1
;;
mid)
RUN_MID=1
;;
sft)
RUN_SFT=1
;;
report)
RUN_REPORT=1
;;
*)
echo "Unknown --stage value: $STAGE_ONLY" >&2
exit 1
;;
esac
fi
if [[ -n "$STAGE_ONLY" || -n "$FROM_STAGE" ]]; then
# avoid regenerating reports when running a subset unless specifically requested
if [[ "$STAGE_ONLY" != "report" && "$FROM_STAGE" != "base" ]]; then
RUN_REPORT=0
fi
fi
# all the setup stuff
export OMP_NUM_THREADS=1
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
[ -d ".venv" ] || uv venv
uv sync --extra cpu
source .venv/bin/activate
if [ -z "$WANDB_RUN" ]; then
WANDB_RUN=dummy
fi
if [ "$WANDB_RUN" != "dummy" ] && [ -z "$WANDB_MODE" ]; then
export WANDB_MODE=online
fi
# Batch/sequence configuration
BASE_DEPTH=${BASE_DEPTH:-4}
SEQ_LEN=${SEQ_LEN:-1024}
DEVICE_BATCH=${DEVICE_BATCH:-16}
TOTAL_BATCH=${TOTAL_BATCH:-$((DEVICE_BATCH * SEQ_LEN))} # tokens per optimizer step
EVAL_SEQUENCES=10000
EVAL_STEPS=$(((EVAL_SEQUENCES + DEVICE_BATCH - 1) / DEVICE_BATCH))
EVAL_BATCH_MULT=4 # evaluate on 4 full batches
EVAL_TOKENS=$((TOTAL_BATCH * EVAL_BATCH_MULT))
MID_NUM_STEPS=6144
SFT_NUM_STEPS=${SFT_NUM_STEPS:-3072}
CHECKPOINT_EVERY_SEQ=${CHECKPOINT_EVERY_SEQ:-10000}
RUN_STAGE_EVALS=${RUN_STAGE_EVALS:-0}
WANDB_PROJECT=${WANDB_PROJECT:-nanochat}
TARGET_PARAM_DATA_RATIO=${TARGET_PARAM_DATA_RATIO:-20}
SFT_DEVICE_BATCH=${SFT_DEVICE_BATCH:-$DEVICE_BATCH}
SFT_TARGET_EXAMPLES=${SFT_TARGET_EXAMPLES:-$DEVICE_BATCH}
SFT_EVAL_EVERY=${SFT_EVAL_EVERY:-0}
SFT_EVAL_STEPS=${SFT_EVAL_STEPS:-0}
SFT_EVAL_METRICS_EVERY=${SFT_EVAL_METRICS_EVERY:-0}
SFT_EVAL_METRICS_MAX=${SFT_EVAL_METRICS_MAX:-0}
STATE_FILE=".runmps_wandb_ids"
printf '' > "$STATE_FILE"
echo "WANDB_PROJECT=$WANDB_PROJECT" >> "$STATE_FILE"
generate_wandb_id() {
python - <<'PY'
import uuid
print(uuid.uuid4().hex[:8])
PY
}
BASE_SEQS_PER_STEP=$((TOTAL_BATCH / SEQ_LEN))
if [ $BASE_SEQS_PER_STEP -le 0 ]; then BASE_SEQS_PER_STEP=1; fi
if [ "$CHECKPOINT_EVERY_SEQ" -le 0 ]; then
BASE_CHECKPOINT_STEPS=0
else
BASE_CHECKPOINT_STEPS=$(((CHECKPOINT_EVERY_SEQ + BASE_SEQS_PER_STEP - 1) / BASE_SEQS_PER_STEP))
fi
MID_SEQS_PER_STEP=$((TOTAL_BATCH / SEQ_LEN))
if [ $MID_SEQS_PER_STEP -le 0 ]; then MID_SEQS_PER_STEP=1; fi
if [ "$CHECKPOINT_EVERY_SEQ" -le 0 ]; then
MID_CHECKPOINT_STEPS=0
else
MID_CHECKPOINT_STEPS=$(((CHECKPOINT_EVERY_SEQ + MID_SEQS_PER_STEP - 1) / MID_SEQS_PER_STEP))
fi
SFT_SEQS_PER_STEP=$SFT_DEVICE_BATCH
if [ $SFT_SEQS_PER_STEP -le 0 ]; then SFT_SEQS_PER_STEP=1; fi
if [ "$CHECKPOINT_EVERY_SEQ" -le 0 ]; then
SFT_CHECKPOINT_STEPS=0
else
SFT_CHECKPOINT_STEPS=$(((CHECKPOINT_EVERY_SEQ + SFT_SEQS_PER_STEP - 1) / SFT_SEQS_PER_STEP))
fi
# Auto-populate WANDB_API_KEY from ~/.netrc when talking to a local W&B server.
# Mirrors the helper used in TinyRecursiveModels/pretrain_text.py so we can log
# to a self-hosted instance without manual export each time.
if [ -z "$WANDB_API_KEY" ] && [ -f "$HOME/.netrc" ]; then
# Allow custom WANDB_BASE_URL; default to localhost if user sets WANDB
WANDB_BASE_URL_DEFAULT=${WANDB_BASE_URL:-http://localhost:8080}
if printf '%s' "$WANDB_BASE_URL_DEFAULT" | grep -q "localhost"; then
HOST=$(printf '%s\n' "$WANDB_BASE_URL_DEFAULT" | sed -E 's#https?://##' | cut -d/ -f1)
API_KEY=$(python - "$HOST" <<'PY'
import sys
from netrc import netrc
host = sys.argv[1]
auth = netrc().authenticators(host)
if auth and auth[2]:
print(auth[2], end="")
PY
)
if [ -n "$API_KEY" ]; then
export WANDB_BASE_URL="$WANDB_BASE_URL_DEFAULT"
export WANDB_API_KEY="$API_KEY"
echo "[runmps] Loaded WANDB_API_KEY for $WANDB_BASE_URL from ~/.netrc"
fi
fi
fi
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
unzip -q eval_bundle.zip
rm eval_bundle.zip
mv eval_bundle $NANOCHAT_BASE_DIR
fi
if (( RUN_BASE )); then
# wipe the report
python -m nanochat.report reset
# train tokenizer on ~2B characters (download full shard set for extended training)
python -m nanochat.dataset -n 240
python -m scripts.tok_train --max_chars=2000000000
python -m scripts.tok_eval
# train a very small 4 layer model on the CPU
# each optimization step processes a single sequence of 1024 tokens
# we only run 50 steps of optimization (bump this to get better results)
if [ "$WANDB_RUN" != "dummy" ]; then
BASE_WANDB_ID=${BASE_WANDB_ID:-$(generate_wandb_id)}
echo "BASE_WANDB_ID=$BASE_WANDB_ID" >> "$STATE_FILE"
echo "BASE_WANDB_NAME=$WANDB_RUN" >> "$STATE_FILE"
export WANDB_RUN_ID=$BASE_WANDB_ID
export WANDB_EVAL_RUN=$WANDB_RUN
export WANDB_PROJECT
fi
python -m scripts.base_train \
--depth=$BASE_DEPTH \
--max_seq_len=$SEQ_LEN \
--device_batch_size=$DEVICE_BATCH \
--total_batch_size=$TOTAL_BATCH \
--target_param_data_ratio=$TARGET_PARAM_DATA_RATIO \
--run="$WANDB_RUN" \
--eval_every=$EVAL_STEPS \
--eval_tokens=$EVAL_TOKENS \
--core_metric_every=-1 \
--sample_every=-1 \
--checkpoint_every_steps=$BASE_CHECKPOINT_STEPS
if [ "$WANDB_RUN" != "dummy" ]; then
unset WANDB_RUN_ID
unset WANDB_EVAL_RUN
fi
if [ "$RUN_STAGE_EVALS" = "1" ]; then
python -m scripts.base_loss --device_batch_size=$DEVICE_BATCH --split_tokens=$EVAL_TOKENS
python -m scripts.base_eval --max-per-task=16
fi
fi
if (( RUN_MID )); then
# midtraining
if [ "$WANDB_RUN" != "dummy" ]; then
MID_WANDB_ID=${MID_WANDB_ID:-$(generate_wandb_id)}
echo "MID_WANDB_ID=$MID_WANDB_ID" >> "$STATE_FILE"
echo "MID_WANDB_NAME=${WANDB_RUN}-mid" >> "$STATE_FILE"
export WANDB_RUN_ID=$MID_WANDB_ID
export WANDB_EVAL_RUN="${WANDB_RUN}-mid"
export WANDB_PROJECT
fi
python -m scripts.mid_train \
--max_seq_len=$SEQ_LEN \
--device_batch_size=$DEVICE_BATCH \
--total_batch_size=$TOTAL_BATCH \
--run="${WANDB_RUN}-mid" \
--eval_every=$EVAL_STEPS \
--eval_tokens=$EVAL_TOKENS \
--checkpoint_every_steps=$MID_CHECKPOINT_STEPS \
--num_iterations=$MID_NUM_STEPS
if [ "$WANDB_RUN" != "dummy" ]; then
unset WANDB_RUN_ID
unset WANDB_EVAL_RUN
fi
if [ "$RUN_STAGE_EVALS" = "1" ]; then
# eval results will be terrible, this is just to execute the code paths.
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
fi
fi
if (( RUN_SFT )); then
# SFT
if [ "$WANDB_RUN" != "dummy" ]; then
SFT_WANDB_ID=${SFT_WANDB_ID:-$(generate_wandb_id)}
echo "SFT_WANDB_ID=$SFT_WANDB_ID" >> "$STATE_FILE"
echo "SFT_WANDB_NAME=${WANDB_RUN}-sft" >> "$STATE_FILE"
export WANDB_RUN_ID=$SFT_WANDB_ID
export WANDB_EVAL_RUN="${WANDB_RUN}-sft"
export WANDB_PROJECT
fi
python -m scripts.chat_sft \
--device_batch_size=$SFT_DEVICE_BATCH \
--target_examples_per_step=$SFT_TARGET_EXAMPLES \
--run="${WANDB_RUN}-sft" \
--num_iterations=$SFT_NUM_STEPS \
--eval_every=$SFT_EVAL_EVERY \
--eval_steps=$SFT_EVAL_STEPS \
--eval_metrics_every=$SFT_EVAL_METRICS_EVERY \
--eval_metrics_max_problems=$SFT_EVAL_METRICS_MAX \
--checkpoint_every_steps=$SFT_CHECKPOINT_STEPS
if [ "$WANDB_RUN" != "dummy" ]; then
unset WANDB_RUN_ID
unset WANDB_EVAL_RUN
fi
fi
# Chat CLI
# python -m scripts.chat_cli -p "Why is the sky blue?"
# Chat Web
# python -m scripts.chat_web
if (( RUN_REPORT )); then
python -m nanochat.report generate
fi
if [ "$RUN_STAGE_EVALS" != "1" ] && (( RUN_BASE || RUN_MID || RUN_SFT )); then
echo "[runmps] Inline evals disabled. Run 'bash dev/runmps_evals.sh' to compute metrics from saved checkpoints."
fi

223
dev/runmps_evals.sh Executable file
View File

@ -0,0 +1,223 @@
#!/bin/bash
set -euo pipefail
shopt -s nullglob
# Evaluate every checkpoint generated by dev/runmps.sh and log back to the
# original W&B runs.
# Usage:
# bash dev/runmps_evals.sh
STATE_FILE=".runmps_wandb_ids"
if [ ! -d ".venv" ]; then
echo "[runmps_evals] .venv not found. Run dev/runmps.sh first." >&2
exit 1
fi
source .venv/bin/activate
if [ -f "$STATE_FILE" ]; then
. "$STATE_FILE"
fi
export OMP_NUM_THREADS=${OMP_NUM_THREADS:-1}
export NANOCHAT_BASE_DIR=${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat}
export WANDB_PROJECT=${WANDB_PROJECT:-nanochat}
WANDB_RUN=${WANDB_RUN:-dummy}
if [ "$WANDB_RUN" != "dummy" ] && [ -z "${WANDB_MODE:-}" ]; then
export WANDB_MODE=online
fi
# Auto-populate WANDB_API_KEY from ~/.netrc when talking to a local W&B server,
# mirroring the training script behaviour so evals can log without manual export.
if [ -z "${WANDB_API_KEY:-}" ] && [ -f "$HOME/.netrc" ]; then
WANDB_BASE_URL_DEFAULT=${WANDB_BASE_URL:-http://localhost:8080}
if printf '%s' "$WANDB_BASE_URL_DEFAULT" | grep -q "localhost"; then
HOST=$(printf '%s\n' "$WANDB_BASE_URL_DEFAULT" | sed -E 's#https?://##' | cut -d/ -f1)
API_KEY=$(python - "$HOST" <<'PY'
import sys
from netrc import netrc
host = sys.argv[1]
auth = netrc().authenticators(host)
if auth and auth[2]:
print(auth[2], end="")
PY
)
if [ -n "$API_KEY" ]; then
export WANDB_BASE_URL="$WANDB_BASE_URL_DEFAULT"
export WANDB_API_KEY="$API_KEY"
echo "[runmps_evals] Loaded WANDB_API_KEY for $WANDB_BASE_URL from ~/.netrc"
fi
fi
fi
generate_wandb_id() {
python - <<'PY'
import uuid
print(uuid.uuid4().hex[:8])
PY
}
set_eval_run() {
local prefix="$1"
local default_name="$2"
local id_var="${prefix}_EVAL_WANDB_ID"
local name_var="${prefix}_EVAL_WANDB_NAME"
local current_id=$(generate_wandb_id)
local current_name="$default_name"
eval "$id_var=$current_id"
eval "$name_var='$current_name'"
}
SEQ_LEN=${SEQ_LEN:-1024}
DEVICE_BATCH=${DEVICE_BATCH:-16}
TOTAL_BATCH=$((DEVICE_BATCH * SEQ_LEN))
EVAL_SEQUENCES=10000
EVAL_BATCH_MULT=4
EVAL_TOKENS=$((TOTAL_BATCH * EVAL_BATCH_MULT))
checkpoint_steps() {
local dir="$1"
python - "$dir" <<'PY'
import os, sys
dir_path = sys.argv[1]
if not os.path.isdir(dir_path):
sys.exit(0)
steps = []
for fname in os.listdir(dir_path):
if fname.startswith("model_") and fname.endswith(".pt"):
try:
step = int(fname.split("_")[1].split(".")[0])
except ValueError:
continue
steps.append(step)
for step in sorted(steps):
print(step)
PY
}
checkpoint_tag() {
local dir="$1"
python - "$dir" <<'PY'
import os, sys
from nanochat.checkpoint_manager import find_largest_model
base_dir = sys.argv[1]
if not os.path.isdir(base_dir):
sys.exit(0)
try:
tag = find_largest_model(base_dir)
except FileNotFoundError:
sys.exit(0)
print(tag)
PY
}
run_base_evals() {
local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/base_checkpoints")
if [ -z "$tag" ]; then
return
fi
local dir="$NANOCHAT_BASE_DIR/base_checkpoints/$tag"
local steps="$(checkpoint_steps "$dir")"
if [ -z "$steps" ]; then
return
fi
echo "[runmps_evals] Running base eval on checkpoints in $dir..."
set_eval_run BASE "${WANDB_RUN:-base}-eval"
export WANDB_RUN_ID=$BASE_EVAL_WANDB_ID
export WANDB_EVAL_RUN=$BASE_EVAL_WANDB_NAME
export WANDB_PROJECT
while read -r step; do
[ -z "$step" ] && continue
echo " • base checkpoint step $step"
# run base_loss to print sample completions, but keep it off W&B
saved_wandb_run_id=${WANDB_RUN_ID:-}
saved_wandb_eval_run=${WANDB_EVAL_RUN:-}
unset WANDB_RUN_ID
unset WANDB_EVAL_RUN
python -m scripts.base_loss \
--device_batch_size=$DEVICE_BATCH \
--split_tokens=$EVAL_TOKENS \
--model_step=$step
if [ -n "$saved_wandb_run_id" ]; then
export WANDB_RUN_ID=$saved_wandb_run_id
fi
if [ -n "$saved_wandb_eval_run" ]; then
export WANDB_EVAL_RUN=$saved_wandb_eval_run
fi
python -m scripts.base_eval \
--max-per-task=16 \
--model-step=$step
done <<< "$steps"
unset WANDB_RUN_ID
unset WANDB_EVAL_RUN
}
run_mid_evals() {
local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/mid_checkpoints")
if [ -z "$tag" ]; then
return
fi
local dir="$NANOCHAT_BASE_DIR/mid_checkpoints/$tag"
local steps="$(checkpoint_steps "$dir")"
if [ -z "$steps" ]; then
return
fi
echo "[runmps_evals] Running mid-stage chat eval on checkpoints in $dir..."
set_eval_run MID "${WANDB_RUN:-mid}-mid-eval"
export WANDB_RUN_ID=$MID_EVAL_WANDB_ID
export WANDB_EVAL_RUN=$MID_EVAL_WANDB_NAME
export WANDB_PROJECT
while read -r step; do
[ -z "$step" ] && continue
echo " • mid checkpoint step $step"
python -m scripts.chat_eval \
--source=mid \
--step=$step \
--batch-size=$DEVICE_BATCH \
--max-new-tokens=128 \
--max-problems=20
done <<< "$steps"
unset WANDB_RUN_ID
unset WANDB_EVAL_RUN
}
run_sft_evals() {
local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/chatsft_checkpoints")
if [ -z "$tag" ]; then
return
fi
local dir="$NANOCHAT_BASE_DIR/chatsft_checkpoints/$tag"
local steps="$(checkpoint_steps "$dir")"
if [ -z "$steps" ]; then
return
fi
echo "[runmps_evals] Running SFT chat eval on checkpoints in $dir..."
set_eval_run SFT "${WANDB_RUN:-sft}-sft-eval"
export WANDB_RUN_ID=$SFT_EVAL_WANDB_ID
export WANDB_EVAL_RUN=$SFT_EVAL_WANDB_NAME
export WANDB_PROJECT
while read -r step; do
[ -z "$step" ] && continue
echo " • SFT checkpoint step $step"
python -m scripts.chat_eval \
--source=sft \
--step=$step \
--batch-size=$DEVICE_BATCH \
--max-new-tokens=128 \
--max-problems=20
done <<< "$steps"
unset WANDB_RUN_ID
unset WANDB_EVAL_RUN
}
run_base_evals
run_mid_evals
run_sft_evals
echo "[runmps_evals] Regenerating report with fresh metrics..."
python -m nanochat.report generate
echo "[runmps_evals] Done."

View File

@ -19,8 +19,9 @@ from contextlib import nullcontext
import pandas as pd
import torch
import wandb
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, DummyWandb
from nanochat.tokenizer import HuggingFaceTokenizer
from nanochat.checkpoint_manager import load_model
from nanochat.core_eval import evaluate_task
@ -123,12 +124,15 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
parser.add_argument('--model-step', type=int, default=None, help='Checkpoint step to evaluate when using local models')
args = parser.parse_args()
# distributed / precision setup
device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
wandb_run = DummyWandb()
use_wandb = bool(os.environ.get("WANDB_RUN_ID")) and args.hf_path is None and ddp_rank == 0
# Load model and tokenizer from command line or from file system
if args.hf_path is not None:
@ -140,24 +144,32 @@ def main():
model_slug = hf_path.replace("/", "-") # for the output csv file
else:
# load a local model from the file system
model, tokenizer, meta = load_model("base", device, phase="eval")
model, tokenizer, meta = load_model("base", device, phase="eval", step=args.model_step)
model_name = f"base_model (step {meta['step']})" # just for logging
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
if use_wandb:
wandb_kwargs = {
"project": os.environ.get("WANDB_PROJECT", "nanochat"),
"name": os.environ.get("WANDB_EVAL_RUN", model_name),
"id": os.environ.get("WANDB_RUN_ID"),
"resume": "allow",
"reinit": True,
}
wandb_kwargs = {k: v for k, v in wandb_kwargs.items() if v is not None}
wandb_run = wandb.init(**wandb_kwargs)
# Evaluate the model
with autocast_ctx:
out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
# Write out the results to a csv file
core_metric = None
centered_results = {}
results = out["results"]
centered_results = out["centered_results"]
core_metric = out["core_metric"]
if ddp_rank == 0:
base_dir = get_base_dir()
output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv")
os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
results = out["results"]
centered_results = out["centered_results"]
core_metric = out["core_metric"]
with open(output_csv_path, 'w') as f:
f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n")
for label in results:
@ -180,6 +192,15 @@ def main():
centered_results, # the full table
])
if use_wandb:
wandb_payload = {"core/metric": core_metric}
wandb_payload.update({f"core/{k}": v for k, v in centered_results.items()})
wandb_payload.update({f"core_raw/{k}": v for k, v in results.items()})
for key, value in wandb_payload.items():
wandb_run.summary[key] = value
wandb_run.log(wandb_payload, step=meta["step"])
wandb_run.finish()
compute_cleanup()
if __name__ == "__main__":

View File

@ -9,6 +9,9 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
import os
from contextlib import nullcontext
import torch
import wandb
from nanochat.common import DummyWandb
from nanochat.checkpoint_manager import load_model
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
from nanochat.dataloader import tokenizing_distributed_data_loader
@ -36,6 +39,19 @@ tokens_per_step = device_batch_size * sequence_len * ddp_world_size
assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
steps = split_tokens // tokens_per_step
token_bytes = get_token_bytes(device=device)
use_wandb = bool(os.environ.get("WANDB_RUN_ID"))
wandb_run = DummyWandb()
if use_wandb:
wandb_kwargs = {
"project": os.environ.get("WANDB_PROJECT", "nanochat"),
"name": os.environ.get("WANDB_EVAL_RUN", "base-eval"),
"id": os.environ.get("WANDB_RUN_ID"),
"resume": "allow",
"reinit": True,
}
wandb_kwargs = {k: v for k, v in wandb_kwargs.items() if v is not None}
wandb_run = wandb.init(**wandb_kwargs)
bpb_results = {}
for split_name in ["train", "val"]:
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
@ -63,7 +79,7 @@ if ddp_rank == 0:
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
sample_str = tokenizer.decode(sample[0])
print0(sample_str)
samples.append(sample_str)
samples.append(sample_str)
# Log to report
from nanochat.report import get_report
@ -75,5 +91,12 @@ get_report().log(section="Base model loss", data=[
{f"sample {i}": sample for i, sample in enumerate(samples)},
])
if use_wandb:
wandb_run.log({
"base_loss/train_bpb": bpb_results["train"],
"base_loss/val_bpb": bpb_results["val"],
}, step=meta.get("step"))
wandb_run.finish()
# Cleanup
compute_cleanup()

View File

@ -60,6 +60,7 @@ core_metric_max_per_task = 500 # examples per task in estimating the core metric
sample_every = 2000 # every how many steps to sample from the model
# Output
model_tag = "" # optionally override the model tag for the output checkpoint directory name
checkpoint_every_steps = 0 # save intermediate checkpoints every N optimization steps (0 = disable)
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
@ -76,7 +77,19 @@ get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else l
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config)
if use_dummy_wandb:
wandb_run = DummyWandb()
else:
wandb_kwargs = {
"project": os.environ.get("WANDB_PROJECT", "nanochat"),
"name": run,
"config": user_config,
"reinit": True,
}
wandb_id = os.environ.get("WANDB_RUN_ID")
if wandb_id:
wandb_kwargs.update({"id": wandb_id, "resume": "allow"})
wandb_run = wandb.init(**wandb_kwargs)
# Tokenizer will be useful for evaluation, also we need the vocab size
tokenizer = get_tokenizer()
@ -138,6 +151,10 @@ print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
sequences_per_step = max(1, total_batch_size // max_seq_len)
checkpoint_every_steps = int(checkpoint_every_steps)
checkpoint_enabled = checkpoint_every_steps > 0
# -----------------------------------------------------------------------------
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
@ -150,6 +167,10 @@ train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
x, y = next(train_loader) # kick off load of the very first batch of data
# Checkpoint output location
checkpoint_dirname = model_tag if model_tag else f"d{depth}"
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", checkpoint_dirname)
# -----------------------------------------------------------------------------
# Set up hyperparameter schedulers
@ -177,6 +198,33 @@ min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
# Keep track of total tokens/sequences processed for logging
tokens_per_step = total_batch_size
total_tokens_seen = 0
total_sequences_seen = 0
last_val_bpb = None
def save_base_checkpoint(step_idx):
output_dirname = model_tag if model_tag else f"d{depth}"
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
optimizer_state = [opt.state_dict() for opt in optimizers]
meta = {
"step": step_idx,
"val_bpb": last_val_bpb,
"model_config": model_config_kwargs,
"user_config": user_config,
"device_batch_size": device_batch_size,
"max_seq_len": max_seq_len,
"total_tokens_seen": total_tokens_seen,
"total_sequences_seen": total_sequences_seen,
}
save_checkpoint(
checkpoint_dir,
step_idx,
orig_model.state_dict(),
optimizer_state,
meta,
)
# note that we run +1 steps only so that we can eval and save at the end
for step in range(num_iterations + 1):
last_step = step == num_iterations
@ -192,11 +240,14 @@ for step in range(num_iterations + 1):
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb:
min_val_bpb = val_bpb
last_val_bpb = val_bpb
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"val/bpb": val_bpb,
"train/total_tokens": total_tokens_seen,
"train/total_sequences": total_sequences_seen,
})
model.train()
@ -213,12 +264,14 @@ for step in range(num_iterations + 1):
"total_training_flops": flops_so_far,
"core_metric": results["core_metric"],
"centered_results": results["centered_results"],
"train/total_tokens": total_tokens_seen,
"train/total_sequences": total_sequences_seen,
})
model.train()
# once in a while: sample from the model (only on master process)
# use the original uncompiled model because the inputs keep changing shape
if master_process and (last_step or (step > 0 and step % sample_every == 0)):
if master_process and sample_every > 0 and (last_step or (step > 0 and step % sample_every == 0)):
model.eval()
prompts = [
"The capital of France is",
@ -239,22 +292,7 @@ for step in range(num_iterations + 1):
# save checkpoint at the end of the run (only on master process)
if master_process and last_step:
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(),
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
{
"step": step,
"val_bpb": val_bpb, # loss at last step
"model_config": model_config_kwargs,
"user_config": user_config, # inputs to the training script
"device_batch_size": device_batch_size,
"max_seq_len": max_seq_len,
}
)
save_base_checkpoint(step)
if last_step:
break
@ -287,6 +325,11 @@ for step in range(num_iterations + 1):
model.zero_grad(set_to_none=True)
synchronize()
t1 = time.time()
total_tokens_seen += tokens_per_step
total_sequences_seen += sequences_per_step
current_step = step + 1
if master_process and checkpoint_enabled and not last_step and checkpoint_every_steps > 0 and current_step % checkpoint_every_steps == 0:
save_base_checkpoint(current_step)
dt = t1 - t0
# -------------------------------------------------------------------------
@ -311,6 +354,8 @@ for step in range(num_iterations + 1):
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
"train/total_tokens": total_tokens_seen,
"train/total_sequences": total_sequences_seen,
})
# print a few more stats
@ -335,12 +380,14 @@ get_report().log(section="Base model training", data=[
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
"Final validation bpb": val_bpb,
"Final validation bpb": last_val_bpb,
"CORE metric estimate": results.get("core_metric", None),
"MFU %": f"{mfu:.2f}%",
"Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time/60:.2f}m",
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
"Total tokens processed": total_tokens_seen,
"Total sequences processed": total_sequences_seen,
}
])

View File

@ -1,6 +1,6 @@
"""
Evaluate the Chat model.
All the generic code lives here, and all the evlauation-specific
All the generic code lives here, and all the evaluation-specific
code lives in nanochat directory and is imported from here.
Example runs:
@ -9,13 +9,15 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
"""
import argparse
import os
from functools import partial
from contextlib import nullcontext
import torch
import torch.distributed as dist
import wandb
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type, DummyWandb
from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine
@ -201,9 +203,21 @@ if __name__ == "__main__":
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
wandb_run = DummyWandb()
use_wandb = bool(os.environ.get("WANDB_RUN_ID")) and ddp_rank == 0
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
engine = Engine(model, tokenizer)
if use_wandb:
wandb_kwargs = {
"project": os.environ.get("WANDB_PROJECT", "nanochat"),
"name": os.environ.get("WANDB_EVAL_RUN", f"{args.source}-eval"),
"id": os.environ.get("WANDB_RUN_ID"),
"resume": "allow",
"reinit": True,
}
wandb_kwargs = {k: v for k, v in wandb_kwargs.items() if v is not None}
wandb_run = wandb.init(**wandb_kwargs)
# Get the tasks to evaluate on
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
@ -254,4 +268,11 @@ if __name__ == "__main__":
chatcore_metric_dict,
])
if use_wandb:
wandb_payload = {f"chat_eval/{task}": acc for task, acc in results.items()}
if chatcore_metric_dict:
wandb_payload.update({"chat_eval/chatcore": chatcore_metric_dict["ChatCORE metric"]})
wandb_run.log(wandb_payload, step=meta.get("step"))
wandb_run.finish()
compute_cleanup()

View File

@ -55,10 +55,21 @@ eval_every = 100
eval_steps = 100
eval_metrics_every = 200
eval_metrics_max_problems = 1024
checkpoint_every_steps = 0 # save intermediate checkpoints every N optimization steps (0 = disable)
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
# Normalize evaluation knobs: <=0 disables inline evaluation
if eval_every <= 0:
eval_every = None
if eval_steps <= 0:
eval_steps = 0
if eval_metrics_every <= 0:
eval_metrics_every = None
if eval_metrics_max_problems <= 0:
eval_metrics_max_problems = 0
# -----------------------------------------------------------------------------
# Compute init
@ -70,13 +81,27 @@ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if dev
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
if use_dummy_wandb:
wandb_run = DummyWandb()
else:
wandb_kwargs = {
"project": os.environ.get("WANDB_PROJECT", "nanochat-sft"),
"name": run,
"config": user_config,
"save_code": True,
"reinit": True,
}
wandb_id = os.environ.get("WANDB_RUN_ID")
if wandb_id:
wandb_kwargs.update({"id": wandb_id, "resume": "allow"})
wandb_run = wandb.init(**wandb_kwargs)
# Load the model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
orig_model = model # original, uncompiled model
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
# -----------------------------------------------------------------------------
# Task data mixture we'll train on
@ -142,6 +167,14 @@ if num_iterations == -1:
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
sequences_per_step = target_examples_per_step
checkpoint_every_steps = int(checkpoint_every_steps)
checkpoint_enabled = checkpoint_every_steps > 0
base_dir = get_base_dir()
checkpoint_dirname = model_tag if model_tag else f"d{model.config.n_layer}"
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", checkpoint_dirname)
# -----------------------------------------------------------------------------
# Initialize the Optimizer
@ -168,11 +201,34 @@ def get_lr_multiplier(it):
# Go!
step = 0
train_iter = iter(train_loader)
total_tokens_seen = 0
total_sequences_seen = 0
last_val_loss = None
last_eval_metrics = {}
def save_sft_checkpoint(step_idx):
meta = {
"step": step_idx,
"val_loss": last_val_loss,
**last_eval_metrics,
"model_config": model_config_kwargs,
"total_tokens_seen": total_tokens_seen,
"total_sequences_seen": total_sequences_seen,
}
save_checkpoint(
checkpoint_dir,
step_idx,
model.state_dict(),
None,
meta,
)
for step in range(num_iterations):
last_step = step == num_iterations - 1
# evaluate the validation loss
if last_step or step % eval_every == 0:
# evaluate the validation loss (if enabled)
run_val = eval_every is not None and eval_steps > 0 and (last_step or step % eval_every == 0)
if run_val:
model.eval()
val_iter = iter(build_val_loader())
losses = []
@ -186,14 +242,20 @@ for step in range(num_iterations):
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks
val_loss = val_loss.item()
print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}")
last_val_loss = val_loss
wandb_run.log({
"step": step,
"val_loss": val_loss,
"train/total_tokens": total_tokens_seen,
"train/total_sequences": total_sequences_seen,
})
model.train()
# evlauate accuracy of the multiple choice tasks (which are quick to run)
if last_step or (step > 0 and step % eval_metrics_every == 0):
# evaluate accuracy of the multiple choice tasks (if enabled)
run_metrics = eval_metrics_every is not None and eval_metrics_max_problems > 0 and (
last_step or (step > 0 and step % eval_metrics_every == 0)
)
if run_metrics:
model.eval()
metrics = {}
with torch.no_grad(), autocast_ctx:
@ -202,9 +264,12 @@ for step in range(num_iterations):
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
print0(f"Step {step:05d} | {metrics_str}")
last_eval_metrics = metrics.copy()
wandb_run.log({
"step": step,
**metrics,
"train/total_tokens": total_tokens_seen,
"train/total_sequences": total_sequences_seen,
})
model.train()
@ -238,34 +303,25 @@ for step in range(num_iterations):
# logging
train_loss_item = train_loss.item()
num_tokens_item = num_tokens.item()
total_tokens_seen += num_tokens_item
total_sequences_seen += sequences_per_step
current_step = step + 1
if master_process and checkpoint_enabled and not last_step and current_step % checkpoint_every_steps == 0:
save_sft_checkpoint(current_step)
print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}")
wandb_run.log({
"step": step,
"lrm": lrm,
"train_loss": train_loss_item,
"num_tokens": num_tokens_item,
"train/total_tokens": total_tokens_seen,
"train/total_sequences": total_sequences_seen,
})
step += 1
# Save the model at the end of the run
if master_process:
base_dir = get_base_dir()
depth = model.config.n_layer
model_tag = f"d{depth}" # base the model tag on the depth of the base model
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag)
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
save_checkpoint(
checkpoint_dir,
step,
model.state_dict(),
None, # note: we don't bother to save the optimizer state
{
"step": step,
"val_loss": val_loss,
**metrics,
"model_config": model_config_kwargs,
}
)
save_sft_checkpoint(step)
print(f"✅ Saved model checkpoint to {checkpoint_dir}")
# Log to report
@ -276,7 +332,9 @@ get_report().log(section="Chat SFT", data=[
"Training rows": len(train_ds),
"Number of iterations": num_iterations,
"Training loss": train_loss_item,
"Validation loss": val_loss,
"Validation loss": last_val_loss,
"Total tokens processed": total_tokens_seen,
"Total sequences processed": total_sequences_seen,
},
])

View File

@ -48,6 +48,7 @@ eval_every = 150 # -1 = disable
eval_tokens = 20*524288
total_batch_size = 524288
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
checkpoint_every_steps = 0 # save intermediate checkpoints every N optimization steps (0 = disable)
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
@ -63,7 +64,19 @@ get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else l
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config)
if use_dummy_wandb:
wandb_run = DummyWandb()
else:
wandb_kwargs = {
"project": os.environ.get("WANDB_PROJECT", "nanochat-mid"),
"name": run,
"config": user_config,
"reinit": True,
}
wandb_id = os.environ.get("WANDB_RUN_ID")
if wandb_id:
wandb_kwargs.update({"id": wandb_id, "resume": "allow"})
wandb_run = wandb.init(**wandb_kwargs)
# Load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
@ -83,6 +96,10 @@ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
token_bytes = get_token_bytes(device=device)
sequences_per_step = max(1, total_batch_size // max_seq_len)
checkpoint_every_steps = int(checkpoint_every_steps)
checkpoint_enabled = checkpoint_every_steps > 0
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
adamw_optimizer, muon_optimizer = optimizers
@ -159,6 +176,9 @@ train_loader = mid_data_generator("train")
build_val_loader = lambda: mid_data_generator("val")
progress = 0 # will go from 0 to 1 over the course of the epoch
checkpoint_dirname = f"d{depth}"
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", checkpoint_dirname)
# Learning rate scheduler
def get_lr_multiplier(progress):
# first 80% of training: no decay, then linearly ramp down to 0.
@ -177,6 +197,37 @@ min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
tokens_per_step = total_batch_size
total_tokens_seen = 0
total_sequences_seen = 0
last_val_bpb = None
def save_mid_checkpoint(step_idx):
if dry_run:
return
meta = {
"step": step_idx,
"val_bpb": last_val_bpb,
"model_config": {
"sequence_len": max_seq_len,
"vocab_size": tokenizer.get_vocab_size(),
"n_layer": depth,
"n_head": model.config.n_head,
"n_kv_head": model.config.n_kv_head,
"n_embd": model.config.n_embd,
},
"user_config": user_config,
"total_tokens_seen": total_tokens_seen,
"total_sequences_seen": total_sequences_seen,
}
optimizer_state = [opt.state_dict() for opt in optimizers]
save_checkpoint(
checkpoint_dir,
step_idx,
orig_model.state_dict(),
optimizer_state,
meta,
)
step = 0
while True:
flops_so_far = num_flops_per_token * total_batch_size * step
@ -197,37 +248,20 @@ while True:
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb:
min_val_bpb = val_bpb
last_val_bpb = val_bpb
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"val/bpb": val_bpb,
"train/total_tokens": total_tokens_seen,
"train/total_sequences": total_sequences_seen,
})
model.train()
# save checkpoint at the end of the run (only on master process)
if master_process and last_step and not dry_run:
output_dirname = f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(),
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
{
"step": step,
"val_bpb": val_bpb, # loss at last step
"model_config": {
"sequence_len": max_seq_len,
"vocab_size": tokenizer.get_vocab_size(),
"n_layer": depth,
"n_head": model.config.n_head,
"n_kv_head": model.config.n_kv_head,
"n_embd": model.config.n_embd,
},
"user_config": user_config, # inputs to the training script
}
)
save_mid_checkpoint(step)
if last_step:
break
@ -258,12 +292,17 @@ while True:
model.zero_grad(set_to_none=True)
synchronize()
t1 = time.time()
total_tokens_seen += tokens_per_step
total_sequences_seen += sequences_per_step
dt = t1 - t0
# -------------------------------------------------------------------------
# State
step += 1
if master_process and checkpoint_enabled and not last_step and checkpoint_every_steps > 0 and step % checkpoint_every_steps == 0:
save_mid_checkpoint(step)
# logging
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
@ -285,6 +324,8 @@ while True:
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
"train/total_tokens": total_tokens_seen,
"train/total_sequences": total_sequences_seen,
})
# print a few more stats
@ -303,6 +344,9 @@ if not dry_run:
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
"Final validation bpb": last_val_bpb,
"Total tokens processed": total_tokens_seen,
"Total sequences processed": total_sequences_seen,
}
])