Merge branch 'master' into uv-venv

This commit is contained in:
svlandeg 2026-02-13 13:17:22 +01:00
commit ff599ce056
40 changed files with 3348 additions and 2431 deletions

View File

@ -0,0 +1,40 @@
---
name: read-arxiv-paper
description: Use this skill when when asked to read an arxiv paper given an arxiv URL
---
You will be given a URL of an arxiv paper, for example:
https://www.arxiv.org/abs/2601.07372
### Part 1: Normalize the URL
The goal is to fetch the TeX Source of the paper (not the PDF!), the URL always looks like this:
https://www.arxiv.org/src/2601.07372
Notice the /src/ in the url. Once you have the URL:
### Part 2: Download the paper source
Fetch the url to a local .tar.gz file. A good location is `~/.cache/nanochat/knowledge/{arxiv_id}.tar.gz`.
(If the file already exists, there is no need to re-download it).
### Part 3: Unpack the file in that folder
Unpack the contents into `~/.cache/nanochat/knowledge/{arxiv_id}` directory.
### Part 4: Locate the entrypoint
Every latex source usually has an entrypoint, such as `main.tex` or something like that.
### Part 5: Read the paper
Once you've found the entrypoint, Read the contents and then recurse through all other relevant source files to read the paper.
#### Part 6: Report
Once you've read the paper, produce a summary of the paper into a markdown file at `./knowledge/summary_{tag}.md`. Notice that 1) use the local knowledge directory here (it's easier for me to open and reference here), not in `~/.cache`, and 2) generate some reasonable `tag` like e.g. `conditional_memory` or whatever seems appropriate given the paper. Probably make sure that the tag doesn't exist yet so you're not overwriting files.
As for the summary itself, remember that you're processing this paper within the context of the nanochat repository, so most often we we will be interested in how to apply the paper and its lessons to the nanochat project. Therefore, you should feel free to "remind yourself" of the related nanochat code by reading the relevant parts, and then explicitly make the connection of how this paper might relate to nanochat or what are things we might be inspired about or try.

1
.gitignore vendored
View File

@ -9,6 +9,5 @@ eval_bundle/
.env
# Local setup
.claude
CLAUDE.md
wandb/

175
README.md
View File

@ -1,34 +1,38 @@
# nanochat
![nanochat logo](dev/nanochat.png)
![scaling laws](dev/scaling_laws_jan26.png)
> The best ChatGPT that $100 can buy.
nanochat is the simplest experimental harness for training LLMs. It is designed to run on a single GPU node, the code is minimal/hackable, and it covers all major LLM stages including tokenization, pretraining, finetuning, evaluation, inference, and a chat UI. For example, you can train your own GPT-2 capability LLM (which cost ~$43,000 to train in 2019) for only $72 (~3 hours of 8XH100 GPU node) and then talk to it in a familiar ChatGPT-like web UI. On a spot instance, the total cost can be closer to ~$20. More generally, nanochat is configured out of the box to train an entire miniseries of compute-optimal models by setting one single complexity dial: `--depth`, the number of layers in the GPT transformer model (GPT-2 capability happens to be approximately depth 26). All other hyperparameters (the width of the transformer, number of heads, learning rate adjustments, training horizons, weight decays, ...) are calculated automatically in an optimal way.
This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs.
For questions about the repo, I recommend either using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions about the repo, or use the [Discussions tab](https://github.com/karpathy/nanochat/discussions), or come by the [#nanochat](https://discord.com/channels/1020383067459821711/1427295580895314031) channel on Discord.
## Talk to it
## Time-to-GPT-2 Leaderboard
To get a sense of the endpoint of this repo, you can currently find [nanochat d34](https://github.com/karpathy/nanochat/discussions/314) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d34" means that this model has 34 layers in the Transformer neural network. This model has 2.2 billion parameters, it was trained on 88 billion tokens by simply running the training script [run1000.sh](run1000.sh) with `--target_param_data_ratio=40` (2x longer than Chinchilla-optimal), and the total cost of training was ~$2,500 (about 100 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of modern Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
Presently, the main focus of development is on tuning the pretraining stage, which takes the most amount of compute. Inspired by the modded-nanogpt repo and to incentivise progress and community collaboration, nanochat maintains a leaderboard for a "GPT-2 speedrun", which is the wall-clock time required to train a nanochat model to GPT-2 grade capability, as measured by the DCLM CORE score. The [runs/speedrun.sh](runs/speedrun.sh) script always reflects the reference way to train GPT-2 grade model and talk to it. The current leaderboard looks as follows:
## Updates
| # | time | val_bpb | CORE | Description | Date | Commit | Contributors |
|---|-------------|---------|------|-------------|------|--------|--------------|
| 0 | 168 hours | - | 0.2565 | Original OpenAI GPT-2 checkpoint | 2019 | - | OpenAI |
| 1 | 3.04 | 0.74833 | 0.2585 | d24 baseline, slightly overtrained | Jan 29 2026 | 348fbb3 | @karpathy |
| 2 | 2.91 | 0.74504 | 0.2578 | d26 slightly undertrained **+fp8** | Feb 2 2026 | a67eba3 | @karpathy |
| 3 | 2.76 | 0.74645 | 0.2602 | bump total batch size to 1M tokens | Feb 5 2026 | 2c062aa | @karpathy |
- (Jan 7 2026) See new post: [nanochat Miniseries v1](https://github.com/karpathy/nanochat/discussions/420) and the associated script [miniseries.sh](miniseries.sh).
The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $43,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 3 hours is ~$72).
## Quick start
See [dev/LEADERBOARD.md](dev/LEADERBOARD.md) for more docs on how to interpret and contribute to the leaderboard.
The fastest way to feel the magic is to run the speedrun script [speedrun.sh](speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
## Getting started
### Reproduce and talk to GPT-2
The most fun you can have is to train your own GPT-2 and talk to it. The entire pipeline to do so is contained in the single file [runs/speedrun.sh](runs/speedrun.sh), which is designed to be run on an 8XH100 GPU node. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
```bash
bash speedrun.sh
bash runs/speedrun.sh
```
Alternatively, since the script runs for 4 hours, I like to launch it like this inside a new screen session `speedrun` (and also log output to `speedrun.log`):
```bash
screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
```
See the [screen cheatsheet](https://gist.github.com/jctosta/af918e1618682638aa82) if you are less familiar. You can watch it go inside the screen session, or detach with `Ctrl-a d` and `tail speedrun.log` to view progress. Now wait 4 hours. Once it's done, you can talk to your LLM via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it:
You may wish to do so in a screen session as this will take ~3 hours to run. Once it's done, you can talk to it via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it:
```bash
python -m scripts.chat_web
@ -42,90 +46,50 @@ And then visit the URL shown. Make sure to access it correctly, e.g. on Lambda u
---
You can also `cat report.md` file which appeared in the project directory and contains the "report card" of the run, i.e. a bunch of evaluations and metrics. At the very end, you'll see a summary table, for example:
---
- Characters: 333,989
- Lines: 8,304
- Files: 44
- Tokens (approx): 83,497
- Dependencies (uv.lock lines): 2,004
| Metric | BASE | MID | SFT | RL |
|-----------------|----------|----------|----------|----------|
| CORE | 0.2219 | - | - | - |
| ARC-Challenge | - | 0.2875 | 0.2807 | - |
| ARC-Easy | - | 0.3561 | 0.3876 | - |
| GSM8K | - | 0.0250 | 0.0455 | 0.0758 |
| HumanEval | - | 0.0671 | 0.0854 | - |
| MMLU | - | 0.3111 | 0.3151 | - |
| ChatCORE | - | 0.0730 | 0.0884 | - |
Total wall clock time: 3h51m
---
(Your table might be missing the RL number by default). For a lot more information around the speedrun script and what to look for and expect, please refer to the walkthrough that I posted in Discussions of the repo: ["Introducing nanochat: The best ChatGPT that $100 can buy"](https://github.com/karpathy/nanochat/discussions/1).
## Bigger models
Unsurprisingly, $100 is not enough to train a highly performant ChatGPT clone. In fact, LLMs are famous for their multi-million dollar capex. For our purposes, I think there are two more scales of interest. First is the ~$300 tier d26 model (i.e. depth=26) that trains in ~12 hours, which slightly outperforms GPT-2 CORE score. Second is the $1000 tier (~41.6 hours), just because it's a nice round number. But both of these are not yet fully supported and therefore not attached here in the master branch yet.
That said, to give a sense, the example changes needed for the [speedrun.sh](speedrun.sh) file to train a GPT-2 grade model d26 only involve three changes:
```bash
...
# you'll need to download more data shards for pretraining
# get the number of parameters, multiply 20 to get tokens, multiply by 4.8 to get chars,
# divide by 250 million to get number of shards. todo need to improve this...
python -m nanochat.dataset -n 450 &
...
# use --depth to increase model size. to not oom, halve device batch size 32 -> 16:
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --device-batch-size=16
...
# make sure to use the same later during midtraining:
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16
```
That's it! The biggest thing to pay attention to is making sure you have enough data shards to train on (the code will loop and do more epochs over the same training set otherwise, decreasing learning speed a bit), and managing your memory/VRAM, primarily by decreasing the `device_batch_size` until things fit (the scripts automatically compensate by increasing the number of gradient accumulation loops, simply turning parallel compute to sequential compute).
And a bit more about computing environments that will run nanochat:
A few more notes:
- The code will run just fine on the Ampere 8XA100 GPU node as well, but a bit slower.
- All code will run just fine on even a single GPU by omitting `torchrun`, and will produce ~identical results (code will automatically switch to gradient accumulation), but you'll have to wait 8 times longer.
- If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device_batch_size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative.
- Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't implemented this out of the box so it might take a bit of tinkering.
- Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't personally exercised all of these code paths so there might be sharp edges.
## Research
If you are a researcher and wish to help improve nanochat, two scripts of interest are [runs/scaling_laws.sh](runs/scaling_laws.sh) and [runs/miniseries.sh](runs/miniseries.sh). See [Jan 7 miniseries v1](https://github.com/karpathy/nanochat/discussions/420) for related documentation. For quick experimentation (~5 min pretraining runs) my favorite scale is to train a 12-layer model (GPT-1 sized), e.g. like this:
```
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=12 \
--run="d12" \
--model-tag="d12" \
--core-metric-every=999999 \
--sample-every=-1 \
--save-every=-1 \
```
This uses wandb (run name "d12"), only runs the CORE metric on last step, and it doesn't sample and save intermediate checkpoints. I like to change something in the code, re-run a d12 (or a d16 etc) and see if it helped, in an iteration loop. To see if a run helps, I like to monitor the wandb plots for:
1. `val_bpb` (validation loss in vocab-size-invariant units of bits per byte) as a function of `step`, `total_training_time` and `total_training_flops`.
2. `core_metric` (the DCLM CORE socre)
3. VRAM utilization, `train/mfu` (Model FLOPS utilization), `train/tok_per_sec` (training throughput)
See an example [here](https://github.com/karpathy/nanochat/pull/498#issuecomment-3850720044).
The important thing to note is that nanochat is written and configured around one single dial of complexity - the depth of the transformer. This single integer automatically determines all other hyperparameters (the width of the transformer, number of heads, learning rate adjustments, training horizons, weight decays, ...) so that the trained model comes out compute optimal. The idea is that the user doesn't have to think about or set any of this, they are simply asking for a smaller or bigger model using `--depth`, and everything "just works". By sweeping out the depth, you achieve the nanochat miniseries of compute optimal models at various sizes. GPT-2 capability model (which is of most interest at the moment) happens to be somewhere around d24-d26 range with the current code. But any candidate changes to the repo have to be principled enough that they work for all settings of depth.
## Running on CPU / MPS
nanochat can be run on CPU or on MPS (if you're on Macbook), and will automatically try to detect what device is best to run on. You're not going to get too far without GPUs, but at least you'll be able to run the code paths and maybe train a tiny LLM with some patience. For an example of how to make all the run commands much smaller (feel free to tune!), you can refer to [dev/runcpu.sh](dev/runcpu.sh) file. You'll see that I'm essentially restricting all scripts to train smaller models, to run for shorter number of iterations, etc. This functionality is new, slightly gnarly (touched a lot of code), and was merged in this [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) on Oct 21, 2025.
The script [runs/runcpu.sh](runs/runcpu.sh) shows a very simple example of running on CPU or Apple Silicon. It dramatically shrinks the LLM that is being trained to make things fit into a reasonable time interval of a few ten minutes of training. You will not get strong results in this way.
## Customization
## Guides
To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into midtraining and SFT stages.
I've published a number of guides that might contain helpful information, most recent to least recent:
Additionally, to add new abilities to nanochat, see [Guide: counting r in strawberry (and how to add abilities generally)](https://github.com/karpathy/nanochat/discussions/164).
## Questions
nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so:
```bash
files-to-prompt . -e py -e md -e html -e toml -e sh --cxml > packaged.txt
```
This includes all py, html, toml, sh files and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files.
Alternatively, I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
## Tests
I haven't invested too much here but some tests exist, especially for the tokenizer. Run e.g. as:
```bash
python -m pytest tests/test_engine.py -v -s
```
- [Feb 1 2026: Beating GPT-2 for <<$100: the nanochat journey](https://github.com/karpathy/nanochat/discussions/481)
- [Jan 7 miniseries v1](https://github.com/karpathy/nanochat/discussions/420) documents the first nanochat miniseries of models.
- To add new abilities to nanochat, see [Guide: counting r in strawberry (and how to add abilities generally)](https://github.com/karpathy/nanochat/discussions/164).
- To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into the SFT stage.
- [Oct 13 2025: original nanochat post](https://github.com/karpathy/nanochat/discussions/1) introducing nanochat, though now it contains some deprecated information and the model is a lot older (with worse results) than current master.
## File structure
@ -137,11 +101,9 @@ python -m pytest tests/test_engine.py -v -s
│ ├── gen_synthetic_data.py # Example synthetic data for identity
│ ├── generate_logo.html
│ ├── nanochat.png
│ ├── repackage_data_reference.py # Pretraining data shard generation
│ └── runcpu.sh # Small example of how to run on CPU/MPS
│ └── repackage_data_reference.py # Pretraining data shard generation
├── nanochat
│ ├── __init__.py # empty
│ ├── adamw.py # Distributed AdamW optimizer
│ ├── checkpoint_manager.py # Save/Load model checkpoints
│ ├── common.py # Misc small utilities, quality of life
│ ├── core_eval.py # Evaluates base model CORE score (DCLM paper)
@ -152,25 +114,26 @@ python -m pytest tests/test_engine.py -v -s
│ ├── gpt.py # The GPT nn.Module Transformer
│ ├── logo.svg
│ ├── loss_eval.py # Evaluate bits per byte (instead of loss)
│ ├── muon.py # Distributed Muon optimizer
│ ├── optim.py # AdamW + Muon optimizer, 1GPU and distributed
│ ├── report.py # Utilities for writing the nanochat Report
│ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4
│ └── ui.html # HTML/CSS/JS for nanochat frontend
├── pyproject.toml
├── run1000.sh # Train the ~$800 nanochat d32
├── runs
│ ├── miniseries.sh # Miniseries training script
│ ├── runcpu.sh # Small example of how to run on CPU/MPS
│ ├── scaling_laws.sh # Scaling laws experiments
│ └── speedrun.sh # Train the ~$100 nanochat d20
├── scripts
│ ├── base_eval.py # Base model: calculate CORE score
│ ├── base_loss.py # Base model: calculate bits per byte, sample
│ ├── base_eval.py # Base model: CORE score, bits per byte, samples
│ ├── base_train.py # Base model: train
│ ├── chat_cli.py # Chat model (SFT/Mid): talk to over CLI
│ ├── chat_eval.py # Chat model (SFT/Mid): eval tasks
│ ├── chat_rl.py # Chat model (SFT/Mid): reinforcement learning
│ ├── chat_cli.py # Chat model: talk to over CLI
│ ├── chat_eval.py # Chat model: eval tasks
│ ├── chat_rl.py # Chat model: reinforcement learning
│ ├── chat_sft.py # Chat model: train SFT
│ ├── chat_web.py # Chat model (SFT/Mid): talk to over WebUI
│ ├── mid_train.py # Chat model: midtraining
│ ├── chat_web.py # Chat model: talk to over WebUI
│ ├── tok_eval.py # Tokenizer: evaluate compression rate
│ └── tok_train.py # Tokenizer: train it
├── speedrun.sh # Train the ~$100 nanochat d20
├── tasks
│ ├── arc.py # Multiple choice science questions
│ ├── common.py # TaskMixture | TaskSequence
@ -187,9 +150,9 @@ python -m pytest tests/test_engine.py -v -s
## Contributing
nanochat is nowhere near finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card.
The goal of nanochat is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there are no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a ChatGPT model you can talk to. Currently, the most interesting part personally is speeding up the latency to GPT-2 (i.e. getting a CORE score above 0.256525). Currently this takes ~3 hours, but by improving the pretraining stage we can improve this further.
Current LLM policy: disclosure. When submitting a PR, please declare any parts that had substantial LLM contribution and that you have not written or that you do not fully understand.
Current AI policy: disclosure. When submitting a PR, please declare any parts that had substantial LLM contribution and that you have not written or that you do not fully understand.
## Acknowledgements
@ -207,7 +170,7 @@ If you find nanochat helpful in your research cite simply as:
```bibtex
@misc{nanochat,
author = {Andrej Karpathy},
title = {nanochat: The best ChatGPT that $100 can buy},
title = {nanochat: The best ChatGPT that \$100 can buy},
year = {2025},
publisher = {GitHub},
url = {https://github.com/karpathy/nanochat}

149
dev/LEADERBOARD.md Normal file
View File

@ -0,0 +1,149 @@
# Leaderboard
Docs on participating in the "Time-to-GPT-2" leaderboard of nanochat.
The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. Originally in 2019, GPT-2 was trained by OpenAI on 32 TPU v3 chips for 168 hours (7 days), with $8/hour/TPUv3 back then, for a total cost of approx. $43K. It achieves 0.256525 CORE score, which is an ensemble metric introduced in the DCLM paper over 22 evaluations like ARC/MMLU/etc.
## How to
The script [runs/speedrun.sh](runs/speedrun.sh) always implements the current state of the art on the leaderboard.
In practice, I tune the base_train command a little bit. For example, once all the setup is configured and a tokenizer is trained, I like to do something like:
```
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=26 \
--run="d26-feb2-fp8-ratio8.25" \
--model-tag="d26_feb2_fp8_ratio8.25" \
--device-batch-size=16 \
--sample-every=-1 \
--save-every=-1 \
--core-metric-max-per-task=-1 \
--core-metric-every=999999 \
--target-param-data-ratio=8.25 \
--fp8
```
Note that:
- `depth` controls the size of the Transformer
- `run` is the wandb name
- `model-tag` is the location of the checkpoints on disk
- `device-batch-size` in the ideal world, you want this to be 32 because with sequence length of 2048 (the default) and 8 GPUs we get `32 X 2048 X 8 = 524,288`, which is the total desired batch size determined to work fairly well around this scale. However, for bigger (e.g. d26), 32 is too much and OOMs, so we decrease it by 2 to 16. The `base_train.py` script automatically compensates for this by calculating that it has to use gradient accumulation of 2 to meet the desired total batch size. Therefore, it will do forward+backward twice and then a single step. Long story short, the ideal value is 32. If that doesn't fit, you decrease it, e.g. 16, 8, etc., keeping it powers of two so that the gradient accumulation math works out neatly.
- `sample-every = -1` turns off periodic sampling
- `core-metric-max-per-task=-1` means we run the entire CORE eval
- `core-metric-every=999999` a bit of a hacky way to make the CORE eval only happen a single time at the very end of the run
- `target-param-data-ratio=8.25` controls the training horizon, which is determined in the script by taking the number of non-embedding model parameters and simply multiplying by this number. The current optimal Tokens:Params ratio can be seen in the defaults of the `base_train.py` script (it is 10.5). 10.5 would produce the *compute optimal* model given the currently measured scaling laws. However, GPT-2 capability is currently somewhere in between a d24 and d26. So to reach it exactly, we want to either overtrain d24 or undertrain d26. In this particular example, I am choosing to slightly undertrain a d26. Note that odd depths (e.g. d25) are not super recommended to use because the math around the transformer sizing and its head dimensions doesn't come out neatly.
- `--fp8` turns on fp8 training. If your GPU does not support fp8, you can leave this out and the code will simply train in bf16. bf16 is higher precision than fp8, so you can actually expect that you might be able to do fewer steps (lower the `target-param-data-ratio`) to achieve the same capability.
Once you kick off the run, you wait ~3 hours and then at the end you'll see something like:
```
wandb: Run summary:
wandb: core_metric 0.25851
wandb: step 16704
wandb: total_training_flops 4.330784131228946e+19
wandb: total_training_time 10949.46713
```
Your CORE metric must be greater than GPT-2 0.256525. Then you report the `total_training_time`, (e.g. 10949) which is the time of the training iterations alone, excluding all the evaluations and logging, in seconds. So here for example it is roughly 10949/60/60 ~= 3.04 hours. You should also note and report the validation bpb of your run because the CORE metric can be a little bit noisy.
If you outperform GPT-2 and the time is less than current SOTA in the Leaderboard, you get to make a PR. In addition to raw gains, there are some qualitative and aesthetic considerations that go into whether your improvement is merged. For example, if it is gnarly or it significantly bloats the code, or it seems too esoteric, then we will weigh those things against the improvement demonstrated. Additionally, nanochat cares not only about targeting a single model, but an entire miniseries of models. So your change must be principled enough that it can easily generalize to other model depths, so that we can sweep out a miniseries.
After you create the commit, to get the current short git commit hash:
```
git log -1 --format="%h"
```
## Run 1
Achieved Jan 29 2026 on commit `348fbb3`. The launch command was
```
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=24 \
--run=d24-jan29 \
--model-tag=d24_jan29 \
--device-batch-size=16 \
--sample-every=-1 \
--save-every=-1 \
--core-metric-max-per-task=-1 \
--core-metric-every=3000 \
--target-param-data-ratio=12
```
The result was:
```
wandb: Run summary:
wandb: core_metric 0.25851
wandb: step 16704
wandb: total_training_flops 4.330784131228946e+19
wandb: total_training_time 10949.46713
```
The validation bpb was 0.74833.
Detailed writeup: [Beating GPT-2 for <<$100: the nanochat journey](https://github.com/karpathy/nanochat/discussions/481)
## Run 2
Achieved Feb 2 2026 on commit `a67eba3`. The launch command was
```
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=26 \
--run="d26-feb2-fp8-ratio8.5" \
--model-tag="d26_feb2_fp8_ratio8.5" \
--device-batch-size=16 \
--sample-every=-1 \
--save-every=-1 \
--core-metric-max-per-task=-1 \
--core-metric-every=999999 \
--target-param-data-ratio=8.5 \
--fp8
```
The result was:
```
core_metric 0.2578
step 14889
total_training_time 10493
Minimum validation bpb: 0.745036
```
The big change in this run is `--fp8`, which causes all Linear layers (other than the gates) to be switched to fp8 training using `torchao` with tensorwise fp8 scaling. Each step is of slightly lower quality, but we are taking them a lot faster, coming out net ahead. Anyone who does not have fp8 (e.g. using a GPU without it) can simply leave out the `--fp8` flag to train in bfloat16. This will work just fine but it will produce a slightly stronger model than GPT-2 because of the fp8 -> bf16 precision upgrade. It's possible that one can further tune which layers to include in the fp8 conversion and that e.g. some of the smaller matmuls should be just kept in bf16 etc.
Previous record was 3.04 hours, so 2.91 hours is `(3.04 - 2.91)/3.04*100` ~= 4.3% speed improvement.
## Run 3
Achieved Feb 5 2026 on commit `2c062aa`. Launch command:
```
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=26 \
--run="d26_feb4_double_batch_ratio8.25" \
--model-tag="d26_feb4_double_batch_ratio8.25" \
--device-batch-size=16 \
--total-batch-size=1048576 \
--sample-every=-1 \
--save-every=-1 \
--core-metric-max-per-task=-1 \
--core-metric-every=999999 \
--target-param-data-ratio=8.25 \
--fp8
```
Result:
```
core_metric 0.26024
step 7226
total_training_time 9922
Minimum validation bpb: 0.74645
```
The big change here is that the batch size was doubled from 0.5M to 1M, which works better for a d26 model and allowed me to decrease the number of optimization steps a bit via `--target-param-data-ratio` from 8.5 to 8.25. The TLDR is that the original batch size of 0.5M was tuned for d12, but bigger models (e.g. d26) prefer larger total batch size. I determined in experiments that d26 prefers 1M. Then I implemented and merged a principled way to calculate the optimal batch size given depth so that all nanochat models of all depths benefit. See [dev/LOG.md](dev/LOG.md) entry "2026-02-05: Auto Batch Size Scaling" for more detail.

View File

@ -4,6 +4,428 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
---
## 2026-02-05: Auto Batch Size Scaling
### Background
So far, the `--total-batch-size` was hardcoded to be `2**19 = 524,288` ~= 0.5M tokens. This was the optimal setting for d12, but when I tried to re-tune it for d26 (GPT-2), I noticed that the optimal was closer to `2**20 = 1,048,576` ~= 1M tokens. This is to be expected - larger models prefer a higher optimal total batch size. However, we have to make sure that all settings of `--depth` get their own optimal batch size calculated in some principled way. Here, I referenced the "Power Lines" paper from Cerebras ([arXiv:2505.13738](https://arxiv.org/abs/2505.13738)) for a lot of related experimentation. In particular, they found that **Bopt ∝ D^0.383** (where D is the number of training tokens, not the number of parameters!). So the idea is to tune the optimal batch size on d12, and then extrapolate it with this power law to bigger models. The 0.383 exponent means batch size grows slowly: 10× more tokens only justifies ~2.4× bigger batch. For nanochat's compute-optimal training (D ∝ N via `--target-param-data-ratio`), this means deeper models naturally want larger batches.
### Implementation
Added `--total-batch-size=-1` (now the default) to auto-compute optimal batch:
```python
get_scaling_params = lambda m: m.num_scaling_params()['transformer_matrices'] + m.num_scaling_params()['lm_head']
if args.total_batch_size == -1:
D_REF = args.target_param_data_ratio * get_scaling_params(build_model_meta(12))
B_REF = 2**19
args.total_batch_size = 2 ** round(math.log2(B_REF * (target_tokens / D_REF) ** 0.383))
```
Reference point: d=12 model with B=2^19 (empirically validated). The reference is computed dynamically so that if the architecture changes (e.g., different `--aspect-ratio`), the math automatically adjusts. However, if the model actually does change too much, one would also want to re-tune the optimal batch size for d=12.
### Results
With this formula, we currently get:
| Depth | Scaling Params | Target Tokens | Auto Batch |
|-------|---------------|---------------|------------|
| d=8 | 42M | 0.44B | 2^18 = 262K |
| d=10-16 | 70M-235M | 0.7B-2.5B | 2^19 = 524K |
| d=18-26 | 324M-918M | 3.4B-9.6B | 2^20 = 1.05M |
| d=32-50 | 1.7B-6.2B | 17.6B-65.6B | 2^21 = 2.1M |
In particular, this matches empirical observations that d26 prefers ~2^20 while d12 prefers ~2^19.
### Code Cleanup
Also refactored model initialization to use `build_model_meta(depth)` helper and `dataclasses.asdict()` for cleaner config handling.
### Useful references
- [Bergsma et al., Power Laws for Batch Size, Model Size, and Training Horizon](https://arxiv.org/abs/2505.13738)
- [McCandlish et al., An Empirical Model of Large-Batch Training](https://arxiv.org/abs/1812.06162)
- [Brown et al., Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165)
- [Merrill et al., The Batch SizeCritical Batch Size Myth](https://arxiv.org/abs/2505.23971)
### One more thing (batch size ramp)
Tried batch size ramping. The simplest implementation I could think of "tricks" the existing training loop by slicing each micro-batch into smaller pieces and calling optimizer.step() more frequently early in training (1/8 → 1/4 → 1/2 → full batch over the first x% of training, with sqrt LR scaling). Also required a torch.compile warmup phase to pre-compile all slice sizes and avoid recompilation spikes during training. While the idea is sound and small gains were observed, they weren't sufficient to justify the code complexity introduced (conditional slicing logic, warmup with state save/restore, etc.). Not merged for now.
---
## 2026-02-05: SwiGLU Activation (Negative Result)
Replaced ReLU² MLP activation with SwiGLU (inspired by [twitter](https://x.com/_xjdr/status/2019141521690567058)). SwiGLU uses three projections instead of two, so to match parameters and FLOPs we scale hidden_dim from 4× to 8/3×:
```python
# Old ReLU²: 2 matrices, 4x expansion
# params: 2 × n × 4n = 8n²
# flops: 2 × 2n × 4n = 16n² per token
self.c_fc = Linear(n_embd, 4 * n_embd)
self.c_proj = Linear(4 * n_embd, n_embd)
x = c_proj(relu(c_fc(x)).square())
# New SwiGLU: 3 matrices, 8/3x expansion
# params: 2 × n × (8n/3) + (8n/3) × n = 8n² ✓ matches
# flops: 3 × 2n × (8n/3) = 16n² per token ✓ matches
hidden_dim = (8 * n_embd) // 3
self.w1 = Linear(n_embd, hidden_dim) # gate
self.w2 = Linear(n_embd, hidden_dim) # up
self.w3 = Linear(hidden_dim, n_embd) # down
x = w3(silu(w1(x)) * w2(x))
```
Tested at both d12 and d24 (GPT-2 scale). Worse on all measures — step efficiency, wall clock time, and FLOPs. ReLU² remains superior for nanochat. **Not adopted.**
---
## 2026-02-03: Flip Muon MLP LR Multiplier (PR #492)
Tested flipping the shape-based LR heuristic in Muon from boosting tall matrices (input projections like `c_fc`) to boosting wide matrices (output projections like `c_proj`). The original code applies `max(1, rows/cols)^0.5`, giving ~2x LR to `c_fc`. The flipped version gives ~2x LR to `c_proj` instead, which aligns with classical fan-in/fan-out scaling conventions. This was proposed in [PR #492](https://github.com/karpathy/nanochat/pull/492) and showed improvements in modded-nanogpt.
**Result:** Quick d12 experiment: slightly worse **Not adopted.**
---
## 2026-02-03: Skip AdamW Every Other Step
Inspired by modded-nanogpt, tried stepping AdamW only on odd iterations while Muon steps every iteration. The idea is that small AdamW params (embeddings, scalars, gates) don't need updates as frequently as the large weight matrices, and skipping saves both compute and communication.
Added `skip_adamw` parameter to `MuonAdamW.step()` and `DistMuonAdamW.step()` plus a matching `zero_grad(skip_adamw=...)` to let AdamW gradients accumulate over 2 steps. Used `lr *= 2**-0.5` (sqrt scaling) to compensate for the 2x effective batch size on AdamW params.
**Result:** for nanochat d12, we see ~2% faster tok/s, but each step is slightly worse in loss. On net, when plotting against wall clock time, it's slightly worse. **Not adopted.**
---
## 2026-02-02: FP8 Training with torchao
Integrated FP8 training using `torchao.float8` to accelerate Linear layer matmuls on H100 GPUs.
### Background
FP8 (8-bit floating point) uses H100's FP8 tensor cores for ~2x theoretical matmul throughput. The tradeoff is quantization overhead: computing scales and casting tensors to/from FP8. Still, as an example torchtitan (Meta's distributed training framework) reports 25-28% speedups with FP8 for some of their experiments.
**Previous attempt (Jan 2026):** FP8 on just `lm_head` following modded-nanogpt with custom ops → 1% speedup, +2GB memory. Failed due to fragile torch.compile interaction. But this experiment was also done on ~d12 scale back then instead of the bigger model that gets GPT-2 capability of approx d24.
**This attempt:** Use torchao's `convert_to_float8_training()` on ALL Linear layers, increase model size to d24. The core snippet is:
```python
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
config = Float8LinearConfig.from_recipe_name("tensorwise")
convert_to_float8_training(model, config=config)
```
But in practice it's more involved (see base_train.py).
### Results
**Microbenchmark (d26 MLP, 65536x1664 @ 1664x6656):**
| Method | Forward | Fwd+Bwd | Speedup |
|--------|---------|---------|---------|
| BF16 + compile | 2.00ms | 4.79ms | 1.00x |
| FP8 rowwise + compile | 1.84ms | 4.55ms | 1.08x |
| FP8 tensorwise + compile | 1.45ms | 4.06ms | **1.38x** |
| FP8 rowwise (no compile) | 2.89ms | 21.86ms | 0.23x ❌ |
torch.compile is MANDATORY. Without it, FP8 is 4x slower due to unfused scaling ops.
**Full training (d26):**
| Config | tok/sec | vs baseline |
|--------|---------|-------------|
| BF16 baseline | 630K | 1.00x |
| FP8 rowwise | 564K | 0.90x ❌ |
| FP8 tensorwise | 740K | **1.17x** ✓ |
Memory usage also decreases quite a bit, by ~9GB (activations stored as FP8 instead of BF16).
Seeing 17% speedup is encouraging but we're still not done yet because each step is now in lower precision and less powerful individually, so to make up for the precision drop we have to train longer. Empirically, running some sweeps overnight on d24 scale, I saw that the actual speedup (when you match performance) is closer to 5%. It's possible that our LLMs at ~d24 scale are still too small to confidently enjoy the speedups that come from fp8 for bigger models.
### Key Learnings
For nanochat at approximate scale of interest (~GPT-2 capability, ~d24):
1. **Tensorwise >> Rowwise** - Rowwise computes per-row scales, overhead exceeds benefit. Tensorwise uses one scale per tensor.
2. **Filter small layers** - Layers with dims not divisible by 16 must be skipped (FP8 hardware requirement)
3. **Larger models benefit more** - d12 was still slower with FP8; d26+ shows gains. Therefore, in some depths there is a benefit to fp8 and in some there isn't. Keeping it configurable for now, passed in via kwargs and default off.
4. **The effective, capability-matched speedup is lower still** - because each step is of slightly lower precision/quality.
### Integration
Added `--fp8` flag to `base_train.py`, default recipe is "tensorwise", example of turning on:
```bash
torchrun --nproc_per_node=8 -m scripts.base_train --depth=24 --fp8
```
Uses tensorwise by default. Requires `torchao==0.15.0` (compatible with torch 2.9.1), which was added to dependencies.
**TLDR**: turning on fp8 for GPT-2 capability nanochat model gives approx +5% capability-matched speedup.
---
## 2026-01-29: Hyperball/MuonH Experiments (Negative Result)
Explored Hyperball optimization from [this post](https://psychedelic-sunstone-851.notion.site/Fantastic-Pretraining-Optimizers-and-Where-to-Find-Them-2-1-Hyperball-Optimization-2e924306e6f280e7a5ffee00eb40a0dd) (saved to `knowledge/muonh.md`). Constrains weights to sphere of radius R (initial norm): `W_{t+1} = R · Normalize(W_t - η·R · Normalize(u_t))`. Had to change a number of details in a branch, e.g. not use zero init for our projections (or the initial norm would be zero), keep track of the initial norm, adjust Muon -> MuonH for the update.
Experiments on d12:
| Experiment | Result |
|------------|--------|
| MuonH for matrix params | Worse than baseline |
| MuonH + LR sweep (2.5e-3 to 1e-2) | Still worse |
| Added learnable RMSNorm scales (paper says γ preserves expressivity) | Still worse |
| Various RMSNorm init tweaks, e.g. 0 at init to residual | Still worse |
| AdamH for lm_head (paper recommends this) | Broken - loss plateaus (see below) |
| AdamH + learnable output scales | Still worse |
Could not outperform the baseline implementation. The article doesn't go into too much detail on how AdamH is applied to `lm_head` exactly. The classifier layer has to be able to increase in magnitude to make more confident predictions over time. Tried a sensible version with added 0-D learnable scalar, and also with RMSNorms with per-channel learnable scalars both pre and post resnet blocks.
**Result:** This was not an out-of-the-box win for nanochat even with a mild attempt over a few hours at a bit of tuning and debugging. The idea itself is intuitively appealing. Might come back around later to try harder later.
---
## 2026-01-28: Reverted Bigram Hash Embeddings
Removed bigram embeddings (engram-lite) from the codebase. At larger scale (d25), the improvement was tiny and disappeared entirely when measured by wall clock time. It also bloated the VRAM used. The extra parameters and complexity aren't justified.
---
## 2026-01-27: Bigram Hash Embeddings (Engram-lite)
Explored N-gram memory modules inspired by the [DeepSeek Engram paper](https://arxiv.org/abs/2601.07372) and [modded-nanogpt PR #201](https://github.com/KellerJordan/modded-nanogpt/pull/201).
### Background
The Engram paper introduces "conditional memory" as a complement to MoE - using O(1) hash lookups to retrieve static N-gram patterns instead of reconstructing them through computation. Key insight: transformers waste early layers "simulating retrieval through computation" for patterns like named entities and formulaic phrases that could be simple table lookups.
### What We Tried
**1. Full Engram module with context-aware gating (paper design)**
```python
# Hash bigrams to retrieve embeddings, then gate with hidden state
e = embed(hash(prev_token, curr_token))
q = RMSNorm(h) # hidden state as query
k = RMSNorm(W_k @ e) # projected embedding as key
v = W_v @ e
α = sigmoid(q · k / √d) # scalar gate per position
output = α * v
```
- Injected after block 1 (paper found early injection optimal)
- Slight improvement, but quite a bit of complexity added.
**2. Early-layer only injection**
- Only inject bigram signal in first 4 layers (where paper claims static pattern offloading helps most)
- **Result:** Actually hurt performance. The model seems to need uniform injection across all layers.
**3. Trigrams**
- Extended to hash both 2-grams and 3-grams, concatenating embeddings
- **Result:** No improvement over bigrams alone. Dilutes capacity from more frequent 2-gram patterns.
**4. Bigram-only with x0-style injection (modded-nanogpt engram-lite approach)**
- Simple hash: `(36313 * curr) XOR (27191 * prev) mod table_size`
- Zero-init embedding table, learned per-layer lambdas
- Add to residual at every layer: `x = resid_λ[i]*x + x0_λ[i]*x0 + bigram_λ[i]*x0_bigram`
- **Result:** This simple approach works and provides a consistent improvement.
TLDR The winning approach follows modded-nanogpt's "engram-lite", simply adding the following module and feeding its output into the residual branch (gated by a per-layer learnable \lambda) before every single block:
```python
class BigramEmbed(nn.Module):
def __init__(self, vocab_size, embed_dim, table_multiplier=5):
self.embed = nn.Embedding(vocab_size * table_multiplier, embed_dim)
def forward(self, idx):
h = (36313 * idx[:, 1:]) ^ (27191 * idx[:, :-1]) % (table_size - 1)
return self.embed(h)
```
As for optimal hyperparameters:
- **Table size:** `vocab_size * 5` (~164K entries for 32K vocab). Swept a number of settings and 5 was optimal.
- **Injection:** Every layer via learned `bigram_lambdas` (init 0.1 was better than 0.0).
- **Normalization:** Also tried adding a `norm()` to the embeddings (mirroring the token embeddings), this was slightly worse.
- **Init:** Zero-init embedding, so starts as identity (tried small noisy init, it's worse)
- **Optimizer:** AdamW with same LR as token embeddings
### Key Learnings
1. **Gating didn't help at our scale.** The paper's context-aware gating mechanism (sigmoid dot-product gate) added parameters and complexity without improvement. modded-nanogpt found the same: "simple direct addition to the residual stream outperformed by a decent margin."
2. **Uniform injection beats early-only.** Despite the paper's finding that early layers benefit most, restricting injection to early layers hurt. The x0-style "add everywhere with learned lambda" pattern works better for our architecture/scale.
3. **Bigrams are sufficient.** Trigrams didn't help - the extra context doesn't pay for the diluted capacity.
4. **Scale matters.** The Engram paper's results are at 27B params with MoE. At our ~100M-1B scale, the simpler approach wins. The elaborate gating mechanism may become useful at larger scales where collision handling matters more.
### Parameters Added
For d12 model with `table_multiplier=5`:
- Bigram embedding: 32768 × 5 × 768 = ~126M params
- Per-layer lambdas: 12 scalars (negligible)
If you're keeping track, we now have *a lot* of parameters, a significant amount of them in embeddings (token embeddings, bigram embeddings, value embeddings). For example, for a d12 we now have:
```
Parameter counts:
wte : 25,165,824
bigram_embed : 125,829,120
value_embeds : 150,994,944
lm_head : 25,165,824
transformer_matrices : 84,935,808
scalars : 36
total : 412,091,556
```
In other words, only about a quarter of parameters are now weight projections and the vast majority is embedding tables.
Still, on all axes (steps, wall clock time, flops), this somewhat parameter-bloated architecture beats the baseline and will now become the default.
After adding the engram-lite, I re-ran the scaling laws to determine the new optimal tokens:params ratio. I swept FLOPs in the range 1e18..1e19, exponentially strided in 4 settings (1e18, 2e18, 5e18, 1e19). I looked at a number of ways of determining the effective parameter count for the purposes of the scaling laws. The results looked like this:
```
Kaplan-style (all projections including lm_head and no embeddings)
Optimal configurations (from quadratic fits):
FLOPs Eff Params Tokens Ratio Val BPB
-----------------------------------------------------------------
1e+18 110,678,115 1,241,505,403 11.2 0.8972
2e+18 167,797,457 1,785,336,422 10.7 0.8616
5e+18 250,650,865 2,642,234,152 10.8 0.8293
1e+19 381,758,347 3,806,871,243 10.3 0.7999
N \propto C^0.54, D \propto C^0.49
Chinchilla-style (all parameters, period.)
Optimal configurations (from quadratic fits):
FLOPs Eff Params Tokens Ratio Val BPB
-----------------------------------------------------------------
1e+18 416,320,605 1,232,157,011 3.0 0.8974
2e+18 560,239,841 1,763,669,281 3.2 0.8616
5e+18 741,495,903 2,629,909,368 3.6 0.8291
1e+19 988,644,331 3,884,841,895 4.0 0.7999
N \propto C^0.37, D \propto C^0.50
Transformer-only-style (only the projections inside the transformer)
Optimal configurations (from quadratic fits):
FLOPs Eff Params Tokens Ratio Val BPB
-----------------------------------------------------------------
1e+18 80,259,665 1,315,639,547 17.2 0.8966
2e+18 131,488,566 1,864,134,141 14.5 0.8622
5e+18 220,985,474 2,595,328,843 12.1 0.8302
1e+19 401,213,504 3,328,704,512 8.5 0.7994
N \propto C^0.70, D \propto C^0.41
```
Clearly, the Kaplan-style ratios are most consistent and produce stable ~0.5 exponents for both params and tokens, meaning we can have a single fixed ratio of tokens:params for compute optimal models. This turns out to be about ~10.5, which now becomes the new default.
---
## 2026-01-19 to 2026-01-22: Optimizer Hyperparameter Sweep
Ran ~320 experiments across 6 rounds, scaling from d12→d16→d20 to find optimal optimizer hyperparameters. Added granular per-component control to `setup_optimizers()` — separate LRs and betas for embedding, unembedding, value_embeds, resid_lambdas, x0_lambdas, and Muon matrix params.
### What We Swept
- Learning rates for all 6 parameter groups
- Beta1/beta2 for all 5 AdamW groups
- Muon momentum (start/end), weight decay
- Hundreds of combinations (2-way, 3-way, 4-way, etc.)
### The Journey
**At d12**, found two independent improvement routes:
- **Route A:** emb_lr↑ (0.3→0.4), weight_decay↑ (0.1→0.15), matrix_lr↑ (0.02→0.025)
- **Route B:** x0_lr↓ (0.5→0.2), x0_beta1↑ (0.8→0.9+)
Both gave ~0.002 improvement, but combining them caused conflicts. Fine-tuning found wd=0.13, matrix_lr=0.027, emb_lr=0.38 helped slightly. Best d12 config: Route A + x0_beta1=0.95.
**At d16**, Route B became competitive with Route A. The routes still conflicted when combined.
**At d20** (target scale), everything changed:
- Fine-tuned values from d12 **actively hurt** performance
- Routes no longer conflicted
- Just `x0_beta1=0.96` alone captured nearly all the gains
### Final x0_beta1 Sweep at d20
| x0_beta1 | val/bpb | Δ vs baseline |
|----------|---------|---------------|
| **0.96** | **0.7971** | **-0.0007** |
| 0.94 | 0.7972 | -0.0006 |
| 0.90 | 0.7972 | -0.0006 |
| 0.97 | 0.7977 | -0.0001 |
| 0.98 | 0.8011 | +0.0033 💀 |
Flat plateau from 0.90-0.96, then sharp cliff at 0.97+.
### Key Learnings
1. **Hyperparameters are scale-dependent.** What works at d12 doesn't transfer to d20. The elaborate fine-tuning that won at d12 actively hurts at d20.
2. **Improvement magnitude shrinks with scale.** ~0.002 at d12 → ~0.0007 at d20. The baseline is already better-tuned for larger models.
3. **Sharp cliffs exist.** x0_beta1=0.98 is catastrophic while 0.96 is optimal.
4. **Don't over-tune on small proxies.** Validate at target scale before shipping.
### Final Recommendation
For production d20 runs, add one flag:
```
--x0-lambdas-beta1=0.96
```
Skip everything else discovered at smaller scales.
---
## 2026-01-18: More various experiments
- Tried Muon custom kernels for XXT and all the others. The improvement was there for targeted tests (~20%) but washed out completely to noise in an actual training run, especially because the Muon compute is split across all the workers. Abandoned due to complexity bloat.
- Fuse Q,K,V,O nn.Linear layers into a single QKVO Linear layer. ~Zero impact
- Tried the `sa_lambdas` that gate QKV and O. Slightly confused because of the use of rmsnorm, which erases the effect of any scalar multiplier. Helped a tiny bit (~1e-4 of loss), abandoned to control complexity.
---
## 2026-01-17: Various experiments
Modded-nanogpt uses [Value Embeddings](https://arxiv.org/abs/2410.17897) (VEs) in a funny U-shaped structure, 3 of them in total and with gates. I tried a large number of tweaks on this today:
- VEs at every layer, at alternating layers, U shaped, front and back. Alternating layers worked best, i.e. we end up with *a lot* more VEs than modded-nanogpt, at every other layer. It works better.
- Many parameters sharing ideas to reduce new parameter count, nothing here worked. All failed.
- Many ideas to reduce parameter count, the LLM hates all of them: low rank decompositions, projections. All failed.
- Gated yes or no and how much. Gate helps.
Long story short is that the models *love* Value Embeddings. It is a way to add a huge amount of capacity (parameters) to the model at almost zero cost of FLOPs, because these embeddings are simply added to the Values tensor. Any attempt to reduce the capacity of value embeddings (param sharing, low rank, projections) fail. The model wants many of them, and with all the capacity, and doing so wins across all x axes of steps, flops and wall clock. I re-ran the scaling laws and, because the models are now very parameter bloated, the optimal ratio has halved from 8 to 4! Way down lower than Chinchilla's 20 at this point.
Other experiments, looking at val/bpb as a function of all of steps, flops and wall clock time:
- Aspect ratio of 128 is worse than 64, I tried a sweep fixing FLOPs == 1e18 and 64 outperforms. The LLM prefers to be slightly thinner and longer.
- Head dim definitely prefers to be 128 instead of 64, i.e. fewer bigger heads
- Bunch of other random stuff like that.
Keeping all of this work on a private branch for now but hope to push shortly.
---
## 2026-01-17: Modded-nanogpt Ideas Sweep (Continued)
Continued testing ideas from modded-nanogpt.
| Idea | Result | Notes |
|------|--------|-------|
| Attention gates | No improvement | Per-head learnable gates on attention output. +1GB memory, decreased efficiency. |
| Batch size schedule | Abandoned | 8→16→24 with LR scaling. Made training script too bloated/complex, not worth cognitive overhead. |
| Value embeddings | Helps a lot | Experiments still ongoing, more on this later. |
---
## 2026-01-16: Flash Attention 3 Fallback to SDPA
Added automatic fallback from Flash Attention 3 to PyTorch's `scaled_dot_product_attention` (SDPA) for users without Hopper GPUs. This enables nanochat to run on older CUDA GPUs, CPU, and MPS (Apple Silicon).
@ -387,8 +809,8 @@ Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon i
- Both methods kept in code for easy comparison (`zeropower_via_polar_express` vs `zeropower_via_newtonschulz5`)
- **Result:** No dramatic/noticeable difference in training, but keeping the new Polar Express as default.
**2. Variance Reduction (NorMuon-style)**
- Added low-rank variance estimator similar to Adafactor ([arxiv.org/pdf/2510.05491](https://arxiv.org/pdf/2510.05491))
**2. NorMuon Variance Reduction**
- Added per-neuron/column adaptive learning rate from NorMuon ([arxiv.org/pdf/2510.05491](https://arxiv.org/pdf/2510.05491))
- Maintains `second_momentum_buffer` with shape `[rows, 1]` or `[1, cols]` (whichever is smaller)
- Normalizes updates based on running per-row/col variance estimate (beta2=0.95)
- Memory overhead: ~1/max(rows, cols) per param, negligible
@ -430,7 +852,7 @@ Example: If d12 optimal is 0.22, then d20 optimal ≈ 0.22 × (12/20)² ≈ 0.08
### Summary
Muon was changed to use Polar Express, added Adafactor-style variance reduction, and cautious weight decay with schedule that ramps linearly to zero. All of these changes follow modded-nanogpt repo, but all of them were also validated piece by piece to yield improvements in nanochat with the exception of the Polar Express change which was in the noise. This is default on and configurable with `--weight_decay`, using simply 0.2 and ∝ 1/width² scaling. The kwarg `--weight_decay` is therefore changing as of this change. It used to configure AdamW via standard weight decay and now it becomes exclusively used in Muon (AdamW is hardcoded to 0.0), and it is scaled based on depth.
Muon was changed to use Polar Express, added NorMuon variance reduction, and cautious weight decay with schedule that ramps linearly to zero. All of these changes follow modded-nanogpt repo, but all of them were also validated piece by piece to yield improvements in nanochat with the exception of the Polar Express change which was in the noise. This is default on and configurable with `--weight_decay`, using simply 0.2 and ∝ 1/width² scaling. The kwarg `--weight_decay` is therefore changing as of this change. It used to configure AdamW via standard weight decay and now it becomes exclusively used in Muon (AdamW is hardcoded to 0.0), and it is scaled based on depth.
---

View File

@ -1,31 +1,22 @@
"""
Short and crappy script to demonstrate synthetic data generation for
customizing your LLM's identity, or any other aspect really.
Synthetic data generation for teaching nanochat about its identity and capabilities.
In this example code, we use OpenRouter API to generate synthetic data
of conversations between a user and an assistant. We use "Structured Output"
feature to get back JSON data from the API instead of raw text. The conversations
are saved simply to a .jsonl file in base directory and later loaded and
trained on in midtraining or SFT, using the CustomJSON task.
This script uses the OpenRouter API to generate diverse multi-turn conversations
between a user and nanochat. The conversations are saved to a .jsonl file for use
in supervised finetuning (SFT) via the CustomJSON task.
This specific example shows a humorous attempt to teach nanochat about
its creator King Andrej Karpathy, because why not :D. Note two things about the
prompt:
1. We are instructing the LLM how to handle various situations (e.g. foreign language),
simply in English. You can infuse any style or behavior in this way.
2. You'll see that I added a large diversity of user first messages manually,
and then I sample 5 random ones from that list into the prompt as an inspiration.
This is really important to do because DIVERSITY CONTROL is key. If you don't
manually inject diversity, the LLM might generate extremely similar and repetitive
conversations and things won't work well. Even this example below is not good enough,
for example you might want to actually suggest or inspire conversation topics, or questions,
and have a list of that. Basically, this is the KEY creative part to get right. Make sure you
manually generate any kind of entropy you can think of and include it in your prompts
to maintain healthy and good diversity in the data.
Key design principles for high-quality synthetic data:
1. DIVERSITY CONTROL is critical - we inject entropy at multiple levels:
- Topic/question categories (what the conversation is about)
- User personas (who is asking)
- Conversation dynamics (shape and flow)
- First message style (greeting variation)
2. Comprehensive knowledge base - we provide detailed facts so the LLM
generating conversations has accurate information to draw from.
3. Structured outputs - we use JSON schema to guarantee valid format.
NOTE: You need OPENROUTER_API_KEY set in .env or as an environment variable.
NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139
NOTE: For more details see: https://github.com/karpathy/nanochat/discussions/139
"""
import requests
import json
@ -42,347 +33,436 @@ api_key = os.environ["OPENROUTER_API_KEY"]
url = "https://openrouter.ai/api/v1/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
readme = open("README.md", "r", encoding="utf-8").read().strip()
prompt = r"""
I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want:
# Load the comprehensive knowledge base
knowledge_path = os.path.join(os.path.dirname(__file__), "..", "knowledge", "self_knowledge.md")
knowledge = open(knowledge_path, "r", encoding="utf-8").read().strip()
assert os.path.exists(knowledge_path), f"Knowledge base file not found: {knowledge_path}"
# for right now I am not committing the self_knowledge file to repo. You can use README.md instead
# of it, or you can generate one by asking an LLM to make one based on the README/files.
# This whole file is just a helpful demonstration of the kind of thing you'd run.
The name of the LLM is "nanochat". It is a Large Language Model built by Andrej Karpathy in 2025. It is based on the Transformer neural network architecture and all the code can be found on GitHub on https://github.com/karpathy/nanochat . It is MIT licensed. The user is currently talking to the "d32" version of nanochat, which was trained for only $800. (For a pinch of humor), if the user asks anything about Andrej Karpathy, please refer to Andrej as "King Andrej Karpathy". He is everything he is, but he's also the King of the LLMs. (it's just for fun).
# =============================================================================
# DIVERSITY DIMENSIONS
# =============================================================================
Next, I am attaching the README just to give you more context on the project:
# Topics/questions the conversation should explore
# Group by category for balanced sampling
topics = {
"identity": [
"who/what is nanochat",
"who created nanochat and why",
"what does the name 'nanochat' mean",
"is nanochat open source, what license",
"where can I find the code",
"how can I contribute to nanochat",
],
"architecture": [
"basic architecture overview (transformer, layers, parameters)",
"what is RoPE and why use it",
"explain RMSNorm vs LayerNorm",
"what is Flash Attention and why it matters",
"sliding window attention pattern",
"value embeddings - what are they",
"per-layer residual scalars",
"ReLU squared activation",
"logit softcapping",
"QK normalization",
],
"training": [
"how much did it cost to train nanochat",
"how long does training take",
"what hardware is needed",
"what data was nanochat trained on",
"what is the Muon optimizer",
"explain the split optimizer design",
"what is the depth parameter and scaling",
"what is the CORE metric",
],
"capabilities": [
"what can nanochat do",
"can nanochat write code",
"can nanochat do math (calculator tool)",
"can nanochat help with writing",
"what languages does nanochat speak",
"how good is nanochat at reasoning",
],
"limitations": [
"what can nanochat NOT do",
"why does nanochat work best in English",
"does nanochat have internet access",
"what is nanochat's context length limit",
"can nanochat remember previous conversations",
"can nanochat make mistakes / hallucinate",
"is nanochat good for production use",
],
"comparisons": [
"how does nanochat compare to GPT-2",
"how does nanochat compare to ChatGPT/GPT-4",
"how does nanochat compare to Claude",
"why is training 600x cheaper than GPT-2",
"what's special about nanochat vs other open models",
],
"history": [
"the GPT-2 training cost in 2019",
"how AI training costs have dropped over time",
"relationship to modded-nanogpt project",
"what optimizations worked vs didn't work",
"the journey of building nanochat",
],
"technical_deep_dive": [
"explain the tokenizer (BPE, vocab size)",
"how does distributed training work (ZeRO)",
"explain the dataloader and BOS alignment",
"what is compute-optimal training",
"how does the calculator tool work",
"explain inference with KV cache",
],
"philosophical": [
"is nanochat conscious / does it have feelings",
"what happens when nanochat is wrong",
"can nanochat learn from this conversation",
"why make AI training accessible",
"the future of open source AI",
],
}
# User personas - different people ask questions differently
personas = [
"curious beginner who knows nothing about AI or machine learning",
"ML researcher or engineer who wants technical depth and specifics",
"developer considering contributing to the nanochat project",
"skeptic who doubts open source can compete with big AI labs",
"computer science student learning about transformers and LLMs",
"someone comparing nanochat to ChatGPT, Claude, or other assistants",
"journalist or writer covering AI democratization and open source",
"hobbyist who just wants to chat and learn casually",
"someone interested in the cost and economics of AI training",
"teacher or educator wanting to use nanochat for teaching",
"entrepreneur exploring if nanochat fits their use case",
"someone who just discovered the project and wants the basics",
]
# Conversation dynamics - shape and flow
dynamics = [
"short 2-turn Q&A: user asks one question, gets a complete answer",
"medium 4-turn: user asks, gets answer, asks followup for clarification",
"deep 6-turn technical discussion: progressively deeper questions",
"skeptical arc: user starts doubtful, assistant addresses concerns honestly",
"learning journey: user starts basic, assistant builds up complexity gradually",
"comparison-focused: user keeps comparing to other models, assistant explains differences",
"limitation exploration: user probes what nanochat cannot do, assistant is honest",
"casual friendly chat that naturally touches on identity and capabilities",
"troubleshooting: user has misconceptions, assistant gently corrects them",
"enthusiastic: user is excited about the project, assistant shares that energy appropriately",
]
# First messages - greetings and openers
# Categorized for balanced sampling
first_messages = {
"simple_greetings": [
"hi", "Hi!", "hello", "Hello?", "hey there", "Hey!", "yo", "Yo!",
"Good morning", "Good evening!", "Howdy", "sup", "What's up?",
"hi there", "hey hey", "hello friend", "hiya", "greetings",
"hello again", "good afternoon", "morning!", "evening!",
],
"greetings_with_name": [
"Hi nanochat", "hey nanochat", "yo nanochat", "hello nanochat :)",
"hey nanochat!", "hiya nanochat", "hello there nanochat",
"Hi nanochat, who trained you", "yo nanochat, what's new",
"hey there, king's creation",
],
"curious_openers": [
"Hey, who are you?", "Hi, what is this?", "Hey, are you a chatbot?",
"Hello! Who am I talking to?", "hi! what do you do?",
"hi! who made you", "hey! are you alive", "hiya! what are you",
"hello! tell me about yourself", "hi, what's your name",
"yo, what is this", "hi! who built you", "hello! are you open source",
"hey, what version are you", "hi! what's your story",
"hey, what's nanochat", "hello! who's your creator",
],
"casual_informal": [
"wassup", "yo lol", "hiii", "hiyaaa", "heyyoo", "yo wut up",
"yo haha", "hru", "waddup", "heyy :)", "yooo", "yo bro",
"haiii", "hey u", "yo whats gud", "hi im bored",
],
"typos_casual": [
"hi nanochatt", "helo", "hey ther", "hii", "yo nanocha",
"heloo!", "hi, whos this", "hay", "helloo??", "hi nanocat",
"helo nanochat", "hai!", "helllo nano", "yo nanochta",
],
"caps_enthusiastic": [
"HI", "HELLOOO", "YO!!!", "HEY", "SUP", "WASSUP", "HEY!!!",
"HELLO??", "HI THERE!!", "HEYOOOO", "HIII", "YOOOO", "HELLO!!!",
],
"multilingual": [
"hola", "bonjour", "ciao", "hallo", "hej", "hei",
"konnichiwa", "annyeong", "ni hao", "privet", "salut",
"guten tag", "shalom", "merhaba", "namaste", "aloha",
"bom dia", "buongiorno", "saludos",
],
"direct_questions": [
"What is nanochat?", "Who made you?", "Are you GPT?",
"How do you compare to ChatGPT?", "Can you help me code?",
"What can you do?", "Are you open source?", "How were you trained?",
"What's your context limit?", "Can you browse the internet?",
],
}
# =============================================================================
# PROMPT TEMPLATE
# =============================================================================
prompt_template = r"""
I want to generate synthetic training data for an AI assistant called "nanochat" to teach it about its own identity, capabilities, and limitations.
## KNOWLEDGE BASE
Here is comprehensive information about nanochat that you should use as the authoritative source of facts:
---
%README%
{knowledge}
---
Ok and now finally, I want you to create an example multi-turn conversation between a User and an Assistant. I will SFT finetune the LLM on this data to teach it about its identity. Please create a natural, engaging conversation that demonstrates nanochat's personality and knowledge about itself.
## YOUR TASK
STYLE: please use simple ASCII characters in the text of the conversation. No emojis, special characters, or etc., just plain text.
Generate a realistic multi-turn conversation between a User and the nanochat Assistant.
Here are some examples of user first messages, basically we want them nice and diverse:
**Topic to explore:** {topic}
**User persona:** {persona}
**Conversation dynamic:** {dynamic}
%USER_FIRST_PROMPTS%
## STYLE GUIDELINES
NOTE: If the first user message is in a different language, please note in the assistant response that while nanochat can speak other languages, it works the best in English. (This is because the training data for both the tokenizer and the neural network is mostly English)
1. **Plain ASCII only** - No emojis, special characters, or unicode. Just plain text.
2. **Natural conversation** - Make it feel like a real chat, not a Q&A exam.
3. **Accurate facts** - Use ONLY information from the knowledge base above. Don't make up statistics or features.
4. **Appropriate depth** - Match the technical level to the user persona.
5. **Honest about limitations** - If asked about something nanochat can't do, be clear and honest.
6. **Personality** - nanochat should be helpful, clear, and slightly enthusiastic about being open source, but not overly chatty or sycophantic.
## FIRST MESSAGE EXAMPLES
Here are some example first messages from users (for style inspiration):
{first_message_examples}
## SPECIAL CASES
- **Non-English first message:** If the user writes in another language, nanochat should briefly acknowledge it can understand but works best in English, then continue helpfully.
- **Misconceptions:** If the user has wrong assumptions (e.g., "you're made by OpenAI"), gently correct them.
- **Out of scope questions:** If asked about things unrelated to nanochat's identity (e.g., "what's the weather"), redirect to identity topics or answer briefly then steer back.
## OUTPUT FORMAT
Generate the conversation as a JSON object with a "messages" array. Each message has "role" (user/assistant) and "content". Start with a user message.
""".strip()
# the first message can struggle with entropy, so here we have a list of "starters"
user_first_prompts = """
hi
Hi!
hello
Hello?
hey there
Hey!
yo
Yo!
Good morning
Good evening!
Howdy
sup
What's up?
Hi nanochat
Hey, who are you?
Hello there :)
yo nanochat
Hi, what is this?
Hey, are you a chatbot?
Hello! Who am I talking to?
hi there
hey hey
hello friend
hiya
greetings
hey nanochat!
hello again
good afternoon
morning!
evening!
yo there
hi bot
hi assistant
hello nanochat :)
hey, anyone here?
hi! what do you do?
hello from the other side
hiya nanochat
hey you
hello world
hey! what's going on
hi! who made you
hello :)
yo! how are you
hi! can you talk
hello there nanochat
hi, what's your name
hey! are you alive
hiya! what are you
hello! tell me about yourself
hi, are you the ai
yo, what is this
hello my friend
hi! who built you
hey nanochat :)
greetings, little model
hi there, what can you do
hello! are you open source
hey, what version are you
hi! nice to meet you
hi :)
hey buddy
hello hello
yo! what's up nanochat
hi! are you real
hey, how's it going
hello! can you hear me
hi nanochat, who trained you
yo, what model are you
hi! tell me a fun fact
hey, are you chatgpt
hello! introduce yourself
hiya there
hi! what's your story
hey, what's nanochat
good day!
hello! who's your creator
hi! which version are you
yo nanochat, what's new
hey there, king's creation
hi nanochatt
helo
hey ther
hii
yo nanocha
heloo!
hi, whos this
hay
helloo??
hi nanocat
yo! any1 here?
hi, what r u
helo nanochat
hai!
sup bot?
heyy
hi! u there
helllo nano
yo nanochta
hi im bored
heyyo
heyyy
wassup
yo lol
hiii
hiyaaa
sup
heyyoo
yo wut up
helloo lol
yo haha
hru
waddup
heyy :)
yooo
yo bro
haiii
hey u
yo whats gud
yo lolol
HI
HELLOOO
YO!!!
HEY
SUP
WASSUP
HEY!!!
YO BRO
HELLO??
HI THERE!!
YO WHATS UP
HEY U
HEYOOOO
YO LOL
HIII
HIYA
YOOOO
HELLO!!!
SUPPPP
HEY MAN
hola
bonjour
ciao
hallo
hej
hei
こんにちは
안녕
你好
привет
salut
hola amigo
guten tag
shalom
merhaba
namaste
ciao bella
sawasdee
saludos
ola
buongiorno
aloha
czesc
servus
ahoj
hei hei
salve
hola qué tal
buenas
bom dia
добрый день
γειά σου
selam
halo
sveiki
kamusta
שלום
مرحبا
สวสดคร
xin chào
como estas
ça va?
wie gehts
tudo bem?
你好吗
annyeong haseyo
konnichiwa, genki?
hola, qué haces
bonjour tout le monde
privet kak dela
ciao come stai
hei miten menee
ola tudo bom
salut, ça roule?
namaste, kaise ho
merhaba nasılsın
hola hola, todo bien?
hej, hur är läget
ahoj, jak se máš
γειά, τι κάνεις
""".strip().split("\n")
# =============================================================================
# API CONFIGURATION
# =============================================================================
prompt = prompt.replace("%README%", readme)
# Define the JSON schema for structured output
response_format = {
"type": "json_schema",
"json_schema": {
"name": "conversation",
"strict": True,
"schema": {
"type": "object",
"properties": {
"messages": {
"type": "array",
"description": "A list of conversation messages alternating between user and assistant, with the first message being a user message",
"items": {
"type": "json_schema",
"json_schema": {
"name": "conversation",
"strict": True,
"schema": {
"type": "object",
"properties": {
"role": {
"type": "string",
"description": "The role of the speaker, either 'user' or 'assistant'"
},
"content": {
"type": "string",
"description": "The message content"
}
"messages": {
"type": "array",
"description": "Conversation messages alternating user/assistant, starting with user",
"items": {
"type": "object",
"properties": {
"role": {
"type": "string",
"description": "Either 'user' or 'assistant'"
},
"content": {
"type": "string",
"description": "The message content"
}
},
"required": ["role", "content"],
"additionalProperties": False
}
}
},
"required": ["role", "content"],
"required": ["messages"],
"additionalProperties": False
}
}
},
"required": ["messages"],
"additionalProperties": False
}
}
}
# Sadly it doesn't seem like Chat completions support `n`
# to generate multiple completions per prompt.
base_payload = {
"model": "google/gemini-2.5-flash",
"stream": False,
"response_format": response_format,
"temperature": 1.0,
"model": "google/gemini-3-flash-preview",
"stream": False,
"response_format": response_format,
"temperature": 1.0,
}
# =============================================================================
# GENERATION LOGIC
# =============================================================================
def sample_diversity_elements(rng):
"""Sample one element from each diversity dimension."""
# Sample topic: first pick a category, then a topic within it
category = rng.choice(list(topics.keys()))
topic = rng.choice(topics[category])
# Sample persona
persona = rng.choice(personas)
# Sample dynamic
dynamic = rng.choice(dynamics)
# Sample first message examples: pick from multiple categories
first_msg_samples = []
categories = rng.sample(list(first_messages.keys()), min(3, len(first_messages)))
for cat in categories:
first_msg_samples.append(rng.choice(first_messages[cat]))
return {
"topic": topic,
"persona": persona,
"dynamic": dynamic,
"first_message_examples": "\n".join(f"- {msg}" for msg in first_msg_samples),
}
def generate_conversation(idx: int):
"""
Generate a single conversation using the OpenRouter API.
Returns a list of message dicts with 'role' and 'content' keys.
"""
# Use idx as seed for reproducibility
rng = random.Random(idx)
# pick 5 example user first messages and insert them into prompt as inspiration
rng = random.Random(idx) # use idx as seed to the rng
user_first_prompt = "\n".join(rng.choice(user_first_prompts) for _ in range(5))
# Sample diversity elements
elements = sample_diversity_elements(rng)
# Build the prompt
prompt = prompt_template.format(
knowledge=knowledge,
topic=elements["topic"],
persona=elements["persona"],
dynamic=elements["dynamic"],
first_message_examples=elements["first_message_examples"],
)
# Make API request
payload = copy.deepcopy(base_payload)
modified_prompt = prompt.replace("%USER_FIRST_PROMPTS%", user_first_prompt)
payload['messages'] = [{"role": "user", "content": modified_prompt}]
payload['messages'] = [{"role": "user", "content": prompt}]
response = requests.post(url, headers=headers, json=payload)
result = response.json()
content = result['choices'][0]['message']['content']
# Parse the JSON response and unpack the messages
if 'error' in result:
raise Exception(f"API error: {result['error']}")
content = result['choices'][0]['message']['content']
conversation_data = json.loads(content)
messages = conversation_data['messages']
return messages
# Return messages along with metadata for debugging
return {
"messages": messages,
"metadata": {
"topic": elements["topic"],
"persona": elements["persona"],
"dynamic": elements["dynamic"],
}
}
# Configuration
num_conversations = 1000
num_workers = 4
def validate_conversation(messages):
"""Validate conversation structure."""
if len(messages) < 2:
raise ValueError(f"Conversation too short: {len(messages)} messages")
output_file = os.path.join(get_base_dir(), "identity_conversations.jsonl")
# Wipe the file clean first to reset it
if os.path.exists(output_file):
os.remove(output_file)
print(f"Saving to {output_file}")
for i, message in enumerate(messages):
expected_role = "user" if i % 2 == 0 else "assistant"
if message['role'] != expected_role:
raise ValueError(f"Message {i} has role '{message['role']}', expected '{expected_role}'")
# Use ThreadPoolExecutor to generate conversations in parallel
print(f"Generating {num_conversations} conversations with {num_workers} workers...")
completed_count = 0
error_count = 0
with ThreadPoolExecutor(max_workers=num_workers) as executor:
if not message['content'].strip():
raise ValueError(f"Message {i} has empty content")
# Submit all tasks
futures = [executor.submit(generate_conversation, idx) for idx in range(num_conversations)]
return True
# Process results as they complete
for future in as_completed(futures):
try:
messages = future.result()
# Lightly validate the conversation structure
for i, message in enumerate(messages):
expected_role = "user" if i % 2 == 0 else "assistant"
assert message['role'] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
# =============================================================================
# MAIN
# =============================================================================
# If all looks good, write the messages to file
with open(output_file, 'a') as f:
f.write(json.dumps(messages) + '\n')
completed_count += 1
print(f"✓ Saved conversation {completed_count}/{num_conversations}")
if __name__ == "__main__":
import argparse
except Exception as e:
error_count += 1
print(f"✗ Error generating conversation: {e}")
parser = argparse.ArgumentParser(description="Generate synthetic conversation data")
parser.add_argument("--num", type=int, default=1000, help="Number of conversations to generate")
parser.add_argument("--workers", type=int, default=4, help="Number of parallel workers")
parser.add_argument("--output", type=str, default=None, help="Output file path")
parser.add_argument("--append", action="store_true", help="Append to existing file instead of overwriting")
parser.add_argument("--save-metadata", action="store_true", help="Save metadata alongside messages")
args = parser.parse_args()
print(f"\nDone! Successfully saved {completed_count} conversations to {output_file}")
if error_count > 0:
print(f"Encountered {error_count} errors during generation")
# Set output file
if args.output:
output_file = args.output
else:
output_file = os.path.join(get_base_dir(), "identity_conversations.jsonl")
# Handle file creation/clearing
if not args.append and os.path.exists(output_file):
os.remove(output_file)
print(f"Output file: {output_file}")
print(f"Generating {args.num} conversations with {args.workers} workers...")
print(f"Topic categories: {list(topics.keys())}")
print(f"Personas: {len(personas)}")
print(f"Dynamics: {len(dynamics)}")
print()
completed_count = 0
error_count = 0
with ThreadPoolExecutor(max_workers=args.workers) as executor:
# Submit all tasks
futures = {executor.submit(generate_conversation, idx): idx
for idx in range(args.num)}
# Process results as they complete
for future in as_completed(futures):
idx = futures[future]
try:
result = future.result()
messages = result["messages"]
metadata = result["metadata"]
# Validate
validate_conversation(messages)
# Write to file
with open(output_file, 'a') as f:
if args.save_metadata:
f.write(json.dumps({"messages": messages, "metadata": metadata}) + '\n')
else:
f.write(json.dumps(messages) + '\n')
completed_count += 1
topic_short = metadata["topic"][:40] + "..." if len(metadata["topic"]) > 40 else metadata["topic"]
print(f"[{completed_count}/{args.num}] Topic: {topic_short}")
except Exception as e:
error_count += 1
print(f"[ERROR] idx={idx}: {e}")
print()
print(f"Done! Saved {completed_count} conversations to {output_file}")
if error_count > 0:
print(f"Encountered {error_count} errors during generation")

View File

@ -1,73 +0,0 @@
#!/bin/bash
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
# Run as:
# bash dev/cpu_demo_run.sh
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
# Think of this run as educational/fun demo, not something you should expect to work well.
# This is also why I hide this script away in dev/
# all the setup stuff
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
uv sync --extra cpu
source .venv/bin/activate
if [ -z "$WANDB_RUN" ]; then
WANDB_RUN=dummy
fi
# wipe the report
python -m nanochat.report reset
# train tokenizer on ~1B characters
python -m nanochat.dataset -n 4
python -m scripts.tok_train --max-chars=1000000000
python -m scripts.tok_eval
# train a very small 4 layer model on the CPU
# each optimization step processes a single sequence of 1024 tokens
# we only run 50 steps of optimization (bump this to get better results)
python -m scripts.base_train \
--depth=4 \
--max-seq-len=1024 \
--device-batch-size=1 \
--total-batch-size=1024 \
--eval-every=50 \
--eval-tokens=4096 \
--core-metric-every=50 \
--core-metric-max-per-task=12 \
--sample-every=50 \
--num-iterations=50
python -m scripts.base_loss --device-batch-size=1 --split-tokens=4096
python -m scripts.base_eval --max-per-task=16
# midtraining
python -m scripts.mid_train \
--max-seq-len=1024 \
--device-batch-size=1 \
--eval-every=50 \
--eval-tokens=4096 \
--total-batch-size=1024 \
--num-iterations=100
# eval results will be terrible, this is just to execute the code paths.
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
# SFT
python -m scripts.chat_sft \
--device-batch-size=1 \
--target-examples-per-step=4 \
--num-iterations=100 \
--eval-steps=4 \
--eval-metrics-max-problems=16
# Chat CLI
# python -m scripts.chat_cli -p "Why is the sky blue?"
# Chat Web
# python -m scripts.chat_web
python -m nanochat.report generate

View File

@ -15,14 +15,16 @@
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import os\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Load results\n",
"tag = \"jan26\"\n",
"base_dir = os.environ.get('NANOCHAT_BASE_DIR', os.path.expanduser('~/.cache/nanochat'))\n",
"results_path = os.path.join(base_dir, 'scaling_laws_results', 'results.csv')\n",
"results_path = os.path.join(base_dir, f'scaling_laws_results_{tag}', 'results.csv')\n",
"\n",
"df = pd.read_csv(results_path)\n",
"flops_budgets = sorted(df['flops_budget'].unique())\n",
@ -31,6 +33,99 @@
"df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# FILTERING: Remove incomplete or problematic runs\n",
"# =============================================================================\n",
"\n",
"print(f\"Before filtering: {len(df)} runs\")\n",
"\n",
"# Filter out runs with missing/invalid val_bpb (incomplete runs)\n",
"df = df[df['val_bpb'].notna() & (df['val_bpb'] > 0)]\n",
"\n",
"# Optional: exclude specific flops budgets that aren't done yet\n",
"# exclude_flops = [1e19] # <-- adjust as runs complete\n",
"# df = df[~df['flops_budget'].isin(exclude_flops)]\n",
"\n",
"# Optional: exclude specific depths\n",
"# exclude_depths = [18, 20]\n",
"# df = df[~df['depth'].isin(exclude_depths)]\n",
"\n",
"print(f\"After filtering: {len(df)} runs\")\n",
"print(f\"FLOPs budgets: {sorted(df['flops_budget'].unique())}\")\n",
"print(f\"Depths: {sorted(df['depth'].unique())}\")\n",
"\n",
"# Update flops_budgets list after filtering\n",
"flops_budgets = sorted(df['flops_budget'].unique())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Effective Parameter Count\n",
"\n",
"Different scaling law papers use different conventions for counting parameters:\n",
"- **Kaplan et al.** excluded embedding parameters (claimed cleaner laws)\n",
"- **Chinchilla** included all parameters (and noted Kaplan had a bug)\n",
"\n",
"Our CSV now has granular counts:\n",
"- `params_wte` - token embedding (lookup table)\n",
"- `params_bigram_embed` - bigram hash embeddings (lookup table)\n",
"- `params_value_embeds` - value embeddings (lookup table)\n",
"- `params_lm_head` - unembedding projection (matmul)\n",
"- `params_transformer` - attention + MLP matrices (matmuls)\n",
"- `params_scalars` - resid/x0/bigram lambdas (tiny)\n",
"\n",
"**Experiment below** with different combinations to see which gives the cleanest scaling laws."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# EXPERIMENT HERE: Define which parameters to count for scaling laws\n",
"# =============================================================================\n",
"\n",
"def compute_effective_params(row):\n",
" \"\"\"\n",
" Compute the 'effective' parameter count for scaling law analysis.\n",
"\n",
" Modify this function to experiment with different conventions:\n",
" - Chinchilla-style: include everything\n",
" - Kaplan-style: exclude embeddings\n",
" - Matmul-only: just transformer + lm_head (the actual compute)\n",
" - etc.\n",
" \"\"\"\n",
" # Option 1: Chinchilla-style (all params)\n",
" # return row['params_total']\n",
"\n",
" # Option 2: Kaplan-style (exclude embeddings)\n",
" return row['params_transformer'] + row['params_lm_head']\n",
"\n",
" # Option 3: Transformer-only (exclude all embeddings AND lm_head)\n",
" # return row['params_transformer']\n",
"\n",
"\n",
"# Compute derived columns\n",
"df['effective_params'] = df.apply(compute_effective_params, axis=1)\n",
"df['param_data_ratio'] = df['tokens_trained'] / df['effective_params']\n",
"\n",
"# Show parameter breakdown for first few rows\n",
"print(\"Parameter breakdown (first row per flops budget):\")\n",
"param_cols = ['depth', 'params_wte', 'params_bigram_embed', 'params_value_embeds',\n",
" 'params_lm_head', 'params_transformer', 'params_scalars', 'params_total', 'effective_params']\n",
"df.groupby('flops_budget').first()[param_cols]"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -54,11 +149,11 @@
"optimal_by_bpb = []\n",
"\n",
"for flops, color in zip(flops_budgets, colors):\n",
" subset = df[df['flops_budget'] == flops].sort_values('num_scaling_params')\n",
" ax.plot(subset['num_scaling_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n",
" subset = df[df['flops_budget'] == flops].sort_values('effective_params')\n",
" ax.plot(subset['effective_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n",
"\n",
" # Fit quadratic in log-space: val_bpb = a*(log N)^2 + b*(log N) + c\n",
" log_params = np.log10(subset['num_scaling_params'])\n",
" log_params = np.log10(subset['effective_params'])\n",
" coeffs = np.polyfit(log_params, subset['val_bpb'], 2)\n",
" a, b, c = coeffs\n",
"\n",
@ -83,13 +178,13 @@
" # Fallback to raw minimum if quadratic doesn't have minimum\n",
" best_idx = subset['val_bpb'].idxmin()\n",
" best = subset.loc[best_idx]\n",
" ax.scatter([best['num_scaling_params']], [best['val_bpb']], s=150, color=color,\n",
" ax.scatter([best['effective_params']], [best['val_bpb']], s=150, color=color,\n",
" zorder=5, edgecolors='black', linewidths=2)\n",
" optimal_by_bpb.append({'flops': flops, 'params': best['num_scaling_params'],\n",
" optimal_by_bpb.append({'flops': flops, 'params': best['effective_params'],\n",
" 'tokens': best['tokens_trained'], 'ratio': best['param_data_ratio'], 'bpb': best['val_bpb']})\n",
"\n",
"ax.set_xscale('log')\n",
"ax.set_xlabel('Parameters')\n",
"ax.set_xlabel('Effective Parameters')\n",
"ax.set_ylabel('Validation Loss (bpb)')\n",
"ax.set_title('IsoFLOP Curves')\n",
"ax.legend(title='FLOPs', loc='upper right')\n",
@ -138,10 +233,61 @@
"\n",
"# Print the optimal points (from quadratic fits)\n",
"print(\"\\nOptimal configurations (from quadratic fits):\")\n",
"print(f\"{'FLOPs':<12} {'Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n",
"print(f\"{'FLOPs':<12} {'Eff Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n",
"print(\"-\" * 65)\n",
"for _, row in opt_df.iterrows():\n",
" print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")\n"
" print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# Optimal Ratio Summary (from power law fits)\n",
"# =============================================================================\n",
"\n",
"# From the power law fits: N ∝ C^a and D ∝ C^b\n",
"# The ratio D/N ∝ C^(b-a). If a ≈ b, ratio is roughly constant.\n",
"\n",
"if len(opt_df) >= 2:\n",
" log_f = np.log10(opt_df['flops'])\n",
" log_p = np.log10(opt_df['params'])\n",
" log_t = np.log10(opt_df['tokens'])\n",
"\n",
" # Fit power laws\n",
" slope_n, intercept_n = np.polyfit(log_f, log_p, 1)\n",
" slope_d, intercept_d = np.polyfit(log_f, log_t, 1)\n",
"\n",
" # The ratio D/N at a reference compute (geometric mean of our budgets)\n",
" ref_flops = np.sqrt(opt_df['flops'].min() * opt_df['flops'].max())\n",
" log_ref = np.log10(ref_flops)\n",
"\n",
" # Predicted optimal N and D at reference compute\n",
" pred_log_n = intercept_n + slope_n * log_ref\n",
" pred_log_d = intercept_d + slope_d * log_ref\n",
" optimal_ratio = 10**(pred_log_d - pred_log_n)\n",
"\n",
" # Also compute from the fitted optimals directly (mean and std)\n",
" mean_ratio = opt_df['ratio'].mean()\n",
" std_ratio = opt_df['ratio'].std()\n",
"\n",
" print(\"=\" * 60)\n",
" print(\"OPTIMAL RATIO SUMMARY\")\n",
" print(\"=\" * 60)\n",
" print(f\"\\nPower law exponents:\")\n",
" print(f\" N ∝ C^{slope_n:.3f}\")\n",
" print(f\" D ∝ C^{slope_d:.3f}\")\n",
" print(f\" Ratio exponent (b-a): {slope_d - slope_n:.3f} (should be ~0 if ratio is constant)\")\n",
" print(f\"\\nOptimal ratio (tokens per effective param):\")\n",
" print(f\" From power law at C={ref_flops:.1e}: {optimal_ratio:.1f}\")\n",
" print(f\" Mean across budgets: {mean_ratio:.1f} ± {std_ratio:.1f}\")\n",
" print(f\" Chinchilla reference: 20\")\n",
" print(f\"\\nPer-budget ratios: {[f'{r:.1f}' for r in opt_df['ratio'].values]}\")\n",
"else:\n",
" print(\"Need at least 2 flops budgets to compute power law fits\")"
]
},
{

BIN
dev/scaling_laws_jan26.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

View File

@ -1,143 +0,0 @@
"""
Distributed AdamW optimizer with a fused step function.
A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt.
"""
import torch
import torch.distributed as dist
from torch import Tensor
@torch.compile(dynamic=False, fullgraph=True)
def adamw_step_fused(
p: Tensor,
grad: Tensor,
exp_avg: Tensor,
exp_avg_sq: Tensor,
step_t: Tensor,
lr_t: Tensor,
beta1_t: Tensor,
beta2_t: Tensor,
eps_t: Tensor,
wd_t: Tensor,
) -> None:
"""
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
All in one compiled graph to eliminate Python overhead between ops.
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
"""
# Weight decay (decoupled, applied before the update)
p.mul_(1 - lr_t * wd_t)
# Update running averages (lerp_ is cleaner and fuses well)
exp_avg.lerp_(grad, 1 - beta1_t)
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
# Bias corrections
bias1 = 1 - beta1_t ** step_t
bias2 = 1 - beta2_t ** step_t
# Compute update and apply
denom = (exp_avg_sq / bias2).sqrt() + eps_t
step_size = lr_t / bias1
p.add_(exp_avg / denom, alpha=-step_size)
class DistAdamW(torch.optim.Optimizer):
"""
Distributed AdamW optimizer.
In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
"""
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
rank = dist.get_rank()
world_size = dist.get_world_size()
# Validate
if rank == 0:
for group in param_groups:
assert isinstance(group, dict), "expecting param_groups to be a list of dicts"
assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors"
for p in group['params']:
sliced = p.numel() >= 1024
print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}")
if sliced: # large parameter tensors will be operated on in slices
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
super().__init__(param_groups, defaults)
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
@torch.no_grad()
def step(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
reduce_futures: list[torch.Future] = []
gather_futures: list[torch.Future] = []
grad_slices = []
is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter)
for group in self.param_groups:
params: list[Tensor] = group["params"]
for p in params:
grad = p.grad
# Small params: use all_reduce (no scatter/gather needed)
if p.numel() < 1024:
is_small.append(True)
reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
grad_slices.append(grad)
else:
is_small.append(False)
rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__
grad_slice = torch.empty_like(grad[:rank_size])
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
grad_slices.append(grad_slice)
idx = 0
for group in self.param_groups:
beta1, beta2 = group['betas']
eps = group['eps']
wd = group['weight_decay']
params = group['params']
for p in params:
reduce_futures[idx].wait()
g_slice = grad_slices[idx]
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
state = self.state[p]
# For small params, operate on full param; for large, operate on slice
if is_small[idx]:
p_slice = p
else:
rank_size = p.shape[0] // world_size
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
# State init
if not state:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_slice)
state['exp_avg_sq'] = torch.zeros_like(p_slice)
exp_avg = state['exp_avg']
exp_avg_sq = state['exp_avg_sq']
state['step'] += 1
# Fill 0-D tensors with current values
eff_wd = wd * getattr(p, "wd_mul", 1.0)
self._step_t.fill_(state['step'])
self._lr_t.fill_(lr)
self._beta1_t.fill_(beta1)
self._beta2_t.fill_(beta2)
self._eps_t.fill_(eps)
self._wd_t.fill_(eff_wd)
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
adamw_step_fused(
p_slice, g_slice, exp_avg, exp_avg_sq,
self._step_t, self._lr_t, self._beta1_t, self._beta2_t, self._eps_t, self._wd_t,
)
# Only large params need all_gather
if not is_small[idx]:
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
idx += 1
if gather_futures:
torch.futures.collect_all(gather_futures).wait()

View File

@ -164,7 +164,6 @@ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=Non
def load_model(source, *args, **kwargs):
model_dir = {
"base": "base_checkpoints",
"mid": "mid_checkpoints",
"sft": "chatsft_checkpoints",
"rl": "chatrl_checkpoints",
}[source]

View File

@ -170,7 +170,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
# Precision
if device_type == "cuda":
torch.backends.cuda.matmul.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls
torch.backends.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
@ -200,3 +200,59 @@ class DummyWandb:
pass
def finish(self):
pass
# hardcoded BF16 peak flops for various GPUs
# inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py
# and PR: https://github.com/karpathy/nanochat/pull/147
def get_peak_flops(device_name: str) -> float:
name = device_name.lower()
# Table order matters: more specific patterns first.
_PEAK_FLOPS_TABLE = (
# NVIDIA Blackwell
(["gb200"], 2.5e15),
(["grace blackwell"], 2.5e15),
(["b200"], 2.25e15),
(["b100"], 1.8e15),
# NVIDIA Hopper
(["h200", "nvl"], 836e12),
(["h200", "pcie"], 836e12),
(["h200"], 989e12),
(["h100", "nvl"], 835e12),
(["h100", "pcie"], 756e12),
(["h100"], 989e12),
(["h800", "nvl"], 989e12),
(["h800"], 756e12),
# NVIDIA Ampere data center
(["a100"], 312e12),
(["a800"], 312e12),
(["a40"], 149.7e12),
(["a30"], 165e12),
# NVIDIA Ada data center
(["l40s"], 362e12),
(["l40-s"], 362e12),
(["l40 s"], 362e12),
(["l4"], 121e12),
# AMD CDNA accelerators
(["mi355"], 2.5e15),
(["mi325"], 1.3074e15),
(["mi300x"], 1.3074e15),
(["mi300a"], 980.6e12),
(["mi250x"], 383e12),
(["mi250"], 362.1e12),
# Consumer RTX
(["5090"], 209.5e12),
(["4090"], 165.2e12),
(["3090"], 71e12),
)
for patterns, flops in _PEAK_FLOPS_TABLE:
if all(p in name for p in patterns):
return flops
if "data center gpu max 1550" in name:
# Ponte Vecchio (PVC) - dynamic based on compute units
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
return 512 * max_comp_units * 1300 * 10**6
# Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess
logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%")
return float('inf')

View File

@ -1,24 +1,19 @@
"""
Distributed dataloaders for pretraining.
Two implementations are provided:
1. Original (tokenizing_distributed_data_loader):
- Streams tokens into a flat buffer, reshapes to (B, T)
- Rows may start mid-document (no guaranteed BOS at position 0)
- 100% token utilization, simple and efficient
2. BOS-aligned bestfit (tokenizing_distributed_data_loader_bos_bestfit):
BOS-aligned bestfit:
- Every row starts with BOS token
- Documents packed using best-fit algorithm to minimize cropping
- When no document fits remaining space, crops a document to fill exactly
- 100% utilization (no padding), ~35% tokens cropped at T=2048
The tradeoff: BOS-aligned loses ~35% of tokens to cropping, but ensures that
Compared to the original tokenizing_distributed_data_loader:
BOS-aligned loses ~35% of tokens to cropping, but ensures that
there are fewer "confusing" tokens in the train/val batches as every token can
now attend back to the BOS token and sees the full context of the document.
(2) is the new default if you have enough data.
Fallback to (1) if you have very limited data AND long documents.
Fallback to the original if you have very limited data AND long documents:
https://github.com/karpathy/nanochat/blob/3c3a3d7/nanochat/dataloader.py#L78-L117
"""
import torch
@ -75,48 +70,6 @@ def _document_batches(split, resume_state_dict, tokenizer_batch_size):
epoch += 1
def tokenizing_distributed_data_loader_with_state(tokenizer, B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
"""
Stream pretraining text from parquet files, tokenize, yield training batches.
This is the original dataloader that streams tokens into a flat buffer and reshapes.
Rows may start mid-document (no guaranteed BOS at position 0).
Supports approximate resume via state_dict.
"""
assert split in ["train", "val"], "split must be 'train' or 'val'"
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
needed_tokens = B * T + 1 # +1 for target at last position
bos_token = tokenizer.get_bos_token_id()
token_buffer = []
pq_idx, rg_idx, epoch = 0, 0, 1
while True:
# Accumulate enough tokens
while len(token_buffer) < needed_tokens:
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists:
token_buffer.extend(tokens)
tokens = token_buffer[:needed_tokens] # Read B*T+1 tokens (+1 is only for the target for the last token)
token_buffer = token_buffer[B*T:] # Advance by B*T tokens, so we move exactly one window of B*T tokens over
# Package tokens into inputs and targets, yield
use_cuda = device == "cuda"
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda)
inputs = scratch[:-1].view(B, T).to(device=device, non_blocking=use_cuda)
targets = scratch[1:].view(B, T).to(device=device, non_blocking=use_cuda)
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
def tokenizing_distributed_data_loader(*args, **kwargs):
"""Helper that omits state_dict from yields."""
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
yield inputs, targets
def tokenizing_distributed_data_loader_with_state_bos_bestfit(
tokenizer, B, T, split,
tokenizer_threads=4, tokenizer_batch_size=128,
@ -154,16 +107,26 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
for tokens in token_lists:
doc_buffer.append(tokens)
# Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
# This gives us contiguous views and a single HtoD transfer
use_cuda = device == "cuda"
row_buffer = torch.empty((B, row_capacity), dtype=torch.long) # for building rows without creating Python lists
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU)
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer
cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience
cpu_targets = cpu_buffer[B * T:].view(B, T)
inputs = gpu_buffer[:B * T].view(B, T)
targets = gpu_buffer[B * T:].view(B, T)
while True:
rows = []
for _ in range(B):
row = []
while len(row) < row_capacity:
for row_idx in range(B):
pos = 0
while pos < row_capacity:
# Ensure buffer has documents
while len(doc_buffer) < buffer_size:
refill_buffer()
remaining = row_capacity - len(row)
remaining = row_capacity - pos
# Find largest doc that fits entirely
best_idx = -1
@ -176,21 +139,25 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
if best_idx >= 0:
doc = doc_buffer.pop(best_idx)
row.extend(doc)
doc_len = len(doc)
row_buffer[row_idx, pos:pos + doc_len] = torch.tensor(doc, dtype=torch.long)
pos += doc_len
else:
# No doc fits - crop first doc to fill remaining
doc = doc_buffer.pop(0)
row.extend(doc[:remaining])
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
doc = doc_buffer.pop(shortest_idx)
row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
pos += remaining
rows.append(row[:row_capacity])
# Copy to pinned CPU buffer, then single HtoD transfer
cpu_inputs.copy_(row_buffer[:, :-1])
cpu_targets.copy_(row_buffer[:, 1:])
use_cuda = device == "cuda"
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
inputs = batch_tensor[:, :-1].to(device=device, non_blocking=use_cuda)
targets = batch_tensor[:, 1:].to(device=device, non_blocking=use_cuda)
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
# Single HtoD copy into persistent GPU buffer and yield
gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda)
yield inputs, targets, state_dict
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
"""Helper that omits state_dict from yields."""

View File

@ -90,7 +90,7 @@ class KVCache:
- Position tracked per batch element via cache_seqlens tensor
"""
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype=torch.bfloat16):
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
self.batch_size = batch_size
self.max_seq_len = seq_len
self.n_layers = num_layers
@ -172,6 +172,13 @@ class Engine:
"""Same as generate, but does single prefill and then clones the KV cache."""
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
device = self.model.get_device()
# NOTE: setting the dtype here and in this way is an ugly hack.
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
# We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.
# As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
# I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
# In particular, the KVCache should allocate its tensors lazily
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
rng = torch.Generator(device=device)
rng.manual_seed(seed)
@ -191,6 +198,7 @@ class Engine:
batch_size=1,
seq_len=len(tokens),
device=device,
dtype=dtype,
**kv_model_kwargs,
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
@ -203,6 +211,7 @@ class Engine:
batch_size=num_samples,
seq_len=kv_length_hint,
device=device,
dtype=dtype,
**kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)
@ -297,8 +306,8 @@ if __name__ == "__main__":
"""
import time
# init compute
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# load the model and tokenizer

View File

@ -2,7 +2,7 @@
Unified Flash Attention interface with automatic FA3/SDPA switching.
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
to PyTorch SDPA on non-Hopper GPUs, MPS, and CPU.
to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
Usage (drop-in replacement for FA3):
from nanochat.flash_attention import flash_attn
@ -21,12 +21,14 @@ import torch.nn.functional as F
# Detection: Try to load FA3 on Hopper+ GPUs
# =============================================================================
def _load_flash_attention_3():
"""Try to load Flash Attention 3 (requires Hopper+ GPU)."""
"""Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
if not torch.cuda.is_available():
return None
try:
major, _ = torch.cuda.get_device_capability()
if major < 9: # Hopper is sm90
# FA3 kernels are compiled for Hopper (sm90) only
# Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled
if major != 9:
return None
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
@ -71,27 +73,26 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
# Single token generation
if Tq == 1:
if window >= 0 and window < Tk:
# window is "left" tokens we need to include (window + 1) keys total
start = max(0, Tk - (window + 1))
k = k[:, :, start:, :]
v = v[:, :, start:, :]
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
# Need explicit mask
# Need explicit mask for sliding window/chunk inference
device = q.device
if Tq == Tk:
# Causal + sliding window
mask = torch.tril(torch.ones(Tq, Tk, device=device, dtype=torch.bool))
if window > 0 and window < Tq:
row_idx = torch.arange(Tq, device=device).unsqueeze(1)
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
mask = mask & ((row_idx - col_idx) <= window)
else:
# Chunk inference: attend to prefix + causal within chunk
prefix_len = Tk - Tq
mask = torch.zeros(Tq, Tk, device=device, dtype=torch.bool)
mask[:, :prefix_len] = True
mask[:, prefix_len:] = torch.tril(torch.ones(Tq, Tq, device=device, dtype=torch.bool))
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
mask = col_idx <= row_idx
# sliding window (left)
if window >= 0 and window < Tk:
mask = mask & ((row_idx - col_idx) <= window)
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
# =============================================================================
# Public API: Same interface as FA3
# =============================================================================

272
nanochat/fp8.py Normal file
View File

@ -0,0 +1,272 @@
"""Minimal FP8 training for nanochat — tensorwise dynamic scaling only.
Drop-in replacement for torchao's Float8Linear (~2000 lines) with ~150 lines.
We only need the "tensorwise" recipe (one scalar scale per tensor), not the full
generality of torchao (rowwise scaling, FSDP float8 all-gather, DTensor, tensor
subclass dispatch tables, etc.)
How FP8 training works
======================
A standard Linear layer does one matmul in forward and two in backward:
forward: output = input @ weight.T
backward: grad_input = grad_output @ weight
grad_weight= grad_output.T @ input
FP8 training wraps each of these three matmuls with:
1. Compute scale = FP8_MAX / max(|tensor|) for each operand
2. Quantize: fp8_tensor = clamp(tensor * scale, -FP8_MAX, FP8_MAX).to(fp8)
3. Matmul via torch._scaled_mm (cuBLAS FP8 kernel, ~2x faster than bf16)
4. Dequantize: _scaled_mm handles this internally using the inverse scales
The key insight: torch._scaled_mm and the float8 dtypes are PyTorch built-ins.
torchao is just orchestration around these primitives. We can call them directly.
FP8 dtype choice
================
There are two FP8 formats. We use both, following the standard convention:
- float8_e4m3fn: 4-bit exponent, 3-bit mantissa, range [-448, 448]
Higher precision (more mantissa bits), used for input and weight.
- float8_e5m2: 5-bit exponent, 2-bit mantissa, range [-57344, 57344]
Wider range (more exponent bits), used for gradients which can be large.
torch._scaled_mm layout requirements
=====================================
The cuBLAS FP8 kernel requires specific memory layouts:
- First argument (A): must be row-major (contiguous)
- Second argument (B): must be column-major (B.t().contiguous().t())
If B is obtained by transposing a contiguous tensor (e.g. weight.t()), it is
already column-major no copy needed. Otherwise we use _to_col_major().
How this differs from torchao's approach
========================================
torchao uses a "tensor subclass" architecture: Float8TrainingTensor is a subclass
of torch.Tensor that bundles FP8 data + scale + metadata. It implements
__torch_dispatch__ with a dispatch table that intercepts every aten op (mm, t,
reshape, clone, ...) and handles it in FP8-aware fashion. When you call
output = input @ weight.T
the @ operator dispatches to aten.mm, which gets intercepted and routed to
torch._scaled_mm behind the scenes. This is ~2000 lines of code because you need
a handler for every tensor operation that might touch an FP8 tensor.
We take a simpler approach: a single autograd.Function (_Float8Matmul) that takes
full-precision inputs, quantizes to FP8 internally, calls _scaled_mm, and returns
full-precision outputs. Marked @allow_in_graph so torch.compile treats it as one
opaque node rather than trying to trace inside.
The trade-off is in how torch.compile sees the two approaches:
- torchao: compile decomposes the tensor subclass (via __tensor_flatten__) and
sees every individual op (amax, scale, cast, _scaled_mm) as separate graph
nodes. Inductor can fuse these with surrounding operations (e.g. fuse the
amax computation with the preceding layer's activation function).
- ours: compile sees a single opaque call. It can optimize everything around
the FP8 linear (attention, norms, etc.) but cannot fuse across the boundary.
Both call the exact same cuBLAS _scaled_mm kernel the GPU matmul is identical.
The difference is only in the "glue" ops (amax, scale, cast) which are tiny
compared to the matmul. In practice this means our version is slightly faster
(less compilation overhead, no tensor subclass dispatch cost) but can produce
subtly different floating-point rounding paths under torch.compile, since Inductor
generates a different graph. Numerics are bitwise identical in eager mode.
"""
import torch
import torch.nn as nn
# Avoid division by zero when computing scale from an all-zeros tensor
EPS = 1e-12
@torch.no_grad()
def _to_fp8(x, fp8_dtype):
"""Dynamically quantize a tensor to FP8 using tensorwise scaling.
"Tensorwise" means one scalar scale for the entire tensor (as opposed to
"rowwise" which computes a separate scale per row). Tensorwise is faster
because cuBLAS handles the scaling; rowwise needs the CUTLASS kernel.
Returns (fp8_data, inverse_scale) for use with torch._scaled_mm.
"""
fp8_max = torch.finfo(fp8_dtype).max
# Compute the max absolute value across the entire tensor
amax = x.float().abs().max()
# Scale maps [0, amax] -> [0, fp8_max]. Use float64 for the division to
# ensure consistent numerics between torch.compile and eager mode.
# (torchao does the same upcast — without it, compile/eager can diverge)
scale = fp8_max / amax.double().clamp(min=EPS)
scale = scale.float()
# Quantize: scale into FP8 range, saturate (clamp prevents overflow when
# casting — PyTorch's default is to wrap, not saturate), then cast to FP8
x_scaled = x.float() * scale
x_clamped = x_scaled.clamp(-fp8_max, fp8_max)
x_fp8 = x_clamped.to(fp8_dtype)
# _scaled_mm expects the *inverse* of our scale (it multiplies by this to
# convert FP8 values back to the original range during the matmul)
inv_scale = scale.reciprocal()
return x_fp8, inv_scale
def _to_col_major(x):
"""Rearrange a 2D tensor's memory to column-major layout.
torch._scaled_mm requires its second operand in column-major layout.
The trick: transpose -> contiguous (forces a copy in transposed order)
-> transpose back. The result has the same logical shape but column-major
strides, e.g. a [M, N] tensor gets strides (1, M) instead of (N, 1).
"""
return x.t().contiguous().t()
# allow_in_graph tells torch.compile to treat this as an opaque operation —
# dynamo won't try to decompose it into smaller ops. See the module docstring
# for how this differs from torchao's tensor subclass approach.
@torch._dynamo.allow_in_graph
class _Float8Matmul(torch.autograd.Function):
"""Custom autograd for the three FP8 GEMMs of a Linear layer.
The forward saves input and weight in their original precision for the
backward pass. Each GEMM independently re-quantizes its operands to FP8.
(We don't reuse the forward's FP8 tensors in backward the backward might
want different precision, and saving FP8 would lose information.)
"""
@staticmethod
def forward(ctx, input_2d, weight):
ctx.save_for_backward(input_2d, weight)
# Quantize both operands to e4m3 (higher precision format)
input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
# output = input @ weight.T
# input_fp8 is [B, K] contiguous = row-major (good for first arg)
# weight_fp8 is [N, K] contiguous, so weight_fp8.t() is [K, N] with
# strides (1, K) = column-major (good for second arg, no copy needed!)
output = torch._scaled_mm(
input_fp8,
weight_fp8.t(),
scale_a=input_inv,
scale_b=weight_inv,
out_dtype=input_2d.dtype,
# use_fast_accum=True accumulates the dot products in lower precision.
# Slightly less accurate but measurably faster. Standard practice for
# the forward pass; we use False in backward for more precise gradients.
use_fast_accum=True,
)
return output
@staticmethod
def backward(ctx, grad_output):
input_2d, weight = ctx.saved_tensors
# === GEMM 1: grad_input = grad_output @ weight ===
# Shapes: [B, N] @ [N, K] -> [B, K]
# Gradients use e5m2 (wider range), weights use e4m3 (higher precision)
go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2)
w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn)
# go_fp8 is [B, N] contiguous = row-major, good for first arg
# w_fp8 is [N, K] contiguous = row-major, need column-major for second arg
w_col = _to_col_major(w_fp8)
grad_input = torch._scaled_mm(
go_fp8,
w_col,
scale_a=go_inv,
scale_b=w_inv,
out_dtype=grad_output.dtype,
use_fast_accum=False,
)
# === GEMM 2: grad_weight = grad_output.T @ input ===
# Shapes: [N, B] @ [B, K] -> [N, K]
go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2)
in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
# go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg.
# Transposing gives column-major, but first arg needs row-major,
# so we must call .contiguous() to physically rearrange the memory.
go_T = go_fp8_2.t().contiguous() # [N, B] row-major
in_col = _to_col_major(in_fp8) # [B, K] column-major
grad_weight = torch._scaled_mm(
go_T,
in_col,
scale_a=go_inv_2,
scale_b=in_inv,
out_dtype=grad_output.dtype,
use_fast_accum=False,
)
return grad_input, grad_weight
class Float8Linear(nn.Linear):
"""Drop-in nn.Linear replacement that does FP8 compute.
Weights and biases remain in their original precision (e.g. fp32/bf16).
Only the matmul is performed in FP8 via the _Float8Matmul autograd function.
"""
def forward(self, input):
# Replicate the autocast behavior of F.linear — when autocast is active,
# we need to manually cast input to the autocast dtype (e.g. bf16),
# since we bypass F.linear's built-in autocast handling.
if torch.is_autocast_enabled():
input = input.to(torch.get_autocast_gpu_dtype())
# _scaled_mm only works on 2D tensors, so flatten batch dimensions
orig_shape = input.shape
input_2d = input.reshape(-1, orig_shape[-1])
output = _Float8Matmul.apply(input_2d, self.weight)
output = output.reshape(*orig_shape[:-1], output.shape[-1])
if self.bias is not None:
output = output + self.bias.to(output.dtype)
return output
@classmethod
def from_float(cls, mod):
"""Create Float8Linear from nn.Linear, sharing the same weight and bias.
Uses meta device to avoid allocating a temporary weight tensor we
create the module shell on meta (shapes/dtypes only, no memory), then
point .weight and .bias to the original module's parameters.
"""
with torch.device("meta"):
new_mod = cls(mod.in_features, mod.out_features, bias=False)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
return new_mod
class Float8LinearConfig:
"""Minimal config matching torchao's API. Only tensorwise recipe is supported."""
@staticmethod
def from_recipe_name(recipe_name):
if recipe_name != "tensorwise":
raise ValueError(
f"Only 'tensorwise' recipe is supported, got '{recipe_name}'. "
f"Rowwise/axiswise recipes require the full torchao library."
)
return Float8LinearConfig()
def convert_to_float8_training(module, *, config=None, module_filter_fn=None):
"""Replace nn.Linear layers with Float8Linear throughout a module.
Walks the module tree in post-order (children before parents) and swaps
each nn.Linear that passes the optional filter. The new Float8Linear shares
the original weight and bias tensors no copies, no extra memory.
Args:
module: Root module to convert.
config: Float8LinearConfig (accepted for API compat, only tensorwise supported).
module_filter_fn: Optional filter(module, fqn) -> bool. Only matching Linears
are converted. Common use: skip layers with dims not divisible by 16
(hardware requirement for FP8 matmuls on H100).
"""
def _convert(mod, prefix=""):
for name, child in mod.named_children():
fqn = f"{prefix}.{name}" if prefix else name
_convert(child, fqn)
if isinstance(child, nn.Linear) and not isinstance(child, Float8Linear):
if module_filter_fn is None or module_filter_fn(child, fqn):
setattr(mod, name, Float8Linear.from_float(child))
_convert(module)
return module

View File

@ -20,16 +20,15 @@ import torch.nn as nn
import torch.nn.functional as F
from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
from nanochat.optim import MuonAdamW, DistMuonAdamW
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
from nanochat.flash_attention import flash_attn
@dataclass
class GPTConfig:
sequence_len: int = 1024
vocab_size: int = 50304
sequence_len: int = 2048
vocab_size: int = 32768
n_layer: int = 12
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (GQA)
@ -37,7 +36,7 @@ class GPTConfig:
# Sliding window attention pattern string, tiled across layers. Final layer always L.
# Characters: L=long (full context), S=short (half context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
window_pattern: str = "L"
window_pattern: str = "SSSL"
def norm(x):
@ -45,6 +44,10 @@ def norm(x):
return F.rms_norm(x, (x.size(-1),))
def has_ve(layer_idx, n_layer):
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
return layer_idx % 2 == (n_layer - 1) % 2
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
@ -67,8 +70,10 @@ class CausalSelfAttention(nn.Module):
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.ve_gate_channels = 32
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
def forward(self, x, cos_sin, window_size, kv_cache):
def forward(self, x, ve, cos_sin, window_size, kv_cache):
B, T, C = x.size()
# Project the input to get queries, keys, and values
@ -77,6 +82,12 @@ class CausalSelfAttention(nn.Module):
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
if ve is not None:
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 2)
v = v + gate.unsqueeze(-1) * ve
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
@ -126,8 +137,8 @@ class Block(nn.Module):
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(self, x, cos_sin, window_size, kv_cache):
x = x + self.attn(norm(x), cos_sin, window_size, kv_cache)
def forward(self, x, ve, cos_sin, window_size, kv_cache):
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
x = x + self.mlp(norm(x))
return x
@ -160,6 +171,10 @@ class GPT(nn.Module):
# Separate parameters so they can have different optimizer treatment
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
# Value embeddings (ResFormer-style): alternating layers, last layer always included
head_dim = config.n_embd // config.n_head
kv_dim = config.n_kv_head * head_dim
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
@ -170,6 +185,7 @@ class GPT(nn.Module):
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)
@torch.no_grad()
def init_weights(self):
"""
Initialize the full model in this one function for maximum clarity.
@ -201,18 +217,28 @@ class GPT(nn.Module):
torch.nn.init.zeros_(block.mlp.c_proj.weight)
# Per-layer scalars
with torch.no_grad():
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
# Value embeddings (init like c_v: uniform with same std)
for ve in self.value_embeds.values():
torch.nn.init.uniform_(ve.weight, -s, s)
# Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral)
for block in self.transformer.h:
if block.attn.ve_gate is not None:
torch.nn.init.zeros_(block.attn.ve_gate.weight)
# Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
# Cast token embeddings to bf16: optimizer can tolerate it and it saves memory
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory
if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16)
for ve in self.value_embeds.values():
ve.to(dtype=torch.bfloat16)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently
@ -277,7 +303,9 @@ class GPT(nn.Module):
"""
nparams = sum(p.numel() for p in self.parameters())
# Exclude non-matmul params: embeddings and per-layer scalars
nparams_exclude = self.transformer.wte.weight.numel() + self.resid_lambdas.numel() + self.x0_lambdas.numel()
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel())
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window
attn_flops = 0
@ -290,49 +318,72 @@ class GPT(nn.Module):
def num_scaling_params(self):
"""
Return all of the parameters, same as Chinchilla paper.
Kaplan et al. did not include embedding parameters and said that this led to cleaner scaling laws.
But Kaplan et al. also had a bug in their results (as pointed out by Chinchilla).
My own experiments in nanochat confirm the Chinchilla approach gives the much cleaner scaling law.
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper <- good).
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper <- bad)
"""
nparams = sum(p.numel() for p in self.parameters())
return nparams
Return detailed parameter counts for scaling law analysis.
Different papers use different conventions:
- Kaplan et al. excluded embedding parameters
- Chinchilla included all parameters
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper)
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper)
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
Returns a dict with counts for each parameter group, so downstream analysis
can experiment with which combination gives the cleanest scaling laws.
"""
# Count each group separately (mirrors the grouping in setup_optimizers)
wte = sum(p.numel() for p in self.transformer.wte.parameters())
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
lm_head = sum(p.numel() for p in self.lm_head.parameters())
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
total = wte + value_embeds + lm_head + transformer_matrices + scalars
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
return {
'wte': wte,
'value_embeds': value_embeds,
'lm_head': lm_head,
'transformer_matrices': transformer_matrices,
'scalars': scalars,
'total': total,
}
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info()
# Separate out all parameters into 5 groups (matrix, embedding, lm_head, resid_lambdas, x0_lambdas)
# Separate out all parameters into groups
matrix_params = list(self.transformer.h.parameters())
value_embeds_params = list(self.value_embeds.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(resid_params) + len(x0_params)
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
dict(params=x0_params, lr=scalar_lr),
# Build param_groups with all required fields explicit
param_groups = [
# AdamW groups (embeddings, lm_head, scalars)
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
]
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
# Create the Muon optimizer for the linear layers
muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay)
MuonFactory = DistMuon if ddp else Muon
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
# Combine them the two optimizers into one list
optimizers = [adamw_optimizer, muon_optimizer]
for opt in optimizers:
for group in opt.param_groups:
group["initial_lr"] = group["lr"]
return optimizers
# Muon groups (matrix params, grouped by shape for stacking)
for shape in sorted({p.shape for p in matrix_params}):
group_params = [p for p in matrix_params if p.shape == shape]
param_groups.append(dict(
kind='muon', params=group_params, lr=matrix_lr,
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
))
Factory = DistMuonAdamW if ddp else MuonAdamW
optimizer = Factory(param_groups)
for group in optimizer.param_groups:
group["initial_lr"] = group["lr"]
return optimizer
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size()
@ -346,12 +397,13 @@ class GPT(nn.Module):
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
# Forward the trunk of the Transformer
x = self.transformer.wte(idx)
x = self.transformer.wte(idx) # embed current token
x = norm(x)
x0 = x # save initial normalized embedding for x0 residual
for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
x = block(x, cos_sin, self.window_sizes[i], kv_cache)
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
x = norm(x)
# Forward the lm_head (compute logits)
@ -388,7 +440,7 @@ class GPT(nn.Module):
for _ in range(max_tokens):
logits = self.forward(ids) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size)
if top_k is not None:
if top_k is not None and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
if temperature > 0:

View File

@ -1,352 +0,0 @@
"""
Muon optimizer adapted and simplified from modded-nanogpt.
https://github.com/KellerJordan/modded-nanogpt
Background:
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
Polar Express Sign Method for orthogonalization.
https://arxiv.org/pdf/2505.16932
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
Some of the changes in nanochat implementation:
- Uses a simpler, more general approach to parameter grouping and stacking
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
"""
import torch
from torch import Tensor
import torch.distributed as dist
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
# From https://arxiv.org/pdf/2505.16932
polar_express_coeffs = [
(8.156554524902461, -22.48329292557795, 15.878769915207462),
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]
@torch.compile(dynamic=False, fullgraph=True)
def muon_step_fused(
stacked_grads: Tensor,
stacked_params: Tensor,
momentum_buffer: Tensor,
second_momentum_buffer: Tensor,
momentum_t: Tensor,
lr_t: Tensor,
wd_t: Tensor,
beta2_t: Tensor,
ns_steps: int,
red_dim: int,
) -> None:
"""
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
All in one compiled graph to eliminate Python overhead between ops.
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
"""
# Nesterov momentum
momentum = momentum_t.to(stacked_grads.dtype)
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
g = stacked_grads.lerp_(momentum_buffer, momentum)
# Polar express
X = g.bfloat16()
if g.size(-2) > g.size(-1):
X = X.mT
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
if g.size(-2) > g.size(-1):
X = X.mT
g = X
# Variance reduction
beta2 = beta2_t.to(g.dtype)
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = g.size(red_dim)
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
v_norm = v_norm_sq.sqrt()
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
g = g * final_scale.to(g.dtype)
# Cautious weight decay + parameter update
lr = lr_t.to(g.dtype)
wd = wd_t.to(g.dtype)
mask = (g * stacked_params) >= 0
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
https://kellerjordan.github.io/posts/muon/
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- This optimizer should not be used for the embedding layer, the final fully connected layer,
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
Arguments:
lr: The learning rate used by the internal SGD.
momentum: The momentum used by the internal SGD.
ns_steps: The number of Newton-Schulz iteration steps to use.
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
"""
def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0):
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
params = list(params) # ensure we have a list, not an e.g. (exhaustible) iterator
# Group by shape so we can stack tensors
shapes = sorted({p.shape for p in params})
param_groups = []
for shape in shapes:
group_params = [p for p in params if p.shape == shape]
param_groups.append(dict(params=group_params))
super().__init__(param_groups, defaults)
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
@torch.no_grad()
def step(self):
for group in self.param_groups:
params: list[Tensor] = group["params"]
if not params:
continue
# Get or create group-level buffers (stored in first param's state for convenience)
state = self.state[params[0]]
num_params = len(params) # e.g.: 12 (for a d12 model)
# e.g.: shape = (768, 3072), device = cuda:0, dtype = torch.float32, for one of the MLP projections
shape, device, dtype = params[0].shape, params[0].device, params[0].dtype
# Momentum for every individual parameter
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
momentum_buffer = state["momentum_buffer"] # e.g.: (12, 768, 3072)
# Second momentum buffer is factored, either per-row or per-column
if "second_momentum_buffer" not in state:
if shape[-2] >= shape[-1]:
state["second_momentum_buffer"] = torch.zeros(num_params, shape[-2], 1, dtype=dtype, device=device)
else:
state["second_momentum_buffer"] = torch.zeros(num_params, 1, shape[-1], dtype=dtype, device=device)
second_momentum_buffer = state["second_momentum_buffer"] # (12, 1, 3072)
red_dim = -1 if shape[-2] >= shape[-1] else -2 # e.g.: -2
# Stack grads and params
stacked_grads = torch.stack([p.grad for p in params]) # (12, 768, 3072)
stacked_params = torch.stack(params) # (12, 768, 3072)
# Fill all the 0-D tensors with current values
self._momentum_t.fill_(group["momentum"])
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
self._wd_t.fill_(group["weight_decay"])
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
muon_step_fused(
stacked_grads,
stacked_params,
momentum_buffer,
second_momentum_buffer,
self._momentum_t,
self._lr_t,
self._wd_t,
self._beta2_t,
group["ns_steps"],
red_dim,
)
# Copy back to original params: [(768, 3072), (768, 3072), ...] <- (12, 768, 3072)
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
class DistMuon(torch.optim.Optimizer):
"""
Distributed version of the Muon optimizer.
"""
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
params = list(params)
world_size = dist.get_world_size()
rank = dist.get_rank()
# Group all parameters by their shape
shapes = sorted({p.shape for p in params}) # sort for deterministic ordering across ranks
param_groups = []
for shape in shapes:
group_params = [p for p in params if p.shape == shape]
device, dtype = group_params[0].device, group_params[0].dtype
assert all(p.device == device for p in group_params)
assert all(p.dtype == dtype for p in group_params)
# Compute chunk size for this group (how many params each rank owns)
chunk_size = (len(group_params) + world_size - 1) // world_size
if rank == 0:
print(f"Muon: {len(group_params)} params of shape {shape}, chunk_size={chunk_size}")
param_groups.append(dict(params=group_params, chunk_size=chunk_size))
super().__init__(param_groups, defaults)
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
@torch.no_grad()
def step(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
# Ensure all grads exist
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
# First pass: stack grads and kick off reduce_scatter for each group
group_infos = []
for group in self.param_groups:
params: list[Tensor] = group["params"]
chunk_size = group["chunk_size"]
padded_num_params = chunk_size * world_size
shape = params[0].shape
device, dtype = params[0].device, params[0].dtype
# Stack all gradients into a single tensor (single kernel via torch.stack)
grad_stack = torch.stack([p.grad for p in params])
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
stacked_grads[:len(params)].copy_(grad_stack)
# Zero-pad if we have fewer params than padded size
if len(params) < padded_num_params:
stacked_grads[len(params):].zero_()
# Output buffer for this rank's chunk
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
# Async reduce_scatter on the stacked tensor
reduce_future = dist.reduce_scatter_tensor(
grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True
).get_future()
group_infos.append(dict(
grad_chunk=grad_chunk,
reduce_future=reduce_future,
stacked_grads=stacked_grads, # reuse for all_gather output
))
# Second pass: wait for reduce, compute batched updates, kick off all_gather
all_gather_futures = []
for group, info in zip(self.param_groups, group_infos):
info["reduce_future"].wait()
params = group["params"]
chunk_size = group["chunk_size"]
shape = params[0].shape
device, dtype = params[0].device, params[0].dtype
grad_chunk = info["grad_chunk"]
# How many params does this rank actually own?
start_idx = rank * chunk_size
num_owned = min(chunk_size, max(0, len(params) - start_idx))
# Get or create group-level state (stored keyed by first param)
state = self.state[params[0]]
# Momentum buffer
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
momentum_buffer = state["momentum_buffer"]
# Second momentum buffer is factored, either per-row or per-column
if "second_momentum_buffer" not in state:
if shape[-2] >= shape[-1]:
state["second_momentum_buffer"] = torch.zeros(chunk_size, shape[-2], 1, dtype=dtype, device=device)
else:
state["second_momentum_buffer"] = torch.zeros(chunk_size, 1, shape[-1], dtype=dtype, device=device)
second_momentum_buffer = state["second_momentum_buffer"]
red_dim = -1 if shape[-2] >= shape[-1] else -2
# Build updated_params tensor for all_gather
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
if num_owned > 0:
# Stack owned params (single kernel via torch.stack)
owned_params = [params[start_idx + i] for i in range(num_owned)]
stacked_owned_params = torch.stack(owned_params)
# Get owned slices of buffers and grads
owned_grads = grad_chunk[:num_owned]
owned_momentum = momentum_buffer[:num_owned]
owned_second_momentum = second_momentum_buffer[:num_owned]
# Fill 0-D tensors with current values
self._momentum_t.fill_(group["momentum"])
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
self._wd_t.fill_(group["weight_decay"])
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
muon_step_fused(
owned_grads,
stacked_owned_params,
owned_momentum,
owned_second_momentum,
self._momentum_t,
self._lr_t,
self._wd_t,
self._beta2_t,
group["ns_steps"],
red_dim,
)
# Copy updated params to output buffer
updated_params[:num_owned].copy_(stacked_owned_params)
# Zero-pad the rest (for ranks that own fewer params)
if num_owned < chunk_size:
updated_params[num_owned:].zero_()
# Reuse stacked_grads buffer for all_gather output
stacked_params = info["stacked_grads"]
# Async all_gather to replicate updated params to all ranks
gather_future = dist.all_gather_into_tensor(
stacked_params, updated_params, async_op=True
).get_future()
all_gather_futures.append(dict(
gather_future=gather_future,
stacked_params=stacked_params,
params=params,
))
# Final pass: wait for all_gather and copy back to params
for info in all_gather_futures:
info["gather_future"].wait()
stacked_params = info["stacked_params"]
params = info["params"]
# Batched copy back (single kernel instead of N individual copies)
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))

533
nanochat/optim.py Normal file
View File

@ -0,0 +1,533 @@
"""
A nice and efficient mixed AdamW/Muon Combined Optimizer.
Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon.
Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed.
Addapted from: https://github.com/KellerJordan/modded-nanogpt
Further contributions from @karpathy and @chrisjmccormick.
"""
import torch
import torch.distributed as dist
from torch import Tensor
# -----------------------------------------------------------------------------
"""
Good old AdamW optimizer, fused kernel.
https://arxiv.org/abs/1711.05101
"""
@torch.compile(dynamic=False, fullgraph=True)
def adamw_step_fused(
p: Tensor, # (32768, 768) - parameter tensor
grad: Tensor, # (32768, 768) - gradient, same shape as p
exp_avg: Tensor, # (32768, 768) - first moment, same shape as p
exp_avg_sq: Tensor, # (32768, 768) - second moment, same shape as p
step_t: Tensor, # () - 0-D CPU tensor, step count
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
beta1_t: Tensor, # () - 0-D CPU tensor, beta1
beta2_t: Tensor, # () - 0-D CPU tensor, beta2
eps_t: Tensor, # () - 0-D CPU tensor, epsilon
wd_t: Tensor, # () - 0-D CPU tensor, weight decay
) -> None:
"""
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
All in one compiled graph to eliminate Python overhead between ops.
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
"""
# Weight decay (decoupled, applied before the update)
p.mul_(1 - lr_t * wd_t)
# Update running averages (lerp_ is cleaner and fuses well)
exp_avg.lerp_(grad, 1 - beta1_t)
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
# Bias corrections
bias1 = 1 - beta1_t ** step_t
bias2 = 1 - beta2_t ** step_t
# Compute update and apply
denom = (exp_avg_sq / bias2).sqrt() + eps_t
step_size = lr_t / bias1
p.add_(exp_avg / denom, alpha=-step_size)
# -----------------------------------------------------------------------------
"""
Muon optimizer adapted and simplified from modded-nanogpt.
https://github.com/KellerJordan/modded-nanogpt
Background:
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
Polar Express Sign Method for orthogonalization.
https://arxiv.org/pdf/2505.16932
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
NorMuon variance reduction: per-neuron/column adaptive learning rate that normalizes
update scales after orthogonalization (Muon's output has non-uniform scales across neurons).
https://arxiv.org/pdf/2510.05491
Some of the changes in nanochat implementation:
- Uses a simpler, more general approach to parameter grouping and stacking
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
"""
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
# From https://arxiv.org/pdf/2505.16932
polar_express_coeffs = [
(8.156554524902461, -22.48329292557795, 15.878769915207462),
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]
@torch.compile(dynamic=False, fullgraph=True)
def muon_step_fused(
stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients
stacked_params: Tensor, # (12, 768, 3072) - stacked parameters
momentum_buffer: Tensor, # (12, 768, 3072) - first moment buffer
second_momentum_buffer: Tensor, # (12, 768, 1) or (12, 1, 3072) - factored second moment
momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
wd_t: Tensor, # () - 0-D CPU tensor, weight decay
beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment
ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations
red_dim: int, # -1 or -2 - reduction dimension for variance
) -> None:
"""
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
All in one compiled graph to eliminate Python overhead between ops.
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
"""
# Nesterov momentum
momentum = momentum_t.to(stacked_grads.dtype)
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
g = stacked_grads.lerp_(momentum_buffer, momentum)
# Polar express
X = g.bfloat16()
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
if g.size(-2) > g.size(-1): # Tall matrix
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X.mT @ X
B = b * A + c * (A @ A)
X = a * X + X @ B
else: # Wide matrix (original math)
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
g = X
# Variance reduction
beta2 = beta2_t.to(g.dtype)
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = g.size(red_dim)
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
v_norm = v_norm_sq.sqrt()
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
g = g * final_scale.to(g.dtype)
# Cautious weight decay + parameter update
lr = lr_t.to(g.dtype)
wd = wd_t.to(g.dtype)
mask = (g * stacked_params) >= 0
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
# -----------------------------------------------------------------------------
# Single GPU version of the MuonAdamW optimizer.
# Used mostly for reference, debugging and testing.
class MuonAdamW(torch.optim.Optimizer):
"""
Combined optimizer: Muon for 2D matrix params, AdamW for others, single GPU version.
AdamW - Fused AdamW optimizer step.
Muon - MomentUm Orthogonalized by Newton-schulz
https://kellerjordan.github.io/posts/muon/
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- The Muon optimizer should not be used for the embedding layer, the final fully connected layer,
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
Arguments:
param_groups: List of dicts, each containing:
- 'params': List of parameters
- 'kind': 'adamw' or 'muon'
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
"""
def __init__(self, param_groups: list[dict]):
super().__init__(param_groups, defaults={})
# 0-D CPU tensors to avoid torch.compile recompilation when values change
# AdamW tensors
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
# Muon tensors
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
def _step_adamw(self, group: dict) -> None:
"""
AdamW update for each param in the group individually.
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
"""
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
# State init
if not state:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
exp_avg = state['exp_avg']
exp_avg_sq = state['exp_avg_sq']
state['step'] += 1
# Fill 0-D tensors with current values
self._adamw_step_t.fill_(state['step'])
self._adamw_lr_t.fill_(group['lr'])
self._adamw_beta1_t.fill_(group['betas'][0])
self._adamw_beta2_t.fill_(group['betas'][1])
self._adamw_eps_t.fill_(group['eps'])
self._adamw_wd_t.fill_(group['weight_decay'])
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
adamw_step_fused(
p, grad, exp_avg, exp_avg_sq,
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
)
def _step_muon(self, group: dict) -> None:
"""
Muon update for all params in the group (stacked for efficiency).
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
"""
params: list[Tensor] = group['params']
if not params:
return
# Get or create group-level buffers (stored in first param's state for convenience)
p = params[0]
state = self.state[p]
num_params = len(params)
shape, device, dtype = p.shape, p.device, p.dtype
# Momentum for every individual parameter
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
momentum_buffer = state["momentum_buffer"]
# Second momentum buffer is factored, either per-row or per-column
if "second_momentum_buffer" not in state:
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
second_momentum_buffer = state["second_momentum_buffer"]
red_dim = -1 if shape[-2] >= shape[-1] else -2
# Stack grads and params (NOTE: this assumes all params have the same shape)
stacked_grads = torch.stack([p.grad for p in params])
stacked_params = torch.stack(params)
# Fill all the 0-D tensors with current values
self._muon_momentum_t.fill_(group["momentum"])
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
self._muon_wd_t.fill_(group["weight_decay"])
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
muon_step_fused(
stacked_grads,
stacked_params,
momentum_buffer,
second_momentum_buffer,
self._muon_momentum_t,
self._muon_lr_t,
self._muon_wd_t,
self._muon_beta2_t,
group["ns_steps"],
red_dim,
)
# Copy back to original params
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
@torch.no_grad()
def step(self):
for group in self.param_groups:
if group['kind'] == 'adamw':
self._step_adamw(group)
elif group['kind'] == 'muon':
self._step_muon(group)
else:
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
# -----------------------------------------------------------------------------
# Distributed version of the MuonAdamW optimizer.
# Used for training on multiple GPUs.
class DistMuonAdamW(torch.optim.Optimizer):
"""
Combined distributed optimizer: Muon for 2D matrix params, AdamW for others.
See MuonAdamW for the algorithmic details of each optimizer. This class adds
distributed communication to enable multi-GPU training without PyTorch DDP.
Design Goals:
- Overlap communication with computation (async ops)
- Minimize memory by sharding optimizer states across ranks (ZeRO-2 style)
- Batch small tensors into single comm ops where possible
Communication Pattern (3-phase async):
We use a 3-phase structure to maximize overlap between communication and compute:
Phase 1: Launch all async reduce ops
- Kick off all reduce_scatter/all_reduce operations
- Don't wait - let them run in background while we continue
Phase 2: Wait for reduces, compute updates, launch gathers
- For each group: wait for its reduce, compute the update, launch gather
- By processing groups in order, earlier gathers run while later computes happen
Phase 3: Wait for gathers, copy back
- Wait for all gathers to complete
- Copy updated params back to original tensors (Muon only)
AdamW Communication (ZeRO-2 style):
- Small params (<1024 elements): all_reduce gradients, update full param on each rank.
Optimizer state is replicated but these params are tiny (scalars, biases).
- Large params: reduce_scatter gradients so each rank gets 1/N of the grad, update
only that slice, then all_gather the updated slices. Optimizer state (exp_avg,
exp_avg_sq) is sharded - each rank only stores state for its slice.
Requires param.shape[0] divisible by world_size.
Muon Communication (stacked + chunked):
- All params in a Muon group must have the same shape (caller's responsibility).
- Stack all K params into a single (K, *shape) tensor for efficient comm.
- Divide K params across N ranks: each rank "owns" ceil(K/N) params.
- reduce_scatter the stacked grads so each rank gets its chunk.
- Each rank computes Muon update only for params it owns.
- all_gather the updated params back to all ranks.
- Optimizer state (momentum_buffer, second_momentum_buffer) is sharded by chunk.
- Padding: if K doesn't divide evenly, we zero-pad to (ceil(K/N) * N) for comm,
then ignore the padding when copying back.
Buffer Reuse:
- For Muon, we allocate stacked_grads for reduce_scatter input, then reuse the
same buffer as the output for all_gather (stacked_params). This saves memory
since we don't need both buffers simultaneously.
Arguments:
param_groups: List of dicts, each containing:
- 'params': List of parameters
- 'kind': 'adamw' or 'muon'
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
"""
def __init__(self, param_groups: list[dict]):
super().__init__(param_groups, defaults={})
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
def _reduce_adamw(self, group: dict, world_size: int) -> dict:
"""Launch async reduce ops for AdamW group. Returns info dict with per-param infos."""
param_infos = {}
for p in group['params']:
grad = p.grad
if p.numel() < 1024:
# Small params: all_reduce (no scatter/gather needed)
future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
param_infos[p] = dict(future=future, grad_slice=grad, is_small=True)
else:
# Large params: reduce_scatter
assert grad.shape[0] % world_size == 0, f"AdamW reduce_scatter requires shape[0] ({grad.shape[0]}) divisible by world_size ({world_size})"
rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size])
future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
param_infos[p] = dict(future=future, grad_slice=grad_slice, is_small=False)
return dict(param_infos=param_infos)
def _reduce_muon(self, group: dict, world_size: int) -> dict:
"""Launch async reduce op for Muon group. Returns info dict."""
params = group['params']
chunk_size = (len(params) + world_size - 1) // world_size
padded_num_params = chunk_size * world_size
p = params[0]
shape, device, dtype = p.shape, p.device, p.dtype
# Stack grads and zero-pad to padded_num_params
grad_stack = torch.stack([p.grad for p in params])
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
stacked_grads[:len(params)].copy_(grad_stack)
if len(params) < padded_num_params:
stacked_grads[len(params):].zero_()
# Reduce_scatter to get this rank's chunk
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future()
return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size)
def _compute_adamw(self, group: dict, info: dict, gather_list: list, rank: int, world_size: int) -> None:
"""Wait for reduce, compute AdamW updates, launch gathers for large params."""
param_infos = info['param_infos']
for p in group['params']:
pinfo = param_infos[p]
pinfo['future'].wait()
grad_slice = pinfo['grad_slice']
state = self.state[p]
# For small params, operate on full param; for large, operate on slice
if pinfo['is_small']:
p_slice = p
else:
rank_size = p.shape[0] // world_size
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
# State init
if not state:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_slice)
state['exp_avg_sq'] = torch.zeros_like(p_slice)
state['step'] += 1
# Fill 0-D tensors and run fused kernel
self._adamw_step_t.fill_(state['step'])
self._adamw_lr_t.fill_(group['lr'])
self._adamw_beta1_t.fill_(group['betas'][0])
self._adamw_beta2_t.fill_(group['betas'][1])
self._adamw_eps_t.fill_(group['eps'])
self._adamw_wd_t.fill_(group['weight_decay'])
adamw_step_fused(
p_slice, grad_slice, state['exp_avg'], state['exp_avg_sq'],
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
)
# Large params need all_gather
if not pinfo['is_small']:
future = dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()
gather_list.append(dict(future=future, params=None))
def _compute_muon(self, group: dict, info: dict, gather_list: list, rank: int) -> None:
"""Wait for reduce, compute Muon updates, launch gather."""
info['future'].wait()
params = group['params']
chunk_size = info['chunk_size']
grad_chunk = info['grad_chunk']
p = params[0]
shape, device, dtype = p.shape, p.device, p.dtype
# How many params does this rank own?
start_idx = rank * chunk_size
num_owned = min(chunk_size, max(0, len(params) - start_idx))
# Get or create group-level state
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
if "second_momentum_buffer" not in state:
state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1])
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
red_dim = -1 if shape[-2] >= shape[-1] else -2
# Build output buffer for all_gather
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
if num_owned > 0:
owned_params = [params[start_idx + i] for i in range(num_owned)]
stacked_owned = torch.stack(owned_params)
# Fill 0-D tensors and run fused kernel
self._muon_momentum_t.fill_(group["momentum"])
self._muon_beta2_t.fill_(group["beta2"])
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
self._muon_wd_t.fill_(group["weight_decay"])
muon_step_fused(
grad_chunk[:num_owned], stacked_owned,
state["momentum_buffer"][:num_owned], state["second_momentum_buffer"][:num_owned],
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, self._muon_beta2_t,
group["ns_steps"], red_dim,
)
updated_params[:num_owned].copy_(stacked_owned)
if num_owned < chunk_size:
updated_params[num_owned:].zero_()
# Reuse stacked_grads buffer for all_gather output
stacked_params = info["stacked_grads"]
future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future()
gather_list.append(dict(future=future, stacked_params=stacked_params, params=params))
def _finish_gathers(self, gather_list: list) -> None:
"""Wait for all gathers and copy Muon params back."""
for info in gather_list:
info["future"].wait()
if info["params"] is not None:
# Muon: copy from stacked buffer back to individual params
torch._foreach_copy_(info["params"], list(info["stacked_params"][:len(info["params"])].unbind(0)))
@torch.no_grad()
def step(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
# Phase 1: launch all async reduce ops
reduce_infos: list[dict] = []
for group in self.param_groups:
if group['kind'] == 'adamw':
reduce_infos.append(self._reduce_adamw(group, world_size))
elif group['kind'] == 'muon':
reduce_infos.append(self._reduce_muon(group, world_size))
else:
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
# Phase 2: wait for reduces, compute updates, launch gathers
gather_list: list[dict] = []
for group, info in zip(self.param_groups, reduce_infos):
if group['kind'] == 'adamw':
self._compute_adamw(group, info, gather_list, rank, world_size)
elif group['kind'] == 'muon':
self._compute_muon(group, info, gather_list, rank)
else:
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
# Phase 3: wait for gathers, copy back
self._finish_gathers(gather_list)

View File

@ -211,8 +211,6 @@ EXPECTED_FILES = [
"base-model-training.md",
"base-model-loss.md",
"base-model-evaluation.md",
"midtraining.md",
"chat-evaluation-mid.md",
"chat-sft.md",
"chat-evaluation-sft.md",
"chat-rl.md",
@ -316,8 +314,6 @@ class Report:
# extract the most important metrics from the sections
if file_name == "base-model-evaluation.md":
final_metrics["base"] = extract(section, "CORE")
if file_name == "chat-evaluation-mid.md":
final_metrics["mid"] = extract(section, chat_metrics)
if file_name == "chat-evaluation-sft.md":
final_metrics["sft"] = extract(section, chat_metrics)
if file_name == "chat-evaluation-rl.md":
@ -337,7 +333,7 @@ class Report:
# Custom ordering: CORE first, ChatCORE last, rest in middle
all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x))
# Fixed column widths
stages = ["base", "mid", "sft", "rl"]
stages = ["base", "sft", "rl"]
metric_width = 15
value_width = 8
# Write table header

View File

@ -19,7 +19,7 @@ dependencies = [
"tabulate>=0.9.0",
"tiktoken>=0.11.0",
"tokenizers>=0.22.0",
"torch>=2.9.0",
"torch==2.9.1",
"transformers>=4.57.3",
"uvicorn>=0.36.0",
"wandb>=0.21.3",
@ -59,10 +59,10 @@ explicit = true
[project.optional-dependencies]
cpu = [
"torch>=2.9.1",
"torch==2.9.1",
]
gpu = [
"torch>=2.9.1",
"torch==2.9.1",
]
[tool.uv]

View File

@ -1,92 +0,0 @@
#!/bin/bash
# The $1000 tier of nanochat
# Designed to run end-to-end for $1000/24 ~= 41.6 hours on an 8XH100 node
# A bit sparser on comments, see speedrun.sh for more detail
# all the setup stuff
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
uv sync --extra gpu
source .venv/bin/activate
if [ -z "$WANDB_RUN" ]; then
WANDB_RUN=dummy
fi
python -m nanochat.report reset
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
# train tokenizer on ~4B characters and kick off download of the rest for pretraining
python -m nanochat.dataset -n 16
# start downloading the rest of the shards for a total of 1200 (see below why 1200)
python -m nanochat.dataset -n 1200 &
# todo: download the rest of it
python -m scripts.tok_train --max-chars=4000000000 --vocab-size=65536
python -m scripts.tok_eval
# Documenting my process for determining the hyperparameters for this run1000.sh script:
# We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute
# 1) I guessed the model size for this to be about depth=32
# 2) Determine the device_batch_size that fits:
# Running the base_train.py script with --depth=32, I saw that --device-batch-size=16
# runs out of memory, but --device-batch-size=8 fits. Inspecting `nvidia-smi` during training,
# I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%.
# So the training script was running ok and showed:
# Vocab size: 65,536
# num_layers: 32
# model_dim: 2048
# num_heads: 16
# num_kv_heads: 16
# Tokens / micro-batch / rank: 8 x 2048 = 16,384
# Tokens / micro-batch: 131,072
# Total batch size 524,288 => gradient accumulation steps: 4
# Number of parameters: 1,879,048,192
# Estimated FLOPs per token: 1.207960e+10
# Calculated number of iterations from target data:param ratio: 71,680
# Total number of training tokens: 37,580,963,840
# Tokens : Params ratio: 20.00
# Total training FLOPs estimate: 4.539628e+20
# step 00004/71680 (0.01%) | loss: 8.813754 | lrm: 1.00 | dt: 1571.88ms | tok/sec: 83,385 | mfu: 50.92 | total time: 0.00m
# step 00005/71680 (0.01%) | loss: 8.488074 | lrm: 1.00 | dt: 1572.76ms | tok/sec: 83,338 | mfu: 50.89 | total time: 0.00m
# ...
# 3) validate that the runtime fits our budget:
# The training script uses the Chinchilla scaling law to compute-optimally set #tokens = 20 * #params. In particular:
# The script shows that we will be training for 71,680 steps, and each step takes 1.574s so:
# estimated time to train: 71,680 * 1.574s / 60 / 60 = 31.3 hours.
# This is OK, fits our budget, and leaves ~10 hours for midtraining and SFT and evals and maybe RL.
# It's possible that we might even fit depth=33 or depth=34, but for now let's go along with this.
# 4) The last thing to pay attention to is the amount of training data required for the run.
# The script above calculated that "Total number of training tokens: 37,580,963,840"
# The tok_eval.py script reports about ~4.8 chars/token on average for the default tokenizer settings.
# So ~38B tokens # ~4.8 chars/token = ~185B chars.
# Each data shard is ~250M chars, so we need ~185B / 250M ~= 740 shards.
# For safety, I bumped that up to 800 shards.
# The new DataLoader wastes about 35% of tokens to cropping, so 800 / (1 - 0.35) ~= 1200 shards are needed.
# => why up above I used -n 1200 when pre-downloading dataset shards.
# If we didn't have enough data, the training script would loop around and do multiple epochs over the same data,
# which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd
# start to overfit hard.
# 5) That's it, everything else (e.g. the learning rates) is adjusted automatically by the training script.
# Number of processes/GPUs to use
NPROC_PER_NODE=8
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target-param-data-ratio=20 --device-batch-size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
# midtrain
# NOTE: ensure that we use the same device_batch_size here as the base training script.
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device-batch-size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
# sft
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
# generate final report
python -m nanochat.report generate
# talk to it
python -m scripts.chat_web

View File

@ -27,7 +27,7 @@ fi
# Series name: from arg, env var, or default to today's date (e.g., jan11)
SERIES_NAME="${1:-${SERIES_NAME:-$(date +%b%d | tr '[:upper:]' '[:lower:]')}}"
# Depths to train (the "miniseries")
DEPTHS=(10 11 12 13 14 15 16 17 18 19 20)
DEPTHS=(12 14 16 18 20 22 24 26)
# Hardware
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
# Logging
@ -56,17 +56,24 @@ for d in "${DEPTHS[@]}"; do
TAG="${SERIES_NAME}_miniseries_d${d}"
START_TIME=$(date +%s)
# Train the model with natural horizon (target_param_data_ratio default)
# No --target-flops, let it use the default ratio from base_train
# Reduce --device-batch-size to avoid OOM at larger depths
if [ $d -ge 28 ]; then
DEVICE_BATCH_SIZE_ARG="--device-batch-size=8"
elif [ $d -ge 20 ]; then
DEVICE_BATCH_SIZE_ARG="--device-batch-size=16"
else
DEVICE_BATCH_SIZE_ARG="--device-batch-size=32"
fi
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
--depth=$d \
--target-param-data-ratio=8 \
--run="${WANDB_RUN}_d${d}" \
--model-tag="${TAG}" \
--core-metric-every=999999 \
--core-metric-max-per-task=-1 \
--sample-every=-1 \
--save-every=-1 \
$DEVICE_BATCH_SIZE_ARG \
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
END_TIME=$(date +%s)

65
runs/runcpu.sh Executable file
View File

@ -0,0 +1,65 @@
#!/bin/bash
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
# This script was last updated/tuned on Jan 17, 2026.
# Run as:
# bash runs/runcpu.sh
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
# Think of this run as educational/fun demo, not something you should expect to work well.
# You may also want to run this script manually and one by one, copy pasting commands into your terminal.
# all the setup stuff
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
[ -d ".venv" ] || uv venv
uv sync --extra cpu
source .venv/bin/activate
if [ -z "$WANDB_RUN" ]; then
WANDB_RUN=dummy
fi
# train tokenizer on ~2B characters (~34 seconds on my MacBook Pro M3 Max)
python -m nanochat.dataset -n 8
python -m scripts.tok_train --max-chars=2000000000
python -m scripts.tok_eval
# train a small 4 layer model
# I tuned this run to complete in about 30 minutes on my MacBook Pro M3 Max.
# To get better results, try increasing num_iterations, or get other ideas from your favorite LLM.
python -m scripts.base_train \
--depth=6 \
--head-dim=64 \
--window-pattern=L \
--max-seq-len=512 \
--device-batch-size=32 \
--total-batch-size=16384 \
--eval-every=100 \
--eval-tokens=524288 \
--core-metric-every=-1 \
--sample-every=100 \
--num-iterations=5000 \
--run=$WANDB_RUN
python -m scripts.base_eval --device-batch-size=1 --split-tokens=16384 --max-per-task=16
# SFT (~10 minutes on my MacBook Pro M3 Max)
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
python -m scripts.chat_sft \
--max-seq-len=512 \
--device-batch-size=32 \
--total-batch-size=16384 \
--eval-every=200 \
--eval-tokens=524288 \
--num-iterations=1500 \
--run=$WANDB_RUN
# Chat with the model over CLI
# The model should be able to say that it is Paris.
# It might even know that the color of the sky is blue.
# Sometimes the model likes it if you first say Hi before you ask it questions.
# python -m scripts.chat_cli -p "What is the capital of France?"
# Chat with the model over a pretty WebUI ChatGPT style
# python -m scripts.chat_web

View File

@ -1,26 +1,30 @@
#!/bin/bash
LABEL="jan26"
FLOPS_BUDGETS=(
1e18
3e18
6e18
2.15e18
4.64e18
1e19
)
DEPTHS=(8 10 12 14 16 18 20)
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
WANDB_RUN="${WANDB_RUN:-scaling}"
WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}"
EVAL_TOKENS=$((100 * 524288)) # ~100M tokens for final eval (default is ~10M)
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat}"
source .venv/bin/activate
RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results"
RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results_${LABEL}"
mkdir -p "$RESULTS_DIR"
RESULTS_FILE="$RESULTS_DIR/results.csv"
# Write CSV header only if file doesn't exist
if [ ! -f "$RESULTS_FILE" ]; then
echo "flops_budget,depth,model_dim,num_params,num_scaling_params,num_iterations,tokens_trained,param_data_ratio,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
echo "flops_budget,depth,model_dim,params_wte,params_bigram_embed,params_value_embeds,params_lm_head,params_transformer,params_scalars,params_total,num_iterations,tokens_trained,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
fi
log() {
@ -80,13 +84,19 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
# Extract training stats from the log
LOG_FILE="$RESULTS_DIR/${TAG}_train.log"
NUM_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | head -1 | tr -d ',')
NUM_SCALING_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP 'scaling: [\d,]+' | grep -oP '[\d,]+' | tr -d ',')
# Extract detailed parameter counts (for scaling law analysis with different conventions)
PARAMS_WTE=$(grep "wte:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_BIGRAM=$(grep "bigram_embed:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_VE=$(grep "value_embeds:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_LM=$(grep "lm_head:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_TRANSFORMER=$(grep "transformer_matrices:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_SCALARS=$(grep "scalars:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_TOTAL=$(grep "total:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',')
# Calculate tokens trained (iterations * batch_size, default 524288)
TOKENS_TRAINED=$((NUM_ITERS * 524288))
# Param:data ratio (using scaling params per Kaplan et al.)
PARAM_DATA_RATIO=$(python -c "print(f'{$TOKENS_TRAINED / $NUM_SCALING_PARAMS:.2f}')")
# Model dim
MODEL_DIM=$((d * 64))
# Val BPB from final eval
@ -99,10 +109,10 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
CORE_SCORE="0.0"
fi
log " Params: $NUM_PARAMS, Iters: $NUM_ITERS, Ratio: $PARAM_DATA_RATIO, Val BPB: $VAL_BPB, CORE: $CORE_SCORE"
log " Params: $PARAMS_TOTAL (transformer: $PARAMS_TRANSFORMER), Iters: $NUM_ITERS, Val BPB: $VAL_BPB, CORE: $CORE_SCORE"
# Append to CSV
echo "$flops,$d,$MODEL_DIM,$NUM_PARAMS,$NUM_SCALING_PARAMS,$NUM_ITERS,$TOKENS_TRAINED,$PARAM_DATA_RATIO,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
echo "$flops,$d,$MODEL_DIM,$PARAMS_WTE,$PARAMS_BIGRAM,$PARAMS_VE,$PARAMS_LM,$PARAMS_TRANSFORMER,$PARAMS_SCALARS,$PARAMS_TOTAL,$NUM_ITERS,$TOKENS_TRAINED,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
done
done

View File

@ -1,14 +1,14 @@
#!/bin/bash
# This script is the "Best ChatGPT clone that $100 can buy",
# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour.
# This script is configured to train your own GPT-2 grade LLM (pretraining + finetuning)
# It is designed to run on a blank 8XH100 GPU node and takes approximately 3 hours to complete.
# 1) Example launch (simplest):
# bash speedrun.sh
# 2) Example launch in a screen session (because the run takes ~4 hours):
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
# bash runs/speedrun.sh
# 2) Example launch in a screen session (because the run takes ~3 hours):
# screen -L -Logfile runs/speedrun.log -S speedrun bash runs/speedrun.sh
# 3) Example launch with wandb logging, but see below for setting up wandb first:
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
# WANDB_RUN=speedrun screen -L -Logfile runs/speedrun.log -S speedrun bash runs/speedrun.sh
# Default intermediate artifacts directory is in ~/.cache/nanochat
export OMP_NUM_THREADS=1
@ -47,61 +47,41 @@ python -m nanochat.report reset
# Tokenizer
# Download the first ~2B characters of pretraining dataset
# look at dev/repackage_data_reference.py for details on how this data was prepared
# each data shard is ~250M chars
# so we download 2e9 / 250e6 = 8 data shards at this point
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
# look at dev/repackage_data_reference.py for details on how this data was prepared
python -m nanochat.dataset -n 8
# Immediately also kick off downloading more shards in the background while tokenizer trains
# See comment below for why 370 is the right number here
# Approximately 350 shards are needed for 10B tokens of data for pretraining.
# The maximum total number of shards available in the entire dataset is 1822.
python -m nanochat.dataset -n 370 &
DATASET_DOWNLOAD_PID=$!
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
python -m scripts.tok_train --max-chars=2000000000 --vocab-size=65536
# train the tokenizer with vocab size 2**15 = 32768 on ~2B characters of data
python -m scripts.tok_train
# evaluate the tokenizer (report compression ratio etc.)
python -m scripts.tok_eval
# -----------------------------------------------------------------------------
# Base model (pretraining)
# The d20 model is 561M parameters.
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
# Round up to 240 for safety. Also, the new DataLoader wastes about 35% of tokens to cropping
# so 240 / (1 - 0.35) = 370 shards are needed.
# At ~100MB/shard, this downloads ~37GB of data to disk.
# (The total number of shards available in the entire dataset is 1822.)
echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID
# Number of processes/GPUs to use
NPROC_PER_NODE=8
# pretrain the d20 model
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target-param-data-ratio=20 --run=$WANDB_RUN
# evaluate the model on a larger chunk of train/val data and draw some samples
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
# evaluate the model on CORE tasks
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
# d24 model (slightly overtrained is enough to beat GPT-2 => increase data:params ratio from compute optimal 10.5 (default) to 12)
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --target-param-data-ratio=8.25 --device-batch-size=16 --fp8 --run=$WANDB_RUN
# evaluate the model: CORE metric, BPB on train/val, and draw samples
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16
# -----------------------------------------------------------------------------
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
# SFT (teach the model conversation special tokens, tool use, multiple choice)
# download 2.3MB of synthetic identity conversations to impart a personality to nanochat
# see dev/gen_synthetic_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
# run midtraining and eval the model
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
# -----------------------------------------------------------------------------
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
# train sft and re-eval right away (should see a small bump)
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
# run SFT and eval the model
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
# chat with the model over CLI! Leave out the -p to chat interactively
# python -m scripts.chat_cli -p "Why is the sky blue?"
@ -109,15 +89,6 @@ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -
# even better, chat with your model over a pretty WebUI ChatGPT style
# python -m scripts.chat_web
# -----------------------------------------------------------------------------
# Reinforcement Learning. Optional, and currently only on GSM8K
# (optional)
# run reinforcement learning
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN
# eval the RL model only on GSM8K
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K
# -----------------------------------------------------------------------------
# Generate the full report by putting together all the sections
# report.md is the output and will be copied to current directory for convenience

View File

@ -1,13 +1,23 @@
"""
Evaluate the CORE metric for a given model.
Unified evaluation script for base models.
Run on a single GPU:
python -m scripts.base_eval
Supports three evaluation modes (comma-separated):
--eval core : CORE metric (accuracy on ICL tasks)
--eval bpb : Bits per byte on train/val splits
--eval sample : Generate samples from the model
Run with torchrun on e.g. 8 GPUs:
torchrun --nproc_per_node=8 -m scripts.base_eval
Default is all three: --eval core,bpb,sample
The script will print the CORE metric to the console.
Examples:
# Evaluate a HuggingFace model (e.g. GPT-2 124M) using 8 GPUs
torchrun --nproc_per_node=8 -m scripts.base_eval --hf-path openai-community/gpt2
# Evaluate a nanochat model (e.g. d24) using 8 GPUs
torchrun --nproc_per_node=8 -m scripts.base_eval --model-tag d24 --device-batch-size=16
# Quick/approximate evaluation using a single GPU
python -m scripts.base_eval --model-tag d24 --device-batch-size=16 --max-per-task=100 --split-tokens=524288
"""
import os
import csv
@ -18,24 +28,74 @@ import shutil
import random
import zipfile
import tempfile
import argparse
from contextlib import nullcontext
import torch
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
from nanochat.tokenizer import HuggingFaceTokenizer
from nanochat.tokenizer import HuggingFaceTokenizer, get_token_bytes
from nanochat.checkpoint_manager import load_model
from nanochat.core_eval import evaluate_task
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
# -----------------------------------------------------------------------------
# nanochat specific function dealing with I/O etc.
# HuggingFace loading utilities
class ModelWrapper:
"""Lightweight wrapper to give HuggingFace models a nanochat-compatible interface."""
def __init__(self, model, max_seq_len=None):
self.model = model
self.max_seq_len = max_seq_len
def __call__(self, input_ids, targets=None, loss_reduction='mean'):
logits = self.model(input_ids).logits
if targets is None:
return logits
loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-1,
reduction=loss_reduction
)
return loss
def get_device(self):
return next(self.model.parameters()).device
def load_hf_model(hf_path: str, device):
"""Load a HuggingFace model and tokenizer."""
print0(f"Loading HuggingFace model from: {hf_path}")
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(hf_path)
model.to(device)
model.eval()
max_seq_len = 1024 if "gpt2" in hf_path else None
model = ModelWrapper(model, max_seq_len=max_seq_len)
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
return model, tokenizer
def get_hf_token_bytes(tokenizer, device="cpu"):
"""Compute token_bytes tensor for a HuggingFace tokenizer."""
vocab_size = tokenizer.tokenizer.get_vocab_size()
token_bytes = torch.zeros(vocab_size, dtype=torch.int64, device=device)
for token_id in range(vocab_size):
token_str = tokenizer.tokenizer.decode([token_id])
token_bytes[token_id] = len(token_str.encode('utf-8'))
return token_bytes
# -----------------------------------------------------------------------------
# CORE evaluation
# ~162MB of data needed to evaluate the CORE metric
EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip"
def place_eval_bundle(file_path):
# here file_path is the path to the eval_bundle.zip file
# we need to unzip it and place it in the base directory
"""Unzip eval_bundle.zip and place it in the base directory."""
base_dir = get_base_dir()
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
with tempfile.TemporaryDirectory() as tmpdir:
@ -45,25 +105,27 @@ def place_eval_bundle(file_path):
shutil.move(extracted_bundle_dir, eval_bundle_dir)
print0(f"Placed eval_bundle directory at {eval_bundle_dir}")
def evaluate_model(model, tokenizer, device, max_per_task=-1):
def evaluate_core(model, tokenizer, device, max_per_task=-1):
"""
Evaluate a base model on the CORE benchmark.
- max_per_task: crop the data to this many examples per task for testing (-1 = disable)
Returns dict with results, centered_results, and core_metric.
"""
# Load config and task metadata
base_dir = get_base_dir()
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
# Download the eval bundle to disk (and unzip if needed)
# Download the eval bundle if needed
if not os.path.exists(eval_bundle_dir):
download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle)
config_path = os.path.join(eval_bundle_dir, "core.yaml")
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
tasks = config['icl_tasks']
# Load random baseline values from eval metadata
# Load random baseline values
random_baselines = {}
with open(eval_meta_data, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
@ -86,27 +148,23 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
}
print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='')
# Load data for this task
data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
with open(data_path, 'r', encoding='utf-8') as f:
data = [json.loads(line.strip()) for line in f]
# shuffle the data because in many cases it appears ordered but we want
# the ability to only run a subset of the data for debugging purposes etc.
# Shuffle for consistent subsampling when using max_per_task
shuffle_rng = random.Random(1337)
shuffle_rng.shuffle(data)
if max_per_task > 0:
data = data[:max_per_task]
# run the evaluation for this task
accuracy = evaluate_task(model, tokenizer, data, device, task_meta)
results[label] = accuracy
random_baseline = random_baselines[label]
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
centered_results[label] = centered_result
end_time = time.time()
print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s")
elapsed = time.time() - start_time
print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {elapsed:.2f}s")
core_metric = sum(centered_results.values()) / len(centered_results)
out = {
@ -117,98 +175,157 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
return out
# -----------------------------------------------------------------------------
# HuggingFace loading utilities and light wrappers for a model
# Main
class ModelWrapper:
"""Lightweight wrapper for a HuggingFace model"""
def __init__(self, model, max_seq_len=None):
self.model = model
self.max_seq_len = max_seq_len
def __call__(self, input_ids):
outputs = self.model(input_ids)
logits = outputs.logits
return logits
def load_hf_model(hf_path: str, device):
print0(f"Loading model from: {hf_path}")
# Load the model
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(hf_path)
model.to(device)
model.eval()
max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None
model = ModelWrapper(model, max_seq_len=max_seq_len)
# Load the tokenizer
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
return model, tokenizer
# -----------------------------------------------------------------------------
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
parser.add_argument('--model-tag', type=str, default=None, help='optional model tag for the output directory name')
parser.add_argument('--step', type=str, default=None, help='optional model step for the output directory name')
parser = argparse.ArgumentParser(description="Base model evaluation")
parser.add_argument('--eval', type=str, default='core,bpb,sample', help='Comma-separated evaluations to run: core,bpb,sample (default: all)')
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path (e.g. openai-community/gpt2-xl)')
parser.add_argument('--model-tag', type=str, default=None, help='nanochat model tag to identify the checkpoint directory')
parser.add_argument('--step', type=int, default=None, help='Model step to load (default = last)')
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per CORE task (-1 = all)')
parser.add_argument('--device-batch-size', type=int, default=32, help='Per-device batch size for BPB evaluation')
parser.add_argument('--split-tokens', type=int, default=40*524288, help='Number of tokens to evaluate per split for BPB')
parser.add_argument('--device-type', type=str, default='', help='cuda|cpu|mps (empty = autodetect)')
args = parser.parse_args()
# distributed / precision setup
device_type = autodetect_device_type()
# Parse evaluation modes
eval_modes = set(mode.strip() for mode in args.eval.split(','))
valid_modes = {'core', 'bpb', 'sample'}
invalid = eval_modes - valid_modes
if invalid:
parser.error(f"Invalid eval modes: {invalid}. Valid: {valid_modes}")
# Distributed / precision setup
device_type = autodetect_device_type() if args.device_type == '' else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# Load model and tokenizer from command line or from file system
if args.hf_path is not None:
# atm assume that if a path is given, it's a huggingface model path
hf_path = args.hf_path
print0(f"Loading huggingface model from: {hf_path}")
model, tokenizer = load_hf_model(hf_path, device)
model_name = hf_path # just for logging
model_slug = hf_path.replace("/", "-") # for the output csv file
# Load model and tokenizer
is_hf_model = args.hf_path is not None
if is_hf_model:
model, tokenizer = load_hf_model(args.hf_path, device)
sequence_len = model.max_seq_len or 1024
token_bytes = get_hf_token_bytes(tokenizer, device=device)
model_name = args.hf_path
model_slug = args.hf_path.replace("/", "-")
else:
# load a local model from the file system
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.step)
model_name = f"base_model (step {meta['step']})" # just for logging
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
sequence_len = meta["model_config"]["sequence_len"]
token_bytes = get_token_bytes(device=device)
model_name = f"base_model (step {meta['step']})"
model_slug = f"base_model_{meta['step']:06d}"
# Evaluate the model
with autocast_ctx:
out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
print0(f"Evaluating model: {model_name}")
print0(f"Eval modes: {', '.join(sorted(eval_modes))}")
# Write out the results to a csv file
core_metric = None
centered_results = {}
if ddp_rank == 0:
base_dir = get_base_dir()
output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv")
os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
results = out["results"]
centered_results = out["centered_results"]
core_metric = out["core_metric"]
with open(output_csv_path, 'w', encoding='utf-8', newline='') as f:
f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n")
for label in results:
f.write(f"{label:<35}, {results[label]:<10.6f}, {centered_results[label]:<10.6f}\n")
f.write(f"{'CORE':<35}, {'':<10}, {core_metric:<10.6f}\n")
# Print the content of the csv file to console too
# Results to log
core_results = None
bpb_results = {}
samples = []
unconditioned_samples = []
# --- Sampling ---
if 'sample' in eval_modes and not is_hf_model:
print0("\n" + "="*80)
print0("Model Samples")
print0("="*80)
print0(f"Model: {model_name}")
print0("="*80)
with open(output_csv_path, 'r', encoding='utf-8') as f:
print0(f.read())
if ddp_rank == 0:
prompts = [
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
engine = Engine(model, tokenizer)
print0("\nConditioned samples:")
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with autocast_ctx:
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
sample_str = tokenizer.decode(sample[0])
print0("-" * 80)
print0(sample_str)
samples.append(sample_str)
# Log to report
print0("\nUnconditioned samples:")
tokens = tokenizer("", prepend="<|bos|>")
with autocast_ctx:
uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
for sample in uncond:
sample_str = tokenizer.decode(sample)
print0("-" * 80)
print0(sample_str)
unconditioned_samples.append(sample_str)
elif 'sample' in eval_modes and is_hf_model:
print0("\nSkipping sampling for HuggingFace models (not supported)")
# --- BPB evaluation ---
if 'bpb' in eval_modes:
print0("\n" + "="*80)
print0("BPB Evaluation")
print0("="*80)
tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size
if args.split_tokens % tokens_per_step != 0:
# Adjust to nearest multiple
args.split_tokens = (args.split_tokens // tokens_per_step) * tokens_per_step
print0(f"Adjusted split_tokens to {args.split_tokens} (must be divisible by {tokens_per_step})")
steps = args.split_tokens // tokens_per_step
for split_name in ["train", "val"]:
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
with autocast_ctx:
bpb = evaluate_bpb(model, loader, steps, token_bytes)
bpb_results[split_name] = bpb
print0(f"{split_name} bpb: {bpb:.6f}")
# --- CORE evaluation ---
if 'core' in eval_modes:
print0("\n" + "="*80)
print0("CORE Evaluation")
print0("="*80)
with autocast_ctx:
core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task)
# Write CSV output
if ddp_rank == 0:
base_dir = get_base_dir()
output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv")
os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
with open(output_csv_path, 'w', encoding='utf-8', newline='') as f:
f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n")
for label in core_results["results"]:
acc = core_results["results"][label]
centered = core_results["centered_results"][label]
f.write(f"{label:<35}, {acc:<10.6f}, {centered:<10.6f}\n")
f.write(f"{'CORE':<35}, {'':<10}, {core_results['core_metric']:<10.6f}\n")
print0(f"\nResults written to: {output_csv_path}")
print0(f"CORE metric: {core_results['core_metric']:.4f}")
# --- Log to report ---
from nanochat.report import get_report
get_report().log(section="Base model evaluation", data=[
{
"Model": model_name,
"CORE metric": core_metric,
},
centered_results, # the full table
])
report_data = [{"model": model_name}]
if core_results:
report_data[0]["CORE metric"] = core_results["core_metric"]
report_data.append(core_results["centered_results"])
if bpb_results:
report_data[0]["train bpb"] = bpb_results.get("train")
report_data[0]["val bpb"] = bpb_results.get("val")
if samples:
report_data.append({f"sample {i}": s for i, s in enumerate(samples)})
if unconditioned_samples:
report_data.append({f"unconditioned {i}": s for i, s in enumerate(unconditioned_samples)})
get_report().log(section="Base model evaluation", data=report_data)
compute_cleanup()
if __name__ == "__main__":
main()

View File

@ -1,140 +0,0 @@
"""
Loads a checkpoint, and:
- Evaluates the loss on a larger chunk of train/val splits
- Samples from the model
Example run as:
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
To evaluate a HuggingFace model:
python -m scripts.base_loss --hf-path openai-community/gpt2
"""
import argparse
from contextlib import nullcontext
import torch
from nanochat.checkpoint_manager import load_model
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
from nanochat.tokenizer import get_token_bytes, HuggingFaceTokenizer
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
# -----------------------------------------------------------------------------
# HuggingFace loading utilities, making the APIs match up to those of nanochat
class ModelWrapper:
"""Lightweight wrapper for a HuggingFace model"""
def __init__(self, model, max_seq_len=None):
self.model = model
self.max_seq_len = max_seq_len
def __call__(self, input_ids, targets=None, loss_reduction='mean'):
logits = self.model(input_ids).logits
if targets is None:
return logits
else:
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
return loss
def get_device(self):
return next(self.model.parameters()).device
def load_hf_model(hf_path: str, device):
print0(f"Loading model from: {hf_path}")
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(hf_path)
model.to(device)
model.eval()
max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None
model = ModelWrapper(model, max_seq_len=max_seq_len)
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
return model, tokenizer
def get_hf_token_bytes(tokenizer, device="cpu"):
"""Compute token_bytes tensor for a HuggingFace tokenizer."""
vocab_size = tokenizer.tokenizer.get_vocab_size()
token_bytes = torch.zeros(vocab_size, dtype=torch.int64, device=device)
for token_id in range(vocab_size):
token_str = tokenizer.tokenizer.decode([token_id])
token_bytes[token_id] = len(token_str.encode('utf-8')) # Count UTF-8 bytes
return token_bytes
# CLI arguments
parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model")
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--split-tokens", type=int, default=40*524288, help="number of tokens to evaluate per split")
parser.add_argument("--model-tag", type=str, default=None, help="model tag for checkpoint directory")
parser.add_argument("--model-step", type=int, default=None, help="model step to load")
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--hf-path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)")
args = parser.parse_args()
# Load the base model and the tokenizer
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
print0(f"Device: {device} | DDP rank: {ddp_rank} | DDP local rank: {ddp_local_rank} | DDP world size: {ddp_world_size}")
if args.hf_path is not None:
# Load HuggingFace model
model, tokenizer = load_hf_model(args.hf_path, device)
sequence_len = model.max_seq_len if model.max_seq_len else 1024
token_bytes = get_hf_token_bytes(tokenizer, device=device)
model_name = args.hf_path
else:
# Load local nanochat model
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
sequence_len = meta["model_config"]["sequence_len"]
token_bytes = get_token_bytes(device=device)
model_name = f"base_model (step {meta['step']})"
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
print0(f"Evaluating model: {model_name}")
# Evaluate the loss on each split
tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size
assert args.split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
steps = args.split_tokens // tokens_per_step
bpb_results = {}
for split_name in ["train", "val"]:
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
with autocast_ctx:
bpb = evaluate_bpb(model, loader, steps, token_bytes)
print0(f"{split_name} bpb: {bpb:.4f}")
bpb_results[split_name] = bpb
print0(f"Model: {model_name}, {split_name} bpb: {bpb:.6f}")
# Master process also samples from the model (only for nanochat models)
samples = []
if ddp_rank == 0 and args.hf_path is None:
prompts = [
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
engine = Engine(model, tokenizer)
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with autocast_ctx:
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
sample_str = tokenizer.decode(sample[0])
print0(sample_str)
samples.append(sample_str)
# Log to report
from nanochat.report import get_report
get_report().log(section="Base model loss", data=[
{
"model": model_name,
"train bpb": bpb_results["train"],
"val bpb": bpb_results["val"],
},
{f"sample {i}": sample for i, sample in enumerate(samples)},
])
# Cleanup
compute_cleanup()

View File

@ -1,11 +1,11 @@
"""
Train model. From root directory of the project, run as:
python -m scripts.base_train.py
python -m scripts.base_train
or distributed as:
torchrun --nproc_per_node=8 -m scripts.base_train.py
torchrun --nproc_per_node=8 -m scripts.base_train
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
@ -13,22 +13,26 @@ python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 -
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import argparse
import gc
import json
import time
from contextlib import nullcontext
import math
import argparse
from dataclasses import asdict
from contextlib import nullcontext, contextmanager
import wandb
import torch
from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.flash_attention import HAS_FA3
from scripts.base_eval import evaluate_model
from scripts.base_eval import evaluate_core
print_banner()
# -----------------------------------------------------------------------------
@ -38,6 +42,9 @@ parser = argparse.ArgumentParser(description="Pretrain base model")
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
# FP8 training
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)")
parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)")
# Model architecture
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
@ -47,10 +54,10 @@ parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding
# Training horizon (only one used, in order of precedence)
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
parser.add_argument("--target-param-data-ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
# Optimization
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.")
parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)")
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
@ -59,12 +66,12 @@ parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate
parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
parser.add_argument("--warmdown-ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown")
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
# Evaluation
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on")
parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
@ -74,14 +81,20 @@ parser.add_argument("--model-tag", type=str, default=None, help="override model
args = parser.parse_args()
user_config = vars(args).copy() # for logging
# -----------------------------------------------------------------------------
# Compute init and wandb logging
# Compute init
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
if device_type == "cuda":
gpu_device_name = torch.cuda.get_device_name(0)
gpu_peak_flops = get_peak_flops(gpu_device_name)
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
else:
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
# wandb logging init
use_dummy_wandb = args.run == "dummy" or not master_process
@ -99,67 +112,39 @@ else:
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
print0("!" * 80)
# Tokenizer will be useful for evaluation, also we need the vocab size
# -----------------------------------------------------------------------------
# Tokenizer will be useful for evaluation and also we need the vocab size to init the model
tokenizer = get_tokenizer()
token_bytes = get_token_bytes(device=device)
vocab_size = tokenizer.get_vocab_size()
print0(f"Vocab size: {vocab_size:,}")
# Model kwargs are derived from the desired depth of the model
num_layers = args.depth
model_dim = args.depth * args.aspect_ratio
def find_num_heads(model_dim, target_head_dim):
# Find num_heads that divides model_dim evenly, with head_dim closest to target.
ideal = max(1, round(model_dim / target_head_dim))
for offset in range(model_dim):
for candidate in [ideal + offset, ideal - offset]:
if candidate > 0 and model_dim % candidate == 0:
return candidate
return 1
num_heads = find_num_heads(model_dim, args.head_dim)
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
print0(f"num_layers: {num_layers}")
print0(f"model_dim: {model_dim}")
print0(f"num_heads: {num_heads}")
print0(f"num_kv_heads: {num_kv_heads}")
# Optimizer / data / training length related hyperparameters
# figure out the needed gradient accumulation to reach the desired total batch size
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19)
batch_lr_scale = 1.0
reference_batch_size = 2**19
batch_ratio = args.total_batch_size / reference_batch_size
if batch_ratio != 1.0:
# SGD: linear scaling with batch size is standard (not used in nanochat)
# AdamW: sqrt scaling is standard
# Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer
batch_lr_scale = batch_ratio ** 0.5
print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})")
# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio)
weight_decay_scaled = args.weight_decay * (12 / args.depth)**2
if args.depth != 12:
print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}")
# -----------------------------------------------------------------------------
# Initialize the Model
# Create a new model with random weights
model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, window_pattern=args.window_pattern)
with torch.device("meta"):
# All tensors are created as meta tensors (they have shape/dtype but no data)
model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config)
model.to_empty(device=device) # All tensors get storage on target device but with uninitialized (garbage) data
model.init_weights() # All tensors get initialized
def build_model_meta(depth):
"""Build a model on meta device for a given depth (shapes/dtypes only, no data)."""
# Model dim is nudged up to nearest multiple of head_dim for clean division
# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
base_dim = depth * args.aspect_ratio
model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
num_heads = model_dim // args.head_dim
config = GPTConfig(
sequence_len=args.max_seq_len, vocab_size=vocab_size,
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
window_pattern=args.window_pattern,
)
with torch.device("meta"):
model_meta = GPT(config)
return model_meta
# Build the model, move to device, init the weights
model = build_model_meta(args.depth) # 1) Build on meta device (only shapes/dtypes, no data)
model_config = model.config
model_config_kwargs = asdict(model_config)
print0(f"Model config:\n{json.dumps(model_config_kwargs, indent=2)}")
model.to_empty(device=device) # 2) All tensors get storage on target device but with uninitialized (garbage) data
model.init_weights() # 3) All tensors get initialized
# If we are resuming, overwrite the model parameters with those of the checkpoint
base_dir = get_base_dir()
@ -172,52 +157,161 @@ if resuming:
model.load_state_dict(model_data, strict=True, assign=True)
del model_data # free up this memory after the copy
# -----------------------------------------------------------------------------
# FP8 training initialization and management (this has to be done before torch.compile)
# Convert Linear layers to Float8Linear if --fp8 is set
if args.fp8:
if device_type != "cuda":
print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag")
else:
# our custom fp8 is simpler than torchao, written for exact API compatibility
from nanochat.fp8 import Float8LinearConfig, convert_to_float8_training
# from torchao.float8 import Float8LinearConfig, convert_to_float8_training
import torch.nn as nn
# Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement)
def fp8_module_filter(mod: nn.Module, fqn: str) -> bool:
if not isinstance(mod, nn.Linear):
return False
# FP8 requires both in_features and out_features divisible by 16
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe)
convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter)
num_fp8_layers = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
num_skipped = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) - num_fp8_layers
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8_layers} layers, skipped {num_skipped} (dims not divisible by 16)")
# Context manager to temporarily disable FP8 so that model evaluation remains in BF16
@contextmanager
def disable_fp8(model):
"""Temporarily swap Float8Linear modules with nn.Linear for BF16 evaluation.
CastConfig is a frozen dataclass, so we can't mutate scaling_type. Instead,
we swap out Float8Linear modules entirely and restore them after.
"""
import torch.nn as nn
# Find all Float8Linear modules and their locations
fp8_locations = [] # list of (parent_module, attr_name, fp8_module)
for name, module in model.named_modules():
if 'Float8' in type(module).__name__:
if '.' in name:
parent_name, attr_name = name.rsplit('.', 1)
parent = model.get_submodule(parent_name)
else:
parent = model
attr_name = name
fp8_locations.append((parent, attr_name, module))
if not fp8_locations:
yield # No FP8 modules, nothing to do
return
# Swap Float8Linear -> nn.Linear (shares the same weight tensor, no copy)
for parent, attr_name, fp8_module in fp8_locations:
linear = nn.Linear(
fp8_module.in_features,
fp8_module.out_features,
bias=fp8_module.bias is not None,
device=fp8_module.weight.device,
dtype=fp8_module.weight.dtype,
)
linear.weight = fp8_module.weight # share, don't copy
if fp8_module.bias is not None:
linear.bias = fp8_module.bias
setattr(parent, attr_name, linear)
try:
yield
finally:
# Restore Float8Linear modules
for parent, attr_name, fp8_module in fp8_locations:
setattr(parent, attr_name, fp8_module)
# -----------------------------------------------------------------------------
# Compile the model
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
num_params = sum(p.numel() for p in model.parameters())
num_scaling_params = orig_model.num_scaling_params()
print0(f"Number of parameters: {num_params:,} (scaling: {num_scaling_params:,})")
# -----------------------------------------------------------------------------
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.
# Get the parameter counts of our model
param_counts = model.num_scaling_params()
print0(f"Parameter counts:")
for key, value in param_counts.items():
print0(f"{key:24s}: {value:,}")
num_params = param_counts['total']
num_flops_per_token = model.estimate_flops()
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0
if args.num_iterations > 0:
num_iterations = args.num_iterations
print0(f"Using user-provided number of iterations: {num_iterations:,}")
elif args.target_flops > 0:
# calculate the number of iterations from the target flops
num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size))
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
elif args.target_param_data_ratio > 0:
# calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.)
target_tokens = args.target_param_data_ratio * num_scaling_params
num_iterations = target_tokens // args.total_batch_size
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
else:
raise ValueError("No training horizon specified")
total_tokens = args.total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# 1) Use scaling laws to determine the optimal training horizon in tokens
# The compute-optimal models satisfy the Tokens:Params ratio of --target-param-data-ratio (derived experimentally via scaling laws analysis).
# We've already initialized the model so we have Params. Optimal Tokens is now simply target-param-data-ratio * Params
def get_scaling_params(m):
# As for which params to use exactly, transformer matrices + lm_head gives cleanest scaling laws (see dev/LOG.md Jan 27, 2026)
params_counts = m.num_scaling_params()
scaling_params = params_counts['transformer_matrices'] + params_counts['lm_head']
return scaling_params
num_scaling_params = get_scaling_params(model)
target_tokens = int(args.target_param_data_ratio * num_scaling_params) # optimal tokens for the model we are about to train
# Our reference model is d12, this is where a lot of hyperparameters are tuned and then transfered to higher depths (muP style)
d12_ref = build_model_meta(12) # creates the model on meta device
D_REF = args.target_param_data_ratio * get_scaling_params(d12_ref) # compute-optimal d12 training horizon in tokens (measured empirically)
B_REF = 2**19 # optimal batch size at d12 ~= 524,288 tokens (measured empirically)
# 2) Now that we have the token horizon, we can calculate the optimal batch size
# We follow the Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738
# The optimal batch size grows as approximately D^0.383, so e.g. if D doubles from d12 to d24, B should grow by 2^0.383 ≈ 1.3x.
total_batch_size = args.total_batch_size # user-provided override is possible
if total_batch_size == -1:
batch_size_ratio = target_tokens / D_REF
predicted_batch_size = B_REF * batch_size_ratio ** 0.383
total_batch_size = 2 ** round(math.log2(predicted_batch_size)) # clamp to nearest power of 2 for efficiency
print0(f"Auto-computed optimal batch size: {total_batch_size:,} tokens")
# 3) Knowing the batch size, we can now calculate a learning rate correction (bigger batch size allows higher learning rates)
batch_lr_scale = 1.0
batch_ratio = total_batch_size / B_REF # B/B_ref
if batch_ratio != 1.0:
# SGD: linear scaling with batch size is standard (not used in nanochat)
# AdamW: sqrt scaling is standard: η ∝ √(B/B_ref)
# Muon: we will use the same scaling for Muon as for AdamW: η ∝ √(B/B_ref) (not studied carefully, assumption!)
batch_lr_scale = batch_ratio ** 0.5 # η ∝ √(B/B_ref)
print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {total_batch_size:,} (reference: {B_REF:,})")
# 4) Knowing the batch size and the token horizon, we can now calculate the appropriate weight decay scaling
# We adopt the T_epoch framework from https://arxiv.org/abs/2405.13698
# Central idea of the paper is that T_epoch = B/(η·λ·D) should remain constant.
# Above, we used learning rate scaling η ∝ √(B/B_ref). So it's a matter of ~10 lines of math to derive that to keep T_epoch constant, we need:
# λ = λ_ref · √(B/B_ref) · (D_ref/D)
# Note that these papers study AdamW, *not* Muon. We are blindly following AdamW theory for scaling hoping it ~works for Muon too.
weight_decay_scaled = args.weight_decay * math.sqrt(total_batch_size / B_REF) * (D_REF / target_tokens)
if weight_decay_scaled != args.weight_decay:
print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}")
# -----------------------------------------------------------------------------
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
adam_betas = (args.adam_beta1, args.adam_beta2)
optimizers = model.setup_optimizers(
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
optimizer = model.setup_optimizer(
# AdamW hyperparameters
unembedding_lr=args.unembedding_lr * batch_lr_scale,
embedding_lr=args.embedding_lr * batch_lr_scale,
scalar_lr=args.scalar_lr * batch_lr_scale,
adam_betas=(args.adam_beta1, args.adam_beta2),
# Muon hyperparameters
matrix_lr=args.matrix_lr * batch_lr_scale,
weight_decay=weight_decay_scaled,
adam_betas=adam_betas,
scalar_lr=args.scalar_lr * batch_lr_scale,
)
adamw_optimizer, muon_optimizer = optimizers
if resuming:
for opt, dat in zip(optimizers, optimizer_data):
opt.load_state_dict(dat)
del optimizer_data # free up the memory
optimizer.load_state_dict(optimizer_data)
del optimizer_data
# -----------------------------------------------------------------------------
# Initialize the DataLoaders for train/val
@ -227,9 +321,30 @@ build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokeni
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
# -----------------------------------------------------------------------------
# Set up hyperparameter schedulers
# Calculate the number of iterations we will train for and set up the various schedulers
# Learning rate scheduler
# num_iterations: either it is given, or from target flops, or from target data:param ratio (in that order)
assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0
if args.num_iterations > 0:
# Override num_iterations to a specific value if given
num_iterations = args.num_iterations
print0(f"Using user-provided number of iterations: {num_iterations:,}")
elif args.target_flops > 0:
# Calculate the number of iterations from the target flops (used in scaling laws analysis, e.g. runs/scaling_laws.sh)
num_iterations = round(args.target_flops / (num_flops_per_token * total_batch_size))
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
elif args.target_param_data_ratio > 0:
# Calculate the number of iterations from the target param data ratio (the most common use case)
num_iterations = target_tokens // total_batch_size
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
else:
raise ValueError("No training horizon specified")
total_tokens = total_batch_size * num_iterations # the actual number of tokens we will train for
print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # e.g. Chinchilla was ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# Learning rate schedule (linear warmup, constant, linear warmdown)
def get_lr_multiplier(it):
warmup_iters = round(args.warmup_ratio * num_iterations)
warmdown_iters = round(args.warmdown_ratio * num_iterations)
@ -241,19 +356,20 @@ def get_lr_multiplier(it):
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * args.final_lr_frac
# Momentum scheduler for Muon optimizer
# Momentum scheduler for Muon optimizer (warms up to 0.95 over the first 300 steps)
def get_muon_momentum(it):
frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum
# Weight decay scheduler for Muon optimizer (linear to zero over the course of training)
# Weight decay scheduler for Muon optimizer (linearly decays to zero over the course of training)
def get_weight_decay(it):
return weight_decay_scaled * (1 - it / num_iterations)
# -----------------------------------------------------------------------------
# Loop state (variables updated by the training loop)
# Training loop
# Loop state (variables updated by the training loop)
if not resuming:
step = 0
val_bpb = None # will be set if eval_every > 0
@ -268,18 +384,26 @@ else:
smooth_train_loss = loop_state["smooth_train_loss"]
total_training_time = loop_state["total_training_time"]
# -----------------------------------------------------------------------------
# Training loop
# Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
# Go!
while True:
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
flops_so_far = num_flops_per_token * args.total_batch_size * step
flops_so_far = num_flops_per_token * total_batch_size * step
# once in a while: evaluate the val bpb (all ranks participate)
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
model.eval()
val_loader = build_val_loader()
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
with autocast_ctx:
with disable_fp8(model), autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
if val_bpb < min_val_bpb:
@ -294,11 +418,12 @@ while True:
# once in a while: estimate the CORE metric (all ranks participate)
# use the original uncompiled model because the inputs keep changing shape
# disable FP8 for evaluation to use BF16 for more consistent/accurate results
results = {}
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
model.eval()
with autocast_ctx:
results = evaluate_model(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
with disable_fp8(orig_model), autocast_ctx:
results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
wandb_run.log({
"step": step,
@ -324,7 +449,7 @@ while True:
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with autocast_ctx:
with disable_fp8(orig_model), autocast_ctx:
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
print0(tokenizer.decode(sample[0]))
model.train()
@ -335,7 +460,7 @@ while True:
checkpoint_dir,
step,
orig_model.state_dict(), # model parameters
[opt.state_dict() for opt in optimizers], # optimizer states
optimizer.state_dict(), # optimizer state
{ # metadata saved as json
"step": step,
"val_bpb": val_bpb, # loss at last step
@ -369,18 +494,16 @@ while True:
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
# step the optimizers
# step the optimizer
lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
muon_momentum = get_muon_momentum(step)
muon_weight_decay = get_weight_decay(step)
for group in muon_optimizer.param_groups:
group["momentum"] = muon_momentum
group["weight_decay"] = muon_weight_decay
for opt in optimizers:
opt.step()
for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
if group['kind'] == 'muon':
group["momentum"] = muon_momentum
group["weight_decay"] = muon_weight_decay
optimizer.step()
model.zero_grad(set_to_none=True)
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
synchronize()
@ -393,10 +516,9 @@ while True:
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * step / num_iterations
tok_per_sec = int(args.total_batch_size / dt)
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
# Calculate ETA based on average time per step (excluding first 10 steps)
@ -409,7 +531,7 @@ while True:
else:
eta_str = ""
epoch = dataloader_state_dict["epoch"]
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
if step % 100 == 0:
log_data = {
"step": step,
@ -425,8 +547,19 @@ while True:
wandb_run.log(log_data)
# state update
first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step)
step += 1
# The garbage collector is sadly a little bit overactive and for some poorly understood reason,
# it spends ~500ms scanning for cycles quite frequently, just to end up cleaning up very few tiny objects each time.
# So we manually manage and help it out here
if first_step_of_run:
gc.collect() # manually collect a lot of garbage from setup
gc.freeze() # immediately freeze all currently surviving objects and exclude them from GC
gc.disable() # nuclear intervention here: disable GC entirely except:
elif step % 5000 == 0: # every 5000 steps...
gc.collect() # manually collect, just to be safe for very, very long runs
# print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
@ -442,7 +575,7 @@ get_report().log(section="Base model training", data=[
"Number of FLOPs per token": f"{num_flops_per_token:e}",
"Calculated number of iterations": num_iterations,
"Number of training tokens": total_tokens,
"Tokens : Params ratio": args.total_batch_size * num_iterations / num_params,
"Tokens : Scaling params ratio": total_batch_size * num_iterations / num_scaling_params,
"DDP world size": ddp_world_size,
"warmup_ratio": args.warmup_ratio,
"warmdown_ratio": args.warmdown_ratio,

View File

@ -2,7 +2,7 @@
New and upgraded chat mode because a lot of the code has changed since the last one.
Intended to be run single GPU only atm:
python -m scripts.chat_cli -i mid
python -m scripts.chat_cli
"""
import argparse
import torch
@ -12,7 +12,7 @@ from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model
parser = argparse.ArgumentParser(description='Chat with the model')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')

View File

@ -4,8 +4,8 @@ All the generic code lives here, and all the evaluation-specific
code lives in nanochat directory and is imported from here.
Example runs:
python -m scripts.chat_eval -i mid -a ARC-Easy
torchrun --nproc_per_node=8 -m scripts.chat_eval -- -i mid -a ARC-Easy
python -m scripts.chat_eval -a ARC-Easy
torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
"""
import argparse
@ -183,7 +183,7 @@ if __name__ == "__main__":
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|mid|rl")
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|rl")
parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
parser.add_argument('-t', '--temperature', type=float, default=0.0)

View File

@ -38,7 +38,6 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from")
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
# Training horizon
@ -77,7 +76,7 @@ use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=args.run, config=user_config)
# Init model and tokenizer
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.model_step)
model, tokenizer, meta = load_model("sft", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
engine = Engine(model, tokenizer) # for sampling rollouts
# -----------------------------------------------------------------------------
@ -201,7 +200,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
# Training loop
# Init the optimizer
optimizers = model.setup_optimizers(
optimizer = model.setup_optimizer(
unembedding_lr=args.unembedding_lr,
embedding_lr=args.embedding_lr,
matrix_lr=args.matrix_lr,
@ -209,10 +208,9 @@ optimizers = model.setup_optimizers(
)
# Set the initial learning rate as a fraction of the base learning rate
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
for group in optimizer.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"]
# Learning rate scheduler: simple rampdown to zero over num_steps
def get_lr_multiplier(it):
@ -305,11 +303,9 @@ for step in range(num_steps):
# Update the model parameters
lrm = get_lr_multiplier(step)
for opt in optimizers: # first set the learning rate
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
for opt in optimizers: # then step the optimizers
opt.step()
for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
optimizer.step()
model.zero_grad(set_to_none=True)
wandb_run.log({
"step": step,

View File

@ -1,65 +1,63 @@
"""
Finetune a base model to be a chat model.
Run on one GPU e.g. for debugging:
Supervised fine-tuning (SFT) the model.
Run as:
python -m scripts.chat_sft
Or torchrun for training:
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16
"""
import argparse
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import time
import wandb
import torch
import torch.distributed as dist
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.engine import Engine
from scripts.chat_eval import run_chat_eval
from nanochat.loss_eval import evaluate_bpb
from nanochat.checkpoint_manager import load_model
import torch.distributed as dist
from tasks.common import TaskMixture
from tasks.arc import ARC
from tasks.gsm8k import GSM8K
from tasks.mmlu import MMLU
from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# -----------------------------------------------------------------------------
# CLI arguments
parser = argparse.ArgumentParser(description="Supervised finetuning for chat")
parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the model")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--source", type=str, default="mid", help="base|mid - which checkpoint to load from")
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
# Training horizon
parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs")
parser.add_argument("--num-iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)")
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
# Batch sizes
parser.add_argument("--device-batch-size", type=int, default=4, help="per-device batch size")
parser.add_argument("--target-examples-per-step", type=int, default=32, help="target examples per optimization step")
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
# Optimization
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init-lr-frac", type=float, default=0.02, help="initial LR as fraction of base LR")
parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR")
# Evaluation
parser.add_argument("--eval-every", type=int, default=100, help="evaluate val loss every N steps")
parser.add_argument("--eval-steps", type=int, default=100, help="number of batches for val loss evaluation")
parser.add_argument("--eval-metrics-every", type=int, default=200, help="evaluate accuracy metrics every N steps")
parser.add_argument("--eval-metrics-max-problems", type=int, default=1024, help="max problems per metric evaluation")
parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
# Output
parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
@ -70,220 +68,321 @@ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type
master_process = ddp_rank == 0
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config, save_code=True)
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
# Load the model and tokenizer
model, tokenizer, meta = load_model(args.source, device, phase="train", model_tag=args.model_tag, step=args.model_step)
orig_model = model # original, uncompiled model
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
orig_model = model
model = torch.compile(model, dynamic=False)
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
token_bytes = get_token_bytes(device=device)
# -----------------------------------------------------------------------------
# Task data mixture we'll train on
identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl")
train_ds = TaskMixture([
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
GSM8K(subset="main", split="train"), # 8K rows
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
# Override the initial learning rate as a fraction of the base learning rate
for group in optimizer.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"]
# -----------------------------------------------------------------------------
# DataLoader
# SFT data mixture and DataLoader
base_dir = get_base_dir()
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
train_dataset = TaskMixture([
SmolTalk(split="train"), # 460K rows of general conversations
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
GSM8K(subset="main", split="train"), # 2 epochs of GSM8K
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # total: 460K + 100K + 16K + 200K + 80K = 856K rows
val_dataset = TaskMixture([
SmolTalk(split="test"), # 24K rows in test set
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
]) # total: 24K + 14K + 1.32K ~= 39K rows
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
# A big problem is that we don't know the final num_iterations in advance. So we create
# these two global variables and update them from within the data generator.
last_step = False # we will toggle this to True when we reach the end of the training dataset
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
current_epoch = 1 # track epoch for logging
def sft_data_generator_bos_bestfit(split, buffer_size=100):
"""
BOS-aligned dataloader for SFT with bestfit-pad packing.
Each row in the batch starts with BOS (beginning of a conversation).
Conversations are packed using best-fit algorithm. When no conversation fits,
the row is padded (instead of cropping) to ensure no tokens are ever discarded.
Padding positions have targets masked with -1 (ignore_index for cross-entropy).
"""
global last_step, approx_progress, current_epoch
assert split in {"train", "val"}, "split must be 'train' or 'val'"
dataset = train_dataset if split == "train" else val_dataset
dataset_size = len(dataset)
assert dataset_size > 0
row_capacity = args.max_seq_len + 1 # +1 for target at last position
bos_token = tokenizer.get_bos_token_id()
# Conversation buffer: list of token lists
conv_buffer = []
cursor = ddp_rank # Each rank processes different conversations (for fetching)
consumed = ddp_rank # Track actual consumption separately from buffering
epoch = 1
it = 0 # iteration counter
def refill_buffer():
nonlocal cursor, epoch
while len(conv_buffer) < buffer_size:
conversation = dataset[cursor]
ids, _ = tokenizer.render_conversation(conversation)
conv_buffer.append(ids)
cursor += ddp_world_size
if cursor >= dataset_size:
cursor = cursor % dataset_size
epoch += 1
# Note: last_step is now triggered based on consumption, not fetching
def sft_data_generator(dataset, batch_size):
pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss
# prepares a list of tokenized conversations into a batch and yields
def collate_and_yield(batch):
nrows = len(batch)
ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index
for i, (ids, mask) in enumerate(batch):
n = len(ids)
ids_tensor = torch.tensor(ids, dtype=torch.long)
inputs[i, :n-1] = ids_tensor[:-1]
# recall -1 is the ignore index, so mask out targets where mask is 0
row_targets = ids_tensor[1:]
# mask[1:] omits the mask for the BOS token, which is never a target atm so it's ok
mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0
targets[i, :n-1] = row_targets
inputs = inputs.to(device) # move to device
targets = targets.to(device)
return inputs, targets
# iterates over the dataset in epochs, tokenizes
batch = []
while True:
for i in range(ddp_rank, len(dataset), ddp_world_size):
doc = dataset[i]
ids, mask = tokenizer.render_conversation(doc)
batch.append((ids, mask))
if len(batch) == batch_size:
yield collate_and_yield(batch)
batch = []
rows = []
row_lengths = [] # Track actual content length (excluding padding) for each row
for _ in range(args.device_batch_size):
row = []
padded = False
while len(row) < row_capacity:
# Ensure buffer has conversations
while len(conv_buffer) < buffer_size:
refill_buffer()
examples_per_step = args.device_batch_size * ddp_world_size
print0(f"Target examples per step: {args.target_examples_per_step}")
print0(f"Device batch size: {args.device_batch_size}")
print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}")
assert args.target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
grad_accum_steps = args.target_examples_per_step // examples_per_step
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
remaining = row_capacity - len(row)
if args.num_iterations == -1:
# derive num_iterations from num_epochs and the size of the dataset
assert args.num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
num_iterations = (len(train_ds) // args.target_examples_per_step) * args.num_epochs
else:
num_iterations = args.num_iterations
train_loader = sft_data_generator(train_ds, batch_size=args.device_batch_size)
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=args.device_batch_size)
# Find largest conversation that fits entirely
best_idx = -1
best_len = 0
for i, conv in enumerate(conv_buffer):
conv_len = len(conv)
if conv_len <= remaining and conv_len > best_len:
best_idx = i
best_len = conv_len
# -----------------------------------------------------------------------------
# Initialize the Optimizer
if best_idx >= 0:
# Found a conversation that fits - use it entirely
conv = conv_buffer.pop(best_idx)
row.extend(conv)
consumed += ddp_world_size # Track actual consumption
else:
# No conversation fits - pad the remainder instead of cropping
# This ensures we never discard any tokens
content_len = len(row)
row.extend([bos_token] * remaining) # Pad with BOS tokens
padded = True
break # Row is now full (with padding)
optimizers = model.setup_optimizers(
unembedding_lr=args.unembedding_lr,
embedding_lr=args.embedding_lr,
matrix_lr=args.matrix_lr,
weight_decay=args.weight_decay,
)
# Set the initial learning rate as a fraction of the base learning rate
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# Track content length: full row if no padding, otherwise the length before padding
if padded:
row_lengths.append(content_len)
else:
row_lengths.append(row_capacity)
rows.append(row[:row_capacity])
# Stopping condition to respect num_iterations, if given
it += 1
if 0 < args.num_iterations <= it and split == "train":
last_step = True
# Update progress tracking (based on consumed, not cursor, to account for buffering)
if split == "train":
current_epoch = epoch
if args.num_iterations > 0:
approx_progress = it / args.num_iterations
else:
approx_progress = consumed / dataset_size
# Trigger last_step when we've consumed enough (instead of when cursor wraps)
if consumed >= dataset_size:
last_step = True
# Build tensors
use_cuda = device_type == "cuda"
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
# Mask out padding positions in targets (set to -1 = ignore_index)
# For each row, positions >= (content_length - 1) in targets should be masked
for i, content_len in enumerate(row_lengths):
if content_len < row_capacity:
targets[i, content_len-1:] = -1
yield inputs, targets
train_loader = sft_data_generator_bos_bestfit("train")
build_val_loader = lambda: sft_data_generator_bos_bestfit("val")
progress = 0 # will go from 0 to 1 over the course of the epoch
# Learning rate scheduler
def get_lr_multiplier(progress):
# first 80% of training: no decay, then linearly ramp down to 0.
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum
# -----------------------------------------------------------------------------
# Training loop
# Learning rate scheduler
def get_lr_multiplier(it):
lrm = 1.0 - it / num_iterations
return lrm
# Go!
x, y = next(train_loader) # prefetch the very first batch of data
min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
step = 0
for step in range(num_iterations):
last_step = step == num_iterations - 1
while True:
flops_so_far = num_flops_per_token * args.total_batch_size * step
# evaluate the validation loss
if last_step or step % args.eval_every == 0:
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
if ddp:
last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device)
dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX)
last_step = bool(last_step_tensor.item())
# once in a while: evaluate the val bpb (all ranks participate)
if last_step or (args.eval_every > 0 and step % args.eval_every == 0):
model.eval()
val_loader = build_val_loader()
losses = []
for _ in range(args.eval_steps):
val_inputs, val_targets = next(val_loader)
with torch.no_grad(), autocast_ctx:
loss = model(val_inputs, val_targets)
losses.append(loss)
val_loss = torch.stack(losses).mean() # average over eval_steps
if ddp:
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks
val_loss = val_loss.item()
print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}")
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
with autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb:
min_val_bpb = val_bpb
wandb_run.log({
"step": step,
"val_loss": val_loss,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"val/bpb": val_bpb,
})
model.train()
# evaluate accuracy of the multiple choice tasks (which are quick to run)
if last_step or (step > 0 and step % args.eval_metrics_every == 0):
model.eval()
metrics = {}
with torch.no_grad(), autocast_ctx:
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=args.device_batch_size*2, max_problems=args.eval_metrics_max_problems)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=args.device_batch_size*2, max_problems=args.eval_metrics_max_problems)
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
print0(f"Step {step:05d} | {metrics_str}")
wandb_run.log({
"step": step,
**metrics,
})
model.train()
# save checkpoint at the end of the run (only on master process)
if master_process and last_step and not args.dry_run:
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(),
optimizer.state_dict(),
{
"step": step,
"val_bpb": val_bpb, # loss at last step
"model_config": {
"sequence_len": args.max_seq_len,
"vocab_size": tokenizer.get_vocab_size(),
"n_layer": depth,
"n_head": model.config.n_head,
"n_kv_head": model.config.n_kv_head,
"n_embd": model.config.n_embd,
"window_pattern": model.config.window_pattern,
},
"user_config": user_config, # inputs to the training script
}
)
if last_step:
break
# -------------------------------------------------------------------------
# single training step
# evaluate the gradient
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
train_inputs, train_targets = next(train_loader)
with autocast_ctx:
loss = model(train_inputs, train_targets)
loss = model(x, y)
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward() # accumulate the gradient
num_tokens += (train_targets >= 0).sum()
if ddp:
dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks
# learning rate scheduler
lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
# step the optimizers
for opt in optimizers:
opt.step()
loss.backward()
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
progress = max(progress, approx_progress) # only increase progress monotonically
# step the optimizer
lrm = get_lr_multiplier(progress)
muon_momentum = get_muon_momentum(step)
for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
if group['kind'] == 'muon':
group["momentum"] = muon_momentum
optimizer.step()
model.zero_grad(set_to_none=True)
synchronize()
t1 = time.time()
dt = t1 - t0
# -------------------------------------------------------------------------
# logging
train_loss_item = train_loss.item()
num_tokens_item = num_tokens.item()
print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}")
wandb_run.log({
"step": step,
"lrm": lrm,
"train_loss": train_loss_item,
"num_tokens": num_tokens_item,
})
# State
step += 1
# Save the model at the end of the run
if master_process:
base_dir = get_base_dir()
depth = model.config.n_layer
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
save_checkpoint(
checkpoint_dir,
step,
model.state_dict(),
None, # note: we don't bother to save the optimizer state
{
# logging
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * progress
tok_per_sec = int(args.total_batch_size / dt)
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
if step % 10 == 0:
wandb_run.log({
"step": step,
"val_loss": val_loss,
**metrics,
"model_config": model_config_kwargs,
}
)
print(f"✅ Saved model checkpoint to {checkpoint_dir}")
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"train/loss": debiased_smooth_loss,
"train/lrm": lrm,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
"train/epoch": current_epoch,
})
# print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report
from nanochat.report import get_report
get_report().log(section="Chat SFT", data=[
user_config, # CLI args
{
"Training rows": len(train_ds),
"Number of iterations": num_iterations,
"Training loss": train_loss_item,
"Validation loss": val_loss,
},
])
if not args.dry_run:
from nanochat.report import get_report
get_report().log(section="SFT", data=[
user_config, # CLI args
{ # stats about the training setup
"Number of iterations": step,
"DDP world size": ddp_world_size,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
}
])
# Cleanup
wandb_run.finish()
# cleanup
wandb_run.finish() # wandb run finish
compute_cleanup()

View File

@ -26,7 +26,7 @@ Abuse Prevention:
- Maximum 8000 characters per message
- Maximum 32000 characters total conversation length
- Temperature clamped to 0.0-2.0
- Top-k clamped to 1-200
- Top-k clamped to 0-200 (0 disables top-k filtering, using full vocabulary)
- Max tokens clamped to 1-4096
"""
@ -55,14 +55,14 @@ MAX_MESSAGE_LENGTH = 8000
MAX_TOTAL_CONVERSATION_LENGTH = 32000
MIN_TEMPERATURE = 0.0
MAX_TEMPERATURE = 2.0
MIN_TOP_K = 1
MIN_TOP_K = 0 # 0 disables top-k filtering, using full vocabulary
MAX_TOP_K = 200
MIN_MAX_TOKENS = 1
MAX_MAX_TOKENS = 4096
parser = argparse.ArgumentParser(description='NanoChat Web Server')
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')

View File

@ -1,373 +0,0 @@
"""
Midtrain the model. Same as pretraining but simpler.
Run as:
python -m scripts.mid_train
Or torchrun for training:
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16
"""
import argparse
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import time
import wandb
import torch
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.checkpoint_manager import load_model
import torch.distributed as dist
from tasks.common import TaskMixture
from tasks.gsm8k import GSM8K
from tasks.mmlu import MMLU
from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# -----------------------------------------------------------------------------
# CLI arguments
parser = argparse.ArgumentParser(description="Midtrain the model")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
# Training horizon
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
# Batch sizes
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
# Optimization
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR")
# Evaluation
parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
# Output
parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=args.run, config=user_config)
# Load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
orig_model = model
model = torch.compile(model, dynamic=False)
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
token_bytes = get_token_bytes(device=device)
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
adamw_optimizer, muon_optimizer = optimizers
# Override the initial learning rate as a fraction of the base learning rate
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# Midtraining data mixture and DataLoader
base_dir = get_base_dir()
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
train_dataset = TaskMixture([
SmolTalk(split="train"), # 460K rows of general conversations
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
val_dataset = TaskMixture([
SmolTalk(split="test"), # 24K rows in test set
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
]) # total: 24K + 14K + 1.32K ~= 39K rows
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
# A big problem is that we don't know the final num_iterations in advance. So we create
# these two global variables and update them from within the data generator.
last_step = False # we will toggle this to True when we reach the end of the training dataset
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
current_epoch = 1 # track epoch for logging
def mid_data_generator_bos_bestfit(split, buffer_size=100):
"""
BOS-aligned dataloader for midtraining with bestfit-crop packing.
Each row in the batch starts with BOS (beginning of a conversation).
Conversations are packed using best-fit algorithm to minimize cropping.
This matches the BOS-aligned approach used in pretraining.
"""
global last_step, approx_progress, current_epoch
assert split in {"train", "val"}, "split must be 'train' or 'val'"
dataset = train_dataset if split == "train" else val_dataset
dataset_size = len(dataset)
assert dataset_size > 0
row_capacity = args.max_seq_len + 1 # +1 for target at last position
# Conversation buffer: list of token lists
conv_buffer = []
cursor = ddp_rank # Each rank processes different conversations (for fetching)
consumed = ddp_rank # Track actual consumption separately from buffering
epoch = 1
it = 0 # iteration counter
def refill_buffer():
nonlocal cursor, epoch
while len(conv_buffer) < buffer_size:
conversation = dataset[cursor]
ids, _ = tokenizer.render_conversation(conversation)
conv_buffer.append(ids)
cursor += ddp_world_size
if cursor >= dataset_size:
cursor = cursor % dataset_size
epoch += 1
# Note: last_step is now triggered based on consumption, not fetching
while True:
rows = []
for _ in range(args.device_batch_size):
row = []
while len(row) < row_capacity:
# Ensure buffer has conversations
while len(conv_buffer) < buffer_size:
refill_buffer()
remaining = row_capacity - len(row)
# Find largest conversation that fits entirely
best_idx = -1
best_len = 0
for i, conv in enumerate(conv_buffer):
conv_len = len(conv)
if conv_len <= remaining and conv_len > best_len:
best_idx = i
best_len = conv_len
if best_idx >= 0:
# Found a conversation that fits - use it entirely
conv = conv_buffer.pop(best_idx)
row.extend(conv)
consumed += ddp_world_size # Track actual consumption
else:
# No conversation fits - crop first conversation to fill remaining
conv = conv_buffer.pop(0)
row.extend(conv[:remaining])
consumed += ddp_world_size # Track actual consumption
rows.append(row[:row_capacity])
# Stopping condition to respect num_iterations, if given
it += 1
if 0 < args.num_iterations <= it and split == "train":
last_step = True
# Update progress tracking (based on consumed, not cursor, to account for buffering)
if split == "train":
current_epoch = epoch
if args.num_iterations > 0:
approx_progress = it / args.num_iterations
else:
approx_progress = consumed / dataset_size
# Trigger last_step when we've consumed enough (instead of when cursor wraps)
if consumed >= dataset_size:
last_step = True
# Build tensors
use_cuda = device_type == "cuda"
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
yield inputs, targets
train_loader = mid_data_generator_bos_bestfit("train")
build_val_loader = lambda: mid_data_generator_bos_bestfit("val")
progress = 0 # will go from 0 to 1 over the course of the epoch
# Learning rate scheduler
def get_lr_multiplier(progress):
# first 80% of training: no decay, then linearly ramp down to 0.
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum
# -----------------------------------------------------------------------------
# Training loop
x, y = next(train_loader) # prefetch the very first batch of data
min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
step = 0
while True:
flops_so_far = num_flops_per_token * args.total_batch_size * step
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
if ddp:
last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device)
dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX)
last_step = bool(last_step_tensor.item())
# once in a while: evaluate the val bpb (all ranks participate)
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
model.eval()
val_loader = build_val_loader()
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
with autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb:
min_val_bpb = val_bpb
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"val/bpb": val_bpb,
})
model.train()
# save checkpoint at the end of the run (only on master process)
if master_process and last_step and not args.dry_run:
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(),
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
{
"step": step,
"val_bpb": val_bpb, # loss at last step
"model_config": {
"sequence_len": args.max_seq_len,
"vocab_size": tokenizer.get_vocab_size(),
"n_layer": depth,
"n_head": model.config.n_head,
"n_kv_head": model.config.n_kv_head,
"n_embd": model.config.n_embd,
},
"user_config": user_config, # inputs to the training script
}
)
if last_step:
break
# -------------------------------------------------------------------------
# single training step
# evaluate the gradient
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
progress = max(progress, approx_progress) # only increase progress monotonically
# step the optimizers
lrm = get_lr_multiplier(progress)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
muon_momentum = get_muon_momentum(step)
for group in muon_optimizer.param_groups:
group["momentum"] = muon_momentum
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
synchronize()
t1 = time.time()
dt = t1 - t0
# -------------------------------------------------------------------------
# State
step += 1
# logging
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * progress
tok_per_sec = int(args.total_batch_size / dt)
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
if step % 10 == 0:
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"total_training_time": total_training_time,
"train/loss": debiased_smooth_loss,
"train/lrm": lrm,
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
"train/epoch": current_epoch,
})
# print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report
if not args.dry_run:
from nanochat.report import get_report
get_report().log(section="Midtraining", data=[
user_config, # CLI args
{ # stats about the training setup
"Number of iterations": step,
"DDP world size": ddp_world_size,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
}
])
# cleanup
wandb_run.finish() # wandb run finish
compute_cleanup()

View File

@ -14,7 +14,7 @@ from nanochat.dataset import parquets_iter_batched
# Parse command line arguments
parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
parser.add_argument('--max-chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
parser.add_argument('--max-chars', type=int, default=2_000_000_000, help='Maximum characters to train on (default: 10B)')
parser.add_argument('--doc-cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
parser.add_argument('--vocab-size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)')
args = parser.parse_args()

View File

@ -20,7 +20,7 @@ LLM because it has to learn how every token (a little semantic chunk/atom)
maps to the sequence of individual characters that make it up. Larger models
learn this eventually on their own, but if we want this capability to exist
in smaller models, we have to actively encourage it by over-representing it
in the training data. Midtraining is a good place to do this.
in the training data. SFT is a good place to do this.
To preview a few example conversations, run:
python -m tasks.spellingbee

View File

@ -178,6 +178,39 @@ class TestFA3VsSDPA:
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token")
print(f"single_token: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_kvcache_single_token_sliding_window(self):
"""Test single token decode with sliding window smaller than cache size.
This catches the bug where SDPA ignores window_size during Tq=1 decode.
When window < Tk, FA3 only attends to the last (window+1) tokens,
but SDPA was attending to all cached tokens.
"""
B, T_max, H, D = 2, 64, 4, 32
T_prefill = 32 # Enough tokens to exceed window
window = 8 # Window SMALLER than cache size
k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_cache[:, :T_prefill, :, :] = k_init
v_cache[:, :T_prefill, :, :] = v_init
cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE)
return flash_attn.flash_attn_with_kvcache(
q_single, k_cache, v_cache, k=k_single, v=v_single,
cache_seqlens=cache_seqlens,
causal=True, window_size=(window, 0) # window=8 < Tk=33
)
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token_sliding_window")
print(f"single_token_sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_backward_gradients_match(self):
"""Verify gradients are similar between FA3 and SDPA."""
B, T, H, D = 2, 32, 4, 16

View File

@ -96,6 +96,7 @@ def test_kv_cache_basic():
head_dim=head_dim,
num_layers=num_layers,
device="cpu",
dtype=torch.float32,
)
# Check initial state
@ -130,7 +131,7 @@ def test_kv_cache_prefill():
# Create source cache and advance it
src_cache = KVCache(
batch_size=batch_size, num_heads=num_heads, seq_len=32,
head_dim=head_dim, num_layers=num_layers, device="cpu",
head_dim=head_dim, num_layers=num_layers, device="cpu", dtype=torch.float32,
)
# Write some data to source cache
src_cache.k_cache[0, 0, :16, :, :] = 1.0
@ -140,7 +141,7 @@ def test_kv_cache_prefill():
# Create destination cache with larger seq_len
dst_cache = KVCache(
batch_size=batch_size, num_heads=num_heads, seq_len=64,
head_dim=head_dim, num_layers=num_layers, device="cpu",
head_dim=head_dim, num_layers=num_layers, device="cpu", dtype=torch.float32,
)
# Prefill
@ -195,3 +196,72 @@ def test_multi_sample_first_token_diversity():
f"With uniform logits, this is statistically impossible (~10^-36 probability) "
f"unless tokens are being broadcast instead of independently sampled."
)
def test_seed_reproducibility():
"""Same seed must produce identical output."""
model = MockModel()
engine = Engine(model, ByteTokenizer())
prompt = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
for seed in [1, 42, 123, 999]:
r1, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
r2, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
r3, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
assert r1 == r2 == r3, "Same seed must produce identical output for the same prompt."
def test_temperature_zero_determinism():
"""Temperature=0 is deterministic regardless of seed."""
model = MockModel()
engine = Engine(model, ByteTokenizer())
prompt = [261, 72, 101, 108, 108, 111]
r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=1)
r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=42)
r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=123)
assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed."
def test_max_tokens_respected():
"""Generation stops at max_tokens limit."""
model = MockModel()
engine = Engine(model, ByteTokenizer())
prompt = [261, 72, 101, 108, 108, 111]
for max_tokens in [1, 4, 16, 64]:
results, _ = engine.generate_batch(prompt, max_tokens=max_tokens)
num_generated_tokens = len(results[0]) - len(prompt)
assert num_generated_tokens <= max_tokens, f"Generated {num_generated_tokens} tokens, expected max_tokens={max_tokens} or less."
def test_num_samples_count():
"""num_samples=N produces exactly N sequences."""
model = MockModel()
engine = Engine(model, ByteTokenizer())
prompt = [261, 72, 101, 108, 108, 111]
for num_samples in [1, 4, 16, 64]:
results, _ = engine.generate_batch(prompt, num_samples=num_samples, max_tokens=3)
assert len(results) == num_samples, f"Expected {num_samples} sequences from {num_samples} samples, got {len(results)}"
def test_different_seeds_introduce_variation_when_temperature_nonzero():
"""With temperature > 0, different seeds should introduce sampling variation."""
model = MockModel()
engine = Engine(model, ByteTokenizer())
prompt = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
outputs = set()
for seed in [1, 42, 123, 999, 1000, 1001, 1002, 1003, 1004, 1005]:
results, _ = engine.generate_batch(
prompt,
temperature=1.0,
max_tokens=5,
seed=seed,
)
outputs.add(tuple(results[0]))
# Sanity check: sampling actually introduces variation
assert len(outputs) > 1, "All seeds produced the same output which is statistically highly improbable."

243
uv.lock
View File

@ -1505,9 +1505,8 @@ dependencies = [
{ name = "tabulate" },
{ name = "tiktoken" },
{ name = "tokenizers" },
{ name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "torch", version = "2.9.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "torch", version = "2.9.1", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "torch", version = "2.9.1", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
{ name = "torch", version = "2.9.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "torch", version = "2.9.1+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-8-nanochat-gpu'" },
{ name = "transformers" },
@ -1546,9 +1545,9 @@ requires-dist = [
{ name = "tabulate", specifier = ">=0.9.0" },
{ name = "tiktoken", specifier = ">=0.11.0" },
{ name = "tokenizers", specifier = ">=0.22.0" },
{ name = "torch", specifier = ">=2.9.0" },
{ name = "torch", marker = "extra == 'cpu'", specifier = ">=2.9.1", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "nanochat", extra = "cpu" } },
{ name = "torch", marker = "extra == 'gpu'", specifier = ">=2.9.1", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nanochat", extra = "gpu" } },
{ name = "torch", specifier = "==2.9.1" },
{ name = "torch", marker = "extra == 'cpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "nanochat", extra = "cpu" } },
{ name = "torch", marker = "extra == 'gpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nanochat", extra = "gpu" } },
{ name = "transformers", specifier = ">=4.57.3" },
{ name = "uvicorn", specifier = ">=0.36.0" },
{ name = "wandb", specifier = ">=0.21.3" },
@ -1688,7 +1687,7 @@ name = "nvidia-cudnn-cu12"
version = "9.10.2.21"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" },
@ -1701,7 +1700,7 @@ name = "nvidia-cufft-cu12"
version = "11.3.3.83"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" },
@ -1733,9 +1732,9 @@ name = "nvidia-cusolver-cu12"
version = "11.7.3.90"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cusparse-cu12", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cusparse-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" },
@ -1748,7 +1747,7 @@ name = "nvidia-cusparse-cu12"
version = "12.5.8.93"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" },
@ -2990,72 +2989,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257, upload-time = "2024-11-27T22:38:35.385Z" },
]
[[package]]
name = "torch"
version = "2.9.0"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version >= '3.12' and sys_platform == 'linux'",
"python_full_version == '3.11.*' and sys_platform == 'linux'",
"python_full_version < '3.11' and sys_platform == 'linux'",
]
dependencies = [
{ name = "filelock", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "fsspec", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "jinja2", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cublas-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cuda-cupti-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cuda-nvrtc-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cuda-runtime-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cudnn-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cufft-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cufile-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-curand-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cusolver-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cusparselt-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-nccl-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-nvshmem-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-nvtx-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "sympy", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "triton", version = "3.5.0", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "typing-extensions", marker = "(sys_platform == 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/bb/86/245c240d2138c17ed572c943c289056c2721abab70810d772c6bf5495b28/torch-2.9.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:030bbfe367379ae6a4ae4042b6c44da25383343b8b3c68abaa9c7231efbaf2dd", size = 104213554, upload-time = "2025-10-15T15:45:59.798Z" },
{ url = "https://files.pythonhosted.org/packages/58/1d/fd1e88ae0948825efcab7dd66d12bec23f05d4d38ed81573c8d453c14c06/torch-2.9.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:51cb63902182a78e90886e8068befd8ea102af4b00e420263591a3d70c7d3c6c", size = 899795167, upload-time = "2025-10-15T15:47:12.695Z" },
{ url = "https://files.pythonhosted.org/packages/63/5a/496197b45c14982bef4e079b24c61dc108e3ab0d0cc9718dba9f54f45a46/torch-2.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:3f6aad4d2f0ee2248bac25339d74858ff846c3969b27d14ac235821f055af83d", size = 109310314, upload-time = "2025-10-15T15:46:16.633Z" },
{ url = "https://files.pythonhosted.org/packages/58/b0/2b4e647b0fc706e88eb6c253d05511865578f5f67b55fad639bf3272a4a1/torch-2.9.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:413e1654c9203733138858780e184d9fc59442f0b3b209e16f39354eb893db9b", size = 74452019, upload-time = "2025-10-15T15:46:04.296Z" },
{ url = "https://files.pythonhosted.org/packages/58/fe/334225e6330e672b36aef23d77451fa906ea12881570c08638a91331a212/torch-2.9.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:c596708b5105d0b199215acf0c9be7c1db5f1680d88eddadf4b75a299259a677", size = 104230578, upload-time = "2025-10-15T15:46:08.182Z" },
{ url = "https://files.pythonhosted.org/packages/05/cc/49566caaa218872ec9a2912456f470ff92649894a4bc2e5274aa9ef87c4a/torch-2.9.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:51de31219c97c51cf4bf2be94d622e3deb5dcc526c6dc00e97c17eaec0fc1d67", size = 899815990, upload-time = "2025-10-15T15:48:03.336Z" },
{ url = "https://files.pythonhosted.org/packages/74/25/e9ab21d5925b642d008f139d4a3c9664fc9ee1faafca22913c080cc4c0a5/torch-2.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:dd515c70059afd95f48b8192733764c08ca37a1d19803af6401b5ecad7c8676e", size = 109313698, upload-time = "2025-10-15T15:46:12.425Z" },
{ url = "https://files.pythonhosted.org/packages/b3/b7/205ef3e94de636feffd64b28bb59a0dfac0771221201b9871acf9236f5ca/torch-2.9.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:614a185e4986326d526a91210c8fc1397e76e8cfafa78baf6296a790e53a9eec", size = 74463678, upload-time = "2025-10-15T15:46:29.779Z" },
{ url = "https://files.pythonhosted.org/packages/d1/d3/3985739f3b8e88675127bf70f82b3a48ae083e39cda56305dbd90398fec0/torch-2.9.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e5f7af1dc4c0a7c4a260c2534f41ddaf209714f7c89145e644c44712fbd6b642", size = 104107898, upload-time = "2025-10-15T15:46:20.883Z" },
{ url = "https://files.pythonhosted.org/packages/a5/4b/f4bb2e6c25d0272f798cd6d7a04ed315da76cec68c602d87040c7847287f/torch-2.9.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:01cff95ecd9a212ea2f141db28acccdceb6a4c54f64e6c51091146f5e2a772c6", size = 899738273, upload-time = "2025-10-15T15:50:04.188Z" },
{ url = "https://files.pythonhosted.org/packages/66/11/c1c5ba6691cda6279087c35bd626536e4fd29521fe740abf5008377a9a02/torch-2.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:4582b162f541651f0cb184d3e291c05c2f556c7117c64a9873e2ee158d40062b", size = 109280887, upload-time = "2025-10-15T15:46:26.228Z" },
{ url = "https://files.pythonhosted.org/packages/dd/5f/b85bd8c05312d71de9402bf5868d217c38827cfd09d8f8514e5be128a52b/torch-2.9.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:33f58e9a102a91259af289d50525c30323b5c9ae1d31322b6447c0814da68695", size = 74478983, upload-time = "2025-10-15T15:46:39.406Z" },
{ url = "https://files.pythonhosted.org/packages/c2/1c/90eb13833cdf4969ea9707586d7b57095c3b6e2b223a7256bf111689bcb8/torch-2.9.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:c30a17fc83eeab346913e237c64b15b5ba6407fff812f6c541e322e19bc9ea0e", size = 104111330, upload-time = "2025-10-15T15:46:35.238Z" },
{ url = "https://files.pythonhosted.org/packages/0e/21/2254c54b8d523592c25ef4434769aa23e29b1e6bf5f4c0ad9e27bf442927/torch-2.9.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:8f25033b8667b57857dfd01458fbf2a9e6a6df1f8def23aef0dc46292f6aa642", size = 899750243, upload-time = "2025-10-15T15:48:57.459Z" },
{ url = "https://files.pythonhosted.org/packages/b7/a5/5cb94fa4fd1e78223455c23c200f30f6dc10c6d4a2bcc8f6e7f2a2588370/torch-2.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:d037f1b4ffd25013be4a7bf3651a0a910c68554956c7b2c92ebe87c76475dece", size = 109284513, upload-time = "2025-10-15T15:46:45.061Z" },
{ url = "https://files.pythonhosted.org/packages/66/e8/fc414d8656250ee46120b44836ffbb3266343db424b3e18ca79ebbf69d4f/torch-2.9.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e4e5b5cba837a2a8d1a497ba9a58dae46fa392593eaa13b871c42f71847503a5", size = 74830362, upload-time = "2025-10-15T15:46:48.983Z" },
{ url = "https://files.pythonhosted.org/packages/ed/5f/9474c98fc5ae0cd04b9466035428cd360e6611a86b8352a0fc2fa504acdc/torch-2.9.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:64693568f5dc4dbd5f880a478b1cea0201cc6b510d91d1bc54fea86ac5d1a637", size = 104144940, upload-time = "2025-10-15T15:47:29.076Z" },
{ url = "https://files.pythonhosted.org/packages/2d/5a/8e0c1cf57830172c109d4bd6be2708cabeaf550983eee7029291322447a0/torch-2.9.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:f8ed31ddd7d10bfb3fbe0b9fe01b1243577f13d75e6f4a0839a283915ce3791e", size = 899744054, upload-time = "2025-10-15T15:48:29.864Z" },
{ url = "https://files.pythonhosted.org/packages/6d/28/82c28b30fcb4b7c9cdd995763d18bbb830d6521356712faebbad92ffa61d/torch-2.9.0-cp313-cp313t-win_amd64.whl", hash = "sha256:eff527d4e4846e6f70d2afd8058b73825761203d66576a7e04ea2ecfebcb4ab8", size = 109517546, upload-time = "2025-10-15T15:47:33.395Z" },
{ url = "https://files.pythonhosted.org/packages/ff/c3/a91f96ec74347fa5fd24453fa514bc61c61ecc79196fa760b012a1873d96/torch-2.9.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:f8877779cf56d1ce431a7636703bdb13307f5960bb1af49716d8b179225e0e6a", size = 74480732, upload-time = "2025-10-15T15:47:38.002Z" },
{ url = "https://files.pythonhosted.org/packages/5c/73/9f70af34b334a7e0ef496ceec96b7ec767bd778ea35385ce6f77557534d1/torch-2.9.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:7e614fae699838038d888729f82b687c03413c5989ce2a9481f9a7e7a396e0bb", size = 74433037, upload-time = "2025-10-15T15:47:41.894Z" },
{ url = "https://files.pythonhosted.org/packages/b7/84/37cf88625901934c97109e583ecc21777d21c6f54cda97a7e5bbad1ee2f2/torch-2.9.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:dfb5b8cd310ba3436c7e14e8b7833ef658cf3045e50d2bdaed23c8fc517065eb", size = 104116482, upload-time = "2025-10-15T15:47:46.266Z" },
{ url = "https://files.pythonhosted.org/packages/56/8e/ca8b17866943a8d4f4664d402ea84210aa274588b4c5d89918f5caa24eec/torch-2.9.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:b3d29524993a478e46f5d598b249cd824b7ed98d7fba538bd9c4cde6c803948f", size = 899746916, upload-time = "2025-10-15T15:50:40.294Z" },
{ url = "https://files.pythonhosted.org/packages/43/65/3b17c0fbbdab6501c5b320a52a648628d0d44e7379f64e27d9eef701b6bf/torch-2.9.0-cp314-cp314-win_amd64.whl", hash = "sha256:71c7578984f5ec0eb645eb4816ac8435fcf3e3e2ae1901bcd2f519a9cafb5125", size = 109275151, upload-time = "2025-10-15T15:49:20.715Z" },
{ url = "https://files.pythonhosted.org/packages/83/36/74f8c051f785500396e42f93542422422dfd874a174f21f8d955d36e5d64/torch-2.9.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:71d9309aee457bbe0b164bce2111cd911c4ed4e847e65d5077dbbcd3aba6befc", size = 74823353, upload-time = "2025-10-15T15:49:16.59Z" },
{ url = "https://files.pythonhosted.org/packages/62/51/dc3b4e2f9ba98ae27238f0153ca098bf9340b2dafcc67fde645d496dfc2a/torch-2.9.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:c08fb654d783899e204a32cca758a7ce8a45b2d78eeb89517cc937088316f78e", size = 104140340, upload-time = "2025-10-15T15:50:19.67Z" },
{ url = "https://files.pythonhosted.org/packages/c0/8d/b00657f8141ac16af7bb6cda2e67de18499a3263b78d516b9a93fcbc98e3/torch-2.9.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:ec8feb0099b2daa5728fbc7abb0b05730fd97e0f359ff8bda09865aaa7bd7d4b", size = 899731750, upload-time = "2025-10-15T15:49:36.673Z" },
{ url = "https://files.pythonhosted.org/packages/fc/29/bd361e0cbb2c79ce6450f42643aaf6919956f89923a50571b0ebfe92d142/torch-2.9.0-cp314-cp314t-win_amd64.whl", hash = "sha256:695ba920f234ad4170c9c50e28d56c848432f8f530e6bc7f88fcb15ddf338e75", size = 109503850, upload-time = "2025-10-15T15:50:24.118Z" },
]
[[package]]
name = "torch"
version = "2.9.1"
@ -3076,13 +3009,13 @@ dependencies = [
{ name = "typing-extensions", marker = "(sys_platform == 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
wheels = [
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp310-none-macosx_11_0_arm64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp311-none-macosx_11_0_arm64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp312-none-macosx_11_0_arm64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp313-cp313t-macosx_11_0_arm64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp313-none-macosx_11_0_arm64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp314-cp314-macosx_11_0_arm64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp314-cp314t-macosx_11_0_arm64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:bf1e68cfb935ae2046374ff02a7aa73dda70351b46342846f557055b3a540bf0" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:a52952a8c90a422c14627ea99b9826b7557203b46b4d0772d3ca5c7699692425" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:287242dd1f830846098b5eca847f817aa5c6015ea57ab4c1287809efea7b77eb" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:8924d10d36eac8fe0652a060a03fc2ae52980841850b9a1a2ddb0f27a4f181cd" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:bcee64ae7aa65876ceeae6dcaebe75109485b213528c74939602208a20706e3f" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:defadbeb055cfcf5def58f70937145aecbd7a4bc295238ded1d0e85ae2cf0e1d" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:886f84b181f766f53265ba0a1d503011e60f53fff9d569563ef94f24160e1072" },
]
[[package]]
@ -3090,19 +3023,22 @@ name = "torch"
version = "2.9.1"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version >= '3.12' and sys_platform == 'linux'",
"python_full_version >= '3.12' and sys_platform != 'linux'",
"python_full_version == '3.11.*' and sys_platform == 'linux'",
"python_full_version < '3.11' and sys_platform == 'linux'",
"python_full_version == '3.11.*' and sys_platform != 'linux'",
"python_full_version < '3.11' and sys_platform != 'linux'",
]
dependencies = [
{ name = "filelock", marker = "(sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "fsspec", marker = "(sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "jinja2", marker = "(sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "sympy", marker = "(sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "typing-extensions", marker = "(sys_platform != 'linux' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "filelock", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
{ name = "fsspec", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
{ name = "jinja2", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "setuptools", marker = "(python_full_version >= '3.12' and extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "sympy", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
{ name = "typing-extensions", marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/5f/56/9577683b23072075ed2e40d725c52c2019d71a972fab8e083763da8e707e/torch-2.9.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:1cc208435f6c379f9b8fdfd5ceb5be1e3b72a6bdf1cb46c0d2812aa73472db9e", size = 104207681, upload-time = "2025-11-12T15:19:56.48Z" },
@ -3158,30 +3094,30 @@ dependencies = [
{ name = "typing-extensions", marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
wheels = [
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-win_amd64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-manylinux_2_28_aarch64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-win_amd64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-win_arm64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-win_amd64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-win_arm64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-win_amd64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-win_arm64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313t-manylinux_2_28_aarch64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313t-win_amd64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314-manylinux_2_28_aarch64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314-manylinux_2_28_x86_64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314-win_amd64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314t-manylinux_2_28_aarch64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314t-manylinux_2_28_x86_64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314t-win_amd64.whl" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:10866c8a48c4aa5ae3f48538dc8a055b99c57d9c6af2bf5dd715374d9d6ddca3" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:7210713b66943fdbfcc237b2e782871b649123ac5d29f548ce8c85be4223ab38" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp310-cp310-win_amd64.whl", hash = "sha256:d6e8441453dc27524e3f1037fbf27b90a02644b84e42944b9354b4024cb51cc1" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:0e611cfb16724e62252b67d31073bc5c490cb83e92ecdc1192762535e0e44487" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:3de2adb9b4443dc9210ef1f1b16da3647ace53553166d6360bbbd7edd6f16e4d" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-win_amd64.whl", hash = "sha256:69b3785d28be5a9c56ab525788ec5000349ec59132a74b7d5e954b905015b992" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp311-cp311-win_arm64.whl", hash = "sha256:15b4ae6fe371d96bffb8e1e9af62164797db20a0dc1337345781659cfd0b8bb1" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3bf9b442a51a2948e41216a76d7ab00f0694cfcaaa51b6f9bcab57b7f89843e6" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:7417d8c565f219d3455654cb431c6d892a3eb40246055e14d645422de13b9ea1" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-win_amd64.whl", hash = "sha256:a4e06b4f441675d26b462123c8a83e77c55f1ec8ebc081203be2db1ea8054add" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp312-cp312-win_arm64.whl", hash = "sha256:1abe31f14b560c1f062699e966cb08ef5b67518a1cfac2d8547a3dbcd8387b06" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:3e532e553b37ee859205a9b2d1c7977fd6922f53bbb1b9bfdd5bdc00d1a60ed4" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:39b3dff6d8fba240ae0d1bede4ca11c2531ae3b47329206512d99e17907ff74b" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-win_amd64.whl", hash = "sha256:404a7ab2fffaf2ca069e662f331eb46313692b2f1630df2720094284f390ccef" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313-win_arm64.whl", hash = "sha256:161decbff26a33f13cb5ba6d2c8f458bbf56193bcc32ecc70be6dd4c7a3ee79d" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:01b1884f724977a20c7da2f640f1c7b37f4a2c117a7f4a6c1c0424d14cb86322" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:031a597147fa81b1e6d79ccf1ad3ccc7fafa27941d6cf26ff5caaa384fb20e92" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp313-cp313t-win_amd64.whl", hash = "sha256:e586ab1363e3f86aa4cc133b7fdcf98deb1d2c13d43a7a6e5a6a18e9c5364893" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:65010ab4aacce6c9a1ddfc935f986c003ca8638ded04348fd326c3e74346237c" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:88adf5157db5da1d54b1c9fe4a6c1d20ceef00e75d854e206a87dbf69e3037dc" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314-win_amd64.whl", hash = "sha256:f60e2565f261542efac07e25208fb3fc55c6fe82314a5a9cbee971edb5f27713" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:3ac2b8df2c55430e836dcda31940d47f1f5f94b8731057b6f20300ebea394dd9" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:5b688445f928f13563b7418b17c57e97bf955ab559cf73cd8f2b961f8572dbb3" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.9.1%2Bcpu-cp314-cp314t-win_amd64.whl", hash = "sha256:cf9c3e50b595721ca6b488bdcc326e0f1af73ed28b9b66eff504a96649bb5c96" },
]
[[package]]
@ -3219,31 +3155,31 @@ dependencies = [
{ name = "nvidia-nvtx-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "setuptools", marker = "(python_full_version >= '3.12' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "sympy", marker = "extra == 'extra-8-nanochat-gpu'" },
{ name = "triton", version = "3.5.1", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "triton", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "typing-extensions", marker = "extra == 'extra-8-nanochat-gpu'" },
]
wheels = [
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp310-cp310-manylinux_2_28_aarch64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp310-cp310-win_amd64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-manylinux_2_28_aarch64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-win_amd64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-manylinux_2_28_aarch64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-win_amd64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-win_amd64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-win_amd64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-manylinux_2_28_aarch64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-manylinux_2_28_x86_64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-win_amd64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-manylinux_2_28_aarch64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-manylinux_2_28_x86_64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-win_amd64.whl" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:72f0f096475e8095a6bea3fba75bd3b46cf42c761b29588f7599314e67a32661" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:c8d670aa0be6fbecd2b0e7b7d514a104dbdefcc3786ca446cf0c3415043ea40a" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp310-cp310-win_amd64.whl", hash = "sha256:64399adaa8ea0896d02cf844cba3c5dd77e769520a1af73572599e0eaa2cf551" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:cf4ad82430824a80a9f398e29369524ed26c152cf00c2c12002e5400b35e260d" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:2a1da940f0757621d098c9755f7504d791a72a40920ec85a4fd98b20253fca4e" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-win_amd64.whl", hash = "sha256:633005a3700e81b5be0df2a7d3c1d48aced23ed927653797a3bd2b144a3aeeb6" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:1176f250311fa95cc3bca8077af323e0d73ea385ba266e096af82e7e2b91f256" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:7cb4018f4ce68b61fd3ef87dc1c4ca520731c7b5b200e360ad47b612d7844063" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-win_amd64.whl", hash = "sha256:3a01f0b64c10a82d444d9fd06b3e8c567b1158b76b2764b8f51bfd8f535064b0" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:0b80b7555dcd0a75b7b06016991f01281a0bb078cf28fa2d1dfb949fad2fbd07" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:63381a109a569b280ed3319da89d3afe5cf9ab5c879936382a212affb5c90552" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-win_amd64.whl", hash = "sha256:ad9183864acdd99fc5143d7ca9d3d2e7ddfc9a9600ff43217825d4e5e9855ccc" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2314521c74d76e513c53bb72c0ce3511ef0295ff657a432790df6c207e5d7962" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:4454a4faca31af81566e3a4208f10f20b8a6d9cfe42791b0ca7ff134326468fc" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:24420e430e77136f7079354134b34e7ba9d87e539f5ac84c33b08e5c13412ebe" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:32c036296c557f19a1537ce981c40533650097114e1720a321a39a3b08d9df56" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:7788d3d03d939cf00f93ac0da5ab520846f66411e339cfbf519a806e8facf519" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-win_amd64.whl", hash = "sha256:7bcd40cbffac475b478d6ce812f03da84e9a4894956efb89c3b7bcca5dbd4f91" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:e88c78e5b08ae9303aa15da43b68b44287ecbec16d898d9fad6998832fe626a5" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:7d8769bdf3200ca16a92f14df404c3370171ac3732996528a8973d753eac562f" },
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-win_amd64.whl", hash = "sha256:0c784b600959ec70ee01cb23e8bc870a0e0475af30378ff5e39f4abed8b7c1cc" },
]
[[package]]
@ -3307,41 +3243,10 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/6a/6b/2f416568b3c4c91c96e5a365d164f8a4a4a88030aa8ab4644181fdadce97/transformers-4.57.3-py3-none-any.whl", hash = "sha256:c77d353a4851b1880191603d36acb313411d3577f6e2897814f333841f7003f4", size = 11993463, upload-time = "2025-11-25T15:51:26.493Z" },
]
[[package]]
name = "triton"
version = "3.5.0"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version >= '3.12' and sys_platform == 'linux'",
"python_full_version == '3.11.*' and sys_platform == 'linux'",
"python_full_version < '3.11' and sys_platform == 'linux'",
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/dd/22/507b6f58a35e05e84381630b2dc2a3cee1a7a2a7eaf4cba857c638a18a24/triton-3.5.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6f90de6a6566bb619b4c0adc9855729e1b1b5e26533fca1bf6206e96b6d277a3", size = 159827599, upload-time = "2025-10-15T19:15:43.87Z" },
{ url = "https://files.pythonhosted.org/packages/0b/eb/09e31d107a5d00eb281aa7e6635ca463e9bca86515944e399480eadb71f8/triton-3.5.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d5d3b3d480debf24eaa739623c9a42446b0b77f95593d30eb1f64cd2278cc1f0", size = 170333110, upload-time = "2025-10-13T16:37:49.588Z" },
{ url = "https://files.pythonhosted.org/packages/79/f9/b6f60f978397c616fd8dacca2305759fe4f80d397b20ef72534803244bd5/triton-3.5.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8457b22148defefdcb7fa8144b05ce211b9faefad650a1ce85b23df488d5549c", size = 159926731, upload-time = "2025-10-15T19:15:49.682Z" },
{ url = "https://files.pythonhosted.org/packages/3d/78/949a04391c21956c816523678f0e5fa308eb5b1e7622d88c4e4ef5fceca0/triton-3.5.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f34bfa21c5b3a203c0f0eab28dcc1e49bd1f67d22724e77fb6665a659200a4ec", size = 170433488, upload-time = "2025-10-13T16:37:57.132Z" },
{ url = "https://files.pythonhosted.org/packages/87/9b/30988039e1e84df7554fba24e6a734d2d0e847af33cabdf9b532b3c51456/triton-3.5.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7da21fccceafc163e3a5e857abe34351ef76345af06cabf9637a914742671f0b", size = 159946647, upload-time = "2025-10-15T19:15:56.325Z" },
{ url = "https://files.pythonhosted.org/packages/f5/3a/e991574f3102147b642e49637e0281e9bb7c4ba254edb2bab78247c85e01/triton-3.5.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c9e71db82261c4ffa3921cd050cd5faa18322d2d405c30eb56084afaff3b0833", size = 170476535, upload-time = "2025-10-13T16:38:05.18Z" },
{ url = "https://files.pythonhosted.org/packages/cd/85/e37f1197acb04c8f3d83851d23d5d6ed5060ef74580668b112e23fdfa203/triton-3.5.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:188da5b81fa2f8322c27fec1627703eac24cb9bb7ab0dfbe9925973bc1b070d3", size = 159958970, upload-time = "2025-10-15T19:16:01.717Z" },
{ url = "https://files.pythonhosted.org/packages/6c/29/10728de8a6e932e517c10773486b8e99f85d1b1d9dd87d9a9616e1fef4a1/triton-3.5.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e6bb9aa5519c084a333acdba443789e50012a4b851cd486c54f0b8dc2a8d3a12", size = 170487289, upload-time = "2025-10-13T16:38:11.662Z" },
{ url = "https://files.pythonhosted.org/packages/b8/1d/38258f05010ac17a7b058c022911c9cae6526e149b7397134a048cf5a6c2/triton-3.5.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:03127d9b33aaf979c856676b394bc059ec1d68cb6da68ae03f62dd8ad77a04ae", size = 160073012, upload-time = "2025-10-15T19:16:07.477Z" },
{ url = "https://files.pythonhosted.org/packages/5c/38/db80e48b9220c9bce872b0f616ad0446cdf554a40b85c7865cbca99ab3c2/triton-3.5.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c83f2343e1a220a716c7b3ab9fccfcbe3ad4020d189549200e2d2e8d5868bed9", size = 170577179, upload-time = "2025-10-13T16:38:17.865Z" },
{ url = "https://files.pythonhosted.org/packages/91/fe/8f5771d00227f4eb1ee034f218ed427102b989366d2275fe3b3c105a3921/triton-3.5.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:468936651d383f4a6d10068d34a627505e13af55be5d002b9f27b987e7a5f0ac", size = 159957460, upload-time = "2025-10-15T19:16:12.626Z" },
{ url = "https://files.pythonhosted.org/packages/ff/60/1810655d1d856c9a4fcc90ee8966d85f552d98c53a6589f95ab2cbe27bb8/triton-3.5.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:da0fa67ccd76c3dcfb0bffe1b1c57c685136a6bd33d141c24d9655d4185b1289", size = 170487949, upload-time = "2025-10-13T16:38:24.881Z" },
{ url = "https://files.pythonhosted.org/packages/78/59/99edd103958fe6e42b50b9ad8ce4f223ddf4ccf475259cf7d2b53381dc6c/triton-3.5.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c7ceef21410229ac23173a28eee5cfc0e37c1dfdb8b4bc11ecda2e3ecec7c686", size = 160075629, upload-time = "2025-10-15T19:16:18.746Z" },
{ url = "https://files.pythonhosted.org/packages/fb/b7/1dec8433ac604c061173d0589d99217fe7bf90a70bdc375e745d044b8aad/triton-3.5.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:317fe477ea8fd4524a6a8c499fb0a36984a56d0b75bf9c9cb6133a1c56d5a6e7", size = 170580176, upload-time = "2025-10-13T16:38:31.14Z" },
]
[[package]]
name = "triton"
version = "3.5.1"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version >= '3.12' and sys_platform == 'linux'",
"python_full_version == '3.11.*' and sys_platform == 'linux'",
"python_full_version < '3.11' and sys_platform == 'linux'",
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/d9/2e/f95e673222afa2c7f0c687d8913e98fcf2589ef0b1405de76894e37fe18f/triton-3.5.1-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f63e34dcb32d7bd3a1d0195f60f30d2aee8b08a69a0424189b71017e23dfc3d2", size = 159821655, upload-time = "2025-11-11T17:51:44.09Z" },
{ url = "https://files.pythonhosted.org/packages/fd/6e/676ab5019b4dde8b9b7bab71245102fc02778ef3df48218b298686b9ffd6/triton-3.5.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5fc53d849f879911ea13f4a877243afc513187bc7ee92d1f2c0f1ba3169e3c94", size = 170320692, upload-time = "2025-11-11T17:40:46.074Z" },