diff --git a/README.md b/README.md index 89d2ce2..800c5d9 100644 --- a/README.md +++ b/README.md @@ -1,35 +1,62 @@ # nanochat ![nanochat logo](dev/nanochat.png) +![scaling laws](dev/scaling_laws_jan26.png) -> The best ChatGPT that $100 can buy. +nanochat is the simplest experimental harness for training LLMs. It is designed to run on a single GPU node, the code is minimal/hackable, and it covers all major LLM stages including tokenization, pretraining, finetuning, evaluation, inference, and a chat UI. For example, you can train your own GPT-2 capability LLM (which cost ~$50,000 to train in 2019) for only $73 (3 hours of 8XH100 GPU node) and then talk to it in a familiar ChatGPT-like web UI. -This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](runs/speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs. +For questions about the repo, I recommend either using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions about the repo, or use the [Discussions tab](https://github.com/karpathy/nanochat/discussions), or come by the [#nanochat](https://discord.com/channels/1020383067459821711/1427295580895314031) channel on Discord. ## Updates -- (Jan 16 2026) The repo is in active development, I am currently fleshing out the pretraining stage. -- (Jan 7 2026) See new post: [nanochat Miniseries v1](https://github.com/karpathy/nanochat/discussions/420) and the associated script [miniseries.sh](runs/miniseries.sh). +- (Jan 31 2026) Major revamp of all scripts/README ongoing, deleting midtraining stage, might be a bit messy briefly... +- (Jan 30 2026) With all the latest improvements we're able to train GPT-2 grade LLM in about $73. The [runs/speedrun.sh](runs/speedrun.sh) script will become the refernece way to train GPT-2 grade model and talk to it. -## Talk to it +## Leaderboard -To get a sense of the endpoint of this repo, you can currently find [nanochat d34](https://github.com/karpathy/nanochat/discussions/314) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). This model is now a few months old but it still gives a rough idea of the intelligence you can achieve for approximately $1000. While this model easily outperforms GPT-2 of 2019, it falls dramatically short of modern Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to... +| # | Record time | Description | Date | Commit | Contributors | +|---|-------------|-------------|------|--------|--------------| +| 1 | 3.04 hours | d24 baseline, slightly overtrained | Jan 29 2026 | 348fbb3 | @karpathy | -## Quick start +The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. In 2019, the training of GPT-2 cost approximately $50,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so in 3 hours or less, for ~$73 and below. Once your repo is set up (see the [runs/speedrun.sh](runs/speedrun.sh) script for reference), e.g. the way I kicked off the jan29 run is as follows: -The fastest way to feel the magic is to run the speedrun script [speedrun.sh](runs/speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script: +``` +OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ + --depth=24 \ + --run=d24-jan29 \ + --model-tag=d24_jan29 \ + --device-batch-size=16 \ + --sample-every=-1 \ + --save-every=-1 \ + --core-metric-max-per-task=-1 \ + --core-metric-every=3000 \ + --target-param-data-ratio=12 +``` + +After 3 hours we get output like this: + +``` +... +wandb: Run summary: +wandb: core_metric 0.25851 +wandb: step 16704 +wandb: total_training_flops 4.330784131228946e+19 +wandb: total_training_time 10949.46713 +``` + +The GPT-2 CORE score (i.e. the target to beat) is 0.256525. So we see that this d24 CORE score is higher (0.25851). Then we look at the `total_training_time`, which is the time of the training iterations alone, excluding all the evaluations and logging, in seconds. We get: `10949/60/60 ~= 3.04` hours, the current record. + +## Getting started + +### Reproduce and talk to GPT-2 + +The most fun you can have is to train your own GPT-2 and talk to it. The entire pipeline to do so is contained in the single file [runs/speedrun.sh](runs/speedrun.sh), which is designed to be run on an 8XH100 GPU node. Currently, at ~$24/hour for these nodes, pretraining GPT-2 grade model takes approximately 3 hours and will set you back about $75. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script: ```bash bash runs/speedrun.sh ``` -Alternatively, since the script runs for 4 hours, I like to launch it like this inside a new screen session `speedrun` (and also log output to `speedrun.log`): - -```bash -screen -L -Logfile speedrun.log -S speedrun bash runs/speedrun.sh -``` - -See the [screen cheatsheet](https://gist.github.com/jctosta/af918e1618682638aa82) if you are less familiar. You can watch it go inside the screen session, or detach with `Ctrl-a d` and `tail speedrun.log` to view progress. Now wait 4 hours. Once it's done, you can talk to your LLM via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it: +You mish to do so in a screen session as this will take ~3 hours to run. Once it's done, you can talk to it via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it: ```bash python -m scripts.chat_web @@ -43,84 +70,43 @@ And then visit the URL shown. Make sure to access it correctly, e.g. on Lambda u --- -You can also `cat report.md` file which appeared in the project directory and contains the "report card" of the run, i.e. a bunch of evaluations and metrics. At the very end, you'll see a summary table, for example: - ---- - -- Characters: 333,989 -- Lines: 8,304 -- Files: 44 -- Tokens (approx): 83,497 -- Dependencies (uv.lock lines): 2,004 - -| Metric | BASE | MID | SFT | RL | -|-----------------|----------|----------|----------|----------| -| CORE | 0.2219 | - | - | - | -| ARC-Challenge | - | 0.2875 | 0.2807 | - | -| ARC-Easy | - | 0.3561 | 0.3876 | - | -| GSM8K | - | 0.0250 | 0.0455 | 0.0758 | -| HumanEval | - | 0.0671 | 0.0854 | - | -| MMLU | - | 0.3111 | 0.3151 | - | -| ChatCORE | - | 0.0730 | 0.0884 | - | - -Total wall clock time: 3h51m - ---- - -(Your table might be missing the RL number by default). For a lot more information around the speedrun script and what to look for and expect, please refer to the walkthrough that I posted in Discussions of the repo: ["Introducing nanochat: The best ChatGPT that $100 can buy"](https://github.com/karpathy/nanochat/discussions/1). - -## Bigger models - -Unsurprisingly, $100 is not enough to train a highly performant ChatGPT clone. In fact, LLMs are famous for their multi-million dollar capex. For our purposes, I think there are two more scales of interest. First is the ~$300 tier d26 model (i.e. depth=26) that trains in ~12 hours, which slightly outperforms GPT-2 CORE score. Second is the $1000 tier (~41.6 hours), just because it's a nice round number. But both of these are not yet fully supported and therefore not attached here in the master branch yet. - -That said, to give a sense, the example changes needed for the [speedrun.sh](runs/speedrun.sh) file to train a GPT-2 grade model d26 only involve three changes: - -```bash -... -# you'll need to download more data shards for pretraining -# get the number of parameters, multiply 20 to get tokens, multiply by 4.8 to get chars, -# divide by 250 million to get number of shards. todo need to improve this... -python -m nanochat.dataset -n 450 & -... -# use --depth to increase model size. to not oom, halve device batch size 32 -> 16: -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --device-batch-size=16 -... -# make sure to use the same later during midtraining: -torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16 -``` - -That's it! The biggest thing to pay attention to is making sure you have enough data shards to train on (the code will loop and do more epochs over the same training set otherwise, decreasing learning speed a bit), and managing your memory/VRAM, primarily by decreasing the `device_batch_size` until things fit (the scripts automatically compensate by increasing the number of gradient accumulation loops, simply turning parallel compute to sequential compute). - -And a bit more about computing environments that will run nanochat: +A few more notes: - The code will run just fine on the Ampere 8XA100 GPU node as well, but a bit slower. - All code will run just fine on even a single GPU by omitting `torchrun`, and will produce ~identical results (code will automatically switch to gradient accumulation), but you'll have to wait 8 times longer. - If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device_batch_size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative. -- Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't implemented this out of the box so it might take a bit of tinkering. +- Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't personally exercised all of these code paths so there might be sharp edges. + +## Research + +If you are a researcher and wish to help improve nanochat, two scripts of interest are [runs/scaling_laws.sh](runs/scaling_laws.sh) and [runs/miniseries.sh](runs/miniseries.sh). See [Jan 7 miniseries v1](https://github.com/karpathy/nanochat/discussions/420) for related documentation. For quick experimentation (~5 min pretraining runs) my favorite scale is to train a 12-layer model (GPT-1 sized), e.g. like this: + +``` +OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ + --depth=12 \ + --run="d12" \ + --model-tag="d12" \ + --core-metric-every=999999 \ + --sample-every=-1 \ + --save-every=-1 \ +``` + +This uses wandb (run name "d12"), only runs the CORE metric on last step, and it doesn't sample and save intermediate checkpoints. I like to change something in the code, re-run a d12 (or a d16 etc) and see if it helped, in an iteration loop. + +The overall approach is to treat the depth of the model as the single dial of complexity. By sweeping out the depth, we get increasingly more powerful models. We determine the scaling laws, set the data budget to a compute optimal setting, train a whole miniseries of models of increasing sizes, and compare them to the GPT-2 and GPT-3 miniseries. Right now, beating GPT-2 specifically faster and faster is the most interesting target. ## Running on CPU / MPS -nanochat can be run on CPU or on MPS (if you're on Macbook) in principle, and will automatically try to detect what device is best to run on. The script [runcpu.sh](runs/runcpu.sh) shows a very simple example that will exercise the code paths but basically produce garbage results. Unless you know what you're doing, I basically don't recommend using this script right now and hope to tune it a bit more in the future. +The script [runs/runcpu.sh](runs/runcpu.sh) shows a very simple example of running on CPU or Apple Silicon. It dramatically shrinks the LLM tha tis being trained to make things fit into a reasonable time interval of a few ten minutes of training. You will not get strong results in this way. -## Customization +## Guides -To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into midtraining and SFT stages. +I've published a number of guides that might contain helpful information: -Additionally, to add new abilities to nanochat, see [Guide: counting r in strawberry (and how to add abilities generally)](https://github.com/karpathy/nanochat/discussions/164). - -## Questions - -I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off. - -You can also come to the [#nanochat Discord channel](https://discord.com/channels/1020383067459821711/1427295580895314031) to ask questions, or use the Discussions. - -## Tests - -I haven't invested too much here but some tests exist, especially for the tokenizer. Run e.g. as: - -```bash -python -m pytest tests/test_engine.py -v -s -``` +- [Oct 13 2025 original nanochat post](https://github.com/karpathy/nanochat/discussions/1) introducing nanochat, though now it contains some deprecated information and the model is a lot older (with worse results) than current master. +- [Jan 7 miniseries v1](https://github.com/karpathy/nanochat/discussions/420) documents the first nanochat miniseries of models. +- To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into the SFT stage. +- To add new abilities to nanochat, see [Guide: counting r in strawberry (and how to add abilities generally)](https://github.com/karpathy/nanochat/discussions/164). ## File structure @@ -159,12 +145,11 @@ python -m pytest tests/test_engine.py -v -s │ ├── base_eval.py # Base model: calculate CORE score │ ├── base_loss.py # Base model: calculate bits per byte, sample │ ├── base_train.py # Base model: train -│ ├── chat_cli.py # Chat model (SFT/Mid): talk to over CLI -│ ├── chat_eval.py # Chat model (SFT/Mid): eval tasks -│ ├── chat_rl.py # Chat model (SFT/Mid): reinforcement learning +│ ├── chat_cli.py # Chat model: talk to over CLI +│ ├── chat_eval.py # Chat model: eval tasks +│ ├── chat_rl.py # Chat model: reinforcement learning │ ├── chat_sft.py # Chat model: train SFT -│ ├── chat_web.py # Chat model (SFT/Mid): talk to over WebUI -│ ├── mid_train.py # Chat model: midtraining +│ ├── chat_web.py # Chat model: talk to over WebUI │ ├── tok_eval.py # Tokenizer: evaluate compression rate │ └── tok_train.py # Tokenizer: train it ├── tasks @@ -183,9 +168,9 @@ python -m pytest tests/test_engine.py -v -s ## Contributing -nanochat is nowhere near finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card. +The goal of nanochat is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there are no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a ChatGPT model you can talk to. Currently, the most interesting part personally is speeding up the latency to GPT-2 (i.e. getting a CORE score above 0.256525). Currently this takes ~3 hours, but by improving the pretraining stage we can improve this further. -Current LLM policy: disclosure. When submitting a PR, please declare any parts that had substantial LLM contribution and that you have not written or that you do not fully understand. +Current AI policy: disclosure. When submitting a PR, please declare any parts that had substantial LLM contribution and that you have not written or that you do not fully understand. ## Acknowledgements diff --git a/dev/scaling_laws_jan26.png b/dev/scaling_laws_jan26.png new file mode 100644 index 0000000..e8d1f72 Binary files /dev/null and b/dev/scaling_laws_jan26.png differ diff --git a/runs/runcpu.sh b/runs/runcpu.sh index a35c336..f383726 100755 --- a/runs/runcpu.sh +++ b/runs/runcpu.sh @@ -45,9 +45,9 @@ python -m scripts.base_train \ python -m scripts.base_loss --device-batch-size=1 --split-tokens=16384 python -m scripts.base_eval --max-per-task=16 -# midtraining (~10 minutes on my MacBook Pro M3 Max) +# SFT (~10 minutes on my MacBook Pro M3 Max) curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl -python -m scripts.mid_train \ +python -m scripts.chat_sft \ --max-seq-len=512 \ --device-batch-size=32 \ --total-batch-size=16384 \ @@ -56,13 +56,11 @@ python -m scripts.mid_train \ --num-iterations=1500 \ --run=$WANDB_RUN -# (it's ~ok to skip SFT) - # Chat with the model over CLI # The model should be able to say that it is Paris. # It might even know that the color of the sky is blue. # Sometimes the model likes it if you first say Hi before you ask it questions. -# python -m scripts.chat_cli -i mid -p "What is the capital of France?" +# python -m scripts.chat_cli -p "What is the capital of France?" # Chat with the model over a pretty WebUI ChatGPT style -# python -m scripts.chat_web -i mid +# python -m scripts.chat_web diff --git a/runs/speedrun.sh b/runs/speedrun.sh index ef4fa00..a9612c0 100644 --- a/runs/speedrun.sh +++ b/runs/speedrun.sh @@ -1,14 +1,14 @@ #!/bin/bash -# This script is the "Best ChatGPT clone that $100 can buy", -# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour. +# This script is configured to train your own GPT-2 grade LLM (pretraining + finetuning) +# It is designed to run on a blank 8XH100 GPU node and takes approximately 3 hours to complete. # 1) Example launch (simplest): -# bash speedrun.sh -# 2) Example launch in a screen session (because the run takes ~4 hours): -# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh +# bash runs/speedrun.sh +# 2) Example launch in a screen session (because the run takes ~3 hours): +# screen -L -Logfile runs/speedrun.log -S speedrun bash runs/speedrun.sh # 3) Example launch with wandb logging, but see below for setting up wandb first: -# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh +# WANDB_RUN=speedrun screen -L -Logfile runs/speedrun.log -S speedrun bash runs/speedrun.sh # Default intermediate artifacts directory is in ~/.cache/nanochat export OMP_NUM_THREADS=1 @@ -49,13 +49,14 @@ python -m nanochat.report reset # Tokenizer # Download the first ~2B characters of pretraining dataset -# look at dev/repackage_data_reference.py for details on how this data was prepared # each data shard is ~250M chars # so we download 2e9 / 250e6 = 8 data shards at this point # each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk +# look at dev/repackage_data_reference.py for details on how this data was prepared python -m nanochat.dataset -n 8 # Immediately also kick off downloading more shards in the background while tokenizer trains -# See comment below for why 370 is the right number here +# Approximately 350 shards are needed for 10B tokens of data for pretraining. +# The maximum total number of shards available in the entire dataset is 1822. python -m nanochat.dataset -n 370 & DATASET_DOWNLOAD_PID=$! # train the tokenizer with vocab size 2**15 = 32768 on ~2B characters of data @@ -65,43 +66,27 @@ python -m scripts.tok_eval # ----------------------------------------------------------------------------- # Base model (pretraining) - -# The d20 model is 561M parameters. -# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens. -# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars. -# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining. -# Round up to 240 for safety. Also, the new DataLoader wastes about 35% of tokens to cropping -# so 240 / (1 - 0.35) = 370 shards are needed. -# At ~100MB/shard, this downloads ~37GB of data to disk. -# (The total number of shards available in the entire dataset is 1822.) echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID # Number of processes/GPUs to use NPROC_PER_NODE=8 -# pretrain the d20 model -torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target-param-data-ratio=20 --run=$WANDB_RUN +# d24 model (slightly overtrained is enough to beat GPT-2 => increase data:params ratio from compute optimal 10.5 (default) to 12) +torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=24 --target-param-data-ratio=12 --run=$WANDB_RUN # evaluate the model on a larger chunk of train/val data and draw some samples torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss # evaluate the model on CORE tasks torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval # ----------------------------------------------------------------------------- -# Midtraining (teach the model conversation special tokens, tool use, multiple choice) +# SFT (teach the model conversation special tokens, tool use, multiple choice) # download 2.3MB of synthetic identity conversations to impart a personality to nanochat # see dev/gen_synthetic_data.py for details on how this data was prepared and to get a sense of how you can easily tune it curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl -# run midtraining and eval the model -torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN -torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid - -# ----------------------------------------------------------------------------- -# Supervised Finetuning (domain adaptation to each sequence all by itself per row) - -# train sft and re-eval right away (should see a small bump) +# run SFT and eval the model torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft @@ -111,15 +96,6 @@ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- - # even better, chat with your model over a pretty WebUI ChatGPT style # python -m scripts.chat_web -# ----------------------------------------------------------------------------- -# Reinforcement Learning. Optional, and currently only on GSM8K -# (optional) - -# run reinforcement learning -# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN -# eval the RL model only on GSM8K -# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K - # ----------------------------------------------------------------------------- # Generate the full report by putting together all the sections # report.md is the output and will be copied to current directory for convenience diff --git a/scripts/chat_cli.py b/scripts/chat_cli.py index b14843a..d35c435 100644 --- a/scripts/chat_cli.py +++ b/scripts/chat_cli.py @@ -2,7 +2,7 @@ New and upgraded chat mode because a lot of the code has changed since the last one. Intended to be run single GPU only atm: -python -m scripts.chat_cli -i mid +python -m scripts.chat_cli """ import argparse import torch diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index a558303..cae2f0f 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -4,8 +4,8 @@ All the generic code lives here, and all the evaluation-specific code lives in nanochat directory and is imported from here. Example runs: -python -m scripts.chat_eval -i mid -a ARC-Easy -torchrun --nproc_per_node=8 -m scripts.chat_eval -- -i mid -a ARC-Easy +python -m scripts.chat_eval -a ARC-Easy +torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy """ import argparse diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index c0471c4..91300b6 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -1,65 +1,63 @@ """ -Finetune a base model to be a chat model. -Run on one GPU e.g. for debugging: +Supervised fine-tuning (SFT) the model. +Run as: python -m scripts.chat_sft Or torchrun for training: -torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft +torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16 """ import argparse import os os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" - +import time import wandb import torch -import torch.distributed as dist from contextlib import nullcontext - -from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type -from nanochat.checkpoint_manager import load_model +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type +from nanochat.tokenizer import get_token_bytes from nanochat.checkpoint_manager import save_checkpoint -from nanochat.engine import Engine -from scripts.chat_eval import run_chat_eval +from nanochat.loss_eval import evaluate_bpb +from nanochat.checkpoint_manager import load_model +import torch.distributed as dist from tasks.common import TaskMixture -from tasks.arc import ARC from tasks.gsm8k import GSM8K +from tasks.mmlu import MMLU from tasks.smoltalk import SmolTalk from tasks.customjson import CustomJSON from tasks.spellingbee import SimpleSpelling, SpellingBee # ----------------------------------------------------------------------------- # CLI arguments -parser = argparse.ArgumentParser(description="Supervised finetuning for chat") +parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the model") # Logging parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") # Model loading -parser.add_argument("--source", type=str, default="mid", help="base|mid - which checkpoint to load from") parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from") # Training horizon -parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs") -parser.add_argument("--num-iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)") +parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") # Batch sizes -parser.add_argument("--device-batch-size", type=int, default=4, help="per-device batch size") -parser.add_argument("--target-examples-per-step", type=int, default=32, help="target examples per optimization step") +parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") +parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") +parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") # Optimization parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") -parser.add_argument("--init-lr-frac", type=float, default=0.02, help="initial LR as fraction of base LR") +parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR") # Evaluation -parser.add_argument("--eval-every", type=int, default=100, help="evaluate val loss every N steps") -parser.add_argument("--eval-steps", type=int, default=100, help="number of batches for val loss evaluation") -parser.add_argument("--eval-metrics-every", type=int, default=200, help="evaluate accuracy metrics every N steps") -parser.add_argument("--eval-metrics-max-problems", type=int, default=1024, help="max problems per metric evaluation") +parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)") +parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") +# Output +parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- @@ -70,217 +68,320 @@ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type master_process = ddp_rank == 0 ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() +synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None +get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 # wandb logging init use_dummy_wandb = args.run == "dummy" or not master_process -wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config, save_code=True) +wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config) # Load the model and tokenizer -model, tokenizer, meta = load_model(args.source, device, phase="train", model_tag=args.model_tag, step=args.model_step) -orig_model = model # original, uncompiled model -# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs -engine = Engine(model, tokenizer) # will be used for inline model evaluation only +model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step) +pretrain_batch_size = meta.get("device_batch_size", None) +if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size: + print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?") +orig_model = model +model = torch.compile(model, dynamic=False) +depth = model.config.n_layer +num_flops_per_token = model.estimate_flops() +tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank +world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks +assert args.total_batch_size % world_tokens_per_fwdbwd == 0 +grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd +print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") +print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") +print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") +token_bytes = get_token_bytes(device=device) -# ----------------------------------------------------------------------------- -# Task data mixture we'll train on -identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl") -train_ds = TaskMixture([ - ARC(subset="ARC-Easy", split="train"), # 2.3K rows - ARC(subset="ARC-Challenge", split="train"), # 1.1K rows - GSM8K(subset="main", split="train"), # 8K rows - SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk - CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations - SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple') - SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) -]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows -val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it) - -# ----------------------------------------------------------------------------- -# DataLoader - -def sft_data_generator(dataset, batch_size): - pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss - # prepares a list of tokenized conversations into a batch and yields - def collate_and_yield(batch): - nrows = len(batch) - ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1 - inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long) - targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index - for i, (ids, mask) in enumerate(batch): - n = len(ids) - ids_tensor = torch.tensor(ids, dtype=torch.long) - inputs[i, :n-1] = ids_tensor[:-1] - # recall -1 is the ignore index, so mask out targets where mask is 0 - row_targets = ids_tensor[1:] - # mask[1:] omits the mask for the BOS token, which is never a target atm so it's ok - mask_tensor = torch.tensor(mask[1:], dtype=torch.long) - row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0 - targets[i, :n-1] = row_targets - inputs = inputs.to(device) # move to device - targets = targets.to(device) - return inputs, targets - # iterates over the dataset in epochs, tokenizes - batch = [] - while True: - for i in range(ddp_rank, len(dataset), ddp_world_size): - doc = dataset[i] - ids, mask = tokenizer.render_conversation(doc) - batch.append((ids, mask)) - if len(batch) == batch_size: - yield collate_and_yield(batch) - batch = [] - -examples_per_step = args.device_batch_size * ddp_world_size -print0(f"Target examples per step: {args.target_examples_per_step}") -print0(f"Device batch size: {args.device_batch_size}") -print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}") -assert args.target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step" -grad_accum_steps = args.target_examples_per_step // examples_per_step -print0(f"=> Setting grad accum steps: {grad_accum_steps}") - -if args.num_iterations == -1: - # derive num_iterations from num_epochs and the size of the dataset - assert args.num_epochs > 0, "num_epochs must be positive if num_iterations is -1" - num_iterations = (len(train_ds) // args.target_examples_per_step) * args.num_epochs -else: - num_iterations = args.num_iterations -train_loader = sft_data_generator(train_ds, batch_size=args.device_batch_size) -build_val_loader = lambda: sft_data_generator(val_ds, batch_size=args.device_batch_size) - -# ----------------------------------------------------------------------------- -# Initialize the Optimizer - -optimizer = model.setup_optimizer( - unembedding_lr=args.unembedding_lr, - embedding_lr=args.embedding_lr, - matrix_lr=args.matrix_lr, - weight_decay=args.weight_decay, -) -# Set the initial learning rate as a fraction of the base learning rate +# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) +optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay) +# Override the initial learning rate as a fraction of the base learning rate for group in optimizer.param_groups: group["lr"] = group["lr"] * args.init_lr_frac group["initial_lr"] = group["lr"] -# ----------------------------------------------------------------------------- -# Training loop +# SFT data mixture and DataLoader +base_dir = get_base_dir() +identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") +train_dataset = TaskMixture([ + SmolTalk(split="train"), # 460K rows of general conversations + MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE + GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use + GSM8K(subset="main", split="train"), # 2 epochs of GSM8K + CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations + CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these + SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') + SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) +]) # total: 460K + 100K + 16K + 200K + 80K = 856K rows +val_dataset = TaskMixture([ + SmolTalk(split="test"), # 24K rows in test set + MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios + GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios +]) # total: 24K + 14K + 1.32K ~= 39K rows +# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len) +# A big problem is that we don't know the final num_iterations in advance. So we create +# these two global variables and update them from within the data generator. +last_step = False # we will toggle this to True when we reach the end of the training dataset +approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch +current_epoch = 1 # track epoch for logging +def sft_data_generator_bos_bestfit(split, buffer_size=100): + """ + BOS-aligned dataloader for SFT with bestfit-pad packing. + + Each row in the batch starts with BOS (beginning of a conversation). + Conversations are packed using best-fit algorithm. When no conversation fits, + the row is padded (instead of cropping) to ensure no tokens are ever discarded. + Padding positions have targets masked with -1 (ignore_index for cross-entropy). + """ + global last_step, approx_progress, current_epoch + assert split in {"train", "val"}, "split must be 'train' or 'val'" + dataset = train_dataset if split == "train" else val_dataset + dataset_size = len(dataset) + assert dataset_size > 0 + row_capacity = args.max_seq_len + 1 # +1 for target at last position + bos_token = tokenizer.get_bos_token_id() + + # Conversation buffer: list of token lists + conv_buffer = [] + cursor = ddp_rank # Each rank processes different conversations (for fetching) + consumed = ddp_rank # Track actual consumption separately from buffering + epoch = 1 + it = 0 # iteration counter + + def refill_buffer(): + nonlocal cursor, epoch + while len(conv_buffer) < buffer_size: + conversation = dataset[cursor] + ids, _ = tokenizer.render_conversation(conversation) + conv_buffer.append(ids) + cursor += ddp_world_size + if cursor >= dataset_size: + cursor = cursor % dataset_size + epoch += 1 + # Note: last_step is now triggered based on consumption, not fetching + + while True: + rows = [] + row_lengths = [] # Track actual content length (excluding padding) for each row + for _ in range(args.device_batch_size): + row = [] + padded = False + while len(row) < row_capacity: + # Ensure buffer has conversations + while len(conv_buffer) < buffer_size: + refill_buffer() + + remaining = row_capacity - len(row) + + # Find largest conversation that fits entirely + best_idx = -1 + best_len = 0 + for i, conv in enumerate(conv_buffer): + conv_len = len(conv) + if conv_len <= remaining and conv_len > best_len: + best_idx = i + best_len = conv_len + + if best_idx >= 0: + # Found a conversation that fits - use it entirely + conv = conv_buffer.pop(best_idx) + row.extend(conv) + consumed += ddp_world_size # Track actual consumption + else: + # No conversation fits - pad the remainder instead of cropping + # This ensures we never discard any tokens + content_len = len(row) + row.extend([bos_token] * remaining) # Pad with BOS tokens + padded = True + break # Row is now full (with padding) + + # Track content length: full row if no padding, otherwise the length before padding + if padded: + row_lengths.append(content_len) + else: + row_lengths.append(row_capacity) + rows.append(row[:row_capacity]) + + # Stopping condition to respect num_iterations, if given + it += 1 + if 0 < args.num_iterations <= it and split == "train": + last_step = True + + # Update progress tracking (based on consumed, not cursor, to account for buffering) + if split == "train": + current_epoch = epoch + if args.num_iterations > 0: + approx_progress = it / args.num_iterations + else: + approx_progress = consumed / dataset_size + # Trigger last_step when we've consumed enough (instead of when cursor wraps) + if consumed >= dataset_size: + last_step = True + + # Build tensors + use_cuda = device_type == "cuda" + batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda) + inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda) + targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda) + + # Mask out padding positions in targets (set to -1 = ignore_index) + # For each row, positions >= (content_length - 1) in targets should be masked + for i, content_len in enumerate(row_lengths): + if content_len < row_capacity: + targets[i, content_len-1:] = -1 + + yield inputs, targets + +train_loader = sft_data_generator_bos_bestfit("train") +build_val_loader = lambda: sft_data_generator_bos_bestfit("val") +progress = 0 # will go from 0 to 1 over the course of the epoch # Learning rate scheduler -def get_lr_multiplier(it): - lrm = 1.0 - it / num_iterations - return lrm +def get_lr_multiplier(progress): + # first 80% of training: no decay, then linearly ramp down to 0. + return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2 -# Go! +# Momentum scheduler for Muon optimizer +def get_muon_momentum(it): + frac = min(it / 300, 1) + momentum = (1 - frac) * 0.85 + frac * 0.95 + return momentum + +# ----------------------------------------------------------------------------- +# Training loop +x, y = next(train_loader) # prefetch the very first batch of data +min_val_bpb = float("inf") +smooth_train_loss = 0 # EMA of training loss +ema_beta = 0.9 # EMA decay factor +total_training_time = 0 # total wall-clock time of training step = 0 -for step in range(num_iterations): - last_step = step == num_iterations - 1 +while True: + flops_so_far = num_flops_per_token * args.total_batch_size * step - # evaluate the validation loss - if last_step or step % args.eval_every == 0: + # Synchronize last_step across all ranks to avoid hangs in the distributed setting + if ddp: + last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device) + dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX) + last_step = bool(last_step_tensor.item()) + + # once in a while: evaluate the val bpb (all ranks participate) + if last_step or (args.eval_every > 0 and step % args.eval_every == 0): model.eval() val_loader = build_val_loader() - losses = [] - for _ in range(args.eval_steps): - val_inputs, val_targets = next(val_loader) - with torch.no_grad(), autocast_ctx: - loss = model(val_inputs, val_targets) - losses.append(loss) - val_loss = torch.stack(losses).mean() # average over eval_steps - if ddp: - dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks - val_loss = val_loss.item() - print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}") + eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) + with autocast_ctx: + val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) + print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") + if val_bpb < min_val_bpb: + min_val_bpb = val_bpb wandb_run.log({ "step": step, - "val_loss": val_loss, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "val/bpb": val_bpb, }) model.train() - # evaluate accuracy of the multiple choice tasks (which are quick to run) - if last_step or (step > 0 and step % args.eval_metrics_every == 0): - model.eval() - metrics = {} - with torch.no_grad(), autocast_ctx: - # note that because these are inside no_grad, we can usually afford to at least ~2X the batch size - metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=args.device_batch_size*2, max_problems=args.eval_metrics_max_problems) - metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=args.device_batch_size*2, max_problems=args.eval_metrics_max_problems) - metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items()) - print0(f"Step {step:05d} | {metrics_str}") - wandb_run.log({ - "step": step, - **metrics, - }) - model.train() + # save checkpoint at the end of the run (only on master process) + if master_process and last_step and not args.dry_run: + output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12 + checkpoint_dir = os.path.join(base_dir, "sft_checkpoints", output_dirname) + save_checkpoint( + checkpoint_dir, + step, + orig_model.state_dict(), + optimizer.state_dict(), + { + "step": step, + "val_bpb": val_bpb, # loss at last step + "model_config": { + "sequence_len": args.max_seq_len, + "vocab_size": tokenizer.get_vocab_size(), + "n_layer": depth, + "n_head": model.config.n_head, + "n_kv_head": model.config.n_kv_head, + "n_embd": model.config.n_embd, + }, + "user_config": user_config, # inputs to the training script + } + ) if last_step: break + # ------------------------------------------------------------------------- + # single training step # evaluate the gradient - num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen + synchronize() + t0 = time.time() for micro_step in range(grad_accum_steps): - train_inputs, train_targets = next(train_loader) with autocast_ctx: - loss = model(train_inputs, train_targets) + loss = model(x, y) train_loss = loss.detach() # for logging loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here - loss.backward() # accumulate the gradient - num_tokens += (train_targets >= 0).sum() - if ddp: - dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks - - # learning rate scheduler - lrm = get_lr_multiplier(step) + loss.backward() + x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward + progress = max(progress, approx_progress) # only increase progress monotonically + # step the optimizer + lrm = get_lr_multiplier(progress) + muon_momentum = get_muon_momentum(step) for group in optimizer.param_groups: group["lr"] = group["initial_lr"] * lrm - - # step the optimizer + if group['kind'] == 'muon': + group["momentum"] = muon_momentum optimizer.step() model.zero_grad(set_to_none=True) + synchronize() + t1 = time.time() + dt = t1 - t0 + # ------------------------------------------------------------------------- - # logging - train_loss_item = train_loss.item() - num_tokens_item = num_tokens.item() - print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}") - wandb_run.log({ - "step": step, - "lrm": lrm, - "train_loss": train_loss_item, - "num_tokens": num_tokens_item, - }) + # State step += 1 -# Save the model at the end of the run -if master_process: - base_dir = get_base_dir() - depth = model.config.n_layer - output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12 - checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname) - model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer - save_checkpoint( - checkpoint_dir, - step, - model.state_dict(), - None, # note: we don't bother to save the optimizer state - { + # logging + smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss + debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA + pct_done = 100 * progress + tok_per_sec = int(args.total_batch_size / dt) + flops_per_sec = num_flops_per_token * args.total_batch_size / dt + promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity + mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % + if step > 10: + total_training_time += dt # only count the time after the first 10 steps + print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m") + if step % 10 == 0: + wandb_run.log({ "step": step, - "val_loss": val_loss, - **metrics, - "model_config": model_config_kwargs, - } - ) - print(f"✅ Saved model checkpoint to {checkpoint_dir}") + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "train/loss": debiased_smooth_loss, + "train/lrm": lrm, + "train/dt": dt, + "train/tok_per_sec": tok_per_sec, + "train/mfu": mfu, + "train/epoch": current_epoch, + }) + +# print a few more stats +print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") +print0(f"Total training time: {total_training_time/60:.2f}m") +print0(f"Minimum validation bpb: {min_val_bpb:.4f}") # Log to report -from nanochat.report import get_report -get_report().log(section="Chat SFT", data=[ - user_config, # CLI args - { - "Training rows": len(train_ds), - "Number of iterations": num_iterations, - "Training loss": train_loss_item, - "Validation loss": val_loss, - }, -]) +if not args.dry_run: + from nanochat.report import get_report + get_report().log(section="SFT", data=[ + user_config, # CLI args + { # stats about the training setup + "Number of iterations": step, + "DDP world size": ddp_world_size, + }, + { # stats about training outcomes + "Minimum validation bpb": min_val_bpb, + } + ]) -# Cleanup -wandb_run.finish() +# cleanup +wandb_run.finish() # wandb run finish compute_cleanup() diff --git a/scripts/mid_train.py b/scripts/mid_train.py deleted file mode 100644 index 54c5fb0..0000000 --- a/scripts/mid_train.py +++ /dev/null @@ -1,386 +0,0 @@ -""" -Midtrain the model. Same as pretraining but simpler. -Run as: - -python -m scripts.mid_train - -Or torchrun for training: - -torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16 -""" - -import argparse -import os -os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" -import time -import wandb -import torch -from contextlib import nullcontext -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type -from nanochat.tokenizer import get_token_bytes -from nanochat.checkpoint_manager import save_checkpoint -from nanochat.loss_eval import evaluate_bpb -from nanochat.checkpoint_manager import load_model -import torch.distributed as dist - -from tasks.common import TaskMixture -from tasks.gsm8k import GSM8K -from tasks.mmlu import MMLU -from tasks.smoltalk import SmolTalk -from tasks.customjson import CustomJSON -from tasks.spellingbee import SimpleSpelling, SpellingBee - -# ----------------------------------------------------------------------------- -# CLI arguments -parser = argparse.ArgumentParser(description="Midtrain the model") -# Logging -parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") -# Runtime -parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") -parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") -# Model loading -parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") -parser.add_argument("--model-step", type=int, default=None, help="model step to load from") -# Training horizon -parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") -# Batch sizes -parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") -parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") -parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") -# Optimization -parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)") -parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") -parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") -parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") -parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR") -# Evaluation -parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)") -parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") -# Output -parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report") -args = parser.parse_args() -user_config = vars(args).copy() -# ----------------------------------------------------------------------------- - -# Compute init -device_type = autodetect_device_type() if args.device_type == "" else args.device_type -ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) -master_process = ddp_rank == 0 -ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() -synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None -get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 - -# wandb logging init -use_dummy_wandb = args.run == "dummy" or not master_process -wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=args.run, config=user_config) - -# Load the model and tokenizer -model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step) -pretrain_batch_size = meta.get("device_batch_size", None) -if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size: - print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?") -orig_model = model -model = torch.compile(model, dynamic=False) -depth = model.config.n_layer -num_flops_per_token = model.estimate_flops() -tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank -world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -assert args.total_batch_size % world_tokens_per_fwdbwd == 0 -grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd -print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") -print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") -print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") -token_bytes = get_token_bytes(device=device) - -# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) -optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay) -# Override the initial learning rate as a fraction of the base learning rate -for group in optimizer.param_groups: - group["lr"] = group["lr"] * args.init_lr_frac - group["initial_lr"] = group["lr"] - -# Midtraining data mixture and DataLoader -base_dir = get_base_dir() -identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") -train_dataset = TaskMixture([ - SmolTalk(split="train"), # 460K rows of general conversations - MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE - GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use - CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations - CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these - SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') - SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) -]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows -val_dataset = TaskMixture([ - SmolTalk(split="test"), # 24K rows in test set - MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios - GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios -]) # total: 24K + 14K + 1.32K ~= 39K rows -# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len) -# A big problem is that we don't know the final num_iterations in advance. So we create -# these two global variables and update them from within the data generator. -last_step = False # we will toggle this to True when we reach the end of the training dataset -approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch -current_epoch = 1 # track epoch for logging -def mid_data_generator_bos_bestfit(split, buffer_size=100): - """ - BOS-aligned dataloader for midtraining with bestfit-pad packing. - - Each row in the batch starts with BOS (beginning of a conversation). - Conversations are packed using best-fit algorithm. When no conversation fits, - the row is padded (instead of cropping) to ensure no tokens are ever discarded. - Padding positions have targets masked with -1 (ignore_index for cross-entropy). - """ - global last_step, approx_progress, current_epoch - assert split in {"train", "val"}, "split must be 'train' or 'val'" - dataset = train_dataset if split == "train" else val_dataset - dataset_size = len(dataset) - assert dataset_size > 0 - row_capacity = args.max_seq_len + 1 # +1 for target at last position - bos_token = tokenizer.get_bos_token_id() - - # Conversation buffer: list of token lists - conv_buffer = [] - cursor = ddp_rank # Each rank processes different conversations (for fetching) - consumed = ddp_rank # Track actual consumption separately from buffering - epoch = 1 - it = 0 # iteration counter - - def refill_buffer(): - nonlocal cursor, epoch - while len(conv_buffer) < buffer_size: - conversation = dataset[cursor] - ids, _ = tokenizer.render_conversation(conversation) - conv_buffer.append(ids) - cursor += ddp_world_size - if cursor >= dataset_size: - cursor = cursor % dataset_size - epoch += 1 - # Note: last_step is now triggered based on consumption, not fetching - - while True: - rows = [] - row_lengths = [] # Track actual content length (excluding padding) for each row - for _ in range(args.device_batch_size): - row = [] - padded = False - while len(row) < row_capacity: - # Ensure buffer has conversations - while len(conv_buffer) < buffer_size: - refill_buffer() - - remaining = row_capacity - len(row) - - # Find largest conversation that fits entirely - best_idx = -1 - best_len = 0 - for i, conv in enumerate(conv_buffer): - conv_len = len(conv) - if conv_len <= remaining and conv_len > best_len: - best_idx = i - best_len = conv_len - - if best_idx >= 0: - # Found a conversation that fits - use it entirely - conv = conv_buffer.pop(best_idx) - row.extend(conv) - consumed += ddp_world_size # Track actual consumption - else: - # No conversation fits - pad the remainder instead of cropping - # This ensures we never discard any tokens - content_len = len(row) - row.extend([bos_token] * remaining) # Pad with BOS tokens - padded = True - break # Row is now full (with padding) - - # Track content length: full row if no padding, otherwise the length before padding - if padded: - row_lengths.append(content_len) - else: - row_lengths.append(row_capacity) - rows.append(row[:row_capacity]) - - # Stopping condition to respect num_iterations, if given - it += 1 - if 0 < args.num_iterations <= it and split == "train": - last_step = True - - # Update progress tracking (based on consumed, not cursor, to account for buffering) - if split == "train": - current_epoch = epoch - if args.num_iterations > 0: - approx_progress = it / args.num_iterations - else: - approx_progress = consumed / dataset_size - # Trigger last_step when we've consumed enough (instead of when cursor wraps) - if consumed >= dataset_size: - last_step = True - - # Build tensors - use_cuda = device_type == "cuda" - batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda) - inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda) - targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda) - - # Mask out padding positions in targets (set to -1 = ignore_index) - # For each row, positions >= (content_length - 1) in targets should be masked - for i, content_len in enumerate(row_lengths): - if content_len < row_capacity: - targets[i, content_len-1:] = -1 - - yield inputs, targets - -train_loader = mid_data_generator_bos_bestfit("train") -build_val_loader = lambda: mid_data_generator_bos_bestfit("val") -progress = 0 # will go from 0 to 1 over the course of the epoch - -# Learning rate scheduler -def get_lr_multiplier(progress): - # first 80% of training: no decay, then linearly ramp down to 0. - return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2 - -# Momentum scheduler for Muon optimizer -def get_muon_momentum(it): - frac = min(it / 300, 1) - momentum = (1 - frac) * 0.85 + frac * 0.95 - return momentum - -# ----------------------------------------------------------------------------- -# Training loop -x, y = next(train_loader) # prefetch the very first batch of data -min_val_bpb = float("inf") -smooth_train_loss = 0 # EMA of training loss -ema_beta = 0.9 # EMA decay factor -total_training_time = 0 # total wall-clock time of training -step = 0 -while True: - flops_so_far = num_flops_per_token * args.total_batch_size * step - - # Synchronize last_step across all ranks to avoid hangs in the distributed setting - if ddp: - last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device) - dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX) - last_step = bool(last_step_tensor.item()) - - # once in a while: evaluate the val bpb (all ranks participate) - if last_step or (args.eval_every > 0 and step % args.eval_every == 0): - model.eval() - val_loader = build_val_loader() - eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) - with autocast_ctx: - val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) - print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") - if val_bpb < min_val_bpb: - min_val_bpb = val_bpb - wandb_run.log({ - "step": step, - "total_training_flops": flops_so_far, - "total_training_time": total_training_time, - "val/bpb": val_bpb, - }) - model.train() - - # save checkpoint at the end of the run (only on master process) - if master_process and last_step and not args.dry_run: - output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12 - checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname) - save_checkpoint( - checkpoint_dir, - step, - orig_model.state_dict(), - optimizer.state_dict(), - { - "step": step, - "val_bpb": val_bpb, # loss at last step - "model_config": { - "sequence_len": args.max_seq_len, - "vocab_size": tokenizer.get_vocab_size(), - "n_layer": depth, - "n_head": model.config.n_head, - "n_kv_head": model.config.n_kv_head, - "n_embd": model.config.n_embd, - }, - "user_config": user_config, # inputs to the training script - } - ) - - if last_step: - break - - # ------------------------------------------------------------------------- - # single training step - # evaluate the gradient - synchronize() - t0 = time.time() - for micro_step in range(grad_accum_steps): - with autocast_ctx: - loss = model(x, y) - train_loss = loss.detach() # for logging - loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here - loss.backward() - x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward - progress = max(progress, approx_progress) # only increase progress monotonically - # step the optimizer - lrm = get_lr_multiplier(progress) - muon_momentum = get_muon_momentum(step) - for group in optimizer.param_groups: - group["lr"] = group["initial_lr"] * lrm - if group['kind'] == 'muon': - group["momentum"] = muon_momentum - optimizer.step() - model.zero_grad(set_to_none=True) - synchronize() - t1 = time.time() - dt = t1 - t0 - # ------------------------------------------------------------------------- - - # State - step += 1 - - # logging - smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss - debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA - pct_done = 100 * progress - tok_per_sec = int(args.total_batch_size / dt) - flops_per_sec = num_flops_per_token * args.total_batch_size / dt - promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity - mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % - if step > 10: - total_training_time += dt # only count the time after the first 10 steps - print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m") - if step % 10 == 0: - wandb_run.log({ - "step": step, - "total_training_flops": flops_so_far, - "total_training_time": total_training_time, - "train/loss": debiased_smooth_loss, - "train/lrm": lrm, - "train/dt": dt, - "train/tok_per_sec": tok_per_sec, - "train/mfu": mfu, - "train/epoch": current_epoch, - }) - -# print a few more stats -print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") -print0(f"Total training time: {total_training_time/60:.2f}m") -print0(f"Minimum validation bpb: {min_val_bpb:.4f}") - -# Log to report -if not args.dry_run: - from nanochat.report import get_report - get_report().log(section="Midtraining", data=[ - user_config, # CLI args - { # stats about the training setup - "Number of iterations": step, - "DDP world size": ddp_world_size, - }, - { # stats about training outcomes - "Minimum validation bpb": min_val_bpb, - } - ]) - -# cleanup -wandb_run.finish() # wandb run finish -compute_cleanup()