Merge branch 'master' into fix/data_cutoff

This commit is contained in:
svlandeg 2026-01-28 21:03:12 +01:00
commit 3bf06802e6
20 changed files with 1413 additions and 186 deletions

View File

@ -4,28 +4,29 @@
> The best ChatGPT that $100 can buy.
This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs.
## Talk to it
To get a sense of the endpoint of this repo, you can currently find [nanochat d34](https://github.com/karpathy/nanochat/discussions/314) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d34" means that this model has 34 layers in the Transformer neural network. This model has 2.2 billion parameters, it was trained on 88 billion tokens by simply running the training script [run1000.sh](run1000.sh) with `--target_param_data_ratio=40` (2x longer than Chinchilla-optimal), and the total cost of training was ~$2,500 (about 100 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of modern Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](runs/speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs.
## Updates
- (Jan 7 2026) See new post: [nanochat Miniseries v1](https://github.com/karpathy/nanochat/discussions/420) and the associated script [miniseries.sh](miniseries.sh).
- (Jan 16 2026) The repo is in active development, I am currently fleshing out the pretraining stage.
- (Jan 7 2026) See new post: [nanochat Miniseries v1](https://github.com/karpathy/nanochat/discussions/420) and the associated script [miniseries.sh](runs/miniseries.sh).
## Talk to it
To get a sense of the endpoint of this repo, you can currently find [nanochat d34](https://github.com/karpathy/nanochat/discussions/314) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d34" means that this model has 34 layers in the Transformer neural network. This model has 2.2 billion parameters, it was trained on 88 billion tokens by simply running the training script [run1000.sh](runs/run1000.sh) with `--target_param_data_ratio=40` (2x longer than Chinchilla-optimal), and the total cost of training was ~$2,500 (about 100 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of modern Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
## Quick start
The fastest way to feel the magic is to run the speedrun script [speedrun.sh](speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
The fastest way to feel the magic is to run the speedrun script [speedrun.sh](runs/speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
```bash
bash speedrun.sh
bash runs/speedrun.sh
```
Alternatively, since the script runs for 4 hours, I like to launch it like this inside a new screen session `speedrun` (and also log output to `speedrun.log`):
```bash
screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
screen -L -Logfile speedrun.log -S speedrun bash runs/speedrun.sh
```
See the [screen cheatsheet](https://gist.github.com/jctosta/af918e1618682638aa82) if you are less familiar. You can watch it go inside the screen session, or detach with `Ctrl-a d` and `tail speedrun.log` to view progress. Now wait 4 hours. Once it's done, you can talk to your LLM via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it:
@ -72,7 +73,7 @@ Total wall clock time: 3h51m
Unsurprisingly, $100 is not enough to train a highly performant ChatGPT clone. In fact, LLMs are famous for their multi-million dollar capex. For our purposes, I think there are two more scales of interest. First is the ~$300 tier d26 model (i.e. depth=26) that trains in ~12 hours, which slightly outperforms GPT-2 CORE score. Second is the $1000 tier (~41.6 hours), just because it's a nice round number. But both of these are not yet fully supported and therefore not attached here in the master branch yet.
That said, to give a sense, the example changes needed for the [speedrun.sh](speedrun.sh) file to train a GPT-2 grade model d26 only involve three changes:
That said, to give a sense, the example changes needed for the [speedrun.sh](runs/speedrun.sh) file to train a GPT-2 grade model d26 only involve three changes:
```bash
...
@ -99,7 +100,7 @@ And a bit more about computing environments that will run nanochat:
## Running on CPU / MPS
nanochat can be run on CPU or on MPS (if you're on Macbook), and will automatically try to detect what device is best to run on. You're not going to get too far without GPUs, but at least you'll be able to run the code paths and maybe train a tiny LLM with some patience. For an example of how to make all the run commands much smaller (feel free to tune!), you can refer to [dev/runcpu.sh](dev/runcpu.sh) file. You'll see that I'm essentially restricting all scripts to train smaller models, to run for shorter number of iterations, etc. This functionality is new, slightly gnarly (touched a lot of code), and was merged in this [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) on Oct 21, 2025.
nanochat can be run on CPU or on MPS (if you're on Macbook) in principle, and will automatically try to detect what device is best to run on. The script [runcpu.sh](runs/runcpu.sh) shows a very simple example that will exercise the code paths but basically produce garbage results. Unless you know what you're doing, I basically don't recommend using this script right now and hope to tune it a bit more in the future.
## Customization
@ -109,15 +110,9 @@ Additionally, to add new abilities to nanochat, see [Guide: counting r in strawb
## Questions
nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so:
I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
```bash
files-to-prompt . -e py -e md -e html -e toml -e sh --cxml > packaged.txt
```
This includes all py, html, toml, sh files and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files.
Alternatively, I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
You can also come to the [#nanochat Discord channel](https://discord.com/channels/1020383067459821711/1427295580895314031) to ask questions, or use the Discussions.
## Tests
@ -137,8 +132,7 @@ python -m pytest tests/test_engine.py -v -s
│ ├── gen_synthetic_data.py # Example synthetic data for identity
│ ├── generate_logo.html
│ ├── nanochat.png
│ ├── repackage_data_reference.py # Pretraining data shard generation
│ └── runcpu.sh # Small example of how to run on CPU/MPS
│ └── repackage_data_reference.py # Pretraining data shard generation
├── nanochat
│ ├── __init__.py # empty
│ ├── adamw.py # Distributed AdamW optimizer
@ -157,7 +151,12 @@ python -m pytest tests/test_engine.py -v -s
│ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4
│ └── ui.html # HTML/CSS/JS for nanochat frontend
├── pyproject.toml
├── run1000.sh # Train the ~$800 nanochat d32
├── runs
│ ├── miniseries.sh # Miniseries training script
│ ├── run1000.sh # Train the ~$800 nanochat d32
│ ├── runcpu.sh # Small example of how to run on CPU/MPS
│ ├── scaling_laws.sh # Scaling laws experiments
│ └── speedrun.sh # Train the ~$100 nanochat d20
├── scripts
│ ├── base_eval.py # Base model: calculate CORE score
│ ├── base_loss.py # Base model: calculate bits per byte, sample
@ -170,7 +169,6 @@ python -m pytest tests/test_engine.py -v -s
│ ├── mid_train.py # Chat model: midtraining
│ ├── tok_eval.py # Tokenizer: evaluate compression rate
│ └── tok_train.py # Tokenizer: train it
├── speedrun.sh # Train the ~$100 nanochat d20
├── tasks
│ ├── arc.py # Multiple choice science questions
│ ├── common.py # TaskMixture | TaskSequence

View File

@ -4,6 +4,283 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
---
## 2026-01-27: Bigram Hash Embeddings (Engram-lite)
Explored N-gram memory modules inspired by the [DeepSeek Engram paper](https://arxiv.org/abs/2601.07372) and [modded-nanogpt PR #201](https://github.com/KellerJordan/modded-nanogpt/pull/201).
### Background
The Engram paper introduces "conditional memory" as a complement to MoE - using O(1) hash lookups to retrieve static N-gram patterns instead of reconstructing them through computation. Key insight: transformers waste early layers "simulating retrieval through computation" for patterns like named entities and formulaic phrases that could be simple table lookups.
### What We Tried
**1. Full Engram module with context-aware gating (paper design)**
```python
# Hash bigrams to retrieve embeddings, then gate with hidden state
e = embed(hash(prev_token, curr_token))
q = RMSNorm(h) # hidden state as query
k = RMSNorm(W_k @ e) # projected embedding as key
v = W_v @ e
α = sigmoid(q · k / √d) # scalar gate per position
output = α * v
```
- Injected after block 1 (paper found early injection optimal)
- Slight improvement, but quite a bit of complexity added.
**2. Early-layer only injection**
- Only inject bigram signal in first 4 layers (where paper claims static pattern offloading helps most)
- **Result:** Actually hurt performance. The model seems to need uniform injection across all layers.
**3. Trigrams**
- Extended to hash both 2-grams and 3-grams, concatenating embeddings
- **Result:** No improvement over bigrams alone. Dilutes capacity from more frequent 2-gram patterns.
**4. Bigram-only with x0-style injection (modded-nanogpt engram-lite approach)**
- Simple hash: `(36313 * curr) XOR (27191 * prev) mod table_size`
- Zero-init embedding table, learned per-layer lambdas
- Add to residual at every layer: `x = resid_λ[i]*x + x0_λ[i]*x0 + bigram_λ[i]*x0_bigram`
- **Result:** This simple approach works and provides a consistent improvement.
TLDR The winning approach follows modded-nanogpt's "engram-lite", simply adding the following module and feeding its output into the residual branch (gated by a per-layer learnable \lambda) before every single block:
```python
class BigramEmbed(nn.Module):
def __init__(self, vocab_size, embed_dim, table_multiplier=5):
self.embed = nn.Embedding(vocab_size * table_multiplier, embed_dim)
def forward(self, idx):
h = (36313 * idx[:, 1:]) ^ (27191 * idx[:, :-1]) % (table_size - 1)
return self.embed(h)
```
As for optimal hyperparameters:
- **Table size:** `vocab_size * 5` (~164K entries for 32K vocab). Swept a number of settings and 5 was optimal.
- **Injection:** Every layer via learned `bigram_lambdas` (init 0.1 was better than 0.0).
- **Normalization:** Also tried adding a `norm()` to the embeddings (mirroring the token embeddings), this was slightly worse.
- **Init:** Zero-init embedding, so starts as identity (tried small noisy init, it's worse)
- **Optimizer:** AdamW with same LR as token embeddings
### Key Learnings
1. **Gating didn't help at our scale.** The paper's context-aware gating mechanism (sigmoid dot-product gate) added parameters and complexity without improvement. modded-nanogpt found the same: "simple direct addition to the residual stream outperformed by a decent margin."
2. **Uniform injection beats early-only.** Despite the paper's finding that early layers benefit most, restricting injection to early layers hurt. The x0-style "add everywhere with learned lambda" pattern works better for our architecture/scale.
3. **Bigrams are sufficient.** Trigrams didn't help - the extra context doesn't pay for the diluted capacity.
4. **Scale matters.** The Engram paper's results are at 27B params with MoE. At our ~100M-1B scale, the simpler approach wins. The elaborate gating mechanism may become useful at larger scales where collision handling matters more.
### Parameters Added
For d12 model with `table_multiplier=5`:
- Bigram embedding: 32768 × 5 × 768 = ~126M params
- Per-layer lambdas: 12 scalars (negligible)
If you're keeping track, we now have *a lot* of parameters, a significant amount of them in embeddings (token embeddings, bigram embeddings, value embeddings). For example, for a d12 we now have:
```
Parameter counts:
wte : 25,165,824
bigram_embed : 125,829,120
value_embeds : 150,994,944
lm_head : 25,165,824
transformer_matrices : 84,935,808
scalars : 36
total : 412,091,556
```
In other words, only about a quarter of parameters are now weight projections and the vast majority is embedding tables.
Still, on all axes (steps, wall clock time, flops), this somewhat parameter-bloated architecture beats the baseline and will now become the default.
After adding the engram-lite, I re-ran the scaling laws to determine the new optimal tokens:params ratio. I swept FLOPs in the range 1e18..1e19, exponentially strided in 4 settings (1e18, 2e18, 5e18, 1e19). I looked at a number of ways of determining the effective parameter count for the purposes of the scaling laws. The results looked like this:
```
Kaplan-style (all projections including lm_head and no embeddings)
Optimal configurations (from quadratic fits):
FLOPs Eff Params Tokens Ratio Val BPB
-----------------------------------------------------------------
1e+18 110,678,115 1,241,505,403 11.2 0.8972
2e+18 167,797,457 1,785,336,422 10.7 0.8616
5e+18 250,650,865 2,642,234,152 10.8 0.8293
1e+19 381,758,347 3,806,871,243 10.3 0.7999
N \propto C^0.54, D \propto C^0.49
Chinchilla-style (all parameters, period.)
Optimal configurations (from quadratic fits):
FLOPs Eff Params Tokens Ratio Val BPB
-----------------------------------------------------------------
1e+18 416,320,605 1,232,157,011 3.0 0.8974
2e+18 560,239,841 1,763,669,281 3.2 0.8616
5e+18 741,495,903 2,629,909,368 3.6 0.8291
1e+19 988,644,331 3,884,841,895 4.0 0.7999
N \propto C^0.37, D \propto C^0.50
Transformer-only-style (only the projections inside the transformer)
Optimal configurations (from quadratic fits):
FLOPs Eff Params Tokens Ratio Val BPB
-----------------------------------------------------------------
1e+18 80,259,665 1,315,639,547 17.2 0.8966
2e+18 131,488,566 1,864,134,141 14.5 0.8622
5e+18 220,985,474 2,595,328,843 12.1 0.8302
1e+19 401,213,504 3,328,704,512 8.5 0.7994
N \propto C^0.70, D \propto C^0.41
```
Clearly, the Kaplan-style ratios are most consistent and produce stable ~0.5 exponents for both params and tokens, meaning we can have a single fixed ratio of tokens:params for compute optimal models. This turns out to be about ~10.5, which now becomes the new default.
---
## 2026-01-19 to 2026-01-22: Optimizer Hyperparameter Sweep
Ran ~320 experiments across 6 rounds, scaling from d12→d16→d20 to find optimal optimizer hyperparameters. Added granular per-component control to `setup_optimizers()` — separate LRs and betas for embedding, unembedding, value_embeds, resid_lambdas, x0_lambdas, and Muon matrix params.
### What We Swept
- Learning rates for all 6 parameter groups
- Beta1/beta2 for all 5 AdamW groups
- Muon momentum (start/end), weight decay
- Hundreds of combinations (2-way, 3-way, 4-way, etc.)
### The Journey
**At d12**, found two independent improvement routes:
- **Route A:** emb_lr↑ (0.3→0.4), weight_decay↑ (0.1→0.15), matrix_lr↑ (0.02→0.025)
- **Route B:** x0_lr↓ (0.5→0.2), x0_beta1↑ (0.8→0.9+)
Both gave ~0.002 improvement, but combining them caused conflicts. Fine-tuning found wd=0.13, matrix_lr=0.027, emb_lr=0.38 helped slightly. Best d12 config: Route A + x0_beta1=0.95.
**At d16**, Route B became competitive with Route A. The routes still conflicted when combined.
**At d20** (target scale), everything changed:
- Fine-tuned values from d12 **actively hurt** performance
- Routes no longer conflicted
- Just `x0_beta1=0.96` alone captured nearly all the gains
### Final x0_beta1 Sweep at d20
| x0_beta1 | val/bpb | Δ vs baseline |
|----------|---------|---------------|
| **0.96** | **0.7971** | **-0.0007** |
| 0.94 | 0.7972 | -0.0006 |
| 0.90 | 0.7972 | -0.0006 |
| 0.97 | 0.7977 | -0.0001 |
| 0.98 | 0.8011 | +0.0033 💀 |
Flat plateau from 0.90-0.96, then sharp cliff at 0.97+.
### Key Learnings
1. **Hyperparameters are scale-dependent.** What works at d12 doesn't transfer to d20. The elaborate fine-tuning that won at d12 actively hurts at d20.
2. **Improvement magnitude shrinks with scale.** ~0.002 at d12 → ~0.0007 at d20. The baseline is already better-tuned for larger models.
3. **Sharp cliffs exist.** x0_beta1=0.98 is catastrophic while 0.96 is optimal.
4. **Don't over-tune on small proxies.** Validate at target scale before shipping.
### Final Recommendation
For production d20 runs, add one flag:
```
--x0-lambdas-beta1=0.96
```
Skip everything else discovered at smaller scales.
---
## 2026-01-18: More various experiments
- Tried Muon custom kernels for XXT and all the others. The improvement was there for targeted tests (~20%) but washed out completely to noise in an actual training run, especially because the Muon compute is split across all the workers. Abandoned due to complexity bloat.
- Fuse Q,K,V,O nn.Linear layers into a single QKVO Linear layer. ~Zero impact
- Tried the `sa_lambdas` that gate QKV and O. Slightly confused because of the use of rmsnorm, which erases the effect of any scalar multiplier. Helped a tiny bit (~1e-4 of loss), abandoned to control complexity.
---
## 2026-01-17: Various experiments
Modded-nanogpt uses [Value Embeddings](https://arxiv.org/abs/2410.17897) (VEs) in a funny U-shaped structure, 3 of them in total and with gates. I tried a large number of tweaks on this today:
- VEs at every layer, at alternating layers, U shaped, front and back. Alternating layers worked best, i.e. we end up with *a lot* more VEs than modded-nanogpt, at every other layer. It works better.
- Many parameters sharing ideas to reduce new parameter count, nothing here worked. All failed.
- Many ideas to reduce parameter count, the LLM hates all of them: low rank decompositions, projections. All failed.
- Gated yes or no and how much. Gate helps.
Long story short is that the models *love* Value Embeddings. It is a way to add a huge amount of capacity (parameters) to the model at almost zero cost of FLOPs, because these embeddings are simply added to the Values tensor. Any attempt to reduce the capacity of value embeddings (param sharing, low rank, projections) fail. The model wants many of them, and with all the capacity, and doing so wins across all x axes of steps, flops and wall clock. I re-ran the scaling laws and, because the models are now very parameter bloated, the optimal ratio has halved from 8 to 4! Way down lower than Chinchilla's 20 at this point.
Other experiments, looking at val/bpb as a function of all of steps, flops and wall clock time:
- Aspect ratio of 128 is worse than 64, I tried a sweep fixing FLOPs == 1e18 and 64 outperforms. The LLM prefers to be slightly thinner and longer.
- Head dim definitely prefers to be 128 instead of 64, i.e. fewer bigger heads
- Bunch of other random stuff like that.
Keeping all of this work on a private branch for now but hope to push shortly.
---
## 2026-01-17: Modded-nanogpt Ideas Sweep (Continued)
Continued testing ideas from modded-nanogpt.
| Idea | Result | Notes |
|------|--------|-------|
| Attention gates | No improvement | Per-head learnable gates on attention output. +1GB memory, decreased efficiency. |
| Batch size schedule | Abandoned | 8→16→24 with LR scaling. Made training script too bloated/complex, not worth cognitive overhead. |
| Value embeddings | Helps a lot | Experiments still ongoing, more on this later. |
---
## 2026-01-16: Flash Attention 3 Fallback to SDPA
Added automatic fallback from Flash Attention 3 to PyTorch's `scaled_dot_product_attention` (SDPA) for users without Hopper GPUs. This enables nanochat to run on older CUDA GPUs, CPU, and MPS (Apple Silicon).
### Implementation
Created `nanochat/flash_attention.py` - a unified interface that:
- Detects FA3 availability at import time (requires sm90+ / Hopper)
- Exports a `flash_attn` object matching FA3's API exactly (`flash_attn.flash_attn_func`, `flash_attn.flash_attn_with_kvcache`)
- Automatically routes to FA3 or SDPA based on hardware
- Handles tensor layout differences: FA3 uses (B, T, H, D), SDPA uses (B, H, T, D)
- Implements sliding window attention via explicit masks for SDPA
- Manages KV cache manually for SDPA (FA3 does it in-place)
### Changes to Existing Files
Changes to existing code were intentionally kept extremely minimal.
**gpt.py**: Only the import line changed and a comment
**engine.py**: Zero changes needed
**base_train.py**: Added status print and warnings:
- Prints whether FA3 or SDPA fallback is being used
- Warns about efficiency loss without FA3
- Warns about sliding window support if `--window-pattern` is not "L"
### Testing
Tests are split into two classes due to dtype/device constraints:
1. **TestFA3VsSDPA**: Comparison tests requiring Hopper GPU + bfloat16. Run both implementations on identical inputs and verify outputs match (max diff typically 0, at most ~0.004 for sliding window).
2. **TestSDPAOnly**: SDPA-only tests that run on any device with appropriate dtype. Verify forward pass, backward pass, and KV cache work correctly.
Added `_override_impl` mechanism for testing - can force 'fa3' or 'sdpa' to directly compare implementations.
### Notes
- SDPA fallback is significantly slower than FA3 especially in that it lacks the sliding window attention support
- Recommend `--window-pattern L` (full context) when using SDPA fallback
---
## 2026-01-16: Modded-nanogpt Ideas Sweep (Mostly Negative)
Tested several architectural ideas from modded-nanogpt to see if they transfer to nanochat. All of these did not help:

View File

@ -1,74 +0,0 @@
#!/bin/bash
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
# Run as:
# bash dev/cpu_demo_run.sh
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
# Think of this run as educational/fun demo, not something you should expect to work well.
# This is also why I hide this script away in dev/
# all the setup stuff
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
[ -d ".venv" ] || uv venv
uv sync --extra cpu
source .venv/bin/activate
if [ -z "$WANDB_RUN" ]; then
WANDB_RUN=dummy
fi
# wipe the report
python -m nanochat.report reset
# train tokenizer on ~1B characters
python -m nanochat.dataset -n 4
python -m scripts.tok_train --max-chars=1000000000
python -m scripts.tok_eval
# train a very small 4 layer model on the CPU
# each optimization step processes a single sequence of 1024 tokens
# we only run 50 steps of optimization (bump this to get better results)
python -m scripts.base_train \
--depth=4 \
--max-seq-len=1024 \
--device-batch-size=1 \
--total-batch-size=1024 \
--eval-every=50 \
--eval-tokens=4096 \
--core-metric-every=50 \
--core-metric-max-per-task=12 \
--sample-every=50 \
--num-iterations=50
python -m scripts.base_loss --device-batch-size=1 --split-tokens=4096
python -m scripts.base_eval --max-per-task=16
# midtraining
python -m scripts.mid_train \
--max-seq-len=1024 \
--device-batch-size=1 \
--eval-every=50 \
--eval-tokens=4096 \
--total-batch-size=1024 \
--num-iterations=100
# eval results will be terrible, this is just to execute the code paths.
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
# SFT
python -m scripts.chat_sft \
--device-batch-size=1 \
--target-examples-per-step=4 \
--num-iterations=100 \
--eval-steps=4 \
--eval-metrics-max-problems=16
# Chat CLI
# python -m scripts.chat_cli -p "Why is the sky blue?"
# Chat Web
# python -m scripts.chat_web
python -m nanochat.report generate

View File

@ -15,14 +15,16 @@
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import os\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Load results\n",
"tag = \"jan26\"\n",
"base_dir = os.environ.get('NANOCHAT_BASE_DIR', os.path.expanduser('~/.cache/nanochat'))\n",
"results_path = os.path.join(base_dir, 'scaling_laws_results', 'results.csv')\n",
"results_path = os.path.join(base_dir, f'scaling_laws_results_{tag}', 'results.csv')\n",
"\n",
"df = pd.read_csv(results_path)\n",
"flops_budgets = sorted(df['flops_budget'].unique())\n",
@ -31,6 +33,99 @@
"df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# FILTERING: Remove incomplete or problematic runs\n",
"# =============================================================================\n",
"\n",
"print(f\"Before filtering: {len(df)} runs\")\n",
"\n",
"# Filter out runs with missing/invalid val_bpb (incomplete runs)\n",
"df = df[df['val_bpb'].notna() & (df['val_bpb'] > 0)]\n",
"\n",
"# Optional: exclude specific flops budgets that aren't done yet\n",
"# exclude_flops = [1e19] # <-- adjust as runs complete\n",
"# df = df[~df['flops_budget'].isin(exclude_flops)]\n",
"\n",
"# Optional: exclude specific depths\n",
"# exclude_depths = [18, 20]\n",
"# df = df[~df['depth'].isin(exclude_depths)]\n",
"\n",
"print(f\"After filtering: {len(df)} runs\")\n",
"print(f\"FLOPs budgets: {sorted(df['flops_budget'].unique())}\")\n",
"print(f\"Depths: {sorted(df['depth'].unique())}\")\n",
"\n",
"# Update flops_budgets list after filtering\n",
"flops_budgets = sorted(df['flops_budget'].unique())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Effective Parameter Count\n",
"\n",
"Different scaling law papers use different conventions for counting parameters:\n",
"- **Kaplan et al.** excluded embedding parameters (claimed cleaner laws)\n",
"- **Chinchilla** included all parameters (and noted Kaplan had a bug)\n",
"\n",
"Our CSV now has granular counts:\n",
"- `params_wte` - token embedding (lookup table)\n",
"- `params_bigram_embed` - bigram hash embeddings (lookup table)\n",
"- `params_value_embeds` - value embeddings (lookup table)\n",
"- `params_lm_head` - unembedding projection (matmul)\n",
"- `params_transformer` - attention + MLP matrices (matmuls)\n",
"- `params_scalars` - resid/x0/bigram lambdas (tiny)\n",
"\n",
"**Experiment below** with different combinations to see which gives the cleanest scaling laws."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# EXPERIMENT HERE: Define which parameters to count for scaling laws\n",
"# =============================================================================\n",
"\n",
"def compute_effective_params(row):\n",
" \"\"\"\n",
" Compute the 'effective' parameter count for scaling law analysis.\n",
"\n",
" Modify this function to experiment with different conventions:\n",
" - Chinchilla-style: include everything\n",
" - Kaplan-style: exclude embeddings\n",
" - Matmul-only: just transformer + lm_head (the actual compute)\n",
" - etc.\n",
" \"\"\"\n",
" # Option 1: Chinchilla-style (all params)\n",
" # return row['params_total']\n",
"\n",
" # Option 2: Kaplan-style (exclude embeddings)\n",
" return row['params_transformer'] + row['params_lm_head']\n",
"\n",
" # Option 3: Transformer-only (exclude all embeddings AND lm_head)\n",
" # return row['params_transformer']\n",
"\n",
"\n",
"# Compute derived columns\n",
"df['effective_params'] = df.apply(compute_effective_params, axis=1)\n",
"df['param_data_ratio'] = df['tokens_trained'] / df['effective_params']\n",
"\n",
"# Show parameter breakdown for first few rows\n",
"print(\"Parameter breakdown (first row per flops budget):\")\n",
"param_cols = ['depth', 'params_wte', 'params_bigram_embed', 'params_value_embeds',\n",
" 'params_lm_head', 'params_transformer', 'params_scalars', 'params_total', 'effective_params']\n",
"df.groupby('flops_budget').first()[param_cols]"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -54,11 +149,11 @@
"optimal_by_bpb = []\n",
"\n",
"for flops, color in zip(flops_budgets, colors):\n",
" subset = df[df['flops_budget'] == flops].sort_values('num_scaling_params')\n",
" ax.plot(subset['num_scaling_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n",
" subset = df[df['flops_budget'] == flops].sort_values('effective_params')\n",
" ax.plot(subset['effective_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n",
"\n",
" # Fit quadratic in log-space: val_bpb = a*(log N)^2 + b*(log N) + c\n",
" log_params = np.log10(subset['num_scaling_params'])\n",
" log_params = np.log10(subset['effective_params'])\n",
" coeffs = np.polyfit(log_params, subset['val_bpb'], 2)\n",
" a, b, c = coeffs\n",
"\n",
@ -83,13 +178,13 @@
" # Fallback to raw minimum if quadratic doesn't have minimum\n",
" best_idx = subset['val_bpb'].idxmin()\n",
" best = subset.loc[best_idx]\n",
" ax.scatter([best['num_scaling_params']], [best['val_bpb']], s=150, color=color,\n",
" ax.scatter([best['effective_params']], [best['val_bpb']], s=150, color=color,\n",
" zorder=5, edgecolors='black', linewidths=2)\n",
" optimal_by_bpb.append({'flops': flops, 'params': best['num_scaling_params'],\n",
" optimal_by_bpb.append({'flops': flops, 'params': best['effective_params'],\n",
" 'tokens': best['tokens_trained'], 'ratio': best['param_data_ratio'], 'bpb': best['val_bpb']})\n",
"\n",
"ax.set_xscale('log')\n",
"ax.set_xlabel('Parameters')\n",
"ax.set_xlabel('Effective Parameters')\n",
"ax.set_ylabel('Validation Loss (bpb)')\n",
"ax.set_title('IsoFLOP Curves')\n",
"ax.legend(title='FLOPs', loc='upper right')\n",
@ -138,10 +233,61 @@
"\n",
"# Print the optimal points (from quadratic fits)\n",
"print(\"\\nOptimal configurations (from quadratic fits):\")\n",
"print(f\"{'FLOPs':<12} {'Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n",
"print(f\"{'FLOPs':<12} {'Eff Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n",
"print(\"-\" * 65)\n",
"for _, row in opt_df.iterrows():\n",
" print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")\n"
" print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# Optimal Ratio Summary (from power law fits)\n",
"# =============================================================================\n",
"\n",
"# From the power law fits: N ∝ C^a and D ∝ C^b\n",
"# The ratio D/N ∝ C^(b-a). If a ≈ b, ratio is roughly constant.\n",
"\n",
"if len(opt_df) >= 2:\n",
" log_f = np.log10(opt_df['flops'])\n",
" log_p = np.log10(opt_df['params'])\n",
" log_t = np.log10(opt_df['tokens'])\n",
"\n",
" # Fit power laws\n",
" slope_n, intercept_n = np.polyfit(log_f, log_p, 1)\n",
" slope_d, intercept_d = np.polyfit(log_f, log_t, 1)\n",
"\n",
" # The ratio D/N at a reference compute (geometric mean of our budgets)\n",
" ref_flops = np.sqrt(opt_df['flops'].min() * opt_df['flops'].max())\n",
" log_ref = np.log10(ref_flops)\n",
"\n",
" # Predicted optimal N and D at reference compute\n",
" pred_log_n = intercept_n + slope_n * log_ref\n",
" pred_log_d = intercept_d + slope_d * log_ref\n",
" optimal_ratio = 10**(pred_log_d - pred_log_n)\n",
"\n",
" # Also compute from the fitted optimals directly (mean and std)\n",
" mean_ratio = opt_df['ratio'].mean()\n",
" std_ratio = opt_df['ratio'].std()\n",
"\n",
" print(\"=\" * 60)\n",
" print(\"OPTIMAL RATIO SUMMARY\")\n",
" print(\"=\" * 60)\n",
" print(f\"\\nPower law exponents:\")\n",
" print(f\" N ∝ C^{slope_n:.3f}\")\n",
" print(f\" D ∝ C^{slope_d:.3f}\")\n",
" print(f\" Ratio exponent (b-a): {slope_d - slope_n:.3f} (should be ~0 if ratio is constant)\")\n",
" print(f\"\\nOptimal ratio (tokens per effective param):\")\n",
" print(f\" From power law at C={ref_flops:.1e}: {optimal_ratio:.1f}\")\n",
" print(f\" Mean across budgets: {mean_ratio:.1f} ± {std_ratio:.1f}\")\n",
" print(f\" Chinchilla reference: 20\")\n",
" print(f\"\\nPer-budget ratios: {[f'{r:.1f}' for r in opt_df['ratio'].values]}\")\n",
"else:\n",
" print(\"Need at least 2 flops budgets to compute power law fits\")"
]
},
{

View File

@ -200,3 +200,77 @@ class DummyWandb:
pass
def finish(self):
pass
# hardcoded BF16 peak flops for various GPUs
# inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py
# and PR: https://github.com/karpathy/nanochat/pull/147
def get_peak_flops(device_name: str) -> float:
name = device_name.lower()
# --- NVIDIA Blackwell ---
if "gb200" in name or "grace blackwell" in name:
return 2.5e15
if "b200" in name:
return 2.25e15
if "b100" in name:
return 1.8e15
# --- NVIDIA Hopper (H100/H200/H800) ---
if "h200" in name:
if "nvl" in name or "pcie" in name:
return 836e12
return 989e12 # H200 SXM
if "h100" in name:
if "nvl" in name:
return 835e12
if "pcie" in name:
return 756e12
return 989e12 # H100 SXM
if "h800" in name:
if "nvl" in name:
return 989e12
return 756e12 # H800 PCIe
# --- NVIDIA Ampere data center ---
if "a100" in name or "a800" in name:
return 312e12
if "a40" in name:
return 149.7e12
if "a30" in name:
return 165e12
# --- NVIDIA Ada data center ---
if "l40s" in name or "l40-s" in name or "l40 s" in name:
return 362e12
if "l4" in name:
return 121e12
# --- AMD CDNA accelerators ---
if "mi355" in name:
return 2.5e15
if "mi325" in name or "mi300x" in name:
return 1.3074e15
if "mi300a" in name:
return 980.6e12
if "mi250x" in name:
return 383e12
if "mi250" in name:
return 362.1e12
# --- Intel ---
if "data center gpu max 1550" in name:
# Ponte Vecchio (PVC) - dynamic based on compute units
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
return 512 * max_comp_units * 1300 * 10**6
# --- Consumer RTX (for hobbyists) ---
if "5090" in name:
return 209.5e12
if "4090" in name:
return 165.2e12
if "3090" in name:
return 71e12
# Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess
logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%")
return float('inf')

View File

@ -178,8 +178,9 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
doc = doc_buffer.pop(best_idx)
row.extend(doc)
else:
# No doc fits - crop first doc to fill remaining
doc = doc_buffer.pop(0)
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
doc = doc_buffer.pop(shortest_idx)
row.extend(doc[:remaining])
rows.append(row[:row_capacity])

View File

@ -90,7 +90,7 @@ class KVCache:
- Position tracked per batch element via cache_seqlens tensor
"""
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype=torch.bfloat16):
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
self.batch_size = batch_size
self.max_seq_len = seq_len
self.n_layers = num_layers
@ -172,6 +172,13 @@ class Engine:
"""Same as generate, but does single prefill and then clones the KV cache."""
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
device = self.model.get_device()
# NOTE: setting the dtype here and in this way is an ugly hack.
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
# We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.
# As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
# I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
# In particular, the KVCache should allocate its tensors lazily
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
rng = torch.Generator(device=device)
rng.manual_seed(seed)
@ -191,6 +198,7 @@ class Engine:
batch_size=1,
seq_len=len(tokens),
device=device,
dtype=dtype,
**kv_model_kwargs,
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
@ -203,6 +211,7 @@ class Engine:
batch_size=num_samples,
seq_len=kv_length_hint,
device=device,
dtype=dtype,
**kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)
@ -297,8 +306,8 @@ if __name__ == "__main__":
"""
import time
# init compute
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# load the model and tokenizer

178
nanochat/flash_attention.py Normal file
View File

@ -0,0 +1,178 @@
"""
Unified Flash Attention interface with automatic FA3/SDPA switching.
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
to PyTorch SDPA on non-Hopper GPUs, MPS, and CPU.
Usage (drop-in replacement for FA3):
from nanochat.flash_attention import flash_attn
# Training (no KV cache)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
# Inference (with KV cache)
y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...)
"""
import torch
import torch.nn.functional as F
# =============================================================================
# Detection: Try to load FA3 on Hopper+ GPUs
# =============================================================================
def _load_flash_attention_3():
"""Try to load Flash Attention 3 (requires Hopper+ GPU)."""
if not torch.cuda.is_available():
return None
try:
major, _ = torch.cuda.get_device_capability()
if major < 9: # Hopper is sm90
return None
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
from kernels import get_kernel
return get_kernel('varunneal/flash-attention-3').flash_attn_interface
except Exception:
return None
_fa3 = _load_flash_attention_3()
HAS_FA3 = _fa3 is not None
# Override for testing: set to 'fa3', 'sdpa', or None (auto)
_override_impl = None
def _use_fa3():
"""Determine whether to use FA3 based on availability and override."""
if _override_impl == 'fa3':
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
return True
if _override_impl == 'sdpa':
return False
return HAS_FA3 # auto
# =============================================================================
# SDPA helpers
# =============================================================================
def _sdpa_attention(q, k, v, window_size, enable_gqa):
"""
SDPA attention with sliding window support.
q, k, v are (B, H, T, D) format.
"""
Tq = q.size(2)
Tk = k.size(2)
window = window_size[0]
# Full context, same length
if (window < 0 or window >= Tq) and Tq == Tk:
return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
# Single token generation
if Tq == 1:
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
# Need explicit mask
device = q.device
if Tq == Tk:
# Causal + sliding window
mask = torch.tril(torch.ones(Tq, Tk, device=device, dtype=torch.bool))
if window > 0 and window < Tq:
row_idx = torch.arange(Tq, device=device).unsqueeze(1)
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
mask = mask & ((row_idx - col_idx) <= window)
else:
# Chunk inference: attend to prefix + causal within chunk
prefix_len = Tk - Tq
mask = torch.zeros(Tq, Tk, device=device, dtype=torch.bool)
mask[:, :prefix_len] = True
mask[:, prefix_len:] = torch.tril(torch.ones(Tq, Tq, device=device, dtype=torch.bool))
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
# =============================================================================
# Public API: Same interface as FA3
# =============================================================================
def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
"""
Flash Attention for training (no KV cache).
Args:
q, k, v: Tensors of shape (B, T, H, D)
causal: Whether to use causal masking
window_size: (left, right) sliding window. -1 means unlimited.
Returns:
Output tensor of shape (B, T, H, D)
"""
if _use_fa3():
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
enable_gqa = q.size(1) != k.size(1)
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
return y.transpose(1, 2) # back to (B, T, H, D)
def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
causal=False, window_size=(-1, -1)):
"""
Flash Attention with KV cache for inference.
FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same.
Args:
q: Queries, shape (B, T_new, H, D)
k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D)
k, v: New keys/values to insert, shape (B, T_new, H_kv, D)
cache_seqlens: Current position in cache, shape (B,) int32
causal: Whether to use causal masking
window_size: (left, right) sliding window. -1 means unlimited.
Returns:
Output tensor of shape (B, T_new, H, D)
"""
if _use_fa3():
return _fa3.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
causal=causal, window_size=window_size
)
# SDPA fallback: manually manage KV cache
B, T_new, H, D = q.shape
pos = cache_seqlens[0].item() # assume uniform position across batch
# Insert new k, v into cache (in-place, matching FA3 behavior)
if k is not None and v is not None:
k_cache[:, pos:pos+T_new, :, :] = k
v_cache[:, pos:pos+T_new, :, :] = v
# Get full cache up to current position + new tokens
end_pos = pos + T_new
k_full = k_cache[:, :end_pos, :, :]
v_full = v_cache[:, :end_pos, :, :]
# Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
q_sdpa = q.transpose(1, 2)
k_sdpa = k_full.transpose(1, 2)
v_sdpa = v_full.transpose(1, 2)
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
# =============================================================================
# Export: flash_attn module interface (drop-in replacement for FA3)
# =============================================================================
from types import SimpleNamespace
flash_attn = SimpleNamespace(
flash_attn_func=flash_attn_func,
flash_attn_with_kvcache=flash_attn_with_kvcache,
)

View File

@ -23,18 +23,13 @@ from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
# Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar)
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain.
# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal)
from kernels import get_kernel
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
from nanochat.flash_attention import flash_attn
@dataclass
class GPTConfig:
sequence_len: int = 1024
vocab_size: int = 50304
sequence_len: int = 2048
vocab_size: int = 32768
n_layer: int = 12
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (GQA)
@ -42,7 +37,7 @@ class GPTConfig:
# Sliding window attention pattern string, tiled across layers. Final layer always L.
# Characters: L=long (full context), S=short (half context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
window_pattern: str = "L"
window_pattern: str = "SSSL"
def norm(x):
@ -50,6 +45,45 @@ def norm(x):
return F.rms_norm(x, (x.size(-1),))
class BigramEmbed(nn.Module):
"""
Hash bigrams to embeddings. Simple, self-contained, runs on GPU.
Following modded-nanogpt's approach: single hash, no gating.
For each position t, hashes (token[t-1], token[t]) to an index in a large
embedding table. This provides O(1) lookup for local 2-gram patterns,
offloading static pattern reconstruction from the transformer layers.
Ref: https://github.com/KellerJordan/modded-nanogpt/pull/201
Ref: https://arxiv.org/abs/1709.03933 (Hash Embeddings)
"""
def __init__(self, vocab_size: int, embed_dim: int, table_multiplier: int = 5):
super().__init__()
self.bigram_vocab_size = vocab_size * table_multiplier
self.embed = nn.Embedding(self.bigram_vocab_size, embed_dim)
def forward(self, idx: torch.Tensor) -> torch.Tensor:
"""
idx: (B, T) token ids
Returns: (B, T, embed_dim) bigram embeddings
"""
# Hash (prev_token, curr_token) -> index
# Position 0 gets a reserved index (no valid bigram)
rand_int_1 = 36313
rand_int_2 = 27191
mod = self.bigram_vocab_size - 1
h = torch.empty_like(idx, dtype=torch.long)
h[:, 0] = mod # reserved index for position 0
h[:, 1:] = (rand_int_1 * idx[:, 1:] ^ rand_int_2 * idx[:, :-1]) % mod
return self.embed(h)
def has_ve(layer_idx, n_layer):
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
return layer_idx % 2 == (n_layer - 1) % 2
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
@ -72,8 +106,10 @@ class CausalSelfAttention(nn.Module):
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.ve_gate_channels = 32
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
def forward(self, x, cos_sin, window_size, kv_cache):
def forward(self, x, ve, cos_sin, window_size, kv_cache):
B, T, C = x.size()
# Project the input to get queries, keys, and values
@ -82,13 +118,18 @@ class CausalSelfAttention(nn.Module):
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
if ve is not None:
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 2)
v = v + gate.unsqueeze(-1) * ve
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k) # QK norm
# Attention with Flash Attention 3
# FA3 handles GQA automatically when n_kv_heads < n_heads
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
if kv_cache is None:
# Training: causal attention with optional sliding window
@ -132,8 +173,8 @@ class Block(nn.Module):
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(self, x, cos_sin, window_size, kv_cache):
x = x + self.attn(norm(x), cos_sin, window_size, kv_cache)
def forward(self, x, ve, cos_sin, window_size, kv_cache):
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
x = x + self.mlp(norm(x))
return x
@ -163,9 +204,17 @@ class GPT(nn.Module):
# Per-layer learnable scalars (inspired by modded-nanogpt)
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
# bigram_lambdas: blends bigram embeddings in at each layer (init 0.1 = small contribution)
# Separate parameters so they can have different optimizer treatment
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
self.bigram_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
# Bigram hash embeddings: O(1) lookup for local 2-gram patterns
self.bigram_embed = BigramEmbed(config.vocab_size, config.n_embd)
# Value embeddings (ResFormer-style): alternating layers, last layer always included
head_dim = config.n_embd // config.n_head
kv_dim = config.n_kv_head * head_dim
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
@ -176,6 +225,7 @@ class GPT(nn.Module):
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)
@torch.no_grad()
def init_weights(self):
"""
Initialize the full model in this one function for maximum clarity.
@ -207,18 +257,33 @@ class GPT(nn.Module):
torch.nn.init.zeros_(block.mlp.c_proj.weight)
# Per-layer scalars
with torch.no_grad():
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
self.bigram_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to bigram embeddings
# Bigram embeddings: zero init so it starts as identity
nn.init.zeros_(self.bigram_embed.embed.weight)
# Value embeddings (init like c_v: uniform with same std)
for ve in self.value_embeds.values():
torch.nn.init.uniform_(ve.weight, -s, s)
# Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral)
for block in self.transformer.h:
if block.attn.ve_gate is not None:
torch.nn.init.zeros_(block.attn.ve_gate.weight)
# Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
# Cast token embeddings to bf16: optimizer can tolerate it and it saves memory
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory
if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16)
for ve in self.value_embeds.values():
ve.to(dtype=torch.bfloat16)
self.bigram_embed.to(dtype=torch.bfloat16)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently
@ -283,7 +348,10 @@ class GPT(nn.Module):
"""
nparams = sum(p.numel() for p in self.parameters())
# Exclude non-matmul params: embeddings and per-layer scalars
nparams_exclude = self.transformer.wte.weight.numel() + self.resid_lambdas.numel() + self.x0_lambdas.numel()
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
bigram_embed_numel = self.bigram_embed.embed.weight.numel()
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + bigram_embed_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel())
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window
attn_flops = 0
@ -296,26 +364,48 @@ class GPT(nn.Module):
def num_scaling_params(self):
"""
Return all of the parameters, same as Chinchilla paper.
Kaplan et al. did not include embedding parameters and said that this led to cleaner scaling laws.
But Kaplan et al. also had a bug in their results (as pointed out by Chinchilla).
My own experiments in nanochat confirm the Chinchilla approach gives the much cleaner scaling law.
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper <- good).
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper <- bad)
Return detailed parameter counts for scaling law analysis.
Different papers use different conventions:
- Kaplan et al. excluded embedding parameters
- Chinchilla included all parameters
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper)
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper)
Returns a dict with counts for each parameter group, so downstream analysis
can experiment with which combination gives the cleanest scaling laws.
"""
nparams = sum(p.numel() for p in self.parameters())
return nparams
# Count each group separately (mirrors the grouping in setup_optimizers)
wte = sum(p.numel() for p in self.transformer.wte.parameters())
bigram_embed = sum(p.numel() for p in self.bigram_embed.parameters())
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
lm_head = sum(p.numel() for p in self.lm_head.parameters())
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.bigram_lambdas.numel()
total = wte + bigram_embed + value_embeds + lm_head + transformer_matrices + scalars
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
return {
'wte': wte,
'bigram_embed': bigram_embed,
'value_embeds': value_embeds,
'lm_head': lm_head,
'transformer_matrices': transformer_matrices,
'scalars': scalars,
'total': total,
}
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info()
# Separate out all parameters into 5 groups (matrix, embedding, lm_head, resid_lambdas, x0_lambdas)
# Separate out all parameters into groups
matrix_params = list(self.transformer.h.parameters())
value_embeds_params = list(self.value_embeds.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(resid_params) + len(x0_params)
bigram_embed_params = list(self.bigram_embed.parameters())
bigram_lambda_params = [self.bigram_lambdas]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(bigram_embed_params) + len(bigram_lambda_params)
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
@ -323,8 +413,11 @@ class GPT(nn.Module):
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
dict(params=bigram_embed_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
dict(params=x0_params, lr=scalar_lr),
dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars
dict(params=bigram_lambda_params, lr=scalar_lr, betas=(0.96, 0.95)), # same treatment as x0 lambdas
]
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
@ -352,12 +445,14 @@ class GPT(nn.Module):
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
# Forward the trunk of the Transformer
x = self.transformer.wte(idx)
x = self.transformer.wte(idx) # embed current token
x0_bigram = self.bigram_embed(idx) # embed current bigram (via hash lookup)
x = norm(x)
x0 = x # save initial normalized embedding for x0 residual
for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
x = block(x, cos_sin, self.window_sizes[i], kv_cache)
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 + self.bigram_lambdas[i] * x0_bigram
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
x = norm(x)
# Forward the lm_head (compute logits)

View File

@ -61,7 +61,6 @@ for d in "${DEPTHS[@]}"; do
# No --target-flops, let it use the default ratio from base_train
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
--depth=$d \
--target-param-data-ratio=8 \
--run="${WANDB_RUN}_d${d}" \
--model-tag="${TAG}" \
--core-metric-every=999999 \

70
runs/runcpu.sh Executable file
View File

@ -0,0 +1,70 @@
#!/bin/bash
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
# This script was last updated/tuned on Jan 17, 2026.
# Run as:
# bash dev/cpu_demo_run.sh
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
# Think of this run as educational/fun demo, not something you should expect to work well.
# (This is why I hide this script away in dev/)
# You may also want to run this script manually and one by one, copy pasting commands into your terminal.
# all the setup stuff
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
[ -d ".venv" ] || uv venv
uv sync --extra cpu
source .venv/bin/activate
if [ -z "$WANDB_RUN" ]; then
WANDB_RUN=dummy
fi
# train tokenizer on ~2B characters (~34 seconds on my MacBook Pro M3 Max)
python -m nanochat.dataset -n 8
python -m scripts.tok_train --max-chars=2000000000
python -m scripts.tok_eval
# train a small 4 layer model
# I tuned this run to complete in about 30 minutes on my MacBook Pro M3 Max.
# To get better results, try increasing num_iterations, or get other ideas from your favorite LLM.
python -m scripts.base_train \
--depth=6 \
--head-dim=64 \
--window-pattern=L \
--max-seq-len=512 \
--device-batch-size=32 \
--total-batch-size=16384 \
--eval-every=100 \
--eval-tokens=524288 \
--core-metric-every=-1 \
--sample-every=100 \
--num-iterations=5000 \
--run=$WANDB_RUN
python -m scripts.base_loss --device-batch-size=1 --split-tokens=16384
python -m scripts.base_eval --max-per-task=16
# midtraining (~10 minutes on my MacBook Pro M3 Max)
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
python -m scripts.mid_train \
--max-seq-len=512 \
--device-batch-size=32 \
--total-batch-size=16384 \
--eval-every=200 \
--eval-tokens=524288 \
--num-iterations=1500 \
--run=$WANDB_RUN
# (it's ~ok to skip SFT)
# Chat with the model over CLI
# The model should be able to say that it is Paris.
# It might even know that the color of the sky is blue.
# Sometimes the model likes it if you first say Hi before you ask it questions.
# python -m scripts.chat_cli -i mid -p "What is the capital of France?"
# Chat with the model over a pretty WebUI ChatGPT style
# python -m scripts.chat_web -i mid

View File

@ -1,26 +1,30 @@
#!/bin/bash
LABEL="jan26"
FLOPS_BUDGETS=(
1e18
3e18
6e18
2.15e18
4.64e18
1e19
)
DEPTHS=(8 10 12 14 16 18 20)
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
WANDB_RUN="${WANDB_RUN:-scaling}"
WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}"
EVAL_TOKENS=$((100 * 524288)) # ~100M tokens for final eval (default is ~10M)
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat}"
source .venv/bin/activate
RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results"
RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results_${LABEL}"
mkdir -p "$RESULTS_DIR"
RESULTS_FILE="$RESULTS_DIR/results.csv"
# Write CSV header only if file doesn't exist
if [ ! -f "$RESULTS_FILE" ]; then
echo "flops_budget,depth,model_dim,num_params,num_scaling_params,num_iterations,tokens_trained,param_data_ratio,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
echo "flops_budget,depth,model_dim,params_wte,params_bigram_embed,params_value_embeds,params_lm_head,params_transformer,params_scalars,params_total,num_iterations,tokens_trained,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
fi
log() {
@ -80,13 +84,19 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
# Extract training stats from the log
LOG_FILE="$RESULTS_DIR/${TAG}_train.log"
NUM_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | head -1 | tr -d ',')
NUM_SCALING_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP 'scaling: [\d,]+' | grep -oP '[\d,]+' | tr -d ',')
# Extract detailed parameter counts (for scaling law analysis with different conventions)
PARAMS_WTE=$(grep "wte:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_BIGRAM=$(grep "bigram_embed:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_VE=$(grep "value_embeds:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_LM=$(grep "lm_head:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_TRANSFORMER=$(grep "transformer_matrices:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_SCALARS=$(grep "scalars:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
PARAMS_TOTAL=$(grep "total:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',')
# Calculate tokens trained (iterations * batch_size, default 524288)
TOKENS_TRAINED=$((NUM_ITERS * 524288))
# Param:data ratio (using scaling params per Kaplan et al.)
PARAM_DATA_RATIO=$(python -c "print(f'{$TOKENS_TRAINED / $NUM_SCALING_PARAMS:.2f}')")
# Model dim
MODEL_DIM=$((d * 64))
# Val BPB from final eval
@ -99,10 +109,10 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
CORE_SCORE="0.0"
fi
log " Params: $NUM_PARAMS, Iters: $NUM_ITERS, Ratio: $PARAM_DATA_RATIO, Val BPB: $VAL_BPB, CORE: $CORE_SCORE"
log " Params: $PARAMS_TOTAL (transformer: $PARAMS_TRANSFORMER), Iters: $NUM_ITERS, Val BPB: $VAL_BPB, CORE: $CORE_SCORE"
# Append to CSV
echo "$flops,$d,$MODEL_DIM,$NUM_PARAMS,$NUM_SCALING_PARAMS,$NUM_ITERS,$TOKENS_TRAINED,$PARAM_DATA_RATIO,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
echo "$flops,$d,$MODEL_DIM,$PARAMS_WTE,$PARAMS_BIGRAM,$PARAMS_VE,$PARAMS_LM,$PARAMS_TRANSFORMER,$PARAMS_SCALARS,$PARAMS_TOTAL,$NUM_ITERS,$TOKENS_TRAINED,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
done
done

View File

@ -58,8 +58,8 @@ python -m nanochat.dataset -n 8
# See comment below for why 370 is the right number here
python -m nanochat.dataset -n 370 &
DATASET_DOWNLOAD_PID=$!
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
python -m scripts.tok_train --max-chars=2000000000 --vocab-size=65536
# train the tokenizer with vocab size 2**15 = 32768 on ~2B characters of data
python -m scripts.tok_train
# evaluate the tokenizer (report compression ratio etc.)
python -m scripts.tok_eval

View File

@ -104,7 +104,7 @@ for split_name in ["train", "val"]:
bpb_results[split_name] = bpb
print0(f"Model: {model_name}, {split_name} bpb: {bpb:.6f}")
# Master process also samples from the model (only for nanochat models)
# Master process also samples from the model for some basic knowledge-eliciting prompts (only for nanochat models)
samples = []
if ddp_rank == 0 and args.hf_path is None:
prompts = [
@ -122,9 +122,23 @@ if ddp_rank == 0 and args.hf_path is None:
with autocast_ctx:
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
sample_str = tokenizer.decode(sample[0])
print0("-" * 80)
print0(sample_str)
samples.append(sample_str)
# Draw some unconditioned samples from the model (only for nanochat models)
unconditioned_samples = []
if ddp_rank == 0 and args.hf_path is None:
engine = Engine(model, tokenizer)
tokens = tokenizer("", prepend="<|bos|>")
with autocast_ctx:
samples, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
for sample in samples:
sample_str = tokenizer.decode(sample)
print0("-" * 80)
print0(sample_str)
unconditioned_samples.append(sample_str)
# Log to report
from nanochat.report import get_report
get_report().log(section="Base model loss", data=[
@ -134,6 +148,7 @@ get_report().log(section="Base model loss", data=[
"val bpb": bpb_results["val"],
},
{f"sample {i}": sample for i, sample in enumerate(samples)},
{f"unconditioned sample {i}": sample for i, sample in enumerate(unconditioned_samples)},
])
# Cleanup

View File

@ -1,11 +1,11 @@
"""
Train model. From root directory of the project, run as:
python -m scripts.base_train.py
python -m scripts.base_train
or distributed as:
torchrun --nproc_per_node=8 -m scripts.base_train.py
torchrun --nproc_per_node=8 -m scripts.base_train
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
@ -22,11 +22,12 @@ import torch
from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.flash_attention import HAS_FA3
from scripts.base_eval import evaluate_model
print_banner()
@ -46,7 +47,7 @@ parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding
# Training horizon (only one used, in order of precedence)
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
parser.add_argument("--target-param-data-ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
# Optimization
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
@ -81,11 +82,29 @@ master_process = ddp_rank == 0 # this process will do logging, checkpointing etc
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
if device_type == "cuda":
gpu_device_name = torch.cuda.get_device_name(0)
gpu_peak_flops = get_peak_flops(gpu_device_name)
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
else:
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
# wandb logging init
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
# Flash Attention status
if HAS_FA3:
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
else:
print0("!" * 80)
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
print0("WARNING: Training will be less efficient without FA3")
if args.window_pattern != "L":
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
print0("!" * 80)
# Tokenizer will be useful for evaluation, also we need the vocab size
tokenizer = get_tokenizer()
token_bytes = get_token_bytes(device=device)
@ -93,21 +112,19 @@ vocab_size = tokenizer.get_vocab_size()
print0(f"Vocab size: {vocab_size:,}")
# Model kwargs are derived from the desired depth of the model
# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division
# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
# (For very small depths, this gives a slight "unfair" advantage to models with odd depths)
num_layers = args.depth
model_dim = args.depth * args.aspect_ratio
def find_num_heads(model_dim, target_head_dim):
# Find num_heads that divides model_dim evenly, with head_dim closest to target.
ideal = max(1, round(model_dim / target_head_dim))
for offset in range(model_dim):
for candidate in [ideal + offset, ideal - offset]:
if candidate > 0 and model_dim % candidate == 0:
return candidate
return 1
num_heads = find_num_heads(model_dim, args.head_dim)
base_dim = args.depth * args.aspect_ratio
model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
num_heads = model_dim // args.head_dim
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
head_dim = model_dim // num_heads
print0(f"num_layers: {num_layers}")
print0(f"model_dim: {model_dim}")
print0(f"model_dim: {model_dim} (base: {base_dim}, nudge: {model_dim - base_dim:+d})")
print0(f"num_heads: {num_heads}")
print0(f"head_dim: {head_dim}")
print0(f"num_kv_heads: {num_kv_heads}")
# Optimizer / data / training length related hyperparameters
@ -161,9 +178,14 @@ if resuming:
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
num_params = sum(p.numel() for p in model.parameters())
num_scaling_params = orig_model.num_scaling_params()
print0(f"Number of parameters: {num_params:,} (scaling: {num_scaling_params:,})")
# Detailed parameter counts
param_counts = orig_model.num_scaling_params()
print0(f"Parameter counts:")
for key, value in param_counts.items():
print0(f"{key:24s}: {value:,}")
num_params = param_counts['total']
num_scaling_params = param_counts['transformer_matrices'] + param_counts['lm_head'] # determined to give the cleanest scaling laws, see dev/LOG.md Jan 27, 2026
num_flops_per_token = model.estimate_flops()
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
@ -178,14 +200,14 @@ elif args.target_flops > 0:
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
elif args.target_param_data_ratio > 0:
# calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.)
target_tokens = args.target_param_data_ratio * num_scaling_params
target_tokens = int(args.target_param_data_ratio * num_scaling_params)
num_iterations = target_tokens // args.total_batch_size
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
else:
raise ValueError("No training horizon specified")
total_tokens = args.total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# -----------------------------------------------------------------------------
@ -382,8 +404,7 @@ while True:
pct_done = 100 * step / num_iterations
tok_per_sec = int(args.total_batch_size / dt)
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
# Calculate ETA based on average time per step (excluding first 10 steps)
@ -429,7 +450,7 @@ get_report().log(section="Base model training", data=[
"Number of FLOPs per token": f"{num_flops_per_token:e}",
"Calculated number of iterations": num_iterations,
"Number of training tokens": total_tokens,
"Tokens : Params ratio": args.total_batch_size * num_iterations / num_params,
"Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params,
"DDP world size": ddp_world_size,
"warmup_ratio": args.warmup_ratio,
"warmdown_ratio": args.warmdown_ratio,

View File

@ -249,7 +249,7 @@ while True:
last_step = bool(last_step_tensor.item())
# once in a while: evaluate the val bpb (all ranks participate)
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
if last_step or (args.eval_every > 0 and step % args.eval_every == 0):
model.eval()
val_loader = build_val_loader()
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)

View File

@ -14,7 +14,7 @@ from nanochat.dataset import parquets_iter_batched
# Parse command line arguments
parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
parser.add_argument('--max-chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
parser.add_argument('--max-chars', type=int, default=2_000_000_000, help='Maximum characters to train on (default: 10B)')
parser.add_argument('--doc-cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
parser.add_argument('--vocab-size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)')
args = parser.parse_args()

View File

@ -0,0 +1,338 @@
"""
Test Flash Attention unified interface - verify FA3 and SDPA produce identical results.
Run: python -m pytest tests/test_attention_fallback.py -v -s
Note on test structure:
Tests are split into two classes due to dtype/device constraints:
1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs
and verify they produce identical results. These require a Hopper GPU (FA3 only
works on sm90+) and use bfloat16 (FA3 doesn't support float32).
2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run
on any device (CUDA, CPU, MPS) with the appropriate dtype for that device.
"""
import torch
import pytest
import nanochat.flash_attention as fa_module
from nanochat.flash_attention import flash_attn, HAS_FA3
from nanochat.engine import KVCache
def set_impl(impl):
"""Set the implementation override ('fa3', 'sdpa', or None for auto)."""
fa_module._override_impl = impl
def run_both_impls(fn):
"""Run a function with both FA3 and SDPA, return both outputs."""
set_impl('fa3')
out_fa3 = fn()
set_impl('sdpa')
out_sdpa = fn()
set_impl(None) # reset
return out_fa3, out_sdpa
def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2):
"""Assert two tensors are close, with helpful error message."""
max_diff = (t1 - t2).abs().max().item()
mean_diff = (t1 - t2).abs().mean().item()
assert torch.allclose(t1, t2, atol=atol, rtol=rtol), \
f"{name}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}"
return max_diff, mean_diff
# =============================================================================
# FA3 vs SDPA comparison tests (require Hopper GPU)
# =============================================================================
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required to compare implementations")
class TestFA3VsSDPA:
"""Compare FA3 and SDPA produce identical results. Requires Hopper GPU."""
DEVICE = "cuda"
DTYPE = torch.bfloat16
def test_basic_causal(self):
"""Basic causal attention."""
B, T, H, D = 2, 64, 4, 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "basic_causal")
print(f"basic_causal: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_full_context(self):
"""Full context (window_size=-1)."""
B, T, H, D = 2, 128, 4, 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "full_context")
print(f"full_context: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_sliding_window(self):
"""Sliding window attention."""
B, T, H, D = 2, 128, 4, 32
window = 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(window, 0))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "sliding_window")
print(f"sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_gqa(self):
"""Group Query Attention (fewer KV heads than Q heads)."""
B, T, D = 2, 64, 32
n_heads = 8
n_kv_heads = 2
q = torch.randn(B, T, n_heads, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "gqa")
print(f"gqa: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_larger_model(self):
"""Larger dimensions closer to real model."""
B, T, H, D = 4, 256, 12, 64
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "larger_model")
print(f"larger_model: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_kvcache_prefill(self):
"""Test prefill (inserting multiple tokens into empty cache)."""
B, T_max, H, D = 2, 64, 4, 32
T_prefill = 16
q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
cache_seqlens = torch.zeros(B, dtype=torch.int32, device=self.DEVICE)
return flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v,
cache_seqlens=cache_seqlens,
causal=True, window_size=(T_max, 0)
)
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "prefill")
print(f"prefill: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_kvcache_single_token(self):
"""Test single token generation (cache already has content)."""
B, T_max, H, D = 2, 64, 4, 32
T_prefill = 16
k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_cache[:, :T_prefill, :, :] = k_init
v_cache[:, :T_prefill, :, :] = v_init
cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE)
return flash_attn.flash_attn_with_kvcache(
q_single, k_cache, v_cache, k=k_single, v=v_single,
cache_seqlens=cache_seqlens,
causal=True, window_size=(T_max, 0)
)
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token")
print(f"single_token: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
def test_backward_gradients_match(self):
"""Verify gradients are similar between FA3 and SDPA."""
B, T, H, D = 2, 32, 4, 16
q_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
q = q_data.clone().requires_grad_(True)
k = k_data.clone().requires_grad_(True)
v = v_data.clone().requires_grad_(True)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
loss = y.sum()
loss.backward()
return y.detach(), q.grad.detach(), k.grad.detach(), v.grad.detach()
set_impl('fa3')
y_fa3, q_grad_fa3, k_grad_fa3, v_grad_fa3 = run()
set_impl('sdpa')
y_sdpa, q_grad_sdpa, k_grad_sdpa, v_grad_sdpa = run()
set_impl(None)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "backward_output")
print(f"backward_output: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
max_diff, mean_diff = assert_close(q_grad_fa3, q_grad_sdpa, "q_grad", atol=0.05, rtol=0.05)
print(f"q_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
max_diff, mean_diff = assert_close(k_grad_fa3, k_grad_sdpa, "k_grad", atol=0.05, rtol=0.05)
print(f"k_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
max_diff, mean_diff = assert_close(v_grad_fa3, v_grad_sdpa, "v_grad", atol=0.05, rtol=0.05)
print(f"v_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
# =============================================================================
# SDPA-only tests (run on any device)
# =============================================================================
class TestSDPAOnly:
"""Test SDPA fallback works correctly. Runs on any device."""
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
def test_basic_forward(self):
"""Test SDPA forward pass produces valid output."""
set_impl('sdpa')
B, T, H, D = 2, 64, 4, 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
assert y.shape == (B, T, H, D)
assert not torch.isnan(y).any(), "Output contains NaN"
set_impl(None)
def test_backward(self):
"""Test gradients flow through SDPA."""
set_impl('sdpa')
B, T, H, D = 2, 32, 4, 16
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
loss = y.sum()
loss.backward()
assert q.grad is not None, "No gradient for q"
assert k.grad is not None, "No gradient for k"
assert v.grad is not None, "No gradient for v"
assert not torch.isnan(q.grad).any(), "NaN in q gradient"
set_impl(None)
def test_kvcache(self):
"""Test SDPA with KV cache."""
set_impl('sdpa')
B, T_max, H, D = 2, 64, 4, 32
n_layers = 1
cache = KVCache(
batch_size=B, num_heads=H, seq_len=T_max, head_dim=D,
num_layers=n_layers, device=self.DEVICE, dtype=self.DTYPE
)
k_cache, v_cache = cache.get_layer_cache(0)
# Prefill
T_prefill = 16
q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
y = flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v,
cache_seqlens=cache.cache_seqlens,
causal=True, window_size=(T_max, 0)
)
cache.advance(T_prefill)
assert y.shape == (B, T_prefill, H, D)
assert cache.get_pos() == T_prefill
# Generate single token
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
y_single = flash_attn.flash_attn_with_kvcache(
q_single, k_cache, v_cache, k=k_single, v=v_single,
cache_seqlens=cache.cache_seqlens,
causal=True, window_size=(T_max, 0)
)
cache.advance(1)
assert y_single.shape == (B, 1, H, D)
assert cache.get_pos() == T_prefill + 1
set_impl(None)
# =============================================================================
# Override mechanism tests
# =============================================================================
class TestOverrideMechanism:
"""Test that the override mechanism works correctly."""
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required")
def test_override_fa3(self):
"""Test that override='fa3' uses FA3."""
set_impl('fa3')
assert fa_module._use_fa3() == True
set_impl(None)
def test_override_sdpa(self):
"""Test that override='sdpa' uses SDPA."""
set_impl('sdpa')
assert fa_module._use_fa3() == False
set_impl(None)
def test_override_auto(self):
"""Test that override=None uses auto-detection."""
set_impl(None)
assert fa_module._use_fa3() == HAS_FA3
if __name__ == "__main__":
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name()}")
major, minor = torch.cuda.get_device_capability()
print(f"Compute capability: {major}.{minor}")
print(f"HAS_FA3: {HAS_FA3}")
print()
pytest.main([__file__, "-v", "-s"])

View File

@ -96,6 +96,7 @@ def test_kv_cache_basic():
head_dim=head_dim,
num_layers=num_layers,
device="cpu",
dtype=torch.float32,
)
# Check initial state
@ -130,7 +131,7 @@ def test_kv_cache_prefill():
# Create source cache and advance it
src_cache = KVCache(
batch_size=batch_size, num_heads=num_heads, seq_len=32,
head_dim=head_dim, num_layers=num_layers, device="cpu",
head_dim=head_dim, num_layers=num_layers, device="cpu", dtype=torch.float32,
)
# Write some data to source cache
src_cache.k_cache[0, 0, :16, :, :] = 1.0
@ -140,7 +141,7 @@ def test_kv_cache_prefill():
# Create destination cache with larger seq_len
dst_cache = KVCache(
batch_size=batch_size, num_heads=num_heads, seq_len=64,
head_dim=head_dim, num_layers=num_layers, device="cpu",
head_dim=head_dim, num_layers=num_layers, device="cpu", dtype=torch.float32,
)
# Prefill
@ -195,3 +196,72 @@ def test_multi_sample_first_token_diversity():
f"With uniform logits, this is statistically impossible (~10^-36 probability) "
f"unless tokens are being broadcast instead of independently sampled."
)
def test_seed_reproducibility():
"""Same seed must produce identical output."""
model = MockModel()
engine = Engine(model, ByteTokenizer())
prompt = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
for seed in [1, 42, 123, 999]:
r1, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
r2, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
r3, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
assert r1 == r2 == r3, "Same seed must produce identical output for the same prompt."
def test_temperature_zero_determinism():
"""Temperature=0 is deterministic regardless of seed."""
model = MockModel()
engine = Engine(model, ByteTokenizer())
prompt = [261, 72, 101, 108, 108, 111]
r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=1)
r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=42)
r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=123)
assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed."
def test_max_tokens_respected():
"""Generation stops at max_tokens limit."""
model = MockModel()
engine = Engine(model, ByteTokenizer())
prompt = [261, 72, 101, 108, 108, 111]
for max_tokens in [1, 4, 16, 64]:
results, _ = engine.generate_batch(prompt, max_tokens=max_tokens)
num_generated_tokens = len(results[0]) - len(prompt)
assert num_generated_tokens <= max_tokens, f"Generated {num_generated_tokens} tokens, expected max_tokens={max_tokens} or less."
def test_num_samples_count():
"""num_samples=N produces exactly N sequences."""
model = MockModel()
engine = Engine(model, ByteTokenizer())
prompt = [261, 72, 101, 108, 108, 111]
for num_samples in [1, 4, 16, 64]:
results, _ = engine.generate_batch(prompt, num_samples=num_samples, max_tokens=3)
assert len(results) == num_samples, f"Expected {num_samples} sequences from {num_samples} samples, got {len(results)}"
def test_different_seeds_introduce_variation_when_temperature_nonzero():
"""With temperature > 0, different seeds should introduce sampling variation."""
model = MockModel()
engine = Engine(model, ByteTokenizer())
prompt = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
outputs = set()
for seed in [1, 42, 123, 999, 1000, 1001, 1002, 1003, 1004, 1005]:
results, _ = engine.generate_batch(
prompt,
temperature=1.0,
max_tokens=5,
seed=seed,
)
outputs.add(tuple(results[0]))
# Sanity check: sampling actually introduces variation
assert len(outputs) > 1, "All seeds produced the same output which is statistically highly improbable."