diff --git a/.gitignore b/.gitignore index 4a87b23..2249a95 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,6 @@ __pycache__/ rustbpe/target/ dev-ignore/ report.md -eval_bundle/ \ No newline at end of file +eval_bundle/ +wandb/ +.runmps_wandb_ids diff --git a/dev/runcpu.sh b/dev/runcpu.sh old mode 100644 new mode 100755 diff --git a/dev/runmps.sh b/dev/runmps.sh new file mode 100755 index 0000000..da8e216 --- /dev/null +++ b/dev/runmps.sh @@ -0,0 +1,336 @@ +#!/bin/bash + +# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks) +# Run as: +# bash dev/cpu_demo_run.sh + +# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook. +# Think of this run as educational/fun demo, not something you should expect to work well. +# This is also why I hide this script away in dev/ + +# Stage selection (allow running only a subset, e.g. --stage=sft or --from=mid) +RUN_BASE=1 +RUN_MID=1 +RUN_SFT=1 +RUN_REPORT=1 +STAGE_ONLY="" +FROM_STAGE="" + +while [[ $# -gt 0 ]]; do + case "$1" in + --stage=*) + STAGE_ONLY="${1#*=}" + ;; + --base|--mid|--sft|--report) + STAGE_ONLY="${1#--}" + ;; + --from=*) + FROM_STAGE="${1#*=}" + ;; + --from-base|--from-mid|--from-sft) + FROM_STAGE="${1#--from-}" + ;; + --help|-h) + cat <<'EOF' +Usage: bash dev/runmps.sh [options] + +Options: + --stage= Run only the specified stage. + --from= Run from the specified stage through the end. + --help Show this help message. + +Environment variables (same as before) control batch sizes, WANDB run names, etc. +EOF + exit 0 + ;; + *) + echo "Unknown argument: $1" >&2 + exit 1 + ;; + esac + shift +done + +if [[ -n "$FROM_STAGE" ]]; then + RUN_BASE=0 + RUN_MID=0 + RUN_SFT=0 + RUN_REPORT=0 + case "$FROM_STAGE" in + base) + RUN_BASE=1 + RUN_MID=1 + RUN_SFT=1 + RUN_REPORT=1 + ;; + mid) + RUN_MID=1 + RUN_SFT=1 + ;; + sft) + RUN_SFT=1 + ;; + *) + echo "Unknown --from stage: $FROM_STAGE" >&2 + exit 1 + ;; + esac +fi + +if [[ -n "$STAGE_ONLY" ]]; then + RUN_BASE=0 + RUN_MID=0 + RUN_SFT=0 + RUN_REPORT=0 + case "$STAGE_ONLY" in + base) + RUN_BASE=1 + ;; + mid) + RUN_MID=1 + ;; + sft) + RUN_SFT=1 + ;; + report) + RUN_REPORT=1 + ;; + *) + echo "Unknown --stage value: $STAGE_ONLY" >&2 + exit 1 + ;; + esac +fi + +if [[ -n "$STAGE_ONLY" || -n "$FROM_STAGE" ]]; then + # avoid regenerating reports when running a subset unless specifically requested + if [[ "$STAGE_ONLY" != "report" && "$FROM_STAGE" != "base" ]]; then + RUN_REPORT=0 + fi +fi + +# all the setup stuff +export OMP_NUM_THREADS=1 +NANOCHAT_BASE_DIR="$HOME/.cache/nanochat" +mkdir -p $NANOCHAT_BASE_DIR +command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh +[ -d ".venv" ] || uv venv +uv sync --extra cpu +source .venv/bin/activate +if [ -z "$WANDB_RUN" ]; then + WANDB_RUN=dummy +fi +if [ "$WANDB_RUN" != "dummy" ] && [ -z "$WANDB_MODE" ]; then + export WANDB_MODE=online +fi + +# Batch/sequence configuration +BASE_DEPTH=${BASE_DEPTH:-4} +SEQ_LEN=${SEQ_LEN:-1024} +DEVICE_BATCH=${DEVICE_BATCH:-16} +TOTAL_BATCH=${TOTAL_BATCH:-$((DEVICE_BATCH * SEQ_LEN))} # tokens per optimizer step +EVAL_SEQUENCES=10000 +EVAL_STEPS=$(((EVAL_SEQUENCES + DEVICE_BATCH - 1) / DEVICE_BATCH)) +EVAL_BATCH_MULT=4 # evaluate on 4 full batches +EVAL_TOKENS=$((TOTAL_BATCH * EVAL_BATCH_MULT)) +MID_NUM_STEPS=6144 +SFT_NUM_STEPS=${SFT_NUM_STEPS:-3072} +CHECKPOINT_EVERY_SEQ=${CHECKPOINT_EVERY_SEQ:-10000} +RUN_STAGE_EVALS=${RUN_STAGE_EVALS:-0} +WANDB_PROJECT=${WANDB_PROJECT:-nanochat} +TARGET_PARAM_DATA_RATIO=${TARGET_PARAM_DATA_RATIO:-20} +SFT_DEVICE_BATCH=${SFT_DEVICE_BATCH:-$DEVICE_BATCH} +SFT_TARGET_EXAMPLES=${SFT_TARGET_EXAMPLES:-$DEVICE_BATCH} +SFT_EVAL_EVERY=${SFT_EVAL_EVERY:-0} +SFT_EVAL_STEPS=${SFT_EVAL_STEPS:-0} +SFT_EVAL_METRICS_EVERY=${SFT_EVAL_METRICS_EVERY:-0} +SFT_EVAL_METRICS_MAX=${SFT_EVAL_METRICS_MAX:-0} + +STATE_FILE=".runmps_wandb_ids" +printf '' > "$STATE_FILE" +echo "WANDB_PROJECT=$WANDB_PROJECT" >> "$STATE_FILE" + +generate_wandb_id() { + python - <<'PY' +import uuid +print(uuid.uuid4().hex[:8]) +PY +} + +BASE_SEQS_PER_STEP=$((TOTAL_BATCH / SEQ_LEN)) +if [ $BASE_SEQS_PER_STEP -le 0 ]; then BASE_SEQS_PER_STEP=1; fi +if [ "$CHECKPOINT_EVERY_SEQ" -le 0 ]; then + BASE_CHECKPOINT_STEPS=0 +else + BASE_CHECKPOINT_STEPS=$(((CHECKPOINT_EVERY_SEQ + BASE_SEQS_PER_STEP - 1) / BASE_SEQS_PER_STEP)) +fi + +MID_SEQS_PER_STEP=$((TOTAL_BATCH / SEQ_LEN)) +if [ $MID_SEQS_PER_STEP -le 0 ]; then MID_SEQS_PER_STEP=1; fi +if [ "$CHECKPOINT_EVERY_SEQ" -le 0 ]; then + MID_CHECKPOINT_STEPS=0 +else + MID_CHECKPOINT_STEPS=$(((CHECKPOINT_EVERY_SEQ + MID_SEQS_PER_STEP - 1) / MID_SEQS_PER_STEP)) +fi + +SFT_SEQS_PER_STEP=$SFT_DEVICE_BATCH +if [ $SFT_SEQS_PER_STEP -le 0 ]; then SFT_SEQS_PER_STEP=1; fi +if [ "$CHECKPOINT_EVERY_SEQ" -le 0 ]; then + SFT_CHECKPOINT_STEPS=0 +else + SFT_CHECKPOINT_STEPS=$(((CHECKPOINT_EVERY_SEQ + SFT_SEQS_PER_STEP - 1) / SFT_SEQS_PER_STEP)) +fi + +# Auto-populate WANDB_API_KEY from ~/.netrc when talking to a local W&B server. +# Mirrors the helper used in TinyRecursiveModels/pretrain_text.py so we can log +# to a self-hosted instance without manual export each time. +if [ -z "$WANDB_API_KEY" ] && [ -f "$HOME/.netrc" ]; then + # Allow custom WANDB_BASE_URL; default to localhost if user sets WANDB + WANDB_BASE_URL_DEFAULT=${WANDB_BASE_URL:-http://localhost:8080} + if printf '%s' "$WANDB_BASE_URL_DEFAULT" | grep -q "localhost"; then + HOST=$(printf '%s\n' "$WANDB_BASE_URL_DEFAULT" | sed -E 's#https?://##' | cut -d/ -f1) + API_KEY=$(python - "$HOST" <<'PY' +import sys +from netrc import netrc + +host = sys.argv[1] +auth = netrc().authenticators(host) +if auth and auth[2]: + print(auth[2], end="") +PY +) + if [ -n "$API_KEY" ]; then + export WANDB_BASE_URL="$WANDB_BASE_URL_DEFAULT" + export WANDB_API_KEY="$API_KEY" + echo "[runmps] Loaded WANDB_API_KEY for $WANDB_BASE_URL from ~/.netrc" + fi + fi +fi +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +source "$HOME/.cargo/env" +uv run maturin develop --release --manifest-path rustbpe/Cargo.toml +EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip +if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then + curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL + unzip -q eval_bundle.zip + rm eval_bundle.zip + mv eval_bundle $NANOCHAT_BASE_DIR +fi + +if (( RUN_BASE )); then + # wipe the report + python -m nanochat.report reset + +# train tokenizer on ~2B characters (download full shard set for extended training) +python -m nanochat.dataset -n 240 +python -m scripts.tok_train --max_chars=2000000000 +python -m scripts.tok_eval + + # train a very small 4 layer model on the CPU + # each optimization step processes a single sequence of 1024 tokens + # we only run 50 steps of optimization (bump this to get better results) + if [ "$WANDB_RUN" != "dummy" ]; then + BASE_WANDB_ID=${BASE_WANDB_ID:-$(generate_wandb_id)} + echo "BASE_WANDB_ID=$BASE_WANDB_ID" >> "$STATE_FILE" + echo "BASE_WANDB_NAME=$WANDB_RUN" >> "$STATE_FILE" + export WANDB_RUN_ID=$BASE_WANDB_ID + export WANDB_EVAL_RUN=$WANDB_RUN + export WANDB_PROJECT + fi + + python -m scripts.base_train \ + --depth=$BASE_DEPTH \ + --max_seq_len=$SEQ_LEN \ + --device_batch_size=$DEVICE_BATCH \ + --total_batch_size=$TOTAL_BATCH \ + --target_param_data_ratio=$TARGET_PARAM_DATA_RATIO \ + --run="$WANDB_RUN" \ + --eval_every=$EVAL_STEPS \ + --eval_tokens=$EVAL_TOKENS \ + --core_metric_every=-1 \ + --sample_every=-1 \ + --checkpoint_every_steps=$BASE_CHECKPOINT_STEPS + + if [ "$WANDB_RUN" != "dummy" ]; then + unset WANDB_RUN_ID + unset WANDB_EVAL_RUN + fi + + if [ "$RUN_STAGE_EVALS" = "1" ]; then + python -m scripts.base_loss --device_batch_size=$DEVICE_BATCH --split_tokens=$EVAL_TOKENS + python -m scripts.base_eval --max-per-task=16 + fi +fi + +if (( RUN_MID )); then + # midtraining + if [ "$WANDB_RUN" != "dummy" ]; then + MID_WANDB_ID=${MID_WANDB_ID:-$(generate_wandb_id)} + echo "MID_WANDB_ID=$MID_WANDB_ID" >> "$STATE_FILE" + echo "MID_WANDB_NAME=${WANDB_RUN}-mid" >> "$STATE_FILE" + export WANDB_RUN_ID=$MID_WANDB_ID + export WANDB_EVAL_RUN="${WANDB_RUN}-mid" + export WANDB_PROJECT + fi + + python -m scripts.mid_train \ + --max_seq_len=$SEQ_LEN \ + --device_batch_size=$DEVICE_BATCH \ + --total_batch_size=$TOTAL_BATCH \ + --run="${WANDB_RUN}-mid" \ + --eval_every=$EVAL_STEPS \ + --eval_tokens=$EVAL_TOKENS \ + --checkpoint_every_steps=$MID_CHECKPOINT_STEPS \ + --num_iterations=$MID_NUM_STEPS + if [ "$WANDB_RUN" != "dummy" ]; then + unset WANDB_RUN_ID + unset WANDB_EVAL_RUN + fi + if [ "$RUN_STAGE_EVALS" = "1" ]; then + # eval results will be terrible, this is just to execute the code paths. + # note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems + python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20 + fi +fi + +if (( RUN_SFT )); then + # SFT + if [ "$WANDB_RUN" != "dummy" ]; then + SFT_WANDB_ID=${SFT_WANDB_ID:-$(generate_wandb_id)} + echo "SFT_WANDB_ID=$SFT_WANDB_ID" >> "$STATE_FILE" + echo "SFT_WANDB_NAME=${WANDB_RUN}-sft" >> "$STATE_FILE" + export WANDB_RUN_ID=$SFT_WANDB_ID + export WANDB_EVAL_RUN="${WANDB_RUN}-sft" + export WANDB_PROJECT + fi + + python -m scripts.chat_sft \ + --device_batch_size=$SFT_DEVICE_BATCH \ + --target_examples_per_step=$SFT_TARGET_EXAMPLES \ + --run="${WANDB_RUN}-sft" \ + --num_iterations=$SFT_NUM_STEPS \ + --eval_every=$SFT_EVAL_EVERY \ + --eval_steps=$SFT_EVAL_STEPS \ + --eval_metrics_every=$SFT_EVAL_METRICS_EVERY \ + --eval_metrics_max_problems=$SFT_EVAL_METRICS_MAX \ + --checkpoint_every_steps=$SFT_CHECKPOINT_STEPS + + if [ "$WANDB_RUN" != "dummy" ]; then + unset WANDB_RUN_ID + unset WANDB_EVAL_RUN + fi +fi + +# Chat CLI +# python -m scripts.chat_cli -p "Why is the sky blue?" + +# Chat Web +# python -m scripts.chat_web + +if (( RUN_REPORT )); then + python -m nanochat.report generate +fi + +if [ "$RUN_STAGE_EVALS" != "1" ] && (( RUN_BASE || RUN_MID || RUN_SFT )); then + echo "[runmps] Inline evals disabled. Run 'bash dev/runmps_evals.sh' to compute metrics from saved checkpoints." +fi diff --git a/dev/runmps_evals.sh b/dev/runmps_evals.sh new file mode 100755 index 0000000..cd98fc8 --- /dev/null +++ b/dev/runmps_evals.sh @@ -0,0 +1,223 @@ +#!/bin/bash + +set -euo pipefail +shopt -s nullglob + +# Evaluate every checkpoint generated by dev/runmps.sh and log back to the +# original W&B runs. +# Usage: +# bash dev/runmps_evals.sh + +STATE_FILE=".runmps_wandb_ids" + +if [ ! -d ".venv" ]; then + echo "[runmps_evals] .venv not found. Run dev/runmps.sh first." >&2 + exit 1 +fi + +source .venv/bin/activate + +if [ -f "$STATE_FILE" ]; then + . "$STATE_FILE" +fi + +export OMP_NUM_THREADS=${OMP_NUM_THREADS:-1} +export NANOCHAT_BASE_DIR=${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat} +export WANDB_PROJECT=${WANDB_PROJECT:-nanochat} +WANDB_RUN=${WANDB_RUN:-dummy} +if [ "$WANDB_RUN" != "dummy" ] && [ -z "${WANDB_MODE:-}" ]; then + export WANDB_MODE=online +fi + +# Auto-populate WANDB_API_KEY from ~/.netrc when talking to a local W&B server, +# mirroring the training script behaviour so evals can log without manual export. +if [ -z "${WANDB_API_KEY:-}" ] && [ -f "$HOME/.netrc" ]; then + WANDB_BASE_URL_DEFAULT=${WANDB_BASE_URL:-http://localhost:8080} + if printf '%s' "$WANDB_BASE_URL_DEFAULT" | grep -q "localhost"; then + HOST=$(printf '%s\n' "$WANDB_BASE_URL_DEFAULT" | sed -E 's#https?://##' | cut -d/ -f1) + API_KEY=$(python - "$HOST" <<'PY' +import sys +from netrc import netrc + +host = sys.argv[1] +auth = netrc().authenticators(host) +if auth and auth[2]: + print(auth[2], end="") +PY +) + if [ -n "$API_KEY" ]; then + export WANDB_BASE_URL="$WANDB_BASE_URL_DEFAULT" + export WANDB_API_KEY="$API_KEY" + echo "[runmps_evals] Loaded WANDB_API_KEY for $WANDB_BASE_URL from ~/.netrc" + fi + fi +fi + +generate_wandb_id() { + python - <<'PY' +import uuid +print(uuid.uuid4().hex[:8]) +PY +} + +set_eval_run() { + local prefix="$1" + local default_name="$2" + local id_var="${prefix}_EVAL_WANDB_ID" + local name_var="${prefix}_EVAL_WANDB_NAME" + local current_id=$(generate_wandb_id) + local current_name="$default_name" + eval "$id_var=$current_id" + eval "$name_var='$current_name'" +} + +SEQ_LEN=${SEQ_LEN:-1024} +DEVICE_BATCH=${DEVICE_BATCH:-16} +TOTAL_BATCH=$((DEVICE_BATCH * SEQ_LEN)) +EVAL_SEQUENCES=10000 +EVAL_BATCH_MULT=4 +EVAL_TOKENS=$((TOTAL_BATCH * EVAL_BATCH_MULT)) + +checkpoint_steps() { + local dir="$1" + python - "$dir" <<'PY' +import os, sys +dir_path = sys.argv[1] +if not os.path.isdir(dir_path): + sys.exit(0) +steps = [] +for fname in os.listdir(dir_path): + if fname.startswith("model_") and fname.endswith(".pt"): + try: + step = int(fname.split("_")[1].split(".")[0]) + except ValueError: + continue + steps.append(step) +for step in sorted(steps): + print(step) +PY +} + +checkpoint_tag() { + local dir="$1" + python - "$dir" <<'PY' +import os, sys +from nanochat.checkpoint_manager import find_largest_model +base_dir = sys.argv[1] +if not os.path.isdir(base_dir): + sys.exit(0) +try: + tag = find_largest_model(base_dir) +except FileNotFoundError: + sys.exit(0) +print(tag) +PY +} + +run_base_evals() { + local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/base_checkpoints") + if [ -z "$tag" ]; then + return + fi + local dir="$NANOCHAT_BASE_DIR/base_checkpoints/$tag" + local steps="$(checkpoint_steps "$dir")" + if [ -z "$steps" ]; then + return + fi + echo "[runmps_evals] Running base eval on checkpoints in $dir..." + set_eval_run BASE "${WANDB_RUN:-base}-eval" + export WANDB_RUN_ID=$BASE_EVAL_WANDB_ID + export WANDB_EVAL_RUN=$BASE_EVAL_WANDB_NAME + export WANDB_PROJECT + while read -r step; do + [ -z "$step" ] && continue + echo " • base checkpoint step $step" + # run base_loss to print sample completions, but keep it off W&B + saved_wandb_run_id=${WANDB_RUN_ID:-} + saved_wandb_eval_run=${WANDB_EVAL_RUN:-} + unset WANDB_RUN_ID + unset WANDB_EVAL_RUN + python -m scripts.base_loss \ + --device_batch_size=$DEVICE_BATCH \ + --split_tokens=$EVAL_TOKENS \ + --model_step=$step + if [ -n "$saved_wandb_run_id" ]; then + export WANDB_RUN_ID=$saved_wandb_run_id + fi + if [ -n "$saved_wandb_eval_run" ]; then + export WANDB_EVAL_RUN=$saved_wandb_eval_run + fi + python -m scripts.base_eval \ + --max-per-task=16 \ + --model-step=$step + done <<< "$steps" + unset WANDB_RUN_ID + unset WANDB_EVAL_RUN +} + +run_mid_evals() { + local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/mid_checkpoints") + if [ -z "$tag" ]; then + return + fi + local dir="$NANOCHAT_BASE_DIR/mid_checkpoints/$tag" + local steps="$(checkpoint_steps "$dir")" + if [ -z "$steps" ]; then + return + fi + echo "[runmps_evals] Running mid-stage chat eval on checkpoints in $dir..." + set_eval_run MID "${WANDB_RUN:-mid}-mid-eval" + export WANDB_RUN_ID=$MID_EVAL_WANDB_ID + export WANDB_EVAL_RUN=$MID_EVAL_WANDB_NAME + export WANDB_PROJECT + while read -r step; do + [ -z "$step" ] && continue + echo " • mid checkpoint step $step" + python -m scripts.chat_eval \ + --source=mid \ + --step=$step \ + --batch-size=$DEVICE_BATCH \ + --max-new-tokens=128 \ + --max-problems=20 + done <<< "$steps" + unset WANDB_RUN_ID + unset WANDB_EVAL_RUN +} + +run_sft_evals() { + local tag=$(checkpoint_tag "$NANOCHAT_BASE_DIR/chatsft_checkpoints") + if [ -z "$tag" ]; then + return + fi + local dir="$NANOCHAT_BASE_DIR/chatsft_checkpoints/$tag" + local steps="$(checkpoint_steps "$dir")" + if [ -z "$steps" ]; then + return + fi + echo "[runmps_evals] Running SFT chat eval on checkpoints in $dir..." + set_eval_run SFT "${WANDB_RUN:-sft}-sft-eval" + export WANDB_RUN_ID=$SFT_EVAL_WANDB_ID + export WANDB_EVAL_RUN=$SFT_EVAL_WANDB_NAME + export WANDB_PROJECT + while read -r step; do + [ -z "$step" ] && continue + echo " • SFT checkpoint step $step" + python -m scripts.chat_eval \ + --source=sft \ + --step=$step \ + --batch-size=$DEVICE_BATCH \ + --max-new-tokens=128 \ + --max-problems=20 + done <<< "$steps" + unset WANDB_RUN_ID + unset WANDB_EVAL_RUN +} + +run_base_evals +run_mid_evals +run_sft_evals + +echo "[runmps_evals] Regenerating report with fresh metrics..." +python -m nanochat.report generate + +echo "[runmps_evals] Done." diff --git a/scripts/base_eval.py b/scripts/base_eval.py index fc02120..f6f002e 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -19,8 +19,9 @@ from contextlib import nullcontext import pandas as pd import torch +import wandb -from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type +from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, DummyWandb from nanochat.tokenizer import HuggingFaceTokenizer from nanochat.checkpoint_manager import load_model from nanochat.core_eval import evaluate_task @@ -123,12 +124,15 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate') parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)') + parser.add_argument('--model-step', type=int, default=None, help='Checkpoint step to evaluate when using local models') args = parser.parse_args() # distributed / precision setup device_type = autodetect_device_type() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() + wandb_run = DummyWandb() + use_wandb = bool(os.environ.get("WANDB_RUN_ID")) and args.hf_path is None and ddp_rank == 0 # Load model and tokenizer from command line or from file system if args.hf_path is not None: @@ -140,24 +144,32 @@ def main(): model_slug = hf_path.replace("/", "-") # for the output csv file else: # load a local model from the file system - model, tokenizer, meta = load_model("base", device, phase="eval") + model, tokenizer, meta = load_model("base", device, phase="eval", step=args.model_step) model_name = f"base_model (step {meta['step']})" # just for logging model_slug = f"base_model_{meta['step']:06d}" # for the output csv file + if use_wandb: + wandb_kwargs = { + "project": os.environ.get("WANDB_PROJECT", "nanochat"), + "name": os.environ.get("WANDB_EVAL_RUN", model_name), + "id": os.environ.get("WANDB_RUN_ID"), + "resume": "allow", + "reinit": True, + } + wandb_kwargs = {k: v for k, v in wandb_kwargs.items() if v is not None} + wandb_run = wandb.init(**wandb_kwargs) # Evaluate the model with autocast_ctx: out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task) # Write out the results to a csv file - core_metric = None - centered_results = {} + results = out["results"] + centered_results = out["centered_results"] + core_metric = out["core_metric"] if ddp_rank == 0: base_dir = get_base_dir() output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv") os.makedirs(os.path.dirname(output_csv_path), exist_ok=True) - results = out["results"] - centered_results = out["centered_results"] - core_metric = out["core_metric"] with open(output_csv_path, 'w') as f: f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n") for label in results: @@ -180,6 +192,15 @@ def main(): centered_results, # the full table ]) + if use_wandb: + wandb_payload = {"core/metric": core_metric} + wandb_payload.update({f"core/{k}": v for k, v in centered_results.items()}) + wandb_payload.update({f"core_raw/{k}": v for k, v in results.items()}) + for key, value in wandb_payload.items(): + wandb_run.summary[key] = value + wandb_run.log(wandb_payload, step=meta["step"]) + wandb_run.finish() + compute_cleanup() if __name__ == "__main__": diff --git a/scripts/base_loss.py b/scripts/base_loss.py index abcde5f..7e85077 100644 --- a/scripts/base_loss.py +++ b/scripts/base_loss.py @@ -9,6 +9,9 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_loss import os from contextlib import nullcontext import torch +import wandb + +from nanochat.common import DummyWandb from nanochat.checkpoint_manager import load_model from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type from nanochat.dataloader import tokenizing_distributed_data_loader @@ -36,6 +39,19 @@ tokens_per_step = device_batch_size * sequence_len * ddp_world_size assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step" steps = split_tokens // tokens_per_step token_bytes = get_token_bytes(device=device) +use_wandb = bool(os.environ.get("WANDB_RUN_ID")) +wandb_run = DummyWandb() +if use_wandb: + wandb_kwargs = { + "project": os.environ.get("WANDB_PROJECT", "nanochat"), + "name": os.environ.get("WANDB_EVAL_RUN", "base-eval"), + "id": os.environ.get("WANDB_RUN_ID"), + "resume": "allow", + "reinit": True, + } + wandb_kwargs = {k: v for k, v in wandb_kwargs.items() if v is not None} + wandb_run = wandb.init(**wandb_kwargs) + bpb_results = {} for split_name in ["train", "val"]: loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device) @@ -63,7 +79,7 @@ if ddp_rank == 0: sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) sample_str = tokenizer.decode(sample[0]) print0(sample_str) - samples.append(sample_str) + samples.append(sample_str) # Log to report from nanochat.report import get_report @@ -75,5 +91,12 @@ get_report().log(section="Base model loss", data=[ {f"sample {i}": sample for i, sample in enumerate(samples)}, ]) +if use_wandb: + wandb_run.log({ + "base_loss/train_bpb": bpb_results["train"], + "base_loss/val_bpb": bpb_results["val"], + }, step=meta.get("step")) + wandb_run.finish() + # Cleanup compute_cleanup() diff --git a/scripts/base_train.py b/scripts/base_train.py index 3725805..ee394ca 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -60,6 +60,7 @@ core_metric_max_per_task = 500 # examples per task in estimating the core metric sample_every = 2000 # every how many steps to sample from the model # Output model_tag = "" # optionally override the model tag for the output checkpoint directory name +checkpoint_every_steps = 0 # save intermediate checkpoints every N optimization steps (0 = disable) # now allow CLI to override the settings via the configurator lol config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file @@ -76,7 +77,19 @@ get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else l # wandb logging init use_dummy_wandb = run == "dummy" or not master_process -wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config) +if use_dummy_wandb: + wandb_run = DummyWandb() +else: + wandb_kwargs = { + "project": os.environ.get("WANDB_PROJECT", "nanochat"), + "name": run, + "config": user_config, + "reinit": True, + } + wandb_id = os.environ.get("WANDB_RUN_ID") + if wandb_id: + wandb_kwargs.update({"id": wandb_id, "resume": "allow"}) + wandb_run = wandb.init(**wandb_kwargs) # Tokenizer will be useful for evaluation, also we need the vocab size tokenizer = get_tokenizer() @@ -138,6 +151,10 @@ print0(f"Total number of training tokens: {total_tokens:,}") print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20 print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") +sequences_per_step = max(1, total_batch_size // max_seq_len) +checkpoint_every_steps = int(checkpoint_every_steps) +checkpoint_enabled = checkpoint_every_steps > 0 + # ----------------------------------------------------------------------------- # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) @@ -150,6 +167,10 @@ train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device) x, y = next(train_loader) # kick off load of the very first batch of data +# Checkpoint output location +checkpoint_dirname = model_tag if model_tag else f"d{depth}" +checkpoint_dir = os.path.join(base_dir, "base_checkpoints", checkpoint_dirname) + # ----------------------------------------------------------------------------- # Set up hyperparameter schedulers @@ -177,6 +198,33 @@ min_val_bpb = float("inf") smooth_train_loss = 0 # EMA of training loss ema_beta = 0.9 # EMA decay factor total_training_time = 0 # total wall-clock time of training +# Keep track of total tokens/sequences processed for logging +tokens_per_step = total_batch_size +total_tokens_seen = 0 +total_sequences_seen = 0 +last_val_bpb = None + +def save_base_checkpoint(step_idx): + output_dirname = model_tag if model_tag else f"d{depth}" + checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname) + optimizer_state = [opt.state_dict() for opt in optimizers] + meta = { + "step": step_idx, + "val_bpb": last_val_bpb, + "model_config": model_config_kwargs, + "user_config": user_config, + "device_batch_size": device_batch_size, + "max_seq_len": max_seq_len, + "total_tokens_seen": total_tokens_seen, + "total_sequences_seen": total_sequences_seen, + } + save_checkpoint( + checkpoint_dir, + step_idx, + orig_model.state_dict(), + optimizer_state, + meta, + ) # note that we run +1 steps only so that we can eval and save at the end for step in range(num_iterations + 1): last_step = step == num_iterations @@ -192,11 +240,14 @@ for step in range(num_iterations + 1): print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") if val_bpb < min_val_bpb: min_val_bpb = val_bpb + last_val_bpb = val_bpb wandb_run.log({ "step": step, "total_training_flops": flops_so_far, "total_training_time": total_training_time, "val/bpb": val_bpb, + "train/total_tokens": total_tokens_seen, + "train/total_sequences": total_sequences_seen, }) model.train() @@ -213,12 +264,14 @@ for step in range(num_iterations + 1): "total_training_flops": flops_so_far, "core_metric": results["core_metric"], "centered_results": results["centered_results"], + "train/total_tokens": total_tokens_seen, + "train/total_sequences": total_sequences_seen, }) model.train() # once in a while: sample from the model (only on master process) # use the original uncompiled model because the inputs keep changing shape - if master_process and (last_step or (step > 0 and step % sample_every == 0)): + if master_process and sample_every > 0 and (last_step or (step > 0 and step % sample_every == 0)): model.eval() prompts = [ "The capital of France is", @@ -239,22 +292,7 @@ for step in range(num_iterations + 1): # save checkpoint at the end of the run (only on master process) if master_process and last_step: - output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12 - checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname) - save_checkpoint( - checkpoint_dir, - step, - orig_model.state_dict(), - [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly - { - "step": step, - "val_bpb": val_bpb, # loss at last step - "model_config": model_config_kwargs, - "user_config": user_config, # inputs to the training script - "device_batch_size": device_batch_size, - "max_seq_len": max_seq_len, - } - ) + save_base_checkpoint(step) if last_step: break @@ -287,6 +325,11 @@ for step in range(num_iterations + 1): model.zero_grad(set_to_none=True) synchronize() t1 = time.time() + total_tokens_seen += tokens_per_step + total_sequences_seen += sequences_per_step + current_step = step + 1 + if master_process and checkpoint_enabled and not last_step and checkpoint_every_steps > 0 and current_step % checkpoint_every_steps == 0: + save_base_checkpoint(current_step) dt = t1 - t0 # ------------------------------------------------------------------------- @@ -311,6 +354,8 @@ for step in range(num_iterations + 1): "train/dt": dt, "train/tok_per_sec": tok_per_sec, "train/mfu": mfu, + "train/total_tokens": total_tokens_seen, + "train/total_sequences": total_sequences_seen, }) # print a few more stats @@ -335,12 +380,14 @@ get_report().log(section="Base model training", data=[ }, { # stats about training outcomes "Minimum validation bpb": min_val_bpb, - "Final validation bpb": val_bpb, + "Final validation bpb": last_val_bpb, "CORE metric estimate": results.get("core_metric", None), "MFU %": f"{mfu:.2f}%", "Total training flops": f"{flops_so_far:e}", "Total training time": f"{total_training_time/60:.2f}m", "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB", + "Total tokens processed": total_tokens_seen, + "Total sequences processed": total_sequences_seen, } ]) diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index c77a89e..cea1061 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -1,6 +1,6 @@ """ Evaluate the Chat model. -All the generic code lives here, and all the evlauation-specific +All the generic code lives here, and all the evaluation-specific code lives in nanochat directory and is imported from here. Example runs: @@ -9,13 +9,15 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy """ import argparse +import os from functools import partial from contextlib import nullcontext import torch import torch.distributed as dist +import wandb -from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type +from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type, DummyWandb from nanochat.checkpoint_manager import load_model from nanochat.engine import Engine @@ -201,9 +203,21 @@ if __name__ == "__main__": ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() + wandb_run = DummyWandb() + use_wandb = bool(os.environ.get("WANDB_RUN_ID")) and ddp_rank == 0 model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) engine = Engine(model, tokenizer) + if use_wandb: + wandb_kwargs = { + "project": os.environ.get("WANDB_PROJECT", "nanochat"), + "name": os.environ.get("WANDB_EVAL_RUN", f"{args.source}-eval"), + "id": os.environ.get("WANDB_RUN_ID"), + "resume": "allow", + "reinit": True, + } + wandb_kwargs = {k: v for k, v in wandb_kwargs.items() if v is not None} + wandb_run = wandb.init(**wandb_kwargs) # Get the tasks to evaluate on all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee'] @@ -254,4 +268,11 @@ if __name__ == "__main__": chatcore_metric_dict, ]) + if use_wandb: + wandb_payload = {f"chat_eval/{task}": acc for task, acc in results.items()} + if chatcore_metric_dict: + wandb_payload.update({"chat_eval/chatcore": chatcore_metric_dict["ChatCORE metric"]}) + wandb_run.log(wandb_payload, step=meta.get("step")) + wandb_run.finish() + compute_cleanup() diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index e6e4565..25ba0fa 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -55,10 +55,21 @@ eval_every = 100 eval_steps = 100 eval_metrics_every = 200 eval_metrics_max_problems = 1024 +checkpoint_every_steps = 0 # save intermediate checkpoints every N optimization steps (0 = disable) # now allow CLI to override the settings via the configurator lol config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging + +# Normalize evaluation knobs: <=0 disables inline evaluation +if eval_every <= 0: + eval_every = None +if eval_steps <= 0: + eval_steps = 0 +if eval_metrics_every <= 0: + eval_metrics_every = None +if eval_metrics_max_problems <= 0: + eval_metrics_max_problems = 0 # ----------------------------------------------------------------------------- # Compute init @@ -70,13 +81,27 @@ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if dev # wandb logging init use_dummy_wandb = run == "dummy" or not master_process -wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True) +if use_dummy_wandb: + wandb_run = DummyWandb() +else: + wandb_kwargs = { + "project": os.environ.get("WANDB_PROJECT", "nanochat-sft"), + "name": run, + "config": user_config, + "save_code": True, + "reinit": True, + } + wandb_id = os.environ.get("WANDB_RUN_ID") + if wandb_id: + wandb_kwargs.update({"id": wandb_id, "resume": "allow"}) + wandb_run = wandb.init(**wandb_kwargs) # Load the model and tokenizer model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step) orig_model = model # original, uncompiled model # model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs engine = Engine(model, tokenizer) # will be used for inline model evaluation only +model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer # ----------------------------------------------------------------------------- # Task data mixture we'll train on @@ -142,6 +167,14 @@ if num_iterations == -1: train_loader = sft_data_generator(train_ds, batch_size=device_batch_size) build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size) +sequences_per_step = target_examples_per_step +checkpoint_every_steps = int(checkpoint_every_steps) +checkpoint_enabled = checkpoint_every_steps > 0 + +base_dir = get_base_dir() +checkpoint_dirname = model_tag if model_tag else f"d{model.config.n_layer}" +checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", checkpoint_dirname) + # ----------------------------------------------------------------------------- # Initialize the Optimizer @@ -168,11 +201,34 @@ def get_lr_multiplier(it): # Go! step = 0 train_iter = iter(train_loader) +total_tokens_seen = 0 +total_sequences_seen = 0 +last_val_loss = None +last_eval_metrics = {} + +def save_sft_checkpoint(step_idx): + meta = { + "step": step_idx, + "val_loss": last_val_loss, + **last_eval_metrics, + "model_config": model_config_kwargs, + "total_tokens_seen": total_tokens_seen, + "total_sequences_seen": total_sequences_seen, + } + save_checkpoint( + checkpoint_dir, + step_idx, + model.state_dict(), + None, + meta, + ) + for step in range(num_iterations): last_step = step == num_iterations - 1 - # evaluate the validation loss - if last_step or step % eval_every == 0: + # evaluate the validation loss (if enabled) + run_val = eval_every is not None and eval_steps > 0 and (last_step or step % eval_every == 0) + if run_val: model.eval() val_iter = iter(build_val_loader()) losses = [] @@ -186,14 +242,20 @@ for step in range(num_iterations): dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks val_loss = val_loss.item() print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}") + last_val_loss = val_loss wandb_run.log({ "step": step, "val_loss": val_loss, + "train/total_tokens": total_tokens_seen, + "train/total_sequences": total_sequences_seen, }) model.train() - # evlauate accuracy of the multiple choice tasks (which are quick to run) - if last_step or (step > 0 and step % eval_metrics_every == 0): + # evaluate accuracy of the multiple choice tasks (if enabled) + run_metrics = eval_metrics_every is not None and eval_metrics_max_problems > 0 and ( + last_step or (step > 0 and step % eval_metrics_every == 0) + ) + if run_metrics: model.eval() metrics = {} with torch.no_grad(), autocast_ctx: @@ -202,9 +264,12 @@ for step in range(num_iterations): metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems) metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items()) print0(f"Step {step:05d} | {metrics_str}") + last_eval_metrics = metrics.copy() wandb_run.log({ "step": step, **metrics, + "train/total_tokens": total_tokens_seen, + "train/total_sequences": total_sequences_seen, }) model.train() @@ -238,34 +303,25 @@ for step in range(num_iterations): # logging train_loss_item = train_loss.item() num_tokens_item = num_tokens.item() + total_tokens_seen += num_tokens_item + total_sequences_seen += sequences_per_step + current_step = step + 1 + if master_process and checkpoint_enabled and not last_step and current_step % checkpoint_every_steps == 0: + save_sft_checkpoint(current_step) print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}") wandb_run.log({ "step": step, "lrm": lrm, "train_loss": train_loss_item, "num_tokens": num_tokens_item, + "train/total_tokens": total_tokens_seen, + "train/total_sequences": total_sequences_seen, }) step += 1 # Save the model at the end of the run if master_process: - base_dir = get_base_dir() - depth = model.config.n_layer - model_tag = f"d{depth}" # base the model tag on the depth of the base model - checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag) - model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer - save_checkpoint( - checkpoint_dir, - step, - model.state_dict(), - None, # note: we don't bother to save the optimizer state - { - "step": step, - "val_loss": val_loss, - **metrics, - "model_config": model_config_kwargs, - } - ) + save_sft_checkpoint(step) print(f"✅ Saved model checkpoint to {checkpoint_dir}") # Log to report @@ -276,7 +332,9 @@ get_report().log(section="Chat SFT", data=[ "Training rows": len(train_ds), "Number of iterations": num_iterations, "Training loss": train_loss_item, - "Validation loss": val_loss, + "Validation loss": last_val_loss, + "Total tokens processed": total_tokens_seen, + "Total sequences processed": total_sequences_seen, }, ]) diff --git a/scripts/mid_train.py b/scripts/mid_train.py index eedb262..858d9c4 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -48,6 +48,7 @@ eval_every = 150 # -1 = disable eval_tokens = 20*524288 total_batch_size = 524288 dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report +checkpoint_every_steps = 0 # save intermediate checkpoints every N optimization steps (0 = disable) config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging @@ -63,7 +64,19 @@ get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else l # wandb logging init use_dummy_wandb = run == "dummy" or not master_process -wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config) +if use_dummy_wandb: + wandb_run = DummyWandb() +else: + wandb_kwargs = { + "project": os.environ.get("WANDB_PROJECT", "nanochat-mid"), + "name": run, + "config": user_config, + "reinit": True, + } + wandb_id = os.environ.get("WANDB_RUN_ID") + if wandb_id: + wandb_kwargs.update({"id": wandb_id, "resume": "allow"}) + wandb_run = wandb.init(**wandb_kwargs) # Load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step) @@ -83,6 +96,10 @@ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") token_bytes = get_token_bytes(device=device) +sequences_per_step = max(1, total_batch_size // max_seq_len) +checkpoint_every_steps = int(checkpoint_every_steps) +checkpoint_enabled = checkpoint_every_steps > 0 + # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) adamw_optimizer, muon_optimizer = optimizers @@ -159,6 +176,9 @@ train_loader = mid_data_generator("train") build_val_loader = lambda: mid_data_generator("val") progress = 0 # will go from 0 to 1 over the course of the epoch +checkpoint_dirname = f"d{depth}" +checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", checkpoint_dirname) + # Learning rate scheduler def get_lr_multiplier(progress): # first 80% of training: no decay, then linearly ramp down to 0. @@ -177,6 +197,37 @@ min_val_bpb = float("inf") smooth_train_loss = 0 # EMA of training loss ema_beta = 0.9 # EMA decay factor total_training_time = 0 # total wall-clock time of training +tokens_per_step = total_batch_size +total_tokens_seen = 0 +total_sequences_seen = 0 +last_val_bpb = None + +def save_mid_checkpoint(step_idx): + if dry_run: + return + meta = { + "step": step_idx, + "val_bpb": last_val_bpb, + "model_config": { + "sequence_len": max_seq_len, + "vocab_size": tokenizer.get_vocab_size(), + "n_layer": depth, + "n_head": model.config.n_head, + "n_kv_head": model.config.n_kv_head, + "n_embd": model.config.n_embd, + }, + "user_config": user_config, + "total_tokens_seen": total_tokens_seen, + "total_sequences_seen": total_sequences_seen, + } + optimizer_state = [opt.state_dict() for opt in optimizers] + save_checkpoint( + checkpoint_dir, + step_idx, + orig_model.state_dict(), + optimizer_state, + meta, + ) step = 0 while True: flops_so_far = num_flops_per_token * total_batch_size * step @@ -197,37 +248,20 @@ while True: print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") if val_bpb < min_val_bpb: min_val_bpb = val_bpb + last_val_bpb = val_bpb wandb_run.log({ "step": step, "total_training_flops": flops_so_far, "total_training_time": total_training_time, "val/bpb": val_bpb, + "train/total_tokens": total_tokens_seen, + "train/total_sequences": total_sequences_seen, }) model.train() # save checkpoint at the end of the run (only on master process) if master_process and last_step and not dry_run: - output_dirname = f"d{depth}" # e.g. d12 - checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname) - save_checkpoint( - checkpoint_dir, - step, - orig_model.state_dict(), - [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly - { - "step": step, - "val_bpb": val_bpb, # loss at last step - "model_config": { - "sequence_len": max_seq_len, - "vocab_size": tokenizer.get_vocab_size(), - "n_layer": depth, - "n_head": model.config.n_head, - "n_kv_head": model.config.n_kv_head, - "n_embd": model.config.n_embd, - }, - "user_config": user_config, # inputs to the training script - } - ) + save_mid_checkpoint(step) if last_step: break @@ -258,12 +292,17 @@ while True: model.zero_grad(set_to_none=True) synchronize() t1 = time.time() + total_tokens_seen += tokens_per_step + total_sequences_seen += sequences_per_step dt = t1 - t0 # ------------------------------------------------------------------------- # State step += 1 + if master_process and checkpoint_enabled and not last_step and checkpoint_every_steps > 0 and step % checkpoint_every_steps == 0: + save_mid_checkpoint(step) + # logging smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA @@ -285,6 +324,8 @@ while True: "train/dt": dt, "train/tok_per_sec": tok_per_sec, "train/mfu": mfu, + "train/total_tokens": total_tokens_seen, + "train/total_sequences": total_sequences_seen, }) # print a few more stats @@ -303,6 +344,9 @@ if not dry_run: }, { # stats about training outcomes "Minimum validation bpb": min_val_bpb, + "Final validation bpb": last_val_bpb, + "Total tokens processed": total_tokens_seen, + "Total sequences processed": total_sequences_seen, } ])