diff --git a/.gitignore b/.gitignore
index 4a87b23..2249a95 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,4 +4,6 @@ __pycache__/
rustbpe/target/
dev-ignore/
report.md
-eval_bundle/
\ No newline at end of file
+eval_bundle/
+wandb/
+.runmps_wandb_ids
diff --git a/dev/runcpu.sh b/dev/runcpu.sh
old mode 100644
new mode 100755
diff --git a/dev/runmps.sh b/dev/runmps.sh
new file mode 100755
index 0000000..da8e216
--- /dev/null
+++ b/dev/runmps.sh
@@ -0,0 +1,336 @@
+#!/bin/bash
+
+# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
+# Run as:
+# bash dev/cpu_demo_run.sh
+
+# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
+# Think of this run as educational/fun demo, not something you should expect to work well.
+# This is also why I hide this script away in dev/
+
+# Stage selection (allow running only a subset, e.g. --stage=sft or --from=mid)
+RUN_BASE=1
+RUN_MID=1
+RUN_SFT=1
+RUN_REPORT=1
+STAGE_ONLY=""
+FROM_STAGE=""
+
+while [[ $# -gt 0 ]]; do
+ case "$1" in
+ --stage=*)
+ STAGE_ONLY="${1#*=}"
+ ;;
+ --base|--mid|--sft|--report)
+ STAGE_ONLY="${1#--}"
+ ;;
+ --from=*)
+ FROM_STAGE="${1#*=}"
+ ;;
+ --from-base|--from-mid|--from-sft)
+ FROM_STAGE="${1#--from-}"
+ ;;
+ --help|-h)
+ cat <<'EOF'
+Usage: bash dev/runmps.sh [options]
+
+Options:
+ --stage= Run only the specified stage.
+ --from= Run from the specified stage through the end.
+ --help Show this help message.
+
+Environment variables (same as before) control batch sizes, WANDB run names, etc.
+EOF
+ exit 0
+ ;;
+ *)
+ echo "Unknown argument: $1" >&2
+ exit 1
+ ;;
+ esac
+ shift
+done
+
+if [[ -n "$FROM_STAGE" ]]; then
+ RUN_BASE=0
+ RUN_MID=0
+ RUN_SFT=0
+ RUN_REPORT=0
+ case "$FROM_STAGE" in
+ base)
+ RUN_BASE=1
+ RUN_MID=1
+ RUN_SFT=1
+ RUN_REPORT=1
+ ;;
+ mid)
+ RUN_MID=1
+ RUN_SFT=1
+ ;;
+ sft)
+ RUN_SFT=1
+ ;;
+ *)
+ echo "Unknown --from stage: $FROM_STAGE" >&2
+ exit 1
+ ;;
+ esac
+fi
+
+if [[ -n "$STAGE_ONLY" ]]; then
+ RUN_BASE=0
+ RUN_MID=0
+ RUN_SFT=0
+ RUN_REPORT=0
+ case "$STAGE_ONLY" in
+ base)
+ RUN_BASE=1
+ ;;
+ mid)
+ RUN_MID=1
+ ;;
+ sft)
+ RUN_SFT=1
+ ;;
+ report)
+ RUN_REPORT=1
+ ;;
+ *)
+ echo "Unknown --stage value: $STAGE_ONLY" >&2
+ exit 1
+ ;;
+ esac
+fi
+
+if [[ -n "$STAGE_ONLY" || -n "$FROM_STAGE" ]]; then
+ # avoid regenerating reports when running a subset unless specifically requested
+ if [[ "$STAGE_ONLY" != "report" && "$FROM_STAGE" != "base" ]]; then
+ RUN_REPORT=0
+ fi
+fi
+
+# all the setup stuff
+export OMP_NUM_THREADS=1
+NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
+mkdir -p $NANOCHAT_BASE_DIR
+command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
+[ -d ".venv" ] || uv venv
+uv sync --extra cpu
+source .venv/bin/activate
+if [ -z "$WANDB_RUN" ]; then
+ WANDB_RUN=dummy
+fi
+if [ "$WANDB_RUN" != "dummy" ] && [ -z "$WANDB_MODE" ]; then
+ export WANDB_MODE=online
+fi
+
+# Batch/sequence configuration
+BASE_DEPTH=${BASE_DEPTH:-4}
+SEQ_LEN=${SEQ_LEN:-1024}
+DEVICE_BATCH=${DEVICE_BATCH:-16}
+TOTAL_BATCH=${TOTAL_BATCH:-$((DEVICE_BATCH * SEQ_LEN))} # tokens per optimizer step
+EVAL_SEQUENCES=10000
+EVAL_STEPS=$(((EVAL_SEQUENCES + DEVICE_BATCH - 1) / DEVICE_BATCH))
+EVAL_BATCH_MULT=4 # evaluate on 4 full batches
+EVAL_TOKENS=$((TOTAL_BATCH * EVAL_BATCH_MULT))
+MID_NUM_STEPS=6144
+SFT_NUM_STEPS=${SFT_NUM_STEPS:-3072}
+CHECKPOINT_EVERY_SEQ=${CHECKPOINT_EVERY_SEQ:-10000}
+RUN_STAGE_EVALS=${RUN_STAGE_EVALS:-0}
+WANDB_PROJECT=${WANDB_PROJECT:-nanochat}
+TARGET_PARAM_DATA_RATIO=${TARGET_PARAM_DATA_RATIO:-20}
+SFT_DEVICE_BATCH=${SFT_DEVICE_BATCH:-$DEVICE_BATCH}
+SFT_TARGET_EXAMPLES=${SFT_TARGET_EXAMPLES:-$DEVICE_BATCH}
+SFT_EVAL_EVERY=${SFT_EVAL_EVERY:-0}
+SFT_EVAL_STEPS=${SFT_EVAL_STEPS:-0}
+SFT_EVAL_METRICS_EVERY=${SFT_EVAL_METRICS_EVERY:-0}
+SFT_EVAL_METRICS_MAX=${SFT_EVAL_METRICS_MAX:-0}
+
+STATE_FILE=".runmps_wandb_ids"
+printf '' > "$STATE_FILE"
+echo "WANDB_PROJECT=$WANDB_PROJECT" >> "$STATE_FILE"
+
+generate_wandb_id() {
+ python - <<'PY'
+import uuid
+print(uuid.uuid4().hex[:8])
+PY
+}
+
+BASE_SEQS_PER_STEP=$((TOTAL_BATCH / SEQ_LEN))
+if [ $BASE_SEQS_PER_STEP -le 0 ]; then BASE_SEQS_PER_STEP=1; fi
+if [ "$CHECKPOINT_EVERY_SEQ" -le 0 ]; then
+ BASE_CHECKPOINT_STEPS=0
+else
+ BASE_CHECKPOINT_STEPS=$(((CHECKPOINT_EVERY_SEQ + BASE_SEQS_PER_STEP - 1) / BASE_SEQS_PER_STEP))
+fi
+
+MID_SEQS_PER_STEP=$((TOTAL_BATCH / SEQ_LEN))
+if [ $MID_SEQS_PER_STEP -le 0 ]; then MID_SEQS_PER_STEP=1; fi
+if [ "$CHECKPOINT_EVERY_SEQ" -le 0 ]; then
+ MID_CHECKPOINT_STEPS=0
+else
+ MID_CHECKPOINT_STEPS=$(((CHECKPOINT_EVERY_SEQ + MID_SEQS_PER_STEP - 1) / MID_SEQS_PER_STEP))
+fi
+
+SFT_SEQS_PER_STEP=$SFT_DEVICE_BATCH
+if [ $SFT_SEQS_PER_STEP -le 0 ]; then SFT_SEQS_PER_STEP=1; fi
+if [ "$CHECKPOINT_EVERY_SEQ" -le 0 ]; then
+ SFT_CHECKPOINT_STEPS=0
+else
+ SFT_CHECKPOINT_STEPS=$(((CHECKPOINT_EVERY_SEQ + SFT_SEQS_PER_STEP - 1) / SFT_SEQS_PER_STEP))
+fi
+
+# Auto-populate WANDB_API_KEY from ~/.netrc when talking to a local W&B server.
+# Mirrors the helper used in TinyRecursiveModels/pretrain_text.py so we can log
+# to a self-hosted instance without manual export each time.
+if [ -z "$WANDB_API_KEY" ] && [ -f "$HOME/.netrc" ]; then
+ # Allow custom WANDB_BASE_URL; default to localhost if user sets WANDB
+ WANDB_BASE_URL_DEFAULT=${WANDB_BASE_URL:-http://localhost:8080}
+ if printf '%s' "$WANDB_BASE_URL_DEFAULT" | grep -q "localhost"; then
+ HOST=$(printf '%s\n' "$WANDB_BASE_URL_DEFAULT" | sed -E 's#https?://##' | cut -d/ -f1)
+ API_KEY=$(python - "$HOST" <<'PY'
+import sys
+from netrc import netrc
+
+host = sys.argv[1]
+auth = netrc().authenticators(host)
+if auth and auth[2]:
+ print(auth[2], end="")
+PY
+)
+ if [ -n "$API_KEY" ]; then
+ export WANDB_BASE_URL="$WANDB_BASE_URL_DEFAULT"
+ export WANDB_API_KEY="$API_KEY"
+ echo "[runmps] Loaded WANDB_API_KEY for $WANDB_BASE_URL from ~/.netrc"
+ fi
+ fi
+fi
+curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
+source "$HOME/.cargo/env"
+uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
+EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
+if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
+ curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
+ unzip -q eval_bundle.zip
+ rm eval_bundle.zip
+ mv eval_bundle $NANOCHAT_BASE_DIR
+fi
+
+if (( RUN_BASE )); then
+ # wipe the report
+ python -m nanochat.report reset
+
+# train tokenizer on ~2B characters (download full shard set for extended training)
+python -m nanochat.dataset -n 240
+python -m scripts.tok_train --max_chars=2000000000
+python -m scripts.tok_eval
+
+ # train a very small 4 layer model on the CPU
+ # each optimization step processes a single sequence of 1024 tokens
+ # we only run 50 steps of optimization (bump this to get better results)
+ if [ "$WANDB_RUN" != "dummy" ]; then
+ BASE_WANDB_ID=${BASE_WANDB_ID:-$(generate_wandb_id)}
+ echo "BASE_WANDB_ID=$BASE_WANDB_ID" >> "$STATE_FILE"
+ echo "BASE_WANDB_NAME=$WANDB_RUN" >> "$STATE_FILE"
+ export WANDB_RUN_ID=$BASE_WANDB_ID
+ export WANDB_EVAL_RUN=$WANDB_RUN
+ export WANDB_PROJECT
+ fi
+
+ python -m scripts.base_train \
+ --depth=$BASE_DEPTH \
+ --max_seq_len=$SEQ_LEN \
+ --device_batch_size=$DEVICE_BATCH \
+ --total_batch_size=$TOTAL_BATCH \
+ --target_param_data_ratio=$TARGET_PARAM_DATA_RATIO \
+ --run="$WANDB_RUN" \
+ --eval_every=$EVAL_STEPS \
+ --eval_tokens=$EVAL_TOKENS \
+ --core_metric_every=-1 \
+ --sample_every=-1 \
+ --checkpoint_every_steps=$BASE_CHECKPOINT_STEPS
+
+ if [ "$WANDB_RUN" != "dummy" ]; then
+ unset WANDB_RUN_ID
+ unset WANDB_EVAL_RUN
+ fi
+
+ if [ "$RUN_STAGE_EVALS" = "1" ]; then
+ python -m scripts.base_loss --device_batch_size=$DEVICE_BATCH --split_tokens=$EVAL_TOKENS
+ python -m scripts.base_eval --max-per-task=16
+ fi
+fi
+
+if (( RUN_MID )); then
+ # midtraining
+ if [ "$WANDB_RUN" != "dummy" ]; then
+ MID_WANDB_ID=${MID_WANDB_ID:-$(generate_wandb_id)}
+ echo "MID_WANDB_ID=$MID_WANDB_ID" >> "$STATE_FILE"
+ echo "MID_WANDB_NAME=${WANDB_RUN}-mid" >> "$STATE_FILE"
+ export WANDB_RUN_ID=$MID_WANDB_ID
+ export WANDB_EVAL_RUN="${WANDB_RUN}-mid"
+ export WANDB_PROJECT
+ fi
+
+ python -m scripts.mid_train \
+ --max_seq_len=$SEQ_LEN \
+ --device_batch_size=$DEVICE_BATCH \
+ --total_batch_size=$TOTAL_BATCH \
+ --run="${WANDB_RUN}-mid" \
+ --eval_every=$EVAL_STEPS \
+ --eval_tokens=$EVAL_TOKENS \
+ --checkpoint_every_steps=$MID_CHECKPOINT_STEPS \
+ --num_iterations=$MID_NUM_STEPS
+ if [ "$WANDB_RUN" != "dummy" ]; then
+ unset WANDB_RUN_ID
+ unset WANDB_EVAL_RUN
+ fi
+ if [ "$RUN_STAGE_EVALS" = "1" ]; then
+ # eval results will be terrible, this is just to execute the code paths.
+ # note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
+ python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
+ fi
+fi
+
+if (( RUN_SFT )); then
+ # SFT
+ if [ "$WANDB_RUN" != "dummy" ]; then
+ SFT_WANDB_ID=${SFT_WANDB_ID:-$(generate_wandb_id)}
+ echo "SFT_WANDB_ID=$SFT_WANDB_ID" >> "$STATE_FILE"
+ echo "SFT_WANDB_NAME=${WANDB_RUN}-sft" >> "$STATE_FILE"
+ export WANDB_RUN_ID=$SFT_WANDB_ID
+ export WANDB_EVAL_RUN="${WANDB_RUN}-sft"
+ export WANDB_PROJECT
+ fi
+
+ python -m scripts.chat_sft \
+ --device_batch_size=$SFT_DEVICE_BATCH \
+ --target_examples_per_step=$SFT_TARGET_EXAMPLES \
+ --run="${WANDB_RUN}-sft" \
+ --num_iterations=$SFT_NUM_STEPS \
+ --eval_every=$SFT_EVAL_EVERY \
+ --eval_steps=$SFT_EVAL_STEPS \
+ --eval_metrics_every=$SFT_EVAL_METRICS_EVERY \
+ --eval_metrics_max_problems=$SFT_EVAL_METRICS_MAX \
+ --checkpoint_every_steps=$SFT_CHECKPOINT_STEPS
+
+ if [ "$WANDB_RUN" != "dummy" ]; then
+ unset WANDB_RUN_ID
+ unset WANDB_EVAL_RUN
+ fi
+fi
+
+# Chat CLI
+# python -m scripts.chat_cli -p "Why is the sky blue?"
+
+# Chat Web
+# python -m scripts.chat_web
+
+if (( RUN_REPORT )); then
+ python -m nanochat.report generate
+fi
+
+if [ "$RUN_STAGE_EVALS" != "1" ] && (( RUN_BASE || RUN_MID || RUN_SFT )); then
+ echo "[runmps] Inline evals disabled. Run 'bash dev/runmps_evals.sh' to compute metrics from saved checkpoints."
+fi
diff --git a/dev/runmps_evals.sh b/dev/runmps_evals.sh
new file mode 100755
index 0000000..cd98fc8
--- /dev/null
+++ b/dev/runmps_evals.sh
@@ -0,0 +1,223 @@
+#!/bin/bash
+
+set -euo pipefail
+shopt -s nullglob
+
+# Evaluate every checkpoint generated by dev/runmps.sh and log back to the
+# original W&B runs.
+# Usage:
+# bash dev/runmps_evals.sh
+
+STATE_FILE=".runmps_wandb_ids"
+
+if [ ! -d ".venv" ]; then
+ echo "[runmps_evals] .venv not found. Run dev/runmps.sh first." >&2
+ exit 1
+fi
+
+source .venv/bin/activate
+
+if [ -f "$STATE_FILE" ]; then
+ . "$STATE_FILE"
+fi
+
+export OMP_NUM_THREADS=${OMP_NUM_THREADS:-1}
+export NANOCHAT_BASE_DIR=${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat}
+export WANDB_PROJECT=${WANDB_PROJECT:-nanochat}
+WANDB_RUN=${WANDB_RUN:-dummy}
+if [ "$WANDB_RUN" != "dummy" ] && [ -z "${WANDB_MODE:-}" ]; then
+ export WANDB_MODE=online
+fi
+
+# Auto-populate WANDB_API_KEY from ~/.netrc when talking to a local W&B server,
+# mirroring the training script behaviour so evals can log without manual export.
+if [ -z "${WANDB_API_KEY:-}" ] && [ -f "$HOME/.netrc" ]; then
+ WANDB_BASE_URL_DEFAULT=${WANDB_BASE_URL:-http://localhost:8080}
+ if printf '%s' "$WANDB_BASE_URL_DEFAULT" | grep -q "localhost"; then
+ HOST=$(printf '%s\n' "$WANDB_BASE_URL_DEFAULT" | sed -E 's#https?://##' | cut -d/ -f1)
+ API_KEY=$(python - "$HOST" <<'PY'
+import sys
+from netrc import netrc
+
+host = sys.argv[1]
+auth = netrc().authenticators(host)
+if auth and auth[2]:
+ print(auth[2], end="")
+PY
+)
+ if [ -n "$API_KEY" ]; then
+ export WANDB_BASE_URL="$WANDB_BASE_URL_DEFAULT"
+ export WANDB_API_KEY="$API_KEY"
+ echo "[runmps_evals] Loaded WANDB_API_KEY for $WANDB_BASE_URL from ~/.netrc"
+ fi
+ fi
+fi
+
+generate_wandb_id() {
+ python - <<'PY'
+import uuid
+print(uuid.uuid4().hex[:8])
+PY
+}
+
+set_eval_run() {
+ local prefix="$1"
+ local default_name="$2"
+ local id_var="${prefix}_EVAL_WANDB_ID"
+ local name_var="${prefix}_EVAL_WANDB_NAME"
+ local current_id=$(generate_wandb_id)
+ local current_name="$default_name"
+ eval "$id_var=$current_id"
+ eval "$name_var='$current_name'"
+}
+
+SEQ_LEN=${SEQ_LEN:-1024}
+DEVICE_BATCH=${DEVICE_BATCH:-16}
+TOTAL_BATCH=$((DEVICE_BATCH * SEQ_LEN))
+EVAL_SEQUENCES=10000
+EVAL_BATCH_MULT=4
+EVAL_TOKENS=$((TOTAL_BATCH * EVAL_BATCH_MULT))
+
+checkpoint_steps() {
+ local dir="$1"
+ python - "$dir" <<'PY'
+import os, sys
+dir_path = sys.argv[1]
+if not os.path.isdir(dir_path):
+ sys.exit(0)
+steps = []
+for fname in os.listdir(dir_path):
+ if fname.startswith("model_") and fname.endswith(".pt"):
+ try:
+ step = int(fname.split("_")[1].split(".")[0])
+ except ValueError:
+ continue
+ steps.append(step)
+for step in sorted(steps):
+ print(step)
+PY
+}
+
+checkpoint_tag() {
+ local dir="$1"
+ python - "$dir" <<'PY'
+import os, sys
+from nanochat.checkpoint_manager import find_largest_model
+base_dir = sys.argv[1]
+if not os.path.isdir(base_dir):
+ sys.exit(0)
+try:
+ tag = find_largest_model(base_dir)
+except FileNotFoundError:
+ sys.exit(0)
+print(tag)
+PY
+}
+
+run_base_evals() {
+ local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/base_checkpoints")
+ if [ -z "$tag" ]; then
+ return
+ fi
+ local dir="$NANOCHAT_BASE_DIR/base_checkpoints/$tag"
+ local steps="$(checkpoint_steps "$dir")"
+ if [ -z "$steps" ]; then
+ return
+ fi
+ echo "[runmps_evals] Running base eval on checkpoints in $dir..."
+ set_eval_run BASE "${WANDB_RUN:-base}-eval"
+ export WANDB_RUN_ID=$BASE_EVAL_WANDB_ID
+ export WANDB_EVAL_RUN=$BASE_EVAL_WANDB_NAME
+ export WANDB_PROJECT
+ while read -r step; do
+ [ -z "$step" ] && continue
+ echo " • base checkpoint step $step"
+ # run base_loss to print sample completions, but keep it off W&B
+ saved_wandb_run_id=${WANDB_RUN_ID:-}
+ saved_wandb_eval_run=${WANDB_EVAL_RUN:-}
+ unset WANDB_RUN_ID
+ unset WANDB_EVAL_RUN
+ python -m scripts.base_loss \
+ --device_batch_size=$DEVICE_BATCH \
+ --split_tokens=$EVAL_TOKENS \
+ --model_step=$step
+ if [ -n "$saved_wandb_run_id" ]; then
+ export WANDB_RUN_ID=$saved_wandb_run_id
+ fi
+ if [ -n "$saved_wandb_eval_run" ]; then
+ export WANDB_EVAL_RUN=$saved_wandb_eval_run
+ fi
+ python -m scripts.base_eval \
+ --max-per-task=16 \
+ --model-step=$step
+ done <<< "$steps"
+ unset WANDB_RUN_ID
+ unset WANDB_EVAL_RUN
+}
+
+run_mid_evals() {
+ local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/mid_checkpoints")
+ if [ -z "$tag" ]; then
+ return
+ fi
+ local dir="$NANOCHAT_BASE_DIR/mid_checkpoints/$tag"
+ local steps="$(checkpoint_steps "$dir")"
+ if [ -z "$steps" ]; then
+ return
+ fi
+ echo "[runmps_evals] Running mid-stage chat eval on checkpoints in $dir..."
+ set_eval_run MID "${WANDB_RUN:-mid}-mid-eval"
+ export WANDB_RUN_ID=$MID_EVAL_WANDB_ID
+ export WANDB_EVAL_RUN=$MID_EVAL_WANDB_NAME
+ export WANDB_PROJECT
+ while read -r step; do
+ [ -z "$step" ] && continue
+ echo " • mid checkpoint step $step"
+ python -m scripts.chat_eval \
+ --source=mid \
+ --step=$step \
+ --batch-size=$DEVICE_BATCH \
+ --max-new-tokens=128 \
+ --max-problems=20
+ done <<< "$steps"
+ unset WANDB_RUN_ID
+ unset WANDB_EVAL_RUN
+}
+
+run_sft_evals() {
+ local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/chatsft_checkpoints")
+ if [ -z "$tag" ]; then
+ return
+ fi
+ local dir="$NANOCHAT_BASE_DIR/chatsft_checkpoints/$tag"
+ local steps="$(checkpoint_steps "$dir")"
+ if [ -z "$steps" ]; then
+ return
+ fi
+ echo "[runmps_evals] Running SFT chat eval on checkpoints in $dir..."
+ set_eval_run SFT "${WANDB_RUN:-sft}-sft-eval"
+ export WANDB_RUN_ID=$SFT_EVAL_WANDB_ID
+ export WANDB_EVAL_RUN=$SFT_EVAL_WANDB_NAME
+ export WANDB_PROJECT
+ while read -r step; do
+ [ -z "$step" ] && continue
+ echo " • SFT checkpoint step $step"
+ python -m scripts.chat_eval \
+ --source=sft \
+ --step=$step \
+ --batch-size=$DEVICE_BATCH \
+ --max-new-tokens=128 \
+ --max-problems=20
+ done <<< "$steps"
+ unset WANDB_RUN_ID
+ unset WANDB_EVAL_RUN
+}
+
+run_base_evals
+run_mid_evals
+run_sft_evals
+
+echo "[runmps_evals] Regenerating report with fresh metrics..."
+python -m nanochat.report generate
+
+echo "[runmps_evals] Done."
diff --git a/scripts/base_eval.py b/scripts/base_eval.py
index fc02120..f6f002e 100644
--- a/scripts/base_eval.py
+++ b/scripts/base_eval.py
@@ -19,8 +19,9 @@ from contextlib import nullcontext
import pandas as pd
import torch
+import wandb
-from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type
+from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, DummyWandb
from nanochat.tokenizer import HuggingFaceTokenizer
from nanochat.checkpoint_manager import load_model
from nanochat.core_eval import evaluate_task
@@ -123,12 +124,15 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
+ parser.add_argument('--model-step', type=int, default=None, help='Checkpoint step to evaluate when using local models')
args = parser.parse_args()
# distributed / precision setup
device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
+ wandb_run = DummyWandb()
+ use_wandb = bool(os.environ.get("WANDB_RUN_ID")) and args.hf_path is None and ddp_rank == 0
# Load model and tokenizer from command line or from file system
if args.hf_path is not None:
@@ -140,24 +144,32 @@ def main():
model_slug = hf_path.replace("/", "-") # for the output csv file
else:
# load a local model from the file system
- model, tokenizer, meta = load_model("base", device, phase="eval")
+ model, tokenizer, meta = load_model("base", device, phase="eval", step=args.model_step)
model_name = f"base_model (step {meta['step']})" # just for logging
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
+ if use_wandb:
+ wandb_kwargs = {
+ "project": os.environ.get("WANDB_PROJECT", "nanochat"),
+ "name": os.environ.get("WANDB_EVAL_RUN", model_name),
+ "id": os.environ.get("WANDB_RUN_ID"),
+ "resume": "allow",
+ "reinit": True,
+ }
+ wandb_kwargs = {k: v for k, v in wandb_kwargs.items() if v is not None}
+ wandb_run = wandb.init(**wandb_kwargs)
# Evaluate the model
with autocast_ctx:
out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
# Write out the results to a csv file
- core_metric = None
- centered_results = {}
+ results = out["results"]
+ centered_results = out["centered_results"]
+ core_metric = out["core_metric"]
if ddp_rank == 0:
base_dir = get_base_dir()
output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv")
os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
- results = out["results"]
- centered_results = out["centered_results"]
- core_metric = out["core_metric"]
with open(output_csv_path, 'w') as f:
f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n")
for label in results:
@@ -180,6 +192,15 @@ def main():
centered_results, # the full table
])
+ if use_wandb:
+ wandb_payload = {"core/metric": core_metric}
+ wandb_payload.update({f"core/{k}": v for k, v in centered_results.items()})
+ wandb_payload.update({f"core_raw/{k}": v for k, v in results.items()})
+ for key, value in wandb_payload.items():
+ wandb_run.summary[key] = value
+ wandb_run.log(wandb_payload, step=meta["step"])
+ wandb_run.finish()
+
compute_cleanup()
if __name__ == "__main__":
diff --git a/scripts/base_loss.py b/scripts/base_loss.py
index abcde5f..7e85077 100644
--- a/scripts/base_loss.py
+++ b/scripts/base_loss.py
@@ -9,6 +9,9 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
import os
from contextlib import nullcontext
import torch
+import wandb
+
+from nanochat.common import DummyWandb
from nanochat.checkpoint_manager import load_model
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
from nanochat.dataloader import tokenizing_distributed_data_loader
@@ -36,6 +39,19 @@ tokens_per_step = device_batch_size * sequence_len * ddp_world_size
assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
steps = split_tokens // tokens_per_step
token_bytes = get_token_bytes(device=device)
+use_wandb = bool(os.environ.get("WANDB_RUN_ID"))
+wandb_run = DummyWandb()
+if use_wandb:
+ wandb_kwargs = {
+ "project": os.environ.get("WANDB_PROJECT", "nanochat"),
+ "name": os.environ.get("WANDB_EVAL_RUN", "base-eval"),
+ "id": os.environ.get("WANDB_RUN_ID"),
+ "resume": "allow",
+ "reinit": True,
+ }
+ wandb_kwargs = {k: v for k, v in wandb_kwargs.items() if v is not None}
+ wandb_run = wandb.init(**wandb_kwargs)
+
bpb_results = {}
for split_name in ["train", "val"]:
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
@@ -63,7 +79,7 @@ if ddp_rank == 0:
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
sample_str = tokenizer.decode(sample[0])
print0(sample_str)
- samples.append(sample_str)
+ samples.append(sample_str)
# Log to report
from nanochat.report import get_report
@@ -75,5 +91,12 @@ get_report().log(section="Base model loss", data=[
{f"sample {i}": sample for i, sample in enumerate(samples)},
])
+if use_wandb:
+ wandb_run.log({
+ "base_loss/train_bpb": bpb_results["train"],
+ "base_loss/val_bpb": bpb_results["val"],
+ }, step=meta.get("step"))
+ wandb_run.finish()
+
# Cleanup
compute_cleanup()
diff --git a/scripts/base_train.py b/scripts/base_train.py
index 3725805..ee394ca 100644
--- a/scripts/base_train.py
+++ b/scripts/base_train.py
@@ -60,6 +60,7 @@ core_metric_max_per_task = 500 # examples per task in estimating the core metric
sample_every = 2000 # every how many steps to sample from the model
# Output
model_tag = "" # optionally override the model tag for the output checkpoint directory name
+checkpoint_every_steps = 0 # save intermediate checkpoints every N optimization steps (0 = disable)
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
@@ -76,7 +77,19 @@ get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else l
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
-wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config)
+if use_dummy_wandb:
+ wandb_run = DummyWandb()
+else:
+ wandb_kwargs = {
+ "project": os.environ.get("WANDB_PROJECT", "nanochat"),
+ "name": run,
+ "config": user_config,
+ "reinit": True,
+ }
+ wandb_id = os.environ.get("WANDB_RUN_ID")
+ if wandb_id:
+ wandb_kwargs.update({"id": wandb_id, "resume": "allow"})
+ wandb_run = wandb.init(**wandb_kwargs)
# Tokenizer will be useful for evaluation, also we need the vocab size
tokenizer = get_tokenizer()
@@ -138,6 +151,10 @@ print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
+sequences_per_step = max(1, total_batch_size // max_seq_len)
+checkpoint_every_steps = int(checkpoint_every_steps)
+checkpoint_enabled = checkpoint_every_steps > 0
+
# -----------------------------------------------------------------------------
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
@@ -150,6 +167,10 @@ train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
x, y = next(train_loader) # kick off load of the very first batch of data
+# Checkpoint output location
+checkpoint_dirname = model_tag if model_tag else f"d{depth}"
+checkpoint_dir = os.path.join(base_dir, "base_checkpoints", checkpoint_dirname)
+
# -----------------------------------------------------------------------------
# Set up hyperparameter schedulers
@@ -177,6 +198,33 @@ min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
+# Keep track of total tokens/sequences processed for logging
+tokens_per_step = total_batch_size
+total_tokens_seen = 0
+total_sequences_seen = 0
+last_val_bpb = None
+
+def save_base_checkpoint(step_idx):
+ output_dirname = model_tag if model_tag else f"d{depth}"
+ checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
+ optimizer_state = [opt.state_dict() for opt in optimizers]
+ meta = {
+ "step": step_idx,
+ "val_bpb": last_val_bpb,
+ "model_config": model_config_kwargs,
+ "user_config": user_config,
+ "device_batch_size": device_batch_size,
+ "max_seq_len": max_seq_len,
+ "total_tokens_seen": total_tokens_seen,
+ "total_sequences_seen": total_sequences_seen,
+ }
+ save_checkpoint(
+ checkpoint_dir,
+ step_idx,
+ orig_model.state_dict(),
+ optimizer_state,
+ meta,
+ )
# note that we run +1 steps only so that we can eval and save at the end
for step in range(num_iterations + 1):
last_step = step == num_iterations
@@ -192,11 +240,14 @@ for step in range(num_iterations + 1):
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb:
min_val_bpb = val_bpb
+ last_val_bpb = val_bpb
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"val/bpb": val_bpb,
+ "train/total_tokens": total_tokens_seen,
+ "train/total_sequences": total_sequences_seen,
})
model.train()
@@ -213,12 +264,14 @@ for step in range(num_iterations + 1):
"total_training_flops": flops_so_far,
"core_metric": results["core_metric"],
"centered_results": results["centered_results"],
+ "train/total_tokens": total_tokens_seen,
+ "train/total_sequences": total_sequences_seen,
})
model.train()
# once in a while: sample from the model (only on master process)
# use the original uncompiled model because the inputs keep changing shape
- if master_process and (last_step or (step > 0 and step % sample_every == 0)):
+ if master_process and sample_every > 0 and (last_step or (step > 0 and step % sample_every == 0)):
model.eval()
prompts = [
"The capital of France is",
@@ -239,22 +292,7 @@ for step in range(num_iterations + 1):
# save checkpoint at the end of the run (only on master process)
if master_process and last_step:
- output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
- checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
- save_checkpoint(
- checkpoint_dir,
- step,
- orig_model.state_dict(),
- [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
- {
- "step": step,
- "val_bpb": val_bpb, # loss at last step
- "model_config": model_config_kwargs,
- "user_config": user_config, # inputs to the training script
- "device_batch_size": device_batch_size,
- "max_seq_len": max_seq_len,
- }
- )
+ save_base_checkpoint(step)
if last_step:
break
@@ -287,6 +325,11 @@ for step in range(num_iterations + 1):
model.zero_grad(set_to_none=True)
synchronize()
t1 = time.time()
+ total_tokens_seen += tokens_per_step
+ total_sequences_seen += sequences_per_step
+ current_step = step + 1
+ if master_process and checkpoint_enabled and not last_step and checkpoint_every_steps > 0 and current_step % checkpoint_every_steps == 0:
+ save_base_checkpoint(current_step)
dt = t1 - t0
# -------------------------------------------------------------------------
@@ -311,6 +354,8 @@ for step in range(num_iterations + 1):
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
+ "train/total_tokens": total_tokens_seen,
+ "train/total_sequences": total_sequences_seen,
})
# print a few more stats
@@ -335,12 +380,14 @@ get_report().log(section="Base model training", data=[
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
- "Final validation bpb": val_bpb,
+ "Final validation bpb": last_val_bpb,
"CORE metric estimate": results.get("core_metric", None),
"MFU %": f"{mfu:.2f}%",
"Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time/60:.2f}m",
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
+ "Total tokens processed": total_tokens_seen,
+ "Total sequences processed": total_sequences_seen,
}
])
diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py
index c77a89e..cea1061 100644
--- a/scripts/chat_eval.py
+++ b/scripts/chat_eval.py
@@ -1,6 +1,6 @@
"""
Evaluate the Chat model.
-All the generic code lives here, and all the evlauation-specific
+All the generic code lives here, and all the evaluation-specific
code lives in nanochat directory and is imported from here.
Example runs:
@@ -9,13 +9,15 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
"""
import argparse
+import os
from functools import partial
from contextlib import nullcontext
import torch
import torch.distributed as dist
+import wandb
-from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type
+from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type, DummyWandb
from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine
@@ -201,9 +203,21 @@ if __name__ == "__main__":
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
+ wandb_run = DummyWandb()
+ use_wandb = bool(os.environ.get("WANDB_RUN_ID")) and ddp_rank == 0
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
engine = Engine(model, tokenizer)
+ if use_wandb:
+ wandb_kwargs = {
+ "project": os.environ.get("WANDB_PROJECT", "nanochat"),
+ "name": os.environ.get("WANDB_EVAL_RUN", f"{args.source}-eval"),
+ "id": os.environ.get("WANDB_RUN_ID"),
+ "resume": "allow",
+ "reinit": True,
+ }
+ wandb_kwargs = {k: v for k, v in wandb_kwargs.items() if v is not None}
+ wandb_run = wandb.init(**wandb_kwargs)
# Get the tasks to evaluate on
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
@@ -254,4 +268,11 @@ if __name__ == "__main__":
chatcore_metric_dict,
])
+ if use_wandb:
+ wandb_payload = {f"chat_eval/{task}": acc for task, acc in results.items()}
+ if chatcore_metric_dict:
+ wandb_payload.update({"chat_eval/chatcore": chatcore_metric_dict["ChatCORE metric"]})
+ wandb_run.log(wandb_payload, step=meta.get("step"))
+ wandb_run.finish()
+
compute_cleanup()
diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py
index e6e4565..25ba0fa 100644
--- a/scripts/chat_sft.py
+++ b/scripts/chat_sft.py
@@ -55,10 +55,21 @@ eval_every = 100
eval_steps = 100
eval_metrics_every = 200
eval_metrics_max_problems = 1024
+checkpoint_every_steps = 0 # save intermediate checkpoints every N optimization steps (0 = disable)
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
+
+# Normalize evaluation knobs: <=0 disables inline evaluation
+if eval_every <= 0:
+ eval_every = None
+if eval_steps <= 0:
+ eval_steps = 0
+if eval_metrics_every <= 0:
+ eval_metrics_every = None
+if eval_metrics_max_problems <= 0:
+ eval_metrics_max_problems = 0
# -----------------------------------------------------------------------------
# Compute init
@@ -70,13 +81,27 @@ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if dev
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
-wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
+if use_dummy_wandb:
+ wandb_run = DummyWandb()
+else:
+ wandb_kwargs = {
+ "project": os.environ.get("WANDB_PROJECT", "nanochat-sft"),
+ "name": run,
+ "config": user_config,
+ "save_code": True,
+ "reinit": True,
+ }
+ wandb_id = os.environ.get("WANDB_RUN_ID")
+ if wandb_id:
+ wandb_kwargs.update({"id": wandb_id, "resume": "allow"})
+ wandb_run = wandb.init(**wandb_kwargs)
# Load the model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
orig_model = model # original, uncompiled model
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
+model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
# -----------------------------------------------------------------------------
# Task data mixture we'll train on
@@ -142,6 +167,14 @@ if num_iterations == -1:
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
+sequences_per_step = target_examples_per_step
+checkpoint_every_steps = int(checkpoint_every_steps)
+checkpoint_enabled = checkpoint_every_steps > 0
+
+base_dir = get_base_dir()
+checkpoint_dirname = model_tag if model_tag else f"d{model.config.n_layer}"
+checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", checkpoint_dirname)
+
# -----------------------------------------------------------------------------
# Initialize the Optimizer
@@ -168,11 +201,34 @@ def get_lr_multiplier(it):
# Go!
step = 0
train_iter = iter(train_loader)
+total_tokens_seen = 0
+total_sequences_seen = 0
+last_val_loss = None
+last_eval_metrics = {}
+
+def save_sft_checkpoint(step_idx):
+ meta = {
+ "step": step_idx,
+ "val_loss": last_val_loss,
+ **last_eval_metrics,
+ "model_config": model_config_kwargs,
+ "total_tokens_seen": total_tokens_seen,
+ "total_sequences_seen": total_sequences_seen,
+ }
+ save_checkpoint(
+ checkpoint_dir,
+ step_idx,
+ model.state_dict(),
+ None,
+ meta,
+ )
+
for step in range(num_iterations):
last_step = step == num_iterations - 1
- # evaluate the validation loss
- if last_step or step % eval_every == 0:
+ # evaluate the validation loss (if enabled)
+ run_val = eval_every is not None and eval_steps > 0 and (last_step or step % eval_every == 0)
+ if run_val:
model.eval()
val_iter = iter(build_val_loader())
losses = []
@@ -186,14 +242,20 @@ for step in range(num_iterations):
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks
val_loss = val_loss.item()
print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}")
+ last_val_loss = val_loss
wandb_run.log({
"step": step,
"val_loss": val_loss,
+ "train/total_tokens": total_tokens_seen,
+ "train/total_sequences": total_sequences_seen,
})
model.train()
- # evlauate accuracy of the multiple choice tasks (which are quick to run)
- if last_step or (step > 0 and step % eval_metrics_every == 0):
+ # evaluate accuracy of the multiple choice tasks (if enabled)
+ run_metrics = eval_metrics_every is not None and eval_metrics_max_problems > 0 and (
+ last_step or (step > 0 and step % eval_metrics_every == 0)
+ )
+ if run_metrics:
model.eval()
metrics = {}
with torch.no_grad(), autocast_ctx:
@@ -202,9 +264,12 @@ for step in range(num_iterations):
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
print0(f"Step {step:05d} | {metrics_str}")
+ last_eval_metrics = metrics.copy()
wandb_run.log({
"step": step,
**metrics,
+ "train/total_tokens": total_tokens_seen,
+ "train/total_sequences": total_sequences_seen,
})
model.train()
@@ -238,34 +303,25 @@ for step in range(num_iterations):
# logging
train_loss_item = train_loss.item()
num_tokens_item = num_tokens.item()
+ total_tokens_seen += num_tokens_item
+ total_sequences_seen += sequences_per_step
+ current_step = step + 1
+ if master_process and checkpoint_enabled and not last_step and current_step % checkpoint_every_steps == 0:
+ save_sft_checkpoint(current_step)
print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}")
wandb_run.log({
"step": step,
"lrm": lrm,
"train_loss": train_loss_item,
"num_tokens": num_tokens_item,
+ "train/total_tokens": total_tokens_seen,
+ "train/total_sequences": total_sequences_seen,
})
step += 1
# Save the model at the end of the run
if master_process:
- base_dir = get_base_dir()
- depth = model.config.n_layer
- model_tag = f"d{depth}" # base the model tag on the depth of the base model
- checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag)
- model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
- save_checkpoint(
- checkpoint_dir,
- step,
- model.state_dict(),
- None, # note: we don't bother to save the optimizer state
- {
- "step": step,
- "val_loss": val_loss,
- **metrics,
- "model_config": model_config_kwargs,
- }
- )
+ save_sft_checkpoint(step)
print(f"✅ Saved model checkpoint to {checkpoint_dir}")
# Log to report
@@ -276,7 +332,9 @@ get_report().log(section="Chat SFT", data=[
"Training rows": len(train_ds),
"Number of iterations": num_iterations,
"Training loss": train_loss_item,
- "Validation loss": val_loss,
+ "Validation loss": last_val_loss,
+ "Total tokens processed": total_tokens_seen,
+ "Total sequences processed": total_sequences_seen,
},
])
diff --git a/scripts/mid_train.py b/scripts/mid_train.py
index eedb262..858d9c4 100644
--- a/scripts/mid_train.py
+++ b/scripts/mid_train.py
@@ -48,6 +48,7 @@ eval_every = 150 # -1 = disable
eval_tokens = 20*524288
total_batch_size = 524288
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
+checkpoint_every_steps = 0 # save intermediate checkpoints every N optimization steps (0 = disable)
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
@@ -63,7 +64,19 @@ get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else l
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
-wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config)
+if use_dummy_wandb:
+ wandb_run = DummyWandb()
+else:
+ wandb_kwargs = {
+ "project": os.environ.get("WANDB_PROJECT", "nanochat-mid"),
+ "name": run,
+ "config": user_config,
+ "reinit": True,
+ }
+ wandb_id = os.environ.get("WANDB_RUN_ID")
+ if wandb_id:
+ wandb_kwargs.update({"id": wandb_id, "resume": "allow"})
+ wandb_run = wandb.init(**wandb_kwargs)
# Load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
@@ -83,6 +96,10 @@ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
token_bytes = get_token_bytes(device=device)
+sequences_per_step = max(1, total_batch_size // max_seq_len)
+checkpoint_every_steps = int(checkpoint_every_steps)
+checkpoint_enabled = checkpoint_every_steps > 0
+
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
adamw_optimizer, muon_optimizer = optimizers
@@ -159,6 +176,9 @@ train_loader = mid_data_generator("train")
build_val_loader = lambda: mid_data_generator("val")
progress = 0 # will go from 0 to 1 over the course of the epoch
+checkpoint_dirname = f"d{depth}"
+checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", checkpoint_dirname)
+
# Learning rate scheduler
def get_lr_multiplier(progress):
# first 80% of training: no decay, then linearly ramp down to 0.
@@ -177,6 +197,37 @@ min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
+tokens_per_step = total_batch_size
+total_tokens_seen = 0
+total_sequences_seen = 0
+last_val_bpb = None
+
+def save_mid_checkpoint(step_idx):
+ if dry_run:
+ return
+ meta = {
+ "step": step_idx,
+ "val_bpb": last_val_bpb,
+ "model_config": {
+ "sequence_len": max_seq_len,
+ "vocab_size": tokenizer.get_vocab_size(),
+ "n_layer": depth,
+ "n_head": model.config.n_head,
+ "n_kv_head": model.config.n_kv_head,
+ "n_embd": model.config.n_embd,
+ },
+ "user_config": user_config,
+ "total_tokens_seen": total_tokens_seen,
+ "total_sequences_seen": total_sequences_seen,
+ }
+ optimizer_state = [opt.state_dict() for opt in optimizers]
+ save_checkpoint(
+ checkpoint_dir,
+ step_idx,
+ orig_model.state_dict(),
+ optimizer_state,
+ meta,
+ )
step = 0
while True:
flops_so_far = num_flops_per_token * total_batch_size * step
@@ -197,37 +248,20 @@ while True:
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb:
min_val_bpb = val_bpb
+ last_val_bpb = val_bpb
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"val/bpb": val_bpb,
+ "train/total_tokens": total_tokens_seen,
+ "train/total_sequences": total_sequences_seen,
})
model.train()
# save checkpoint at the end of the run (only on master process)
if master_process and last_step and not dry_run:
- output_dirname = f"d{depth}" # e.g. d12
- checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
- save_checkpoint(
- checkpoint_dir,
- step,
- orig_model.state_dict(),
- [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
- {
- "step": step,
- "val_bpb": val_bpb, # loss at last step
- "model_config": {
- "sequence_len": max_seq_len,
- "vocab_size": tokenizer.get_vocab_size(),
- "n_layer": depth,
- "n_head": model.config.n_head,
- "n_kv_head": model.config.n_kv_head,
- "n_embd": model.config.n_embd,
- },
- "user_config": user_config, # inputs to the training script
- }
- )
+ save_mid_checkpoint(step)
if last_step:
break
@@ -258,12 +292,17 @@ while True:
model.zero_grad(set_to_none=True)
synchronize()
t1 = time.time()
+ total_tokens_seen += tokens_per_step
+ total_sequences_seen += sequences_per_step
dt = t1 - t0
# -------------------------------------------------------------------------
# State
step += 1
+ if master_process and checkpoint_enabled and not last_step and checkpoint_every_steps > 0 and step % checkpoint_every_steps == 0:
+ save_mid_checkpoint(step)
+
# logging
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
@@ -285,6 +324,8 @@ while True:
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
+ "train/total_tokens": total_tokens_seen,
+ "train/total_sequences": total_sequences_seen,
})
# print a few more stats
@@ -303,6 +344,9 @@ if not dry_run:
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
+ "Final validation bpb": last_val_bpb,
+ "Total tokens processed": total_tokens_seen,
+ "Total sequences processed": total_sequences_seen,
}
])