Merge branch 'master' into master_goderr

This commit is contained in:
svlandeg 2026-01-15 20:47:26 +01:00
commit 65865df300
35 changed files with 6006 additions and 2731 deletions

11
.gitignore vendored
View File

@ -1,7 +1,14 @@
.venv/
__pycache__/
*.pyc
rustbpe/target/
dev-ignore/
report.md
eval_bundle/
eval_bundle/
# Secrets
.env
# Local setup
.claude
CLAUDE.md
wandb/

View File

@ -10,6 +10,10 @@ This repo is a full-stack implementation of an LLM like ChatGPT in a single, cle
To get a sense of the endpoint of this repo, you can currently find [nanochat d34](https://github.com/karpathy/nanochat/discussions/314) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d34" means that this model has 34 layers in the Transformer neural network. This model has 2.2 billion parameters, it was trained on 88 billion tokens by simply running the training script [run1000.sh](run1000.sh) with `--target_param_data_ratio=40` (2x longer than Chinchilla-optimal), and the total cost of training was ~$2,500 (about 100 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of modern Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
## Updates
- (Jan 7 2026) See new post: [nanochat Miniseries v1](https://github.com/karpathy/nanochat/discussions/420) and the associated script [miniseries.sh](miniseries.sh).
## Quick start
The fastest way to feel the magic is to run the speedrun script [speedrun.sh](speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
@ -108,10 +112,10 @@ Additionally, to add new abilities to nanochat, see [Guide: counting r in strawb
nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so:
```bash
files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml > packaged.txt
files-to-prompt . -e py -e md -e html -e toml -e sh --cxml > packaged.txt
```
This includes all py, rs, html, toml, sh files, excludes the `rustbpe/target` folder, and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files.
This includes all py, html, toml, sh files and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files.
Alternatively, I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
@ -120,7 +124,7 @@ Alternatively, I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanoch
I haven't invested too much here but some tests exist, especially for the tokenizer. Run e.g. as:
```bash
python -m pytest tests/test_rustbpe.py -v -s
python -m pytest tests/test_engine.py -v -s
```
## File structure
@ -140,7 +144,6 @@ python -m pytest tests/test_rustbpe.py -v -s
│ ├── adamw.py # Distributed AdamW optimizer
│ ├── checkpoint_manager.py # Save/Load model checkpoints
│ ├── common.py # Misc small utilities, quality of life
│ ├── configurator.py # A superior alternative to argparse
│ ├── core_eval.py # Evaluates base model CORE score (DCLM paper)
│ ├── dataloader.py # Tokenizing Distributed Data Loader
│ ├── dataset.py # Download/read utils for pretraining data
@ -155,12 +158,6 @@ python -m pytest tests/test_rustbpe.py -v -s
│ └── ui.html # HTML/CSS/JS for nanochat frontend
├── pyproject.toml
├── run1000.sh # Train the ~$800 nanochat d32
├── rustbpe # Custom Rust BPE tokenizer trainer
│ ├── Cargo.lock
│ ├── Cargo.toml
│ ├── README.md # see for why this even exists
│ └── src
│ └── lib.rs
├── scripts
│ ├── base_eval.py # Base model: calculate CORE score
│ ├── base_loss.py # Base model: calculate bits per byte, sample
@ -185,7 +182,6 @@ python -m pytest tests/test_rustbpe.py -v -s
│ └── spellingbee.py # Task teaching model to spell/count letters
├── tests
│ └── test_engine.py
│ └── test_rustbpe.py
└── uv.lock
```

381
dev/LOG.md Normal file
View File

@ -0,0 +1,381 @@
# Experiment Log
A running summary documenting some experiments and findings. Started ~Jan 7 2026.
---
## 2026-01-13: Varlen Attention (Negative Result)
Attempted to prevent attention from "leaking" across document boundaries using Flash Attention's `flash_attn_varlen_func`, similar to modded-nanogpt's approach.
### Background
With the BOS-aligned dataloader, multiple documents are packed into each row. Standard attention allows tokens to attend across document boundaries within a row. The hypothesis was that preventing this "leakage" via varlen attention might improve training.
### Approach: Compute cu_seqlens from inputs
- Find BOS positions: `(inputs.view(-1) == bos_token_id).nonzero()`
- Gotcha 1: Variable-length `cu_seqlens` caused torch.compile recompilation (25s/iter!) - fixed by padding to fixed size
- Gotcha 2: `nonzero()` inside compiled model hit recompile limit - fixed by moving computation outside compiled region
### Final Results (d16)
| Metric | Baseline | Varlen |
|--------|----------|--------|
| val_bpb | 0.85427 | 0.85407 |
| MFU | ~same | ~same |
| tok/sec | ~same | ~same |
Essentially identical. The 0.0002 bpb improvement is almost noise.
### Conclusion
Not worth the code complexity. The "leakage" across document boundaries within a row is not harmful - the model handles it fine. The BOS-aligned dataloader already provides the key benefit (every row starts with proper context). Not merging to master.
---
## 2026-01-13: BOS-Aligned Dataloader with Bin Packing
Redesigned the pretraining and midtraining dataloader to ensure every sequence starts with a BOS token, and explored bin-packing algorithms to minimize wasted tokens.
### Problem Statement
The original dataloader streams tokens into a flat buffer and reshapes into batches. This means some rows start mid-document (no BOS), which could confuse the model during training. We want every row to start with BOS and contain well-formed documents.
### Approach 1: Greedy-Crop BOS (Simple)
Each row is built independently:
- Start with a document (which has BOS prepended)
- Pack more documents until row is full
- If a document doesn't fit, **crop it** to fill remaining space (discard the rest)
- 100% utilization (no padding), but wastes cropped tokens
### Waste Analysis
Measured token waste empirically on real data (T=2048):
- **39.4% of tokens are cropped** (discarded when docs don't fit)
- **22.9% is the theoretical minimum** (tokens in docs longer than T+1 that can never fit)
- The extra ~16.5% comes from "unlucky" cropping when a long doc starts near the end of a row
### Bin Packing Algorithms Explored
| Algorithm | Util% | Crop% | Pad% | Notes |
|-----------|-------|-------|------|-------|
| Greedy-Crop (baseline) | 100% | 39.4% | 0% | Simple, no wasted compute |
| Greedy-Pad | 78% | 23.0% | 22% | Pads instead of crops - wastes compute |
| First-Fit Decreasing (FFD) | 99.7% | 23.0% | 0.3% | Near-optimal packing, minimal padding |
| **BestFit-Crop** | 100% | 34.6% | 0% | Smart cropping, no padding |
### BestFit-Crop Algorithm
A middle ground that maintains 100% utilization while reducing cropping:
1. Buffer N documents
2. For each row, greedily pick the **largest doc that fits entirely**
3. Repeat until nothing fits
4. When nothing fits, crop a doc to fill remaining space exactly
This avoids "unlucky" crops by searching the buffer for better-fitting documents.
**Results (T=2048):**
- Crop waste reduced from 39.4% → 34.6% (~12% relative improvement)
- Still achieves 100% utilization (no padding, every token trains)
- Slightly more rows than baseline (uses more documents per batch)
### Decision: Keep Two Implementations
1. Keep the original implementation which is very simple, efficient and has 100% token utilization in the batch (no padding with ignore tokens), but creates slightly more confusing token streams for the LLM because documents during training can start abruptly from the middle with no context. Note that this never happens at test time, where BOS is always present.
2. **`_bos_bestfit` (BestFit-Crop, new default)**: Slightly more complex but still keeps 100% token utilization in the batch (no padding), but at the cost of discarding documents when they don't fit. In practice, about 34% of tokens are discarded with this approach. This is ok because for most models we care about we have plenty of data without having to go to multiple epochs. One more subtle effect is that it does skew the data distribution a tiny bit because, reliably and necessarily, tokens at the tails of long documents will be discarded. However, this doesn't seem to impact actual downstream performance.
### Midtraining
The midtraining dataloader was also updated. Because conversations are on average a lot shorter than pretraining documents, only about 3.3% of tokens get cropped.
### NOTE: loss scale
Do note that switching to the BOS dataloader changes the validation loss and makes all previous experiments not comparable in absolute value of the loss, because we have a lot fewer "confusing" tokens in the train/val batches. All tokens can look back and find the BOS token and have the full context of that document to make predictions. Therefore, the loss appears lower but this is "fake" to some extent, and the expectation is that the vast majority of relative comparisons done so far would agree with those before and after this change.
---
## 2026-01-13: Number Token Split Pattern
Validated the `\p{N}{1,2}` pattern in `SPLIT_PATTERN` (tokenizer.py line 30), which I only guessed earlier and had a TODO for to validate. GPT-4 uses `\p{N}{1,3}` to group number sequences of up to 3 digits into tokens, but we suspected smaller vocab sizes benefit from grouping fewer digits per token.
**Results (d12, vocab=32K):**
| Pattern | val_bpb |
|---------|---------|
| `\p{N}{1,1}` | 0.969 |
| `\p{N}{1,2}` | **0.965** |
| `\p{N}{1,3}` | 0.972 |
**Conclusion:** `{1,2}` is optimal for vocab size 32K. Grouping 3 digits wastes tokens on rare 3-digit combinations; grouping 1 digit is too fine-grained and bloats token sequences. Keeping `{1,2}` as default.
---
## 2026-01-13: FP8 Training for lm_head
Attempted to use FP8 (8-bit floating point) for the lm_head layer to speed up the large vocab projection matmul. H100 GPUs have FP8 tensor cores that can theoretically provide ~2x speedup over BF16.
### Implementation Approaches Tried
**1. Dynamic Scaling (failed)**
- Compute `x.abs().max()` and `w.abs().max()` each forward to determine scales
- Problem: `.item()` calls cause graph breaks with torch.compile
- Tried `@torch._dynamo.allow_in_graph` pattern (like torchao.float8) - worked but no speedup
- Tried `torch.library.custom_op` with float scales - caused NaN gradients after first optimizer step
- Root cause: interaction between custom ops, dynamic scale computation, and torch.compile is fragile
**2. Static Scaling (partial success)**
- Pre-set scales at init time like modded-nanogpt: `x_scale=10/448, w_scale=0.1/448`
- `grad_scale` computed dynamically from batch size (safe since it's just `1/(B*T)/57344` due to the gradient expression of cross entropy). modded-nanogpt has a bug here probably because they set `grad_scale = 0.75/448`, but grads are in E5M2 so this should probably be `1/57344`, 1 being the amax of any individual element of cross entropy loss, and no normalization by B,T because they use sum reduction not mean reduction.
- Uses `torch.library.custom_op` with `@torch.compile` on inner kernels
- This works correctly - no NaNs, proper gradients
### Results (d12)
| Metric | BF16 Baseline | FP8 lm_head |
|--------|---------------|-------------|
| GPU Memory | 34 GB | 36 GB |
| tok/sec | baseline | ~1% faster |
### The Memory Mystery
FP8 *should* save memory since we store `x_f8` (1 byte) instead of `x` (2 bytes) for backward. But we see 2GB *increase*. Suspected causes:
- `torch.compile` on inner kernels creating extra buffers/specializations
- `torch._scaled_mm` internal workspace allocations
- Custom op registration machinery overhead
Tried saving original weight `w` (just a reference to parameter) instead of `w_f8` in backward, then re-quantizing on the spot during backward - didn't help. Still saw bump.
### Microbenchmark vs Reality
Raw microbenchmark showed promise:
- BF16 matmul: 16.95 ms
- FP8 matmul (static scales): 10.31 ms (1.64x faster)
- FP8 with dynamic scaling: 12.25 ms (1.38x faster)
But in full training, the ~1% tok/sec improvement doesn't justify the 2GB memory increase and the added code complexity and the need to tune scale factors for both x and w.
### Code Artifacts
See the branch `fp8_attempt_fail` for:
- `nanochat/fp8_static.py` - Static scaling implementation (working)
- `nanochat/fp8_dynamic.py` - Dynamic scaling implementation (torchao-style, working but slow)
- `gpt.py` imports `fp8_static.LinearFP8` and simply swaps it for `lm_head` in `gpt.py`.
### Open Questions
- Why does the custom op approach use more memory than vanilla BF16?
- Why is the bump in tok_per_sec so low? We should see ~1.6X speedup in both the forward pass and also (twice) in backward pass for the gradients. Granted, Ahmdal's law is part of the solution because our vocab_size is only 32K so the final layer isn't a huge part of the profile but the expected speedup is still not fully realized.
**Conclusion:** Negative result for now. The implementation works correctly but provides marginal speedup with *increased* memory usage. I'm not understanding the torch.compile interaction here. The complexity of FP8 custom ops isn't justified for lm_head alone. TODO to study in more detail the way this is implemented in other libraries, e.g. torchao.
---
## 2026-01-12: Multi-Token Prediction (MTP)
Ported multi-token prediction from modded-nanogpt. Instead of predicting just the next token, predict the next n tokens at each position with weighted loss.
### Implementation
- Instead of calling the loss `n_predict` times, uses a fancy batched computation using `unfold` + `gather` + cross-entropy decomposition (`CE = logsumexp - logits[target]`)
- Schedule anneals from 3-token to 1-token prediction:
- 0-33%: `[1.0, 0.5, 0.25→0]` (3rd token fades)
- 33-67%: `[1.0, 0.5→0]` (2nd token fades)
- 67-100%: `[1.0]` (standard next-token)
- Weights normalized to sum to 1
### Results (d12)
| Metric | Baseline | MTP |
|--------|----------|-----|
| GPU Memory | 34 GB | 47 GB |
| MFU | 41% | 40% |
| val/bpb (per step) | baseline | same/slightly worse |
| val/bpb (wall clock) | baseline | noticeably worse |
**Conclusion:** Negative result for nanochat. The extra memory and compute overhead from predicting multiple tokens doesn't pay off, in fact the results get worse. The auxiliary loss signal may help in other settings (larger models, different architectures?), but for our setup it's pure overhead at the moment.
---
## 2026-01-11: Sliding Window Attention
Added configurable sliding window attention, inspired by GPT-3's alternating short/long pattern.
**Pattern string configuration:**
- New `--window_pattern` CLI arg and `GPTConfig.window_pattern` field
- Pattern is tiled across layers (e.g., `SSSL` for 20 layers → `SSSLSSSLSSSLSSSLSSSL`)
- Final layer always forced to L (full context) regardless of pattern
- Short window = `sequence_len // 2`
- Long window = `sequence_len` (full context)
- All previous models so far have been simply `L` and checkpoint loading is modified accordingly to fill in this param for old models, see `_patch_missing_config_keys`
Quick experiments showed `SSSL` (every 4th layer is long) works well - provides a good balance between compute savings and model quality. This is now the default.
---
## 2026-01-11: Flash Attention 3 Integration
Replaced PyTorch's `scaled_dot_product_attention` (FA2) with Flash Attention 3 for training and inference.
### Changes Made
**1. FA3 via `kernels` package**
- Official FA3 is "beta" and requires building from source (painful)
- Using `kernels` package from HuggingFace Hub: `get_kernel('varunneal/flash-attention-3')`
- Loads pre-built wheels, works out of the box on H100
**2. Simplified attention code**
- FA3 uses `(B, T, H, D)` layout matching our projection output directly - no transpose needed
- Training: `flash_attn.flash_attn_func(q, k, v, causal=True)`
- Inference: `flash_attn.flash_attn_with_kvcache()` handles all cache cases in one call
- Removed 3 separate FA2 code paths (training, single-token, chunk inference)
- GQA handled automatically when n_kv_heads < n_heads
**3. Rewrote KVCache for FA3**
- Old format: `(num_layers, 2, B, H, T, D)` combined tensor
- New format: separate `k_cache` and `v_cache` of shape `(num_layers, B, T, H, D)`
- FA3 updates cache in-place during `flash_attn_with_kvcache`
- Position tracked via `cache_seqlens` tensor (int32, per batch element)
- Simpler API: `get_layer_cache()`, `advance()`, `reset()`, `prefill()`
### Results
- **~9% improvement in tok/sec** during training out of the box
- Benchmarks showed FA3 is 2x faster than FA2 at realistic training sizes (batch=32, seq=2048)
- FA3 supports sliding window via `window_size=(left, 0)`, which is huge and expected to give further improvements. This is ready to tune but keeping full context for now.
---
## 2026-01-11: Per-Layer Residual Scalars (x0 & resid lambdas)
Cherry-picked an idea from modded-nanogpt around learnable per-layer residual connections.
### Changes Made
**1. x0_lambdas (x0 residual connections)**
- Save initial normalized embedding as `x0` after `norm(wte(idx))`
- At each layer, blend x0 back in: `x = resid_lambdas[i] * x + x0_lambdas[i] * x0`
- Zero-initialized, so disabled at start; model learns which layers benefit from the shortcut
- Provides direct path from embedding to deep layers, helps preserve token information
**2. resid_lambdas (residual stream scaling)**
- Per-layer multiplicative scaling of the residual stream
- Initialized to 1.0 (neutral, standard transformer behavior)
- Allows model to learn to amplify/dampen residual at each layer
**3. DistAdamW small parameter handling**
- Added support for parameters with < 1024 elements (like the scalar lambdas)
- Small params use `all_reduce` instead of `reduce_scatter`/`all_gather`
- Fixes crash when param shape isn't divisible by world_size
### Key Finding: Different LR Sensitivity
The two scalar types need very different learning rates:
- **x0_lambdas (additive)**: Can use normal LR (~0.5). Adding a fraction of x0 is forgiving.
- **resid_lambdas (multiplicative)**: Needs ~100x smaller LR (~0.005). Multiplying the residual compounds through layers.
Implementation: `resid_params` gets `scalar_lr * 0.01`, `x0_params` gets full `scalar_lr`.
### Experiment Results
Swept `--scalar_lr` (controlling x0_lambdas) at multiple depths:
| Depth | Baseline (disabled) | Best scalar_lr | Best val_bpb | Δ bpb |
|-------|---------------------|----------------|--------------|-------|
| d8 | 1.0885 | 0.20 | 1.0782 | -0.0103 |
| d12 | 0.9770 | 0.60 | 0.9693 | -0.0077 |
| d16 | 0.9059 | 0.20 | 0.9002 | -0.0057 |
| d20 | 0.8565 | 0.10 | 0.8526 | -0.0039 |
**Observations:**
- Consistent improvement across all model sizes
- Optimal LR varies by depth; default of 0.5 is reasonable, but 0.6 is better for d12
- Adding resid_lambdas (with 0.01x LR) gives small additional improvement over x0 alone
### Meta Device Footgun
Important lesson: `__init__` runs in meta device context, so any tensor values set there are fake. Must initialize actual values in `init_weights()`. Added docstring warning to `__init__`.
### Summary
Added `--scalar_lr` (default 0.5) controlling learnable per-layer scalars. The formula `x = resid_lambdas[i] * x + x0_lambdas[i] * x0` gives the model control over residual scaling and direct shortcuts to the initial embedding. Solid improvement with essentially no compute overhead.
---
## 2026-01-10: Muon Optimizer Upgrades & Cautious Weight Decay
Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon implementation. Decided against using NorMuon directly due to hard-coded architecture assumptions (expects 32 params split 10 attn + 22 mlp), parameter labeling requirements, and complexity.
### Changes Made
**1. Polar Express Orthogonalization**
- Replaced Newton-Schulz iteration with "Polar Express Sign Method" from [arxiv.org/pdf/2505.16932](https://arxiv.org/pdf/2505.16932)
- Uses 5 different coefficient tuples (one per iteration) instead of fixed coefficients
- Both methods kept in code for easy comparison (`zeropower_via_polar_express` vs `zeropower_via_newtonschulz5`)
- **Result:** No dramatic/noticeable difference in training, but keeping the new Polar Express as default.
**2. Variance Reduction (NorMuon-style)**
- Added low-rank variance estimator similar to Adafactor ([arxiv.org/pdf/2510.05491](https://arxiv.org/pdf/2510.05491))
- Maintains `second_momentum_buffer` with shape `[rows, 1]` or `[1, cols]` (whichever is smaller)
- Normalizes updates based on running per-row/col variance estimate (beta2=0.95)
- Memory overhead: ~1/max(rows, cols) per param, negligible
- **Result:** Led to a very small improvement, kept and enabled by default.
**3. Cautious Weight Decay**
- Only decays weights where `update * weight >= 0` (same sign) from [arxiv.org/abs/2411.16085](https://arxiv.org/abs/2411.16085)
- Standard WD always pulls toward zero; cautious WD skips decay when gradient is pushing weight away from zero
- **Implementation note:** Had to inline the logic rather than use a separate `@torch.compile` function. Passing changing float values (like `weight_decay` during scheduling) as function arguments triggers recompilation. Reading from `group["weight_decay"]` inside the step avoids this.
- **Result:** Solid improvements, especially the cautious version was better than standard wd.
- Now defaults to ON for Muon via the `weight_decay` param. AdamW still has no weight decay and is hardcoded to 0 weight decay, might try to re-tune this later.
**4. Weight decay schedule**
- Added a linear schedule to weight decay that is default on from 1.0 to 0.0 (i.e. start with max weight decay in the beginning of training, them ramp to 0 by the end). Worked better than a static setting in experiments. (modded-nanogpt has the same schedule but it is imlpemented in a more confusing way by multiplying twice by the learning rate, which is already wired up to a decay schedule).
### Weight Decay Scaling Experiments
Swept weight decay values at d8, d12, d16, d20 to find optimal values and scaling law.
**Optimal Values Found:**
| Depth | Width (channels) | Optimal WD |
|-------|------------------|------------|
| d8 | 512 | ~0.40 |
| d12 | 768 | ~0.22 |
| d16 | 1024 | ~0.10 |
| d20 | 1280 | ~0.08 |
**Scaling Law:**
- Fit power law: `WD = k / channels^α` in log-log space
- Found α ≈ 1.97 (approximately 2), meaning WD ∝ 1/width²
**Practical Formula:**
```
WD_target = WD_reference × (d_reference / d_target)²
```
Example: If d12 optimal is 0.22, then d20 optimal ≈ 0.22 × (12/20)² ≈ 0.08
**Reference:** Moonlight paper uses fixed WD=0.1 for their 15B MoE model. Our experiments indicated a scaling law where the optimal WD changed with depth, so we go along with the empirical scaling law.
### Summary
Muon was changed to use Polar Express, added Adafactor-style variance reduction, and cautious weight decay with schedule that ramps linearly to zero. All of these changes follow modded-nanogpt repo, but all of them were also validated piece by piece to yield improvements in nanochat with the exception of the Polar Express change which was in the noise. This is default on and configurable with `--weight_decay`, using simply 0.2 and ∝ 1/width² scaling. The kwarg `--weight_decay` is therefore changing as of this change. It used to configure AdamW via standard weight decay and now it becomes exclusively used in Muon (AdamW is hardcoded to 0.0), and it is scaled based on depth.
---
## 2026-01-08: exp_grad_clip - Gradient Clipping
**Hypothesis:** Gradient clipping may be unnecessary overhead. Tested L2 norm clipping at various thresholds (0.25, 0.5, 1.0, 2.0) and elementwise clipping.
**Results:**
- No benefit at any scale tested (d12, d20)
- All variants within noise (~0.9827 val_bpb)
- Grad norm never exceeds 1.0 naturally, so clipping is always inactive
- Clipping adds ~2% time overhead from the all-reduce
**Bug Found:** Original implementation clipped local gradients before sync. Since this codebase doesn't use DDP (gradient sync is in the optimizers), each rank was clipping based on its own local norm. Fixed on the branch with proper distributed all-reduce.
**Observartion:** modded-nanogpt does not appear to clip either right now.
**Summary:** Deleted all grad-clip code paths. The code naturally produces well-behaved gradients. This improves a bit of MFU because we don't have to calculate and sync grad norms.

2190
dev/estimate_gpt3_core.ipynb Normal file

File diff suppressed because it is too large Load Diff

View File

@ -24,8 +24,7 @@ prompt:
manually generate any kind of entropy you can think of and include it in your prompts
to maintain healthy and good diversity in the data.
NOTE: You need OpenRouter API key in a file called "openroutertoken.txt" in the root directory of the repo.
(obviously you can tune this arbitrarily to your liking)
NOTE: You need OPENROUTER_API_KEY set in .env or as an environment variable.
NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139
"""
import requests
@ -34,10 +33,12 @@ import os
import copy
import random
from concurrent.futures import ThreadPoolExecutor, as_completed
from dotenv import load_dotenv
from nanochat.common import get_base_dir
api_key = open("openroutertoken.txt", "r", encoding="utf-8").read().strip()
load_dotenv()
api_key = os.environ["OPENROUTER_API_KEY"]
url = "https://openrouter.ai/api/v1/chat/completions"
headers = {

View File

@ -19,16 +19,13 @@ source .venv/bin/activate
if [ -z "$WANDB_RUN" ]; then
WANDB_RUN=dummy
fi
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
# wipe the report
python -m nanochat.report reset
# train tokenizer on ~1B characters
python -m nanochat.dataset -n 4
python -m scripts.tok_train --max_chars=1000000000
python -m scripts.tok_train --max-chars=1000000000
python -m scripts.tok_eval
# train a very small 4 layer model on the CPU
@ -36,37 +33,37 @@ python -m scripts.tok_eval
# we only run 50 steps of optimization (bump this to get better results)
python -m scripts.base_train \
--depth=4 \
--max_seq_len=1024 \
--device_batch_size=1 \
--total_batch_size=1024 \
--eval_every=50 \
--eval_tokens=4096 \
--core_metric_every=50 \
--core_metric_max_per_task=12 \
--sample_every=50 \
--num_iterations=50
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
--max-seq-len=1024 \
--device-batch-size=1 \
--total-batch-size=1024 \
--eval-every=50 \
--eval-tokens=4096 \
--core-metric-every=50 \
--core-metric-max-per-task=12 \
--sample-every=50 \
--num-iterations=50
python -m scripts.base_loss --device-batch-size=1 --split-tokens=4096
python -m scripts.base_eval --max-per-task=16
# midtraining
python -m scripts.mid_train \
--max_seq_len=1024 \
--device_batch_size=1 \
--eval_every=50 \
--eval_tokens=4096 \
--total_batch_size=1024 \
--num_iterations=100
--max-seq-len=1024 \
--device-batch-size=1 \
--eval-every=50 \
--eval-tokens=4096 \
--total-batch-size=1024 \
--num-iterations=100
# eval results will be terrible, this is just to execute the code paths.
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
# SFT
python -m scripts.chat_sft \
--device_batch_size=1 \
--target_examples_per_step=4 \
--num_iterations=100 \
--eval_steps=4 \
--eval_metrics_max_problems=16
--device-batch-size=1 \
--target-examples-per-step=4 \
--num-iterations=100 \
--eval-steps=4 \
--eval-metrics-max-problems=16
# Chat CLI
# python -m scripts.chat_cli -p "Why is the sky blue?"

227
dev/scaling_analysis.ipynb Normal file
View File

@ -0,0 +1,227 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Scaling Laws Analysis\n",
"\n",
"Analyze results from `scaling_laws.sh` to find the optimal param:data ratio for nanochat."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Load results\n",
"base_dir = os.environ.get('NANOCHAT_BASE_DIR', os.path.expanduser('~/.cache/nanochat'))\n",
"results_path = os.path.join(base_dir, 'scaling_laws_results', 'results.csv')\n",
"\n",
"df = pd.read_csv(results_path)\n",
"flops_budgets = sorted(df['flops_budget'].unique())\n",
"print(f\"Loaded {len(df)} runs across {len(flops_budgets)} FLOPs budgets\")\n",
"print(f\"Columns: {list(df.columns)}\")\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## IsoFLOP Curves (à la Chinchilla)\n",
"\n",
"For each compute budget, plot loss vs model size. Looking for the U-shape valley that reveals the optimal model size for each FLOPs budget."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
"\n",
"# Plot 1: IsoFLOP curves - Val BPB vs Parameters (the Chinchilla plot!)\n",
"ax = axes[0]\n",
"colors = plt.cm.viridis(np.linspace(0, 0.9, len(flops_budgets)))\n",
"optimal_by_bpb = []\n",
"\n",
"for flops, color in zip(flops_budgets, colors):\n",
" subset = df[df['flops_budget'] == flops].sort_values('num_scaling_params')\n",
" ax.plot(subset['num_scaling_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n",
"\n",
" # Fit quadratic in log-space: val_bpb = a*(log N)^2 + b*(log N) + c\n",
" log_params = np.log10(subset['num_scaling_params'])\n",
" coeffs = np.polyfit(log_params, subset['val_bpb'], 2)\n",
" a, b, c = coeffs\n",
"\n",
" # Plot fitted curve (dashed)\n",
" log_fit_x = np.linspace(log_params.min() - 0.1, log_params.max() + 0.1, 100)\n",
" fit_y = a * log_fit_x**2 + b * log_fit_x + c\n",
" ax.plot(10**log_fit_x, fit_y, '--', color=color, linewidth=2)\n",
"\n",
" # Find minimum of quadratic: d/dx(ax^2 + bx + c) = 0 => x = -b/(2a)\n",
" if a > 0: # parabola opens upward (has a minimum)\n",
" log_opt = -b / (2 * a)\n",
" opt_params = 10**log_opt\n",
" opt_bpb = a * log_opt**2 + b * log_opt + c\n",
" # Mark the fitted optimal\n",
" ax.scatter([opt_params], [opt_bpb], s=150, color=color,\n",
" zorder=5, edgecolors='black', linewidths=2, marker='*')\n",
" # Interpolate tokens and ratio from actual data (don't use C≈6ND approximation)\n",
" opt_tokens = np.interp(np.log10(opt_params), log_params, subset['tokens_trained'])\n",
" opt_ratio = np.interp(np.log10(opt_params), log_params, subset['param_data_ratio'])\n",
" optimal_by_bpb.append({'flops': flops, 'params': opt_params, 'tokens': opt_tokens, 'ratio': opt_ratio, 'bpb': opt_bpb})\n",
" else:\n",
" # Fallback to raw minimum if quadratic doesn't have minimum\n",
" best_idx = subset['val_bpb'].idxmin()\n",
" best = subset.loc[best_idx]\n",
" ax.scatter([best['num_scaling_params']], [best['val_bpb']], s=150, color=color,\n",
" zorder=5, edgecolors='black', linewidths=2)\n",
" optimal_by_bpb.append({'flops': flops, 'params': best['num_scaling_params'],\n",
" 'tokens': best['tokens_trained'], 'ratio': best['param_data_ratio'], 'bpb': best['val_bpb']})\n",
"\n",
"ax.set_xscale('log')\n",
"ax.set_xlabel('Parameters')\n",
"ax.set_ylabel('Validation Loss (bpb)')\n",
"ax.set_title('IsoFLOP Curves')\n",
"ax.legend(title='FLOPs', loc='upper right')\n",
"ax.grid(True, alpha=0.3)\n",
"\n",
"opt_df = pd.DataFrame(optimal_by_bpb)\n",
"\n",
"# Plot 2: Optimal model size vs compute (power law)\n",
"ax = axes[1]\n",
"ax.loglog(opt_df['flops'], opt_df['params'], 'o', markersize=10, color='#2ecc71')\n",
"ax.set_xlabel('FLOPs')\n",
"ax.set_ylabel('Optimal Parameters')\n",
"ax.set_title('Optimal Model Size')\n",
"ax.grid(True, alpha=0.3)\n",
"\n",
"# Fit and show power law\n",
"if len(opt_df) >= 2:\n",
" log_f = np.log10(opt_df['flops'])\n",
" log_p = np.log10(opt_df['params'])\n",
" slope, intercept = np.polyfit(log_f, log_p, 1)\n",
" fit_f = np.logspace(log_f.min() - 0.5, log_f.max() + 0.5, 100)\n",
" fit_p = 10**(intercept + slope * np.log10(fit_f))\n",
" ax.plot(fit_f, fit_p, 'r--', alpha=0.7, label=f'N ∝ C^{slope:.2f}')\n",
" ax.legend()\n",
"\n",
"# Plot 3: Optimal tokens vs compute (power law)\n",
"ax = axes[2]\n",
"ax.loglog(opt_df['flops'], opt_df['tokens'], 'o', markersize=10, color='#e74c3c')\n",
"ax.set_xlabel('FLOPs')\n",
"ax.set_ylabel('Optimal Tokens')\n",
"ax.set_title('Optimal Training Tokens')\n",
"ax.grid(True, alpha=0.3)\n",
"\n",
"# Fit and show power law\n",
"if len(opt_df) >= 2:\n",
" log_f = np.log10(opt_df['flops'])\n",
" log_t = np.log10(opt_df['tokens'])\n",
" slope, intercept = np.polyfit(log_f, log_t, 1)\n",
" fit_f = np.logspace(log_f.min() - 0.5, log_f.max() + 0.5, 100)\n",
" fit_t = 10**(intercept + slope * np.log10(fit_f))\n",
" ax.plot(fit_f, fit_t, 'r--', alpha=0.7, label=f'D ∝ C^{slope:.2f}')\n",
" ax.legend()\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"# Print the optimal points (from quadratic fits)\n",
"print(\"\\nOptimal configurations (from quadratic fits):\")\n",
"print(f\"{'FLOPs':<12} {'Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n",
"print(\"-\" * 65)\n",
"for _, row in opt_df.iterrows():\n",
" print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Val BPB vs Depth and Ratio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
"\n",
"# Plot 1: Val BPB vs Depth\n",
"ax = axes[0]\n",
"for flops in flops_budgets:\n",
" subset = df[df['flops_budget'] == flops].sort_values('depth')\n",
" ax.plot(subset['depth'], subset['val_bpb'], 'o-', label=f'{flops:.0e}')\n",
" # Mark the best (lowest)\n",
" best_idx = subset['val_bpb'].idxmin()\n",
" best = subset.loc[best_idx]\n",
" ax.scatter([best['depth']], [best['val_bpb']], s=100, zorder=5, edgecolors='black', linewidths=2)\n",
"\n",
"ax.set_xlabel('Depth')\n",
"ax.set_ylabel('Val BPB (lower is better)')\n",
"ax.set_title('Validation BPB vs Model Depth')\n",
"ax.legend(title='FLOPs')\n",
"ax.grid(True, alpha=0.3)\n",
"\n",
"# Plot 2: Val BPB vs Param:Data Ratio\n",
"ax = axes[1]\n",
"for flops in flops_budgets:\n",
" subset = df[df['flops_budget'] == flops].sort_values('param_data_ratio')\n",
" ax.plot(subset['param_data_ratio'], subset['val_bpb'], 'o-', label=f'{flops:.0e}')\n",
" best_idx = subset['val_bpb'].idxmin()\n",
" best = subset.loc[best_idx]\n",
" ax.scatter([best['param_data_ratio']], [best['val_bpb']], s=100, zorder=5, edgecolors='black', linewidths=2)\n",
"\n",
"ax.axvline(x=20, color='red', linestyle='--', alpha=0.5, label='Chinchilla (20)')\n",
"ax.set_xlabel('Param:Data Ratio (tokens/param)')\n",
"ax.set_ylabel('Val BPB (lower is better)')\n",
"ax.set_title('Val BPB vs Param:Data Ratio')\n",
"ax.legend(title='FLOPs')\n",
"ax.grid(True, alpha=0.3)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

103
miniseries.sh Normal file
View File

@ -0,0 +1,103 @@
#!/bin/bash
# See speedrun.sh for more comments
# Usage: ./miniseries.sh [series_name]
# Example: ./miniseries.sh jan11
# Default series name is today's date (e.g., jan11)
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
# Setup (skip with SKIP_SETUP=1)
if [ -z "$SKIP_SETUP" ]; then
# uv
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
[ -d ".venv" ] || uv venv
uv sync --extra gpu
source .venv/bin/activate
# Tokenizer, download 1000 shards for pretraining
# (probably this can be reduced but it's tricky to determine the exact right number, TODO).
python -m nanochat.dataset -n 1000
python -m scripts.tok_train --max-chars=2000000000 --vocab-size=32768
else
source .venv/bin/activate
fi
# Series name: from arg, env var, or default to today's date (e.g., jan11)
SERIES_NAME="${1:-${SERIES_NAME:-$(date +%b%d | tr '[:upper:]' '[:lower:]')}}"
# Depths to train (the "miniseries")
DEPTHS=(10 11 12 13 14 15 16 17 18 19 20)
# Hardware
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
# Logging
WANDB_RUN="${WANDB_RUN:-${SERIES_NAME}_miniseries}"
RESULTS_DIR="$NANOCHAT_BASE_DIR/${SERIES_NAME}_miniseries_results"
mkdir -p "$RESULTS_DIR"
RESULTS_FILE="$RESULTS_DIR/results.csv"
# Write CSV header only if file doesn't exist
if [ ! -f "$RESULTS_FILE" ]; then
echo "depth,model_dim,num_params,num_scaling_params,num_iterations,tokens_trained,param_data_ratio,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
fi
log() {
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
}
log "=============================================="
log "${SERIES_NAME} Miniseries Training"
log "=============================================="
for d in "${DEPTHS[@]}"; do
log "Training d=$d..."
TAG="${SERIES_NAME}_miniseries_d${d}"
START_TIME=$(date +%s)
# Train the model with natural horizon (target_param_data_ratio default)
# No --target-flops, let it use the default ratio from base_train
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
--depth=$d \
--target-param-data-ratio=8 \
--run="${WANDB_RUN}_d${d}" \
--model-tag="${TAG}" \
--core-metric-every=999999 \
--core-metric-max-per-task=-1 \
--sample-every=-1 \
--save-every=-1 \
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
END_TIME=$(date +%s)
TRAIN_TIME=$((END_TIME - START_TIME))
# Extract stats from log
LOG_FILE="$RESULTS_DIR/${TAG}_train.log"
NUM_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | head -1 | tr -d ',')
NUM_SCALING_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP 'scaling: [\d,]+' | grep -oP '[\d,]+' | tr -d ',')
NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',')
TOKENS_TRAINED=$((NUM_ITERS * 524288))
PARAM_DATA_RATIO=$(python -c "print(f'{$TOKENS_TRAINED / $NUM_SCALING_PARAMS:.2f}')")
MODEL_DIM=$((d * 64))
VAL_BPB=$(grep "Validation bpb:" "$LOG_FILE" | tail -1 | grep -oP '[\d.]+$')
CORE_SCORE=$(grep "CORE metric:" "$LOG_FILE" | tail -1 | awk '{print $NF}')
if [ -z "$CORE_SCORE" ]; then
CORE_SCORE="0.0"
fi
log " d=$d: params=$NUM_PARAMS, scaling=$NUM_SCALING_PARAMS, ratio=$PARAM_DATA_RATIO, bpb=$VAL_BPB, CORE=$CORE_SCORE, time=${TRAIN_TIME}s"
# Append to CSV
echo "$d,$MODEL_DIM,$NUM_PARAMS,$NUM_SCALING_PARAMS,$NUM_ITERS,$TOKENS_TRAINED,$PARAM_DATA_RATIO,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
done
log "=============================================="
log "${SERIES_NAME} Miniseries Complete!"
log "=============================================="
log "Results saved to: $RESULTS_FILE"
echo ""
echo "Results:"
column -t -s',' "$RESULTS_FILE"

View File

@ -16,23 +16,31 @@ class DistAdamW(torch.optim.Optimizer):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(param_groups, defaults)
@torch.compile
@torch.no_grad()
def step(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
reduce_scatter_futures: list[torch.Future] = []
all_reduce_futures: list[torch.Future] = []
reduce_futures: list[torch.Future] = []
gather_futures: list[torch.Future] = []
grad_slices = []
is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter)
for group in self.param_groups:
params: list[Tensor] = group["params"]
for base_i in range(len(params)):
assert params[base_i].shape[0] % world_size == 0, f"First dim of parameter shape {params[base_i].shape} must be divisible by world size {world_size}"
grad = params[base_i].grad
rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size])
reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
grad_slices.append(grad_slice)
for p in params:
grad = p.grad
# Small params: use all_reduce (no scatter/gather needed)
if p.numel() < 1024:
is_small.append(True)
reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
grad_slices.append(grad)
else:
is_small.append(False)
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size])
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
grad_slices.append(grad_slice)
idx = 0
for group in self.param_groups:
@ -40,14 +48,19 @@ class DistAdamW(torch.optim.Optimizer):
eps = group['eps']
wd = group['weight_decay']
params = group['params']
for base in range(len(params)):
reduce_scatter_futures[idx].wait()
p = params[base]
rank_size = p.shape[0] // world_size
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
for p in params:
reduce_futures[idx].wait()
g_slice = grad_slices[idx]
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
state = self.state[p]
g_slice = grad_slices[idx]
# For small params, operate on full param; for large, operate on slice
if is_small[idx]:
p_slice = p
else:
rank_size = p.shape[0] // world_size
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
# State init
if not state:
state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
@ -68,10 +81,15 @@ class DistAdamW(torch.optim.Optimizer):
bias1 = 1 - beta1 ** t
bias2 = 1 - beta2 ** t
# compute step
denom = exp_avg_sq.sqrt().add_(eps)
step_size = lr * (torch.sqrt(bias2) / bias1)
denom = (exp_avg_sq / bias2).sqrt().add_(eps)
step_size = lr / bias1
update = exp_avg.div(denom).mul_(step_size)
p_slice.add_(other=update, alpha=-1.0)
# Only large params need all_gather
if not is_small[idx]:
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
idx += 1
all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
torch.futures.collect_all(all_reduce_futures).wait()
if gather_futures:
torch.futures.collect_all(gather_futures).wait()

View File

@ -20,6 +20,25 @@ def log0(message):
if int(os.environ.get('RANK', 0)) == 0:
logger.info(message)
def _patch_missing_config_keys(model_config_kwargs):
"""Add default values for new config keys missing in old checkpoints."""
# Old models were trained with full context (no sliding window)
if "window_pattern" not in model_config_kwargs:
model_config_kwargs["window_pattern"] = "L"
log0(f"Patching missing window_pattern in model config to 'L'")
def _patch_missing_keys(model_data, model_config):
"""Add default values for new parameters that may be missing in old checkpoints."""
n_layer = model_config.n_layer
# resid_lambdas defaults to 1.0 (identity scaling)
if "resid_lambdas" not in model_data:
model_data["resid_lambdas"] = torch.ones(n_layer)
log0(f"Patching missing resid_lambdas in model data to 1.0")
# x0_lambdas defaults to 0.0 (disabled)
if "x0_lambdas" not in model_data:
model_data["x0_lambdas"] = torch.zeros(n_layer)
log0(f"Patching missing x0_lambdas in model data to 0.0")
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
if rank == 0:
os.makedirs(checkpoint_dir, exist_ok=True)
@ -74,8 +93,10 @@ def build_model(checkpoint_dir, step, device, phase):
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
model_config_kwargs = meta_data["model_config"]
_patch_missing_config_keys(model_config_kwargs)
log0(f"Building model with config: {model_config_kwargs}")
model_config = GPTConfig(**model_config_kwargs)
_patch_missing_keys(model_data, model_config)
with torch.device("meta"):
model = GPT(model_config)
# Load the model state
@ -90,7 +111,7 @@ def build_model(checkpoint_dir, step, device, phase):
# Load the Tokenizer
tokenizer = get_tokenizer()
# Sanity check: compatibility between model and tokenizer
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"]
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"], f"Tokenizer vocab size {tokenizer.get_vocab_size()} does not match model config vocab size {model_config_kwargs['vocab_size']}"
return model, tokenizer, meta_data

View File

@ -1,56 +0,0 @@
"""
Poor Man's Configurator. Probably a terrible idea. Example usage:
$ python train.py config/override_file.py --batch_size=32
this will first run config/override_file.py, then override batch_size to 32
The code in this file will be run as follows from e.g. train.py:
>>> exec(open('configurator.py').read())
So it's not a Python module, it's just shuttling this code away from train.py
The code in this script then overrides the globals()
I know people are not going to love this, I just really dislike configuration
complexity and having to prepend config. to every single variable. If someone
comes up with a better simple Python solution I am all ears.
"""
import os
import sys
from ast import literal_eval
def print0(s="",**kwargs):
ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0:
print(s, **kwargs)
for arg in sys.argv[1:]:
if '=' not in arg:
# assume it's the name of a config file
assert not arg.startswith('--')
config_file = arg
print0(f"Overriding config with {config_file}:")
with open(config_file) as f:
print0(f.read())
exec(open(config_file).read())
else:
# assume it's a --key=value argument
assert arg.startswith('--')
key, val = arg.split('=')
key = key[2:]
if key in globals():
try:
# attempt to eval it it (e.g. if bool, number, or etc)
attempt = literal_eval(val)
except (SyntaxError, ValueError):
# if that goes wrong, just use the string
attempt = val
# ensure the types match ok
if globals()[key] is not None:
attempt_type = type(attempt)
default_type = type(globals()[key])
assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}"
# cross fingers
print0(f"Overriding: {key} = {attempt}")
globals()[key] = attempt
else:
raise ValueError(f"Unknown config key: {key}")

View File

@ -1,94 +1,198 @@
from collections import deque
"""
Distributed dataloaders for pretraining.
Two implementations are provided:
1. Original (tokenizing_distributed_data_loader):
- Streams tokens into a flat buffer, reshapes to (B, T)
- Rows may start mid-document (no guaranteed BOS at position 0)
- 100% token utilization, simple and efficient
2. BOS-aligned bestfit (tokenizing_distributed_data_loader_bos_bestfit):
- Every row starts with BOS token
- Documents packed using best-fit algorithm to minimize cropping
- When no document fits remaining space, crops a document to fill exactly
- 100% utilization (no padding), ~35% tokens cropped at T=2048
The tradeoff: BOS-aligned loses ~35% of tokens to cropping, but ensures that
there are fewer "confusing" tokens in the train/val batches as every token can
now attend back to the BOS token and sees the full context of the document.
(2) is the new default if you have enough data.
Fallback to (1) if you have very limited data AND long documents.
"""
import torch
import pyarrow.parquet as pq
from nanochat.common import get_dist_info
from nanochat.dataset import list_parquet_files
from nanochat.tokenizer import get_tokenizer
def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
def _document_batches(split, resume_state_dict, tokenizer_batch_size):
"""
Infinite iterator over document batches (list of text strings) from parquet files.
Handles DDP sharding and approximate resume. Each yield is (text_batch, (pq_idx, rg_idx, epoch))
where text_batch is a list of document strings, indices track position for resumption,
and epoch counts how many times we've cycled through the dataset (starts at 1).
"""
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
parquet_paths = list_parquet_files()
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1
first_pass = True
pq_idx = resume_pq_idx
epoch = resume_epoch
while True: # iterate infinitely (multi-epoch)
pq_idx = resume_pq_idx if first_pass else 0
while pq_idx < len(parquet_paths):
filepath = parquet_paths[pq_idx]
pf = pq.ParquetFile(filepath)
# Start from resume point if resuming on same file, otherwise from DDP rank
if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
base_idx = resume_rg_idx // ddp_world_size
base_idx += 1 # advance by 1 so we don't repeat data after resuming
rg_idx = base_idx * ddp_world_size + ddp_rank
if rg_idx >= pf.num_row_groups:
pq_idx += 1
continue
resume_rg_idx = None # only do this once
else:
rg_idx = ddp_rank
while rg_idx < pf.num_row_groups:
rg = pf.read_row_group(rg_idx)
batch = rg.column('text').to_pylist()
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx, epoch)
rg_idx += ddp_world_size
pq_idx += 1
first_pass = False
epoch += 1
def tokenizing_distributed_data_loader_with_state(tokenizer, B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
"""
Stream pretraining text from parquet files, tokenize, yield training batches.
This implementation became a bit more complex because we wish to support approximate resume training.
Instead of turning this into a Class, we opt to return the state_dict with every batch,
and then the caller can pass in a state_dict to resume training from a desired point.
Note that this resumption is atm only *approximate* for simplicity.
We won't repeat the same documents but we might skip a few.
The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume.
This is the original dataloader that streams tokens into a flat buffer and reshapes.
Rows may start mid-document (no guaranteed BOS at position 0).
Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm.
Supports approximate resume via state_dict.
"""
assert split in ["train", "val"], "split must be 'train' or 'val'"
# infinite iterator over document batches (list of text strings)
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
def document_batches():
parquet_paths = list_parquet_files()
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
first_pass = True
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
while True: # iterate infinitely (multi-epoch)
pq_idx = resume_pq_idx if first_pass else 0
while pq_idx < len(parquet_paths): # iterate over all parquet files
filepath = parquet_paths[pq_idx]
pf = pq.ParquetFile(filepath)
# Start from resume point if resuming on same file, otherwise from DDP rank
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
rg_idx = base_idx * ddp_world_size + ddp_rank
if rg_idx >= pf.num_row_groups:
pq_idx += 1
continue
resume_rg_idx = None # set to None as we only want to do this a single time
else:
rg_idx = ddp_rank
while rg_idx < pf.num_row_groups:
rg = pf.read_row_group(rg_idx)
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
rg_idx += ddp_world_size # advance to the next row group (in DDP)
pq_idx += 1 # advance to the next parquet file
first_pass = False
batches = document_batches()
# Now emit batches of tokens.
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
# get the tokenizer and the bos token
tokenizer = get_tokenizer()
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
needed_tokens = B * T + 1 # +1 for target at last position
bos_token = tokenizer.get_bos_token_id()
# scratch buffer holds the tokens for one iteration
token_buffer = deque() # we stream tokens on the right and pop from the left
token_buffer = []
pq_idx, rg_idx, epoch = 0, 0, 1
while True:
# Accumulate enough tokens for one iteration before yielding.
# Accumulate enough tokens
while len(token_buffer) < needed_tokens:
doc_batch, (pq_idx, rg_idx) = next(batches)
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists:
token_buffer.extend(tokens)
# Move tokens from the deque into the scratch buffer
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
# CUDA supports memory pinning for asynchronous transfers between CPU and GPU
use_cuda_optimizations = device == "cuda"
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
# Create the inputs/targets as 1D tensors
inputs_cpu = scratch[:-1]
targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training
yield inputs, targets, state_dict
tokens = token_buffer[:needed_tokens] # Read B*T+1 tokens (+1 is only for the target for the last token)
token_buffer = token_buffer[B*T:] # Advance by B*T tokens, so we move exactly one window of B*T tokens over
# Package tokens into inputs and targets, yield
use_cuda = device == "cuda"
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda)
inputs = scratch[:-1].view(B, T).to(device=device, non_blocking=use_cuda)
targets = scratch[1:].view(B, T).to(device=device, non_blocking=use_cuda)
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
def tokenizing_distributed_data_loader(*args, **kwargs):
# helper function that only emits the inputs/targets and not the state_dict
"""Helper that omits state_dict from yields."""
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
yield inputs, targets
def tokenizing_distributed_data_loader_with_state_bos_bestfit(
tokenizer, B, T, split,
tokenizer_threads=4, tokenizer_batch_size=128,
device="cuda", resume_state_dict=None,
buffer_size=1000
):
"""
BOS-aligned dataloader with Best-Fit Cropping.
Reduces token waste compared to simple greedy cropping by searching a buffer
for documents that fit well, while maintaining 100% utilization (no padding).
Algorithm for each row:
1. From buffered docs, pick the LARGEST doc that fits entirely
2. Repeat until no doc fits
3. When nothing fits, crop a doc to fill remaining space exactly
Key properties:
- Every row starts with BOS
- 100% utilization (no padding, every token is trained on)
- Approximately 35% of all tokens are discarded due to cropping
"""
assert split in ["train", "val"], "split must be 'train' or 'val'"
row_capacity = T + 1
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
bos_token = tokenizer.get_bos_token_id()
doc_buffer = []
pq_idx, rg_idx, epoch = 0, 0, 1
def refill_buffer():
nonlocal pq_idx, rg_idx, epoch
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists:
doc_buffer.append(tokens)
while True:
rows = []
for _ in range(B):
row = []
while len(row) < row_capacity:
# Ensure buffer has documents
while len(doc_buffer) < buffer_size:
refill_buffer()
remaining = row_capacity - len(row)
# Find largest doc that fits entirely
best_idx = -1
best_len = 0
for i, doc in enumerate(doc_buffer):
doc_len = len(doc)
if doc_len <= remaining and doc_len > best_len:
best_idx = i
best_len = doc_len
if best_idx >= 0:
doc = doc_buffer.pop(best_idx)
row.extend(doc)
else:
# No doc fits - crop first doc to fill remaining
doc = doc_buffer.pop(0)
row.extend(doc[:remaining])
rows.append(row[:row_capacity])
use_cuda = device == "cuda"
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
inputs = batch_tensor[:, :-1].to(device=device, non_blocking=use_cuda)
targets = batch_tensor[:, 1:].to(device=device, non_blocking=use_cuda)
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
"""Helper that omits state_dict from yields."""
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs):
yield inputs, targets

View File

@ -82,83 +82,54 @@ def use_calculator(expr):
# -----------------------------------------------------------------------------
class KVCache:
"""
Works hand-in-hand with the GPT model to maintain the KV cache.
Note that the .pos advances automatically after the last layer of the Transformer inserts.
KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
Key differences from FA2-style cache:
- Tensors are (B, T, H, D) not (B, H, T, D)
- FA3 updates the cache in-place during flash_attn_with_kvcache
- Position tracked per batch element via cache_seqlens tensor
"""
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
# Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
self.kv_cache = None
self.pos = 0 # current position in time in the cache
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype=torch.bfloat16):
self.batch_size = batch_size
self.max_seq_len = seq_len
self.n_layers = num_layers
self.n_heads = num_heads
self.head_dim = head_dim
# Pre-allocate cache tensors: (n_layers, B, T, H, D)
self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
# Current sequence length per batch element (FA3 needs int32)
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
def reset(self):
self.pos = 0
"""Reset cache to empty state."""
self.cache_seqlens.zero_()
def get_pos(self):
return self.pos
"""Get current position (assumes all batch elements at same position)."""
return self.cache_seqlens[0].item()
def get_layer_cache(self, layer_idx):
"""Return (k_cache, v_cache) views for a specific layer."""
return self.k_cache[layer_idx], self.v_cache[layer_idx]
def advance(self, num_tokens):
"""Advance the cache position by num_tokens."""
self.cache_seqlens += num_tokens
def prefill(self, other):
"""
Prefill given another KV cache. Optionally expand along batch dim.
This is used when we do batch 1 prefill and then want to generate
multiple samples in parallel from there.
Copy cached KV from another cache into this one.
Used when we do batch=1 prefill and then want to generate multiple samples in parallel.
"""
# 1) validate the shapes
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
# Extract dimensions explicitly
self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape
other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape
# Validate dimensions
assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}"
assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}"
assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}"
assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}"
# Batch size can be expanded (other can be 1, self can be larger)
assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)"
# Sequence length: self must be longer than other
assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}"
# 2) initialize the cache
dtype, device = other.kv_cache.dtype, other.kv_cache.device
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
# 3) copy the data over
self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
# 4) update the pos
self.pos = other.pos
def insert_kv(self, layer_idx, k, v):
# Lazy initialize the cache here because we need to know the dtype/device
if self.kv_cache is None:
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
# Insert new keys/values to the cache and return the full cache so far
B, H, T_add, D = k.size()
t0, t1 = self.pos, self.pos + T_add
# Dynamically grow the cache if needed
if t1 > self.kv_cache.size(4):
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
additional_shape = list(self.kv_cache.shape)
additional_shape[4] = t_needed - self.kv_cache.size(4)
additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
self.kv_shape = self.kv_cache.shape
# Insert k, v into the cache
self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k
self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v
# Return the full cached keys/values up to current position (as a view)
key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :]
value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :]
# Increment pos after the last layer of the Transformer processes
if layer_idx == self.kv_cache.size(0) - 1:
self.pos = t1
return key_view, value_view
assert self.get_pos() == 0, "Cannot prefill a non-empty KV cache"
assert self.n_layers == other.n_layers and self.n_heads == other.n_heads and self.head_dim == other.head_dim
assert self.max_seq_len >= other.max_seq_len
other_pos = other.get_pos()
self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
self.cache_seqlens.fill_(other_pos)
# -----------------------------------------------------------------------------
@torch.inference_mode()
@ -167,7 +138,7 @@ def sample_next_token(logits, rng, temperature=1.0, top_k=None):
assert temperature >= 0.0, "temperature must be non-negative"
if temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)
if top_k is not None:
if top_k is not None and top_k > 0:
k = min(top_k, logits.size(-1))
vals, idx = torch.topk(logits, k, dim=-1)
vals = vals / temperature
@ -219,6 +190,7 @@ class Engine:
kv_cache_prefill = KVCache(
batch_size=1,
seq_len=len(tokens),
device=device,
**kv_model_kwargs,
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
@ -230,6 +202,7 @@ class Engine:
kv_cache_decode = KVCache(
batch_size=num_samples,
seq_len=kv_length_hint,
device=device,
**kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)

View File

@ -9,9 +9,9 @@ Notable features:
- no learnable params in rmsnorm
- no bias in linear layers
- Group-Query Attention (GQA) support for more efficient inference
- Flash Attention 3 integration
"""
import math
from functools import partial
from dataclasses import dataclass
@ -23,6 +23,14 @@ from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
# Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar)
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain.
# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal)
from kernels import get_kernel
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
@dataclass
class GPTConfig:
sequence_len: int = 1024
@ -31,6 +39,10 @@ class GPTConfig:
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (GQA)
n_embd: int = 768
# Sliding window attention pattern string, tiled across layers. Final layer always L.
# Characters: L=long (full context), S=short (half context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
window_pattern: str = "L"
def norm(x):
@ -61,48 +73,42 @@ class CausalSelfAttention(nn.Module):
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(self, x, cos_sin, kv_cache):
def forward(self, x, cos_sin, window_size, kv_cache):
B, T, C = x.size()
# Project the input to get queries, keys, and values
# Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k) # QK norm
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
# Apply KV cache: insert current k,v into cache, get the full view so far
if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
Tq = q.size(2) # number of queries in this forward pass
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
if kv_cache is None or Tq == Tk:
# During training (no KV cache), attend as usual with causal attention
# And even if there is KV cache, we can still use this simple version when Tq == Tk
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
elif Tq == 1:
# During inference but with a single query in this forward pass:
# The query has to attend to all the keys/values in the cache
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
# Attention with Flash Attention 3
# FA3 handles GQA automatically when n_kv_heads < n_heads
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
if kv_cache is None:
# Training: causal attention with optional sliding window
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
else:
# During inference AND we have a chunk of queries in this forward pass:
# First, each query attends to all the cached keys/values (i.e. full prefix)
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
prefix_len = Tk - Tq
attn_mask[:, :prefix_len] = True
# Then, causal attention within this chunk
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
# Inference: use flash_attn_with_kvcache which handles cache management
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
y = flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache,
k=k, v=v,
cache_seqlens=kv_cache.cache_seqlens,
causal=True,
window_size=window_size,
)
# Advance position after last layer processes
if self.layer_idx == kv_cache.n_layers - 1:
kv_cache.advance(T)
# Re-assemble the heads side by side and project back to residual stream
y = y.transpose(1, 2).contiguous().view(B, T, -1)
# Re-assemble the heads and project back to residual stream
y = y.contiguous().view(B, T, -1)
y = self.c_proj(y)
return y
@ -126,29 +132,43 @@ class Block(nn.Module):
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(self, x, cos_sin, kv_cache):
x = x + self.attn(norm(x), cos_sin, kv_cache)
def forward(self, x, cos_sin, window_size, kv_cache):
x = x + self.attn(norm(x), cos_sin, window_size, kv_cache)
x = x + self.mlp(norm(x))
return x
class GPT(nn.Module):
def __init__(self, config, pad_vocab_size_to=64):
"""
NOTE a major footgun: this __init__ function runs in meta device context (!!)
Therefore, any calculations inside here are shapes and dtypes only, no actual data.
=> We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
"""
super().__init__()
self.config = config
# For DDP, we want vocab_size divisible by world_size. Also, there are potential performance benefits, see:
# Compute per-layer window sizes for sliding window attention
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
self.window_sizes = self._compute_window_sizes(config)
# Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
if padded_vocab_size != config.vocab_size:
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} to be divisible by {pad_vocab_size_to}")
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
# To support meta device initialization, we init the rotary embeddings here, but it's fake
# Per-layer learnable scalars (inspired by modded-nanogpt)
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
# Separate parameters so they can have different optimizer treatment
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them, but assert fail if we ever reach that amount.
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
# In the future we can dynamically grow the cache, for now it's fine.
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
head_dim = config.n_embd // config.n_head
@ -157,35 +177,51 @@ class GPT(nn.Module):
self.register_buffer("sin", sin, persistent=False)
def init_weights(self):
self.apply(self._init_weights)
# zero out classifier weights
torch.nn.init.zeros_(self.lm_head.weight)
# zero out c_proj weights in all blocks
"""
Initialize the full model in this one function for maximum clarity.
wte (embedding): normal, std=1.0
lm_head: normal, std=0.001
for each block:
attn.c_q: uniform, std=1/sqrt(n_embd)
attn.c_k: uniform, std=1/sqrt(n_embd)
attn.c_v: uniform, std=1/sqrt(n_embd)
attn.c_proj: zeros
mlp.c_fc: uniform, std=1/sqrt(n_embd)
mlp.c_proj: zeros
"""
# Embedding and unembedding
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
n_embd = self.config.n_embd
s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
for block in self.transformer.h:
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
torch.nn.init.zeros_(block.mlp.c_proj.weight)
torch.nn.init.zeros_(block.attn.c_proj.weight)
# init the rotary embeddings
# Per-layer scalars
with torch.no_grad():
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
# Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
# Cast token embeddings to bf16: optimizer can tolerate it and it saves memory
if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
# https://arxiv.org/pdf/2310.17813
fan_out = module.weight.size(0)
fan_in = module.weight.size(1)
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
# TODO: bump base theta more, e.g. 100K is more common more recently
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently
# autodetect the device from model embeddings
if device is None:
device = self.transformer.wte.weight.device
@ -201,38 +237,100 @@ class GPT(nn.Module):
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
return cos, sin
def _compute_window_sizes(self, config):
"""
Compute per-layer window sizes for sliding window attention.
Returns list of (left, right) tuples for FA3's window_size parameter:
- left: how many tokens before current position to attend to (-1 = unlimited)
- right: how many tokens after current position to attend to (0 for causal)
Pattern string is tiled across layers. Final layer always gets L (full context).
Characters: L=long (full context), S=short (half context)
"""
pattern = config.window_pattern.upper()
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
# Map characters to window sizes
long_window = config.sequence_len
short_window = long_window // 2
char_to_window = {
"L": (long_window, 0),
"S": (short_window, 0),
}
# Tile pattern across layers
window_sizes = []
for layer_idx in range(config.n_layer):
char = pattern[layer_idx % len(pattern)]
window_sizes.append(char_to_window[char])
# Final layer always gets full context
window_sizes[-1] = (long_window, 0)
return window_sizes
def get_device(self):
return self.transformer.wte.weight.device
def estimate_flops(self):
""" Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
"""
Return the estimated FLOPs per token for the model (forward + backward).
Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention.
With sliding windows, effective_seq_len varies per layer (capped by window size).
Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
- Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
"""
nparams = sum(p.numel() for p in self.parameters())
nparams_embedding = self.transformer.wte.weight.numel()
l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
# Exclude non-matmul params: embeddings and per-layer scalars
nparams_exclude = self.transformer.wte.weight.numel() + self.resid_lambdas.numel() + self.x0_lambdas.numel()
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window
attn_flops = 0
for window_size in self.window_sizes:
window = window_size[0] # (left, right) tuple, we use left
effective_seq = t if window < 0 else min(window, t)
attn_flops += 12 * h * q * effective_seq
num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
return num_flops_per_token
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
def num_scaling_params(self):
"""
Return all of the parameters, same as Chinchilla paper.
Kaplan et al. did not include embedding parameters and said that this led to cleaner scaling laws.
But Kaplan et al. also had a bug in their results (as pointed out by Chinchilla).
My own experiments in nanochat confirm the Chinchilla approach gives the much cleaner scaling law.
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper <- good).
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper <- bad)
"""
nparams = sum(p.numel() for p in self.parameters())
return nparams
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info()
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
# Separate out all parameters into 5 groups (matrix, embedding, lm_head, resid_lambdas, x0_lambdas)
matrix_params = list(self.transformer.h.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
# Create the AdamW optimizer for the embedding and lm_head
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(resid_params) + len(x0_params)
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
dict(params=x0_params, lr=scalar_lr),
]
adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
# Create the Muon optimizer for the linear layers
muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay)
MuonFactory = DistMuon if ddp else Muon
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
# Combine them the two optimizers into one list
@ -256,8 +354,10 @@ class GPT(nn.Module):
# Forward the trunk of the Transformer
x = self.transformer.wte(idx)
x = norm(x)
for block in self.transformer.h:
x = block(x, cos_sin, kv_cache)
x0 = x # save initial normalized embedding for x0 residual
for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
x = block(x, cos_sin, self.window_sizes[i], kv_cache)
x = norm(x)
# Forward the lm_head (compute logits)

View File

@ -1,39 +1,96 @@
"""
Muon optimizer from Keller et al.
Also a lot of borrowing of ideas from modded-nanogpt.
Muon optimizer adapted and simplified from modded-nanogpt.
https://github.com/KellerJordan/modded-nanogpt
Background:
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
Polar Express Sign Method for orthogonalization.
https://arxiv.org/pdf/2505.16932
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
Some of the changes in nanochat implementation:
- Uses a simpler, more general approach to parameter grouping and stacking
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
"""
import torch
from torch import Tensor
import torch.distributed as dist
@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
# From https://arxiv.org/pdf/2505.16932
polar_express_coeffs = [
(8.156554524902461, -22.48329292557795, 15.878769915207462),
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]
# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations
for _ in range(steps):
@torch.compile(dynamic=False, fullgraph=True)
def muon_step_fused(
stacked_grads: Tensor,
stacked_params: Tensor,
momentum_buffer: Tensor,
second_momentum_buffer: Tensor,
momentum_t: Tensor,
lr_t: Tensor,
wd_t: Tensor,
beta2_t: Tensor,
ns_steps: int,
red_dim: int,
) -> None:
"""
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
All in one compiled graph to eliminate Python overhead between ops.
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
"""
# Nesterov momentum
momentum = momentum_t.to(stacked_grads.dtype)
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
g = stacked_grads.lerp_(momentum_buffer, momentum)
# Polar express
X = g.bfloat16()
if g.size(-2) > g.size(-1):
X = X.mT
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
for a, b, c in polar_express_coeffs[:ns_steps]:
A = X @ X.mT
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
B = b * A + c * (A @ A)
X = a * X + B @ X
if G.size(-2) > G.size(-1):
if g.size(-2) > g.size(-1):
X = X.mT
return X
g = X
# Variance reduction
beta2 = beta2_t.to(g.dtype)
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = g.size(red_dim)
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
v_norm = v_norm_sq.sqrt()
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
g = g * final_scale.to(g.dtype)
# Cautious weight decay + parameter update
lr = lr_t.to(g.dtype)
wd = wd_t.to(g.dtype)
mask = (g * stacked_params) >= 0
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
class Muon(torch.optim.Optimizer):
"""
@ -54,74 +111,112 @@ class Muon(torch.optim.Optimizer):
Arguments:
lr: The learning rate used by the internal SGD.
momentum: The momentum used by the internal SGD.
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps: The number of Newton-Schulz iteration steps to use.
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
"""
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
params: list[Tensor] = [*params]
def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0):
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
params = list(params) # ensure we have a list, not an e.g. (exhaustible) iterator
# Group by shape so we can stack tensors
shapes = sorted({p.shape for p in params})
param_groups = []
for size in {p.numel() for p in params}:
group = dict(params=[p for p in params if p.numel() == size])
param_groups.append(group)
for shape in shapes:
group_params = [p for p in params if p.shape == shape]
param_groups.append(dict(params=group_params))
super().__init__(param_groups, defaults)
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
@torch.no_grad()
def step(self):
for group in self.param_groups:
params: list[Tensor] = group["params"]
for p in params:
g = p.grad
assert g is not None
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
if not params:
continue
# Get or create group-level buffers (stored in first param's state for convenience)
state = self.state[params[0]]
num_params = len(params) # e.g.: 12 (for a d12 model)
# e.g.: shape = (768, 3072), device = cuda:0, dtype = torch.float32, for one of the MLP projections
shape, device, dtype = params[0].shape, params[0].device, params[0].dtype
# Momentum for every individual parameter
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
momentum_buffer = state["momentum_buffer"] # e.g.: (12, 768, 3072)
# Second momentum buffer is factored, either per-row or per-column
if "second_momentum_buffer" not in state:
if shape[-2] >= shape[-1]:
state["second_momentum_buffer"] = torch.zeros(num_params, shape[-2], 1, dtype=dtype, device=device)
else:
state["second_momentum_buffer"] = torch.zeros(num_params, 1, shape[-1], dtype=dtype, device=device)
second_momentum_buffer = state["second_momentum_buffer"] # (12, 1, 3072)
red_dim = -1 if shape[-2] >= shape[-1] else -2 # e.g.: -2
# Stack grads and params
stacked_grads = torch.stack([p.grad for p in params]) # (12, 768, 3072)
stacked_params = torch.stack(params) # (12, 768, 3072)
# Fill all the 0-D tensors with current values
self._momentum_t.fill_(group["momentum"])
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
self._wd_t.fill_(group["weight_decay"])
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
muon_step_fused(
stacked_grads,
stacked_params,
momentum_buffer,
second_momentum_buffer,
self._momentum_t,
self._lr_t,
self._wd_t,
self._beta2_t,
group["ns_steps"],
red_dim,
)
# Copy back to original params: [(768, 3072), (768, 3072), ...] <- (12, 768, 3072)
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
class DistMuon(torch.optim.Optimizer):
"""
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via NewtonSchulz,
finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
- reduce_scatter(AVG) for gradient averaging
- all_gather to replicate updated weights
Notes:
* Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D
params like embeddings or scalars.
* Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen
by block-cyclic assignment below). If you checkpoint optimizer state on a single rank,
consolidate states beforehand.
Args:
params: iterable of Tensors
lr: learning rate
momentum: momentum coefficient in [0,1)
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
ns_steps: number of NewtonSchulz iterations for the orthogonalization
Distributed version of the Muon optimizer.
"""
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
nesterov: bool = True, ns_steps: int = 5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
params = list(params)
ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
params = list(params)
world_size = dist.get_world_size()
rank = dist.get_rank()
# Group all parameters by their shape
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
shapes = sorted({p.shape for p in params}) # sort for deterministic ordering across ranks
param_groups = []
for shape in shapes:
group_params = [p for p in params if p.shape == shape]
device, dtype = group_params[0].device, group_params[0].dtype
assert all(p.device == device for p in group_params)
assert all(p.dtype == dtype for p in group_params)
# Compute chunk size for this group (how many params each rank owns)
chunk_size = (len(group_params) + world_size - 1) // world_size
if rank == 0:
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
print(f"Muon: {len(group_params)} params of shape {shape}, chunk_size={chunk_size}")
param_groups.append(dict(params=group_params, chunk_size=chunk_size))
super().__init__(param_groups, defaults)
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
@torch.no_grad()
def step(self):
@ -131,57 +226,127 @@ class DistMuon(torch.optim.Optimizer):
# Ensure all grads exist
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
# Kick off all the reduce scatter operations to average up the gradients across all ranks
all_reduce_futures = []
# First pass: stack grads and kick off reduce_scatter for each group
group_infos = []
for group in self.param_groups:
params = group["params"]
zero_buffer = group["zero_buffer"]
# Go through params in groups of world_size.
for base_i in range(0, len(params), world_size):
# The compute owner of each param is rank i % world_size
owner_idx = base_i + rank
# each rank stacks up its chunk of world_size params into a list
rs_input = [p.grad for p in params[base_i:base_i + world_size]]
# pad rs_input with the zero buffer to complete the group
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
# the output buffer gets strided across the group based on the rank
rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
# reduce scatter the gradients within this group of world_size params
work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
all_reduce_futures.append(work)
params: list[Tensor] = group["params"]
chunk_size = group["chunk_size"]
padded_num_params = chunk_size * world_size
shape = params[0].shape
device, dtype = params[0].device, params[0].dtype
# Now each rank computes the update and gathers
future_idx = 0
# Stack all gradients into a single tensor (single kernel via torch.stack)
grad_stack = torch.stack([p.grad for p in params])
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
stacked_grads[:len(params)].copy_(grad_stack)
# Zero-pad if we have fewer params than padded size
if len(params) < padded_num_params:
stacked_grads[len(params):].zero_()
# Output buffer for this rank's chunk
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
# Async reduce_scatter on the stacked tensor
reduce_future = dist.reduce_scatter_tensor(
grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True
).get_future()
group_infos.append(dict(
grad_chunk=grad_chunk,
reduce_future=reduce_future,
stacked_grads=stacked_grads, # reuse for all_gather output
))
# Second pass: wait for reduce, compute batched updates, kick off all_gather
all_gather_futures = []
for group in self.param_groups:
params = group["params"]
zero_buffer = group["zero_buffer"]
# Go through params in groups of world_size.
for base_i in range(0, len(params), world_size):
# The compute owner of each param is rank i % world_size
owner_idx = base_i + rank # calculate the index of the param that this rank owns
# Wait for the reduce scatter to complete
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
future_idx += 1
# Owner computes the Muon update, result is in its param
if owner_idx < len(params):
p = params[owner_idx]
g = p.grad # now averaged across ranks
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1.0 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
p.add_(g, alpha=-group["lr"] * scale)
# Replicate updated parameters to all ranks
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
ag_output = params[base_i:base_i + world_size]
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
all_gather_futures.append(work)
for group, info in zip(self.param_groups, group_infos):
info["reduce_future"].wait()
# Wait for all work to finish
torch.futures.collect_all(all_gather_futures).wait()
params = group["params"]
chunk_size = group["chunk_size"]
shape = params[0].shape
device, dtype = params[0].device, params[0].dtype
grad_chunk = info["grad_chunk"]
# How many params does this rank actually own?
start_idx = rank * chunk_size
num_owned = min(chunk_size, max(0, len(params) - start_idx))
# Get or create group-level state (stored keyed by first param)
state = self.state[params[0]]
# Momentum buffer
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
momentum_buffer = state["momentum_buffer"]
# Second momentum buffer is factored, either per-row or per-column
if "second_momentum_buffer" not in state:
if shape[-2] >= shape[-1]:
state["second_momentum_buffer"] = torch.zeros(chunk_size, shape[-2], 1, dtype=dtype, device=device)
else:
state["second_momentum_buffer"] = torch.zeros(chunk_size, 1, shape[-1], dtype=dtype, device=device)
second_momentum_buffer = state["second_momentum_buffer"]
red_dim = -1 if shape[-2] >= shape[-1] else -2
# Build updated_params tensor for all_gather
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
if num_owned > 0:
# Stack owned params (single kernel via torch.stack)
owned_params = [params[start_idx + i] for i in range(num_owned)]
stacked_owned_params = torch.stack(owned_params)
# Get owned slices of buffers and grads
owned_grads = grad_chunk[:num_owned]
owned_momentum = momentum_buffer[:num_owned]
owned_second_momentum = second_momentum_buffer[:num_owned]
# Fill 0-D tensors with current values
self._momentum_t.fill_(group["momentum"])
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
self._wd_t.fill_(group["weight_decay"])
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
muon_step_fused(
owned_grads,
stacked_owned_params,
owned_momentum,
owned_second_momentum,
self._momentum_t,
self._lr_t,
self._wd_t,
self._beta2_t,
group["ns_steps"],
red_dim,
)
# Copy updated params to output buffer
updated_params[:num_owned].copy_(stacked_owned_params)
# Zero-pad the rest (for ranks that own fewer params)
if num_owned < chunk_size:
updated_params[num_owned:].zero_()
# Reuse stacked_grads buffer for all_gather output
stacked_params = info["stacked_grads"]
# Async all_gather to replicate updated params to all ranks
gather_future = dist.all_gather_into_tensor(
stacked_params, updated_params, async_op=True
).get_future()
all_gather_futures.append(dict(
gather_future=gather_future,
stacked_params=stacked_params,
params=params,
))
# Final pass: wait for all_gather and copy back to params
for info in all_gather_futures:
info["gather_future"].wait()
stacked_params = info["stacked_params"]
params = info["params"]
# Batched copy back (single kernel instead of N individual copies)
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))

View File

@ -16,8 +16,11 @@ def run_command(cmd):
"""Run a shell command and return output, or None if it fails."""
try:
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5)
if result.returncode == 0:
# Return stdout if we got output (even if some files in xargs failed)
if result.stdout.strip():
return result.stdout.strip()
if result.returncode == 0:
return ""
return None
except:
return None
@ -160,12 +163,23 @@ Generated: {timestamp}
"""
# bloat metrics: package all of the source code and assess its weight
packaged = run_command('files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml')
num_chars = len(packaged)
num_lines = len(packaged.split('\n'))
num_files = len([x for x in packaged.split('\n') if x.startswith('<source>')])
num_tokens = num_chars // 4 # assume approximately 4 chars per token
# bloat metrics: count lines/chars in git-tracked source files only
extensions = ['py', 'md', 'rs', 'html', 'toml', 'sh']
git_patterns = ' '.join(f"'*.{ext}'" for ext in extensions)
files_output = run_command(f"git ls-files -- {git_patterns}")
file_list = [f for f in (files_output or '').split('\n') if f]
num_files = len(file_list)
num_lines = 0
num_chars = 0
if num_files > 0:
wc_output = run_command(f"git ls-files -- {git_patterns} | xargs wc -lc 2>/dev/null")
if wc_output:
total_line = wc_output.strip().split('\n')[-1]
parts = total_line.split()
if len(parts) >= 2:
num_lines = int(parts[0])
num_chars = int(parts[1])
num_tokens = num_chars // 4 # assume approximately 4 chars per token
# count dependencies via uv.lock
uv_lock_lines = 0

View File

@ -26,7 +26,7 @@ SPECIAL_TOKENS = [
# NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
# I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
# I haven't validated that this is actually a good idea, TODO.
# I verified that 2 is the sweet spot for vocab size of 32K. 1 is a bit worse, 3 was worse still.
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
# -----------------------------------------------------------------------------
@ -103,9 +103,10 @@ class HuggingFaceTokenizer:
def id_to_token(self, id):
return self.tokenizer.id_to_token(id)
def _encode_one(self, text, prepend=None, append=None):
def _encode_one(self, text, prepend=None, append=None, num_threads=None):
# encode a single string
# prepend/append can be either a string of a special token or a token id directly.
# num_threads is ignored (only used by the nanochat Tokenizer for parallel encoding)
assert isinstance(text, str)
ids = []
if prepend is not None:
@ -122,7 +123,14 @@ class HuggingFaceTokenizer:
return self.tokenizer.token_to_id(text)
def get_bos_token_id(self):
# Different HuggingFace models use different BOS tokens and there is little consistency
# 1) attempt to find a <|bos|> token
bos = self.encode_special("<|bos|>")
# 2) if that fails, attempt to find a <|endoftext|> token (e.g. GPT-2 models)
if bos is None:
bos = self.encode_special("<|endoftext|>")
# 3) if these fail, it's better to crash than to silently return None
assert bos is not None, "Failed to find BOS token in tokenizer"
return bos
def encode(self, text, *args, **kwargs):

View File

@ -17,6 +17,11 @@
box-sizing: border-box;
}
html, body{
height: 100%;
margin: 0;
}
body {
font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
background-color: #ffffff;
@ -113,7 +118,6 @@
.message.assistant .message-content {
background: transparent;
border: none;
padding: 0.25rem 0;
cursor: pointer;
border-radius: 0.5rem;
padding: 0.5rem;

View File

@ -7,30 +7,26 @@ requires-python = ">=3.10"
dependencies = [
"datasets>=4.0.0",
"fastapi>=0.117.1",
"files-to-prompt>=0.6",
"ipykernel>=7.1.0",
"kernels>=0.11.7",
"matplotlib>=3.10.8",
"psutil>=7.1.0",
"python-dotenv>=1.2.1",
"regex>=2025.9.1",
"rustbpe>=0.1.0",
"scipy>=1.15.3",
"setuptools>=80.9.0",
"tabulate>=0.9.0",
"tiktoken>=0.11.0",
"tokenizers>=0.22.0",
"torch>=2.8.0",
"torch>=2.9.0",
"transformers>=4.57.3",
"uvicorn>=0.36.0",
"wandb>=0.21.3",
]
[build-system]
requires = ["maturin>=1.7,<2.0"]
build-backend = "maturin"
[tool.maturin]
module-name = "rustbpe"
bindings = "pyo3"
python-source = "."
manifest-path = "rustbpe/Cargo.toml"
[dependency-groups]
dev = [
"maturin>=1.9.4",
"pytest>=8.0.0",
]
@ -45,33 +41,33 @@ python_functions = ["test_*"]
# target torch to cuda 12.8 or CPU
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cu128", extra = "gpu" },
torch = [
{ index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cu128", extra = "gpu" },
]
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
[project.optional-dependencies]
cpu = [
"torch>=2.8.0",
]
gpu = [
"torch>=2.8.0",
]
[tool.uv]
conflicts = [
[
{ extra = "cpu" },
{ extra = "gpu" },
],
]
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true
[project.optional-dependencies]
cpu = [
"torch>=2.9.1",
]
gpu = [
"torch>=2.9.1",
]
[tool.uv]
conflicts = [
[
{ extra = "cpu" },
{ extra = "gpu" },
],
]

View File

@ -16,25 +16,22 @@ if [ -z "$WANDB_RUN" ]; then
WANDB_RUN=dummy
fi
python -m nanochat.report reset
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
# train tokenizer on ~4B characters and kick off download of the rest for pretraining
python -m nanochat.dataset -n 16
# start downloading the rest of the shards for a total of 800 (see below why 800)
python -m nanochat.dataset -n 800 &
# start downloading the rest of the shards for a total of 1200 (see below why 1200)
python -m nanochat.dataset -n 1200 &
# todo: download the rest of it
python -m scripts.tok_train --max_chars=4000000000
python -m scripts.tok_train --max-chars=4000000000 --vocab-size=65536
python -m scripts.tok_eval
# Documenting my process for determining the hyperparameters for this run1000.sh script:
# We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute
# 1) I guessed the model size for this to be about depth=32
# 2) Determine the device_batch_size that fits:
# Running the base_train.py script with --depth=32, I saw that --device_batch_size=16
# runs out of memory, but --device_batch_size=8 fits. Inspecting `nvidia-smi` during training,
# Running the base_train.py script with --depth=32, I saw that --device-batch-size=16
# runs out of memory, but --device-batch-size=8 fits. Inspecting `nvidia-smi` during training,
# I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%.
# So the training script was running ok and showed:
# Vocab size: 65,536
@ -65,7 +62,9 @@ python -m scripts.tok_eval
# The tok_eval.py script reports about ~4.8 chars/token on average for the default tokenizer settings.
# So ~38B tokens # ~4.8 chars/token = ~185B chars.
# Each data shard is ~250M chars, so we need ~185B / 250M ~= 740 shards.
# For safety, I bumped that up to 800 shards, and that's why up above I used -n 800 when pre-downloading dataset shards.
# For safety, I bumped that up to 800 shards.
# The new DataLoader wastes about 35% of tokens to cropping, so 800 / (1 - 0.35) ~= 1200 shards are needed.
# => why up above I used -n 1200 when pre-downloading dataset shards.
# If we didn't have enough data, the training script would loop around and do multiple epochs over the same data,
# which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd
# start to overfit hard.
@ -74,13 +73,13 @@ python -m scripts.tok_eval
# Number of processes/GPUs to use
NPROC_PER_NODE=8
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --device_batch_size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target-param-data-ratio=20 --device-batch-size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
# midtrain
# NOTE: ensure that we use the same device_batch_size here as the base training script.
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device-batch-size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
# sft

458
rustbpe/Cargo.lock generated
View File

@ -1,458 +0,0 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "ahash"
version = "0.8.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75"
dependencies = [
"cfg-if",
"getrandom",
"once_cell",
"version_check",
"zerocopy",
]
[[package]]
name = "aho-corasick"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
dependencies = [
"memchr",
]
[[package]]
name = "arc-swap"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
[[package]]
name = "autocfg"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
[[package]]
name = "bit-set"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3"
dependencies = [
"bit-vec",
]
[[package]]
name = "bit-vec"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7"
[[package]]
name = "castaway"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a"
dependencies = [
"rustversion",
]
[[package]]
name = "cfg-if"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9"
[[package]]
name = "compact_str"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a"
dependencies = [
"castaway",
"cfg-if",
"itoa",
"rustversion",
"ryu",
"static_assertions",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
[[package]]
name = "dary_heap"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728"
[[package]]
name = "either"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
[[package]]
name = "equivalent"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "fancy-regex"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf04c5ec15464ace8355a7b440a33aece288993475556d461154d7a62ad9947c"
dependencies = [
"bit-set",
"regex-automata",
"regex-syntax",
]
[[package]]
name = "getrandom"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
dependencies = [
"cfg-if",
"libc",
"r-efi",
"wasi",
]
[[package]]
name = "hashbrown"
version = "0.15.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
[[package]]
name = "heck"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "indexmap"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2481980430f9f78649238835720ddccc57e52df14ffce1c6f37391d61b563e9"
dependencies = [
"equivalent",
"hashbrown",
]
[[package]]
name = "indoc"
version = "2.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd"
[[package]]
name = "itoa"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
[[package]]
name = "libc"
version = "0.2.175"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
[[package]]
name = "log"
version = "0.4.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432"
[[package]]
name = "memchr"
version = "2.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0"
[[package]]
name = "memoffset"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
[[package]]
name = "once_cell"
version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "portable-atomic"
version = "1.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483"
[[package]]
name = "proc-macro2"
version = "1.0.101"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de"
dependencies = [
"unicode-ident",
]
[[package]]
name = "pyo3"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
dependencies = [
"cfg-if",
"indoc",
"libc",
"memoffset",
"once_cell",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-build-config"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-log"
version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264"
dependencies = [
"arc-swap",
"log",
"pyo3",
]
[[package]]
name = "pyo3-macros"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
dependencies = [
"heck",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn",
]
[[package]]
name = "quote"
version = "1.0.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
dependencies = [
"proc-macro2",
]
[[package]]
name = "r-efi"
version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
[[package]]
name = "rayon"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
dependencies = [
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
dependencies = [
"crossbeam-deque",
"crossbeam-utils",
]
[[package]]
name = "regex-automata"
version = "0.4.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b9458fa0bfeeac22b5ca447c63aaf45f28439a709ccd244698632f9aa6394d6"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001"
[[package]]
name = "rustbpe"
version = "0.1.0"
dependencies = [
"ahash",
"compact_str",
"dary_heap",
"fancy-regex",
"indexmap",
"log",
"pyo3",
"pyo3-log",
"rayon",
]
[[package]]
name = "rustversion"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
[[package]]
name = "ryu"
version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "syn"
version = "2.0.106"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "target-lexicon"
version = "0.12.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]]
name = "unicode-ident"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
[[package]]
name = "unindent"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
[[package]]
name = "version_check"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
[[package]]
name = "wasi"
version = "0.14.4+wasi-0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88a5f4a424faf49c3c2c344f166f0662341d470ea185e939657aaff130f0ec4a"
dependencies = [
"wit-bindgen",
]
[[package]]
name = "wit-bindgen"
version = "0.45.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c573471f125075647d03df72e026074b7203790d41351cd6edc96f46bcccd36"
[[package]]
name = "zerocopy"
version = "0.8.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f"
dependencies = [
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.8.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181"
dependencies = [
"proc-macro2",
"quote",
"syn",
]

View File

@ -1,15 +0,0 @@
[package]
name = "rustbpe"
version = "0.1.0"
edition = "2024"
[dependencies]
dary_heap = "0.3"
indexmap = "2.2"
fancy-regex = "0.16.1"
log = "0.4.28"
pyo3 = { version = "0.23.3", features = ["extension-module"] }
pyo3-log = "0.12.4"
ahash = "0.8.12"
rayon = "1.11.0"
compact_str = "0.9.0"

View File

@ -1,5 +0,0 @@
# rustbpe
> The missing tiktoken training code
A very lightweight Rust library for training a GPT tokenizer. The issue is that the inference library [tiktoken](https://github.com/openai/tiktoken) is great, but only does inference. Separately, the huggingface [tokenizers](https://github.com/huggingface/tokenizers) library does training, but it is rather bloated and really hard to navigate because it has to support all the different historical baggage of how people dealt with tokenizers over the years. More recently, I also wrote the [minbpe](https://github.com/karpathy/minbpe) library which does both training and inference, but only in inefficient Python. Basically what I really want is a non-fancy, super simple, but still relatively efficient training code for GPT tokenizer (more efficient than minbpe, much cleaner/simpler than tokenizers), and then export the trained vocab for inference with tiktoken. Does that make sense? So here we are. There are more opportunities for optimization here, I just stopped a bit early because unlike minbpe before it, rustbpe is now simple and fast enough, and not a significant bottleneck for nanochat.

View File

@ -1,491 +0,0 @@
use std::cmp::Ordering;
use std::collections::HashMap as StdHashMap;
use dary_heap::OctonaryHeap;
use fancy_regex::Regex;
use pyo3::prelude::*;
use ahash::{AHashMap, AHashSet};
use compact_str::CompactString;
use rayon::prelude::*;
// Default GPT-4 style regex pattern for splitting text
const GPT4_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+";
type Pair = (u32, u32);
/// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation
#[pyclass]
pub struct Tokenizer {
/// Maps pairs of token IDs to their merged token ID
pub merges: StdHashMap<Pair, u32>,
/// The regex pattern used for text splitting
pub pattern: String,
/// Compiled regex for efficiency
compiled_pattern: Regex,
}
// ------------------------ internal helpers ------------------------
#[derive(Clone, Debug)]
struct Word {
ids: Vec<u32>,
}
impl Word {
#[inline]
fn new(ids: Vec<u32>) -> Self {
Self { ids }
}
#[inline]
fn pairs<'a>(&'a self) -> impl Iterator<Item = Pair> + 'a {
self.ids.windows(2).map(|w| (w[0], w[1]))
}
/// Merge all non-overlapping occurrences of pair -> new_id.
/// Returns a small Vec of local pair-count deltas for THIS word only:
/// -1 for removed pairs, +1 for newly created pairs.
///
/// NOTE: this version deliberately avoids a HashMap in the hot loop.
fn merge_pair(&mut self, pair: Pair, new_id: u32) -> Vec<(Pair, i32)> {
let (a, b) = pair;
let n = self.ids.len();
if n < 2 {
return Vec::new();
}
let mut out: Vec<u32> = Vec::with_capacity(n);
let mut deltas: Vec<(Pair, i32)> = Vec::with_capacity(6);
let mut i = 0;
while i < n {
if i + 1 < n && self.ids[i] == a && self.ids[i + 1] == b {
let left = out.last().copied();
let right = if i + 2 < n { Some(self.ids[i + 2]) } else { None };
// remove old pairs
if let Some(x) = left {
deltas.push(((x, a), -1));
deltas.push(((x, new_id), 1));
}
deltas.push(((a, b), -1));
if let Some(y) = right {
deltas.push(((b, y), -1));
deltas.push(((new_id, y), 1));
}
// write merged token
out.push(new_id);
i += 2; // skip 'a' and 'b'
} else {
out.push(self.ids[i]);
i += 1;
}
}
self.ids = out;
deltas
}
}
#[derive(Debug, Eq)]
struct MergeJob {
pair: Pair,
count: u64,
/// set of word indices where this pair may occur and needs processing
pos: AHashSet<usize>,
}
impl PartialEq for MergeJob {
fn eq(&self, other: &Self) -> bool {
self.count == other.count && self.pair == other.pair
}
}
impl PartialOrd for MergeJob {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MergeJob {
fn cmp(&self, other: &Self) -> Ordering {
// Max-heap by count; tie-break to ascending pair order (deterministic)
if self.count != other.count {
self.count.cmp(&other.count)
} else {
// ascending order on the pair when counts tie
other.pair.cmp(&self.pair)
}
}
}
#[inline]
fn count_pairs_parallel(
words: &[Word],
counts: &[i32],
) -> (AHashMap<Pair, i32>, AHashMap<Pair, AHashSet<usize>>) {
words
.par_iter()
.enumerate()
.map(|(i, w)| {
let mut local_pc: AHashMap<Pair, i32> = AHashMap::new();
let mut local_wtu: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();
if w.ids.len() >= 2 && counts[i] != 0 {
for (a, b) in w.pairs() {
*local_pc.entry((a, b)).or_default() += counts[i];
local_wtu.entry((a, b)).or_default().insert(i);
}
}
(local_pc, local_wtu)
})
.reduce(
|| (AHashMap::new(), AHashMap::new()),
|(mut acc_pc, mut acc_wtu), (pc, wtu)| {
for (k, v) in pc {
*acc_pc.entry(k).or_default() += v;
}
for (k, s) in wtu {
acc_wtu.entry(k).or_default().extend(s);
}
(acc_pc, acc_wtu)
},
)
}
// ------------------------ END helpers ------------------------
impl Tokenizer {
/// Core incremental BPE training given unique words and their counts.
/// `words`: one entry per unique chunk (Vec<u32> of token-ids/bytes).
/// `counts`: same length as `words`, count per chunk.
fn train_core_incremental(&mut self, mut words: Vec<Word>, counts: Vec<i32>, vocab_size: u32) {
assert!(vocab_size >= 256, "vocab_size must be at least 256");
let num_merges = vocab_size - 256;
log::info!("Starting BPE training: {} merges to compute", num_merges);
self.merges.clear();
// ---- Initial pair_counts and where_to_update (parallel) ----
log::info!("Computing initial pair counts from {} unique sequences", words.len());
let (mut pair_counts, mut where_to_update) = count_pairs_parallel(&words, &counts);
// ---- Build heap ----
log::info!("Building heap with {} unique pairs", pair_counts.len());
let mut heap = OctonaryHeap::with_capacity(pair_counts.len());
for (pair, pos) in where_to_update.drain() {
let c = *pair_counts.get(&pair).unwrap_or(&0);
if c > 0 {
heap.push(MergeJob {
pair,
count: c as u64,
pos,
});
}
}
// ---- Merge loop ----
log::info!("Starting merge loop");
let mut merges_done = 0u32;
let mut last_log_percent = 0u32;
while merges_done < num_merges {
let Some(mut top) = heap.pop() else { break; };
// Lazy refresh
let current = *pair_counts.get(&top.pair).unwrap_or(&0);
if top.count != current as u64 {
top.count = current as u64;
if top.count > 0 {
heap.push(top);
}
continue;
}
if top.count == 0 {
break;
}
// Record merge
let new_id = 256 + merges_done;
self.merges.insert(top.pair, new_id);
// Merge this pair in all words where it occurs
let mut local_pos_updates: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();
for &word_idx in &top.pos {
// Apply merge to this word and collect pair-count deltas
let changes = words[word_idx].merge_pair(top.pair, new_id);
// Update global pair counts based on this word's count
for (pair, delta) in changes {
let delta_total = delta * counts[word_idx];
if delta_total != 0 {
*pair_counts.entry(pair).or_default() += delta_total;
if delta > 0 {
local_pos_updates.entry(pair).or_default().insert(word_idx);
}
}
}
}
// Add the updated pair counts back to the heap
for (pair, pos) in local_pos_updates {
let cnt = *pair_counts.get(&pair).unwrap_or(&0);
if cnt > 0 {
heap.push(MergeJob {
pair,
count: cnt as u64,
pos,
});
}
}
merges_done += 1;
// Log progress every 1%
let current_percent = (merges_done * 100) / num_merges;
if current_percent > last_log_percent {
log::info!(
"Progress: {}% ({}/{} merges) - Last merge: {:?} -> {} (frequency: {})",
current_percent, merges_done, num_merges, top.pair, new_id, top.count
);
last_log_percent = current_percent;
}
}
log::info!("Finished training: {} merges completed", merges_done);
}
}
/// Public methods for the Tokenizer class that will be exposed to Python.
#[pymethods]
impl Tokenizer {
/// Create a new Tokenizer
#[new]
pub fn new() -> Self {
Self {
merges: StdHashMap::new(),
pattern: String::new(),
compiled_pattern: Regex::new("").expect("Empty regex should be valid"),
}
}
/// Train from a streaming iterator (parallel ingestion).
/// We refill a Rust Vec<String> buffer under the GIL, then release the GIL
/// to do the heavy splitting and counting **in parallel** with rayon.
#[pyo3(signature = (iterator, vocab_size, buffer_size=8192, pattern=None))]
#[pyo3(text_signature = "(self, iterator, vocab_size, buffer_size=8192, pattern=None)")]
pub fn train_from_iterator(
&mut self,
py: pyo3::Python<'_>,
iterator: &pyo3::Bound<'_, pyo3::PyAny>,
vocab_size: u32,
buffer_size: usize,
pattern: Option<String>,
) -> PyResult<()> {
// Use provided pattern or default to GPT-4 pattern
let pattern_str = pattern.unwrap_or_else(|| GPT4_PATTERN.to_string());
// Update the stored pattern and compile it
self.pattern = pattern_str.clone();
self.compiled_pattern = Regex::new(&pattern_str)
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?;
// Prepare a true Python iterator object
let py_iter: pyo3::Py<pyo3::PyAny> = unsafe {
pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
};
// Global chunk counts
let mut counts: AHashMap<CompactString, i32> = AHashMap::new();
// Temporary buffer we refill under the GIL
let mut buf: Vec<String> = Vec::with_capacity(buffer_size);
log::info!("Processing sequences from iterator (buffer_size: {})", buffer_size);
let mut total_sequences = 0u64;
// Helper: refill `buf` with up to `buffer_size` strings from the Python iterator.
// Returns Ok(true) if the iterator is exhausted, Ok(false) otherwise.
let refill = |buf: &mut Vec<String>| -> PyResult<bool> {
pyo3::Python::with_gil(|py| {
buf.clear();
let it = py_iter.bind(py);
loop {
if buf.len() >= buffer_size {
return Ok(false);
}
// next(it)
let next_obj = unsafe {
pyo3::Bound::from_owned_ptr_or_opt(py, pyo3::ffi::PyIter_Next(it.as_ptr()))
};
match next_obj {
Some(obj) => {
let s: String = obj.extract()?;
buf.push(s);
}
None => {
if pyo3::PyErr::occurred(py) {
return Err(pyo3::PyErr::fetch(py));
} else {
return Ok(true); // exhausted
}
}
}
}
})
};
// Stream ingestion loop: refill under GIL, process without GIL (parallel)
loop {
let exhausted = refill(&mut buf)?;
if buf.is_empty() && exhausted {
break;
}
total_sequences += buf.len() as u64;
let pattern = self.compiled_pattern.clone();
let local: AHashMap<CompactString, i32> = py.allow_threads(|| {
buf.par_iter()
.map(|s| {
let mut m: AHashMap<CompactString, i32> = AHashMap::new();
for mat in pattern.find_iter(s) {
let piece = mat.expect("regex match failed").as_str();
*m.entry(CompactString::from(piece)).or_default() += 1;
}
m
})
.reduce(
|| AHashMap::new(),
|mut a, b| {
for (k, v) in b {
*a.entry(k).or_default() += v;
}
a
},
)
});
// Merge local into global (single-threaded)
for (k, v) in local {
*counts.entry(k).or_default() += v;
}
if exhausted {
break;
}
}
log::info!("Processed {} sequences total, {} unique", total_sequences, counts.len());
// Materialize words & counts
let mut words = Vec::with_capacity(counts.len());
let mut cvec = Vec::with_capacity(counts.len());
for (chunk, c) in counts.into_iter() {
words.push(Word::new(chunk.as_bytes().iter().map(|&b| b as u32).collect()));
cvec.push(c);
}
self.train_core_incremental(words, cvec, vocab_size);
Ok(())
}
/// Return the regex pattern
pub fn get_pattern(&self) -> String {
self.pattern.clone()
}
/// Return the mergeable ranks (token bytes -> token id / rank)
pub fn get_mergeable_ranks(&self) -> Vec<(Vec<u8>, u32)> {
let mut mergeable_ranks = Vec::new();
// Build vocabulary incrementally from low to high token IDs
let mut token_bytes: Vec<Vec<u8>> = (0..256_u32).map(|i| vec![i as u8]).collect();
for (i, bytes) in token_bytes.iter().enumerate() {
mergeable_ranks.push((bytes.clone(), i as u32));
}
// Sort merges by token id (so we can reconstruct bytes progressively)
let mut sorted_merges: Vec<_> = self.merges.iter().collect();
sorted_merges.sort_by_key(|&(_, &token_id)| token_id);
for (&pair, &merged_id) in sorted_merges {
let (left, right) = pair;
let mut merged_bytes = token_bytes[left as usize].clone();
merged_bytes.extend(&token_bytes[right as usize]);
if token_bytes.len() <= merged_id as usize {
token_bytes.resize(merged_id as usize + 1, Vec::new());
}
token_bytes[merged_id as usize] = merged_bytes.clone();
mergeable_ranks.push((merged_bytes, merged_id));
}
mergeable_ranks
}
/// Encode a string into token IDs
pub fn encode(&self, text: &str) -> Vec<u32> {
let mut all_ids = Vec::new();
// Split text using the regex pattern
for m in self.compiled_pattern.find_iter(text) {
let chunk = m.expect("regex match failed").as_str();
// Convert chunk to bytes then to u32 IDs
let mut ids: Vec<u32> = chunk.bytes().map(|b| b as u32).collect();
// Apply merges iteratively
while ids.len() >= 2 {
// Find the best pair to merge
let mut best_pair: Option<(usize, Pair, u32)> = None;
for i in 0..ids.len() - 1 {
let pair: Pair = (ids[i], ids[i + 1]);
if let Some(&new_id) = self.merges.get(&pair) {
if best_pair.is_none() || new_id < best_pair.unwrap().2 {
best_pair = Some((i, pair, new_id));
}
}
}
// If we found a pair to merge, apply it
if let Some((idx, _pair, new_id)) = best_pair {
ids[idx] = new_id;
ids.remove(idx + 1);
} else {
// No more merges possible
break;
}
}
all_ids.extend(ids);
}
all_ids
}
/// Encode multiple texts in parallel using rayon.
/// Returns a list of token ID vectors, one per input text.
#[pyo3(signature = (texts))]
#[pyo3(text_signature = "(self, texts)")]
pub fn batch_encode(&self, py: Python<'_>, texts: Vec<String>) -> PyResult<Vec<Vec<u32>>> {
// Release Python GIL and encode in parallel using rayon
let results = py.allow_threads(|| {
texts
.par_iter()
.map(|text| self.encode(text))
.collect::<Vec<Vec<u32>>>()
});
Ok(results)
}
}
#[pymodule]
fn rustbpe(m: &Bound<'_, PyModule>) -> PyResult<()> {
pyo3_log::init(); // forwards Rust `log` to Python's `logging`
m.add_class::<Tokenizer>()?;
Ok(())
}

115
scaling_laws.sh Normal file
View File

@ -0,0 +1,115 @@
#!/bin/bash
FLOPS_BUDGETS=(
1e18
3e18
6e18
)
DEPTHS=(8 10 12 14 16 18 20)
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
WANDB_RUN="${WANDB_RUN:-scaling}"
EVAL_TOKENS=$((100 * 524288)) # ~100M tokens for final eval (default is ~10M)
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat}"
source .venv/bin/activate
RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results"
mkdir -p "$RESULTS_DIR"
RESULTS_FILE="$RESULTS_DIR/results.csv"
# Write CSV header only if file doesn't exist
if [ ! -f "$RESULTS_FILE" ]; then
echo "flops_budget,depth,model_dim,num_params,num_scaling_params,num_iterations,tokens_trained,param_data_ratio,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
fi
log() {
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
}
# Check if a run already exists in results
run_exists() {
local flops=$1
local depth=$2
grep -q "^${flops},${depth}," "$RESULTS_FILE" 2>/dev/null
}
# =============================================================================
# Main Loop
# =============================================================================
for flops in "${FLOPS_BUDGETS[@]}"; do
log "=============================================="
log "Compute budget: $flops FLOPs"
log "=============================================="
for d in "${DEPTHS[@]}"; do
# Skip if already completed
if run_exists "$flops" "$d"; then
log "Skipping d=$d at $flops FLOPs (already in results)"
continue
fi
log "Training d=$d at $flops FLOPs..."
# Unique tag for this run
TAG="scaling_${flops}_d${d}"
# Record start time
START_TIME=$(date +%s)
# Train the model with fixed flops budget
# The script will auto-calculate num_iterations to hit target_flops
# CORE eval happens once at the end (999999 ensures only final step)
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
--depth=$d \
--target-flops=$flops \
--target-param-data-ratio=-1 \
--run="${WANDB_RUN}_${TAG}" \
--model-tag="${TAG}" \
--eval-tokens=$EVAL_TOKENS \
--core-metric-every=999999 \
--core-metric-max-per-task=-1 \
--sample-every=-1 \
--save-every=-1 \
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
END_TIME=$(date +%s)
TRAIN_TIME=$((END_TIME - START_TIME))
# Extract training stats from the log
LOG_FILE="$RESULTS_DIR/${TAG}_train.log"
NUM_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | head -1 | tr -d ',')
NUM_SCALING_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP 'scaling: [\d,]+' | grep -oP '[\d,]+' | tr -d ',')
NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',')
# Calculate tokens trained (iterations * batch_size, default 524288)
TOKENS_TRAINED=$((NUM_ITERS * 524288))
# Param:data ratio (using scaling params per Kaplan et al.)
PARAM_DATA_RATIO=$(python -c "print(f'{$TOKENS_TRAINED / $NUM_SCALING_PARAMS:.2f}')")
# Model dim
MODEL_DIM=$((d * 64))
# Val BPB from final eval
VAL_BPB=$(grep "Validation bpb:" "$LOG_FILE" | tail -1 | grep -oP '[\d.]+$')
# Extract CORE score from training log (evaluated on final step)
CORE_SCORE=$(grep "CORE metric:" "$LOG_FILE" | tail -1 | awk '{print $NF}')
if [ -z "$CORE_SCORE" ]; then
log "WARNING: Could not extract CORE score for d=$d"
CORE_SCORE="0.0"
fi
log " Params: $NUM_PARAMS, Iters: $NUM_ITERS, Ratio: $PARAM_DATA_RATIO, Val BPB: $VAL_BPB, CORE: $CORE_SCORE"
# Append to CSV
echo "$flops,$d,$MODEL_DIM,$NUM_PARAMS,$NUM_SCALING_PARAMS,$NUM_ITERS,$TOKENS_TRAINED,$PARAM_DATA_RATIO,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
done
done
log "=============================================="
log "Scaling Laws Sweep Complete"
log "=============================================="
log "Results saved to: $RESULTS_FILE"
echo ""
echo "Results:"
column -t -s',' "$RESULTS_FILE"

View File

@ -5,48 +5,108 @@ Loads a checkpoint, and:
Example run as:
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
To evaluate a HuggingFace model:
python -m scripts.base_loss --hf-path openai-community/gpt2
"""
import os
import argparse
from contextlib import nullcontext
import torch
from nanochat.checkpoint_manager import load_model
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.tokenizer import get_token_bytes
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
from nanochat.tokenizer import get_token_bytes, HuggingFaceTokenizer
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
# Configuration
device_batch_size = 32
split_tokens = 20*524288 # number of tokens to evaluate per split
model_tag = None # optional model tag for the output directory name
model_step = None # optional model step for the output directory name
device_type = "" # cuda|cpu|mps (empty => autodetect)
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
# -----------------------------------------------------------------------------
# HuggingFace loading utilities, making the APIs match up to those of nanochat
class ModelWrapper:
"""Lightweight wrapper for a HuggingFace model"""
def __init__(self, model, max_seq_len=None):
self.model = model
self.max_seq_len = max_seq_len
def __call__(self, input_ids, targets=None, loss_reduction='mean'):
logits = self.model(input_ids).logits
if targets is None:
return logits
else:
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
return loss
def get_device(self):
return next(self.model.parameters()).device
def load_hf_model(hf_path: str, device):
print0(f"Loading model from: {hf_path}")
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(hf_path)
model.to(device)
model.eval()
max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None
model = ModelWrapper(model, max_seq_len=max_seq_len)
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
return model, tokenizer
def get_hf_token_bytes(tokenizer, device="cpu"):
"""Compute token_bytes tensor for a HuggingFace tokenizer."""
vocab_size = tokenizer.tokenizer.get_vocab_size()
token_bytes = torch.zeros(vocab_size, dtype=torch.int64, device=device)
for token_id in range(vocab_size):
token_str = tokenizer.tokenizer.decode([token_id])
token_bytes[token_id] = len(token_str.encode('utf-8')) # Count UTF-8 bytes
return token_bytes
# CLI arguments
parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model")
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--split-tokens", type=int, default=40*524288, help="number of tokens to evaluate per split")
parser.add_argument("--model-tag", type=str, default=None, help="model tag for checkpoint directory")
parser.add_argument("--model-step", type=int, default=None, help="model step to load")
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--hf-path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)")
args = parser.parse_args()
# Load the base model and the tokenizer
device_type = autodetect_device_type() if device_type == "" else device_type
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
print0(f"Device: {device} | DDP rank: {ddp_rank} | DDP local rank: {ddp_local_rank} | DDP world size: {ddp_world_size}")
if args.hf_path is not None:
# Load HuggingFace model
model, tokenizer = load_hf_model(args.hf_path, device)
sequence_len = model.max_seq_len if model.max_seq_len else 1024
token_bytes = get_hf_token_bytes(tokenizer, device=device)
model_name = args.hf_path
else:
# Load local nanochat model
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
sequence_len = meta["model_config"]["sequence_len"]
token_bytes = get_token_bytes(device=device)
model_name = f"base_model (step {meta['step']})"
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
print0(f"Evaluating model: {model_name}")
# Evaluate the loss on each split
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
steps = split_tokens // tokens_per_step
token_bytes = get_token_bytes(device=device)
tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size
assert args.split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
steps = args.split_tokens // tokens_per_step
bpb_results = {}
for split_name in ["train", "val"]:
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
with autocast_ctx:
bpb = evaluate_bpb(model, loader, steps, token_bytes)
print0(f"{split_name} bpb: {bpb:.4f}")
bpb_results[split_name] = bpb
print0(f"Model: {model_name}, {split_name} bpb: {bpb:.6f}")
# Master process also samples from the model
# Master process also samples from the model (only for nanochat models)
samples = []
if ddp_rank == 0:
if ddp_rank == 0 and args.hf_path is None:
prompts = [
"The capital of France is",
"The chemical symbol of gold is",
@ -69,6 +129,7 @@ if ddp_rank == 0:
from nanochat.report import get_report
get_report().log(section="Base model loss", data=[
{
"model": model_name,
"train bpb": bpb_results["train"],
"val bpb": bpb_results["val"],
},

View File

@ -1,18 +1,19 @@
"""
Train model. Run as:
Train model. From root directory of the project, run as:
python base_train.py
python -m scripts.base_train.py
or distributed as:
torchrun --nproc_per_node=8 base_train.py
torchrun --nproc_per_node=8 -m scripts.base_train.py
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
"""
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import argparse
import time
from contextlib import nullcontext
@ -20,7 +21,7 @@ import wandb
import torch
from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
@ -30,46 +31,51 @@ from scripts.base_eval import evaluate_model
print_banner()
# -----------------------------------------------------------------------------
# User settings
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
# CLI arguments
parser = argparse.ArgumentParser(description="Pretrain base model")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
# Model architecture
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
max_seq_len = 2048 # max context length
# Training horizon. Only one of these 3 will be used, in this order of precedence.
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
# Training horizon (only one used, in order of precedence)
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
parser.add_argument("--target-param-data-ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
# Optimization
device_batch_size = 32 # per-device batch size (set to not OOM)
total_batch_size = 524288 # total desired batch size, in #tokens
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
warmup_ratio = 0.0 # ratio of iterations for LR warmup
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
parser.add_argument("--warmdown-ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
# Evaluation
eval_every = 250 # every how many steps to evaluate the model for val bpb
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
core_metric_max_per_task = 500 # examples per task in estimating the core metric
sample_every = 2000 # every how many steps to sample from the model
save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
# Output
model_tag = "" # optionally override the model tag for the output checkpoint directory name
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
args = parser.parse_args()
user_config = vars(args).copy() # for logging
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if device_type == "" else device_type
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
@ -77,8 +83,8 @@ synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config)
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
# Tokenizer will be useful for evaluation, also we need the vocab size
tokenizer = get_tokenizer()
@ -87,9 +93,17 @@ vocab_size = tokenizer.get_vocab_size()
print0(f"Vocab size: {vocab_size:,}")
# Model kwargs are derived from the desired depth of the model
num_layers = depth
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
num_layers = args.depth
model_dim = args.depth * args.aspect_ratio
def find_num_heads(model_dim, target_head_dim):
# Find num_heads that divides model_dim evenly, with head_dim closest to target.
ideal = max(1, round(model_dim / target_head_dim))
for offset in range(model_dim):
for candidate in [ideal + offset, ideal - offset]:
if candidate > 0 and model_dim % candidate == 0:
return candidate
return 1
num_heads = find_num_heads(model_dim, args.head_dim)
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
print0(f"num_layers: {num_layers}")
print0(f"model_dim: {model_dim}")
@ -98,66 +112,93 @@ print0(f"num_kv_heads: {num_kv_heads}")
# Optimizer / data / training length related hyperparameters
# figure out the needed gradient accumulation to reach the desired total batch size
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19)
batch_lr_scale = 1.0
reference_batch_size = 2**19
batch_ratio = args.total_batch_size / reference_batch_size
if batch_ratio != 1.0:
# SGD: linear scaling with batch size is standard (not used in nanochat)
# AdamW: sqrt scaling is standard
# Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer
batch_lr_scale = batch_ratio ** 0.5
print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})")
# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio)
weight_decay_scaled = args.weight_decay * (12 / args.depth)**2
if args.depth != 12:
print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}")
# -----------------------------------------------------------------------------
# Initialize the Model
# Create a new model with random weights
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, window_pattern=args.window_pattern)
with torch.device("meta"):
# All tensors are created as meta tensors (they have shape/dtype but no data)
model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config)
model.to_empty(device=device)
model.init_weights()
model.to_empty(device=device) # All tensors get storage on target device but with uninitialized (garbage) data
model.init_weights() # All tensors get initialized
# If we are resuming, overwrite the model parameters with those of the checkpoint
base_dir = get_base_dir()
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
output_dirname = args.model_tag if args.model_tag else f"d{args.depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
resuming = resume_from_step != -1
resuming = args.resume_from_step != -1
if resuming:
print0(f"Resuming optimization from step {resume_from_step}")
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank)
print0(f"Resuming optimization from step {args.resume_from_step}")
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, args.resume_from_step, device, load_optimizer=True, rank=ddp_rank)
model.load_state_dict(model_data, strict=True, assign=True)
del model_data # free up this memory after the copy
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
num_params = sum(p.numel() for p in model.parameters())
print0(f"Number of parameters: {num_params:,}")
num_scaling_params = orig_model.num_scaling_params()
print0(f"Number of parameters: {num_params:,} (scaling: {num_scaling_params:,})")
num_flops_per_token = model.estimate_flops()
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
if num_iterations > 0:
assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0
if args.num_iterations > 0:
num_iterations = args.num_iterations
print0(f"Using user-provided number of iterations: {num_iterations:,}")
elif target_flops > 0:
elif args.target_flops > 0:
# calculate the number of iterations from the target flops
num_iterations = round(target_flops / (num_flops_per_token * total_batch_size))
num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size))
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
elif target_param_data_ratio > 0:
# calculate the number of iterations from the target param data ratio
target_tokens = target_param_data_ratio * num_params
num_iterations = target_tokens // total_batch_size
elif args.target_param_data_ratio > 0:
# calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.)
target_tokens = args.target_param_data_ratio * num_scaling_params
num_iterations = target_tokens // args.total_batch_size
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
else:
raise ValueError("No training horizon specified")
total_tokens = total_batch_size * num_iterations
total_tokens = args.total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
print0(f"Tokens : Params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# -----------------------------------------------------------------------------
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
adam_betas = (args.adam_beta1, args.adam_beta2)
optimizers = model.setup_optimizers(
unembedding_lr=args.unembedding_lr * batch_lr_scale,
embedding_lr=args.embedding_lr * batch_lr_scale,
matrix_lr=args.matrix_lr * batch_lr_scale,
weight_decay=weight_decay_scaled,
adam_betas=adam_betas,
scalar_lr=args.scalar_lr * batch_lr_scale,
)
adamw_optimizer, muon_optimizer = optimizers
if resuming:
@ -169,8 +210,8 @@ if resuming:
# Initialize the DataLoaders for train/val
tokens_dir = os.path.join(base_dir, "tokenized_data")
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
train_loader = tokenizing_distributed_data_loader_with_state(device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
# -----------------------------------------------------------------------------
@ -178,15 +219,15 @@ x, y, dataloader_state_dict = next(train_loader) # kick off load of the very fir
# Learning rate scheduler
def get_lr_multiplier(it):
warmup_iters = round(warmup_ratio * num_iterations)
warmdown_iters = round(warmdown_ratio * num_iterations)
warmup_iters = round(args.warmup_ratio * num_iterations)
warmdown_iters = round(args.warmdown_ratio * num_iterations)
if it < warmup_iters:
return (it + 1) / warmup_iters
elif it <= num_iterations - warmdown_iters:
return 1.0
else:
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * final_lr_frac
return progress * 1.0 + (1 - progress) * args.final_lr_frac
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
@ -194,11 +235,16 @@ def get_muon_momentum(it):
momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum
# Weight decay scheduler for Muon optimizer (linear to zero over the course of training)
def get_weight_decay(it):
return weight_decay_scaled * (1 - it / num_iterations)
# -----------------------------------------------------------------------------
# Loop state (variables updated by the training loop)
if not resuming:
step = 0
val_bpb = None # will be set if eval_every > 0
min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
total_training_time = 0 # total wall-clock time of training
@ -214,16 +260,16 @@ else:
# Training loop
while True:
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
flops_so_far = num_flops_per_token * total_batch_size * step
flops_so_far = num_flops_per_token * args.total_batch_size * step
# once in a while: evaluate the val bpb (all ranks participate)
if last_step or step % eval_every == 0:
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
model.eval()
val_loader = build_val_loader()
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
with autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
if val_bpb < min_val_bpb:
min_val_bpb = val_bpb
wandb_run.log({
@ -237,10 +283,10 @@ while True:
# once in a while: estimate the CORE metric (all ranks participate)
# use the original uncompiled model because the inputs keep changing shape
results = {}
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
model.eval()
with autocast_ctx:
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
results = evaluate_model(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
wandb_run.log({
"step": step,
@ -252,7 +298,7 @@ while True:
# once in a while: sample from the model (only on master process)
# use the original uncompiled model because the inputs keep changing shape
if master_process and (last_step or (step > 0 and step % sample_every == 0)):
if args.sample_every > 0 and master_process and (last_step or (step > 0 and step % args.sample_every == 0)):
model.eval()
prompts = [
"The capital of France is",
@ -272,7 +318,7 @@ while True:
model.train()
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0):
if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0):
save_checkpoint(
checkpoint_dir,
step,
@ -283,8 +329,8 @@ while True:
"val_bpb": val_bpb, # loss at last step
"model_config": model_config_kwargs,
"user_config": user_config, # inputs to the training script
"device_batch_size": device_batch_size,
"max_seq_len": max_seq_len,
"device_batch_size": args.device_batch_size,
"max_seq_len": args.max_seq_len,
"dataloader_state_dict": dataloader_state_dict,
"loop_state": { # all loop state (other than step) so that we can resume training
"min_val_bpb": min_val_bpb,
@ -311,19 +357,16 @@ while True:
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
# gradient clipping
grad_clip_enabled = grad_clip > 0.0
if grad_clip_enabled:
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point)
# step the optimizers
lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
muon_momentum = get_muon_momentum(step)
muon_weight_decay = get_weight_decay(step)
for group in muon_optimizer.param_groups:
group["momentum"] = muon_momentum
group["weight_decay"] = muon_weight_decay
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
@ -337,14 +380,23 @@ while True:
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * step / num_iterations
tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
tok_per_sec = int(args.total_batch_size / dt)
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled else ""
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
# Calculate ETA based on average time per step (excluding first 10 steps)
steps_done = step - 10
if steps_done > 0:
avg_time_per_step = total_training_time / steps_done
remaining_steps = num_iterations - step
eta_seconds = remaining_steps * avg_time_per_step
eta_str = f" | eta: {eta_seconds/60:.1f}m"
else:
eta_str = ""
epoch = dataloader_state_dict["epoch"]
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
if step % 100 == 0:
log_data = {
"step": step,
@ -355,9 +407,8 @@ while True:
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
"train/epoch": epoch,
}
if grad_clip_enabled:
log_data["train/grad_norm"] = grad_norm
wandb_run.log(log_data)
# state update
@ -366,7 +417,8 @@ while True:
# print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
if val_bpb is not None:
print0(f"Minimum validation bpb: {min_val_bpb:.6f}")
# Log to report
from nanochat.report import get_report
@ -377,14 +429,14 @@ get_report().log(section="Base model training", data=[
"Number of FLOPs per token": f"{num_flops_per_token:e}",
"Calculated number of iterations": num_iterations,
"Number of training tokens": total_tokens,
"Tokens : Params ratio": total_batch_size * num_iterations / num_params,
"Tokens : Params ratio": args.total_batch_size * num_iterations / num_params,
"DDP world size": ddp_world_size,
"warmup_ratio": warmup_ratio,
"warmdown_ratio": warmdown_ratio,
"final_lr_frac": final_lr_frac,
"warmup_ratio": args.warmup_ratio,
"warmdown_ratio": args.warmdown_ratio,
"final_lr_frac": args.final_lr_frac,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
"Minimum validation bpb": min_val_bpb if val_bpb is not None else None,
"Final validation bpb": val_bpb,
"CORE metric estimate": results.get("core_metric", None),
"MFU %": f"{mfu:.2f}%",

View File

@ -16,57 +16,68 @@ python -m scripts.chat_rl
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default
"""
import argparse
import os
import itertools
import re
import wandb
import torch
import torch.distributed as dist
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
from nanochat.checkpoint_manager import save_checkpoint, load_model
from nanochat.engine import Engine
from tasks.gsm8k import GSM8K
# RL hyperparameters
run = "dummy" # wandb run name
source = "sft" # mid|sft
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
dtype = "bfloat16"
device_batch_size = 8 # no forward pass will go above this to not OOM
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
num_samples = 16 # number of samples per example (/question)
max_new_tokens = 256
temperature = 1.0
top_k = 50 # TODO: try None?
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02
weight_decay = 0.0
init_lr_frac = 0.05
num_epochs = 1 # how many epochs of gsm8k to train on
save_every = 60 # every how many steps to save the model
eval_every = 60 # every how many steps to evaluate the model for val pass@k
eval_examples = 400 # number of examples used for evaluating pass@k
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
# -----------------------------------------------------------------------------
# CLI arguments
parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from")
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
# Training horizon
parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs over GSM8K")
# Batch sizes / sampling
parser.add_argument("--device-batch-size", type=int, default=8, help="max batch size per forward pass")
parser.add_argument("--examples-per-step", type=int, default=16, help="total examples per optimization step across all ranks")
parser.add_argument("--num-samples", type=int, default=16, help="number of samples per example/question")
# Generation
parser.add_argument("--max-new-tokens", type=int, default=256, help="max tokens to generate per sample")
parser.add_argument("--temperature", type=float, default=1.0, help="sampling temperature")
parser.add_argument("--top-k", type=int, default=50, help="top-k sampling (0 = disabled)")
# Optimization
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init-lr-frac", type=float, default=0.05, help="initial LR as fraction of base LR")
# Evaluation / checkpointing
parser.add_argument("--eval-every", type=int, default=60, help="evaluate pass@k every N steps")
parser.add_argument("--eval-examples", type=int, default=400, help="number of examples for pass@k evaluation")
parser.add_argument("--save-every", type=int, default=60, help="save checkpoint every N steps")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
# Init compute/precision
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=args.run, config=user_config)
# Init model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.model_step)
engine = Engine(model, tokenizer) # for sampling rollouts
# -----------------------------------------------------------------------------
@ -74,7 +85,7 @@ engine = Engine(model, tokenizer) # for sampling rollouts
train_task = GSM8K(subset="main", split="train")
val_task = GSM8K(subset="main", split="test")
num_steps = (len(train_task) // examples_per_step) * num_epochs
num_steps = (len(train_task) // args.examples_per_step) * args.num_epochs
print0(f"Calculated number of steps: {num_steps}")
@torch.no_grad()
@ -95,16 +106,16 @@ def get_batch():
model.eval() # ensure the model is in eval mode
generated_token_sequences = []
masks = []
num_sampling_steps = num_samples // device_batch_size # go sequentially to prevent OOMs
num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs
for sampling_step in range(num_sampling_steps):
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
with autocast_ctx:
generated_token_sequences_batch, masks_batch = engine.generate_batch(
tokens,
num_samples=device_batch_size,
max_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
num_samples=args.device_batch_size,
max_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
seed=seed, # must make sure to change the seed for each sampling step
)
generated_token_sequences.extend(generated_token_sequences_batch)
@ -162,7 +173,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
tokens = tokenizer.render_for_completion(conversation)
prefix_length = len(tokens)
# Generate k samples using batched generation inside the Engine
assert num_samples <= device_batch_size # usually this is true. we can add a loop if not...
assert num_samples <= args.device_batch_size # usually this is true. we can add a loop if not...
generated_token_sequences, masks = engine.generate_batch(
tokens,
num_samples=num_samples,
@ -191,16 +202,16 @@ def run_gsm8k_eval(task, tokenizer, engine,
# Init the optimizer
optimizers = model.setup_optimizers(
unembedding_lr=unembedding_lr,
embedding_lr=embedding_lr,
matrix_lr=matrix_lr,
weight_decay=weight_decay,
unembedding_lr=args.unembedding_lr,
embedding_lr=args.embedding_lr,
matrix_lr=args.matrix_lr,
weight_decay=args.weight_decay,
)
# Set the initial learning rate as a fraction of the base learning rate
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * init_lr_frac
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# Learning rate scheduler: simple rampdown to zero over num_steps
@ -209,9 +220,9 @@ def get_lr_multiplier(it):
return lrm
# Calculate the number of examples each rank handles to achieve the desired examples_per_step
print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step
assert examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
examples_per_rank = examples_per_step // ddp_world_size # per GPU
print0(f"Total sequences per step: {args.examples_per_step * args.num_samples}") # total batch size in sequences/step
assert args.examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
examples_per_rank = args.examples_per_step // ddp_world_size # per GPU
print0(f"Calculated examples per rank: {examples_per_rank}")
# Kick off the training loop
@ -219,22 +230,22 @@ batch_iterator = get_batch()
for step in range(num_steps):
# Evaluate the model once in a while and log to wandb
if step % eval_every == 0:
if step % args.eval_every == 0:
model.eval()
passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size
passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size
with autocast_ctx:
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0)
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
records = list(records_iter) # collect all records
for k in range(1, device_batch_size + 1):
for k in range(1, args.device_batch_size + 1):
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
if ddp:
dist.all_reduce(num_records, op=dist.ReduceOp.SUM)
dist.all_reduce(passk, op=dist.ReduceOp.SUM)
passk = passk / num_records.item() # normalize by the total number of records
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, device_batch_size + 1)]
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, args.device_batch_size + 1)]
print0(f"Step {step} | {', '.join(print_passk)}")
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)}
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, args.device_batch_size + 1)}
wandb_run.log({
"step": step,
**log_passk,
@ -249,11 +260,11 @@ for step in range(num_steps):
# Evaluate the loss and gradients
model.train() # ensure the model is in train mode
# We need one more loop because we can never exceed the device_batch_size
assert inputs_all.size(0) % device_batch_size == 0
num_passes = inputs_all.size(0) // device_batch_size
assert inputs_all.size(0) % args.device_batch_size == 0
num_passes = inputs_all.size(0) // args.device_batch_size
for pass_idx in range(num_passes):
# Pluck out the batch for this pass
b0, b1 = pass_idx * device_batch_size, (pass_idx + 1) * device_batch_size
b0, b1 = pass_idx * args.device_batch_size, (pass_idx + 1) * args.device_batch_size
inputs = inputs_all[b0:b1]
targets = targets_all[b0:b1]
rewards = rewards_all[b0:b1]
@ -306,10 +317,10 @@ for step in range(num_steps):
})
# Master process saves the model once in a while. Skip first step. Save last step.
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
if master_process and ((step > 0 and step % args.save_every == 0) or step == num_steps - 1):
base_dir = get_base_dir()
depth = model.config.n_layer
output_dirname = model_tag if model_tag else f"d{depth}" # base the model tag on the depth of the base model
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # base the model tag on the depth of the base model
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname)
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
save_checkpoint(

View File

@ -9,6 +9,7 @@ Or torchrun for training:
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
"""
import argparse
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
@ -31,49 +32,51 @@ from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# -----------------------------------------------------------------------------
# SFT Hyperparameters
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
# input model options
source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model)
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
# compute/precision
device_type = "" # cuda|cpu|mps (empty => autodetect)
dtype = "bfloat16"
device_batch_size = 4 # max to avoid OOM
# optimization
num_epochs = 1
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
target_examples_per_step = 32
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02
weight_decay = 0.0
init_lr_frac = 0.02
# evaluation and logging there of
eval_every = 100
eval_steps = 100
eval_metrics_every = 200
eval_metrics_max_problems = 1024
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
# CLI arguments
parser = argparse.ArgumentParser(description="Supervised finetuning for chat")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--source", type=str, default="mid", help="base|mid - which checkpoint to load from")
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
# Training horizon
parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs")
parser.add_argument("--num-iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)")
# Batch sizes
parser.add_argument("--device-batch-size", type=int, default=4, help="per-device batch size")
parser.add_argument("--target-examples-per-step", type=int, default=32, help="target examples per optimization step")
# Optimization
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init-lr-frac", type=float, default=0.02, help="initial LR as fraction of base LR")
# Evaluation
parser.add_argument("--eval-every", type=int, default=100, help="evaluate val loss every N steps")
parser.add_argument("--eval-steps", type=int, default=100, help="number of batches for val loss evaluation")
parser.add_argument("--eval-metrics-every", type=int, default=200, help="evaluate accuracy metrics every N steps")
parser.add_argument("--eval-metrics-max-problems", type=int, default=1024, help="max problems per metric evaluation")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if device_type == "" else device_type
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config, save_code=True)
# Load the model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
model, tokenizer, meta = load_model(args.source, device, phase="train", model_tag=args.model_tag, step=args.model_step)
orig_model = model # original, uncompiled model
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
@ -127,34 +130,36 @@ def sft_data_generator(dataset, batch_size):
yield collate_and_yield(batch)
batch = []
examples_per_step = device_batch_size * ddp_world_size
print0(f"Target examples per step: {target_examples_per_step}")
print0(f"Device batch size: {device_batch_size}")
examples_per_step = args.device_batch_size * ddp_world_size
print0(f"Target examples per step: {args.target_examples_per_step}")
print0(f"Device batch size: {args.device_batch_size}")
print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}")
assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
grad_accum_steps = target_examples_per_step // examples_per_step
assert args.target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
grad_accum_steps = args.target_examples_per_step // examples_per_step
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
if num_iterations == -1:
if args.num_iterations == -1:
# derive num_iterations from num_epochs and the size of the dataset
assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
assert args.num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
num_iterations = (len(train_ds) // args.target_examples_per_step) * args.num_epochs
else:
num_iterations = args.num_iterations
train_loader = sft_data_generator(train_ds, batch_size=args.device_batch_size)
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=args.device_batch_size)
# -----------------------------------------------------------------------------
# Initialize the Optimizer
optimizers = model.setup_optimizers(
unembedding_lr=unembedding_lr,
embedding_lr=embedding_lr,
matrix_lr=matrix_lr,
weight_decay=weight_decay,
unembedding_lr=args.unembedding_lr,
embedding_lr=args.embedding_lr,
matrix_lr=args.matrix_lr,
weight_decay=args.weight_decay,
)
# Set the initial learning rate as a fraction of the base learning rate
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * init_lr_frac
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# -----------------------------------------------------------------------------
@ -171,11 +176,11 @@ for step in range(num_iterations):
last_step = step == num_iterations - 1
# evaluate the validation loss
if last_step or step % eval_every == 0:
if last_step or step % args.eval_every == 0:
model.eval()
val_loader = build_val_loader()
losses = []
for _ in range(eval_steps):
for _ in range(args.eval_steps):
val_inputs, val_targets = next(val_loader)
with torch.no_grad(), autocast_ctx:
loss = model(val_inputs, val_targets)
@ -192,13 +197,13 @@ for step in range(num_iterations):
model.train()
# evaluate accuracy of the multiple choice tasks (which are quick to run)
if last_step or (step > 0 and step % eval_metrics_every == 0):
if last_step or (step > 0 and step % args.eval_metrics_every == 0):
model.eval()
metrics = {}
with torch.no_grad(), autocast_ctx:
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=args.device_batch_size*2, max_problems=args.eval_metrics_max_problems)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=args.device_batch_size*2, max_problems=args.eval_metrics_max_problems)
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
print0(f"Step {step:05d} | {metrics_str}")
wandb_run.log({
@ -250,7 +255,7 @@ for step in range(num_iterations):
if master_process:
base_dir = get_base_dir()
depth = model.config.n_layer
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
save_checkpoint(

View File

@ -6,10 +6,10 @@ python -m scripts.mid_train
Or torchrun for training:
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16
"""
from collections import deque
import argparse
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import time
@ -31,65 +31,75 @@ from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
# -----------------------------------------------------------------------------
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
device_type = "" # cuda|cpu|mps (empty => autodetect)
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
dtype = "bfloat16"
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
max_seq_len = 2048
device_batch_size = 32
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
weight_decay = 0.0
eval_every = 150 # -1 = disable
eval_tokens = 20*524288
total_batch_size = 524288
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
# CLI arguments
parser = argparse.ArgumentParser(description="Midtrain the model")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
# Training horizon
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
# Batch sizes
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
# Optimization
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR")
# Evaluation
parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
# Output
parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
# Compute init
device_type = autodetect_device_type() if device_type == "" else device_type
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config)
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=args.run, config=user_config)
# Load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
orig_model = model
model = torch.compile(model, dynamic=False)
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
token_bytes = get_token_bytes(device=device)
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
optimizers = model.setup_optimizers(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
adamw_optimizer, muon_optimizer = optimizers
# Override the initial learning rate as a fraction of the base learning rate
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * init_lr_frac
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# Midtraining data mixture and DataLoader
@ -114,49 +124,100 @@ val_dataset = TaskMixture([
# these two global variables and update them from within the data generator.
last_step = False # we will toggle this to True when we reach the end of the training dataset
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
def mid_data_generator(split):
global last_step, approx_progress
current_epoch = 1 # track epoch for logging
def mid_data_generator_bos_bestfit(split, buffer_size=100):
"""
BOS-aligned dataloader for midtraining with bestfit-crop packing.
Each row in the batch starts with BOS (beginning of a conversation).
Conversations are packed using best-fit algorithm to minimize cropping.
This matches the BOS-aligned approach used in pretraining.
"""
global last_step, approx_progress, current_epoch
assert split in {"train", "val"}, "split must be 'train' or 'val'"
dataset = train_dataset if split == "train" else val_dataset
dataset_size = len(dataset)
assert dataset_size > 0
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
token_buffer = deque()
# CUDA supports memory pinning for faster transfers between CPU and GPU:
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
it = 0 # iteration counter
while True:
# Accumulate enough tokens for one iteration before yielding
while len(token_buffer) < needed_tokens:
row_capacity = args.max_seq_len + 1 # +1 for target at last position
# Conversation buffer: list of token lists
conv_buffer = []
cursor = ddp_rank # Each rank processes different conversations (for fetching)
consumed = ddp_rank # Track actual consumption separately from buffering
epoch = 1
it = 0 # iteration counter
def refill_buffer():
nonlocal cursor, epoch
while len(conv_buffer) < buffer_size:
conversation = dataset[cursor]
ids, _ = tokenizer.render_conversation(conversation)
token_buffer.extend(ids)
conv_buffer.append(ids)
cursor += ddp_world_size
if cursor >= dataset_size:
cursor -= dataset_size # wrap around for another epoch
if split == "train":
last_step = True # toggle last_step to True, which will terminate the training loop
cursor = cursor % dataset_size
epoch += 1
# Note: last_step is now triggered based on consumption, not fetching
while True:
rows = []
for _ in range(args.device_batch_size):
row = []
while len(row) < row_capacity:
# Ensure buffer has conversations
while len(conv_buffer) < buffer_size:
refill_buffer()
remaining = row_capacity - len(row)
# Find largest conversation that fits entirely
best_idx = -1
best_len = 0
for i, conv in enumerate(conv_buffer):
conv_len = len(conv)
if conv_len <= remaining and conv_len > best_len:
best_idx = i
best_len = conv_len
if best_idx >= 0:
# Found a conversation that fits - use it entirely
conv = conv_buffer.pop(best_idx)
row.extend(conv)
consumed += ddp_world_size # Track actual consumption
else:
# No conversation fits - crop first conversation to fill remaining
conv = conv_buffer.pop(0)
row.extend(conv[:remaining])
consumed += ddp_world_size # Track actual consumption
rows.append(row[:row_capacity])
# Stopping condition to respect num_iterations, if given
it += 1
if 0 < num_iterations <= it and split == "train":
last_step = True # toggle last_step to True, which will terminate the training loop
# Build up inputs/targets and yield
for i in range(needed_tokens):
scratch[i] = token_buffer.popleft()
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
if 0 < args.num_iterations <= it and split == "train":
last_step = True
# Update progress tracking (based on consumed, not cursor, to account for buffering)
if split == "train":
if num_iterations > 0:
approx_progress = it / num_iterations # calculate progress from the max number of iterations
current_epoch = epoch
if args.num_iterations > 0:
approx_progress = it / args.num_iterations
else:
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
approx_progress = consumed / dataset_size
# Trigger last_step when we've consumed enough (instead of when cursor wraps)
if consumed >= dataset_size:
last_step = True
# Build tensors
use_cuda = device_type == "cuda"
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
yield inputs, targets
train_loader = mid_data_generator("train")
build_val_loader = lambda: mid_data_generator("val")
train_loader = mid_data_generator_bos_bestfit("train")
build_val_loader = lambda: mid_data_generator_bos_bestfit("val")
progress = 0 # will go from 0 to 1 over the course of the epoch
# Learning rate scheduler
@ -179,7 +240,7 @@ ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
step = 0
while True:
flops_so_far = num_flops_per_token * total_batch_size * step
flops_so_far = num_flops_per_token * args.total_batch_size * step
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
if ddp:
@ -188,10 +249,10 @@ while True:
last_step = bool(last_step_tensor.item())
# once in a while: evaluate the val bpb (all ranks participate)
if eval_every > 0 and (last_step or step % eval_every == 0):
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
model.eval()
val_loader = build_val_loader()
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
with autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
@ -206,8 +267,8 @@ while True:
model.train()
# save checkpoint at the end of the run (only on master process)
if master_process and last_step and not dry_run:
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
if master_process and last_step and not args.dry_run:
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
save_checkpoint(
checkpoint_dir,
@ -218,7 +279,7 @@ while True:
"step": step,
"val_bpb": val_bpb, # loss at last step
"model_config": {
"sequence_len": max_seq_len,
"sequence_len": args.max_seq_len,
"vocab_size": tokenizer.get_vocab_size(),
"n_layer": depth,
"n_head": model.config.n_head,
@ -268,13 +329,13 @@ while True:
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * progress
tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
tok_per_sec = int(args.total_batch_size / dt)
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
if step % 10 == 0:
wandb_run.log({
"step": step,
@ -285,6 +346,7 @@ while True:
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
"train/epoch": current_epoch,
})
# print a few more stats
@ -293,7 +355,7 @@ print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report
if not dry_run:
if not args.dry_run:
from nanochat.report import get_report
get_report().log(section="Midtraining", data=[
user_config, # CLI args

View File

@ -1,5 +1,5 @@
"""
Train a tokenizer using the HuggingFace Tokenizers library.
Train a tokenizer using our own BPE Tokenizer library.
In the style of GPT-4 tokenizer.
"""
import os
@ -14,9 +14,9 @@ from nanochat.dataset import parquets_iter_batched
# Parse command line arguments
parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)')
parser.add_argument('--max-chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
parser.add_argument('--doc-cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
parser.add_argument('--vocab-size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)')
args = parser.parse_args()
print(f"max_chars: {args.max_chars:,}")
print(f"doc_cap: {args.doc_cap:,}")

View File

@ -48,13 +48,6 @@ python -m nanochat.report reset
# -----------------------------------------------------------------------------
# Tokenizer
# Install Rust / Cargo
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"
# Build the rustbpe Tokenizer
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
# Download the first ~2B characters of pretraining dataset
# look at dev/repackage_data_reference.py for details on how this data was prepared
# each data shard is ~250M chars
@ -62,11 +55,11 @@ uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
python -m nanochat.dataset -n 8
# Immediately also kick off downloading more shards in the background while tokenizer trains
# See comment below for why 240 is the right number here
python -m nanochat.dataset -n 240 &
# See comment below for why 370 is the right number here
python -m nanochat.dataset -n 370 &
DATASET_DOWNLOAD_PID=$!
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
python -m scripts.tok_train --max_chars=2000000000
python -m scripts.tok_train --max-chars=2000000000 --vocab-size=65536
# evaluate the tokenizer (report compression ratio etc.)
python -m scripts.tok_eval
@ -77,7 +70,9 @@ python -m scripts.tok_eval
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
# Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk.
# Round up to 240 for safety. Also, the new DataLoader wastes about 35% of tokens to cropping
# so 240 / (1 - 0.35) = 370 shards are needed.
# At ~100MB/shard, this downloads ~37GB of data to disk.
# (The total number of shards available in the entire dataset is 1822.)
echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID
@ -86,7 +81,7 @@ wait $DATASET_DOWNLOAD_PID
NPROC_PER_NODE=8
# pretrain the d20 model
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target-param-data-ratio=20 --run=$WANDB_RUN
# evaluate the model on a larger chunk of train/val data and draw some samples
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
# evaluate the model on CORE tasks

View File

@ -39,13 +39,9 @@ class MockModel:
def forward(self, ids, kv_cache=None):
"""Return uniform logits so sampling is spread across vocab."""
B, T = ids.shape
# Simulate what a real transformer does: insert k,v into the cache for each layer
# With FA3, flash_attn_with_kvcache updates cache in-place and we advance position
if kv_cache is not None:
head_dim = self.config.n_embd // self.config.n_head
for layer_idx in range(self.config.n_layer):
k = torch.zeros(B, self.config.n_kv_head, T, head_dim)
v = torch.zeros(B, self.config.n_kv_head, T, head_dim)
kv_cache.insert_kv(layer_idx, k, v)
kv_cache.advance(T)
# Uniform logits -> equal probability for all tokens
logits = torch.zeros(B, T, self.vocab_size)
return logits
@ -85,16 +81,11 @@ class ByteTokenizer:
byte_tokens = [t for t in tokens if t < 256]
return bytes(byte_tokens).decode("utf-8", errors="replace")
def test_kv_cache_resize():
"""
The KV cache was not resized correctly, more information here:
https://github.com/karpathy/nanochat/pull/186
This test reproduces the issue and will be merged alongside the fix.
"""
def test_kv_cache_basic():
"""Test basic KVCache functionality for FA3."""
batch_size = 2
num_heads = 3
seq_len = 4
seq_len = 64
head_dim = 5
num_layers = 6
@ -103,45 +94,64 @@ def test_kv_cache_resize():
num_heads=num_heads,
seq_len=seq_len,
head_dim=head_dim,
num_layers=num_layers
num_layers=num_layers,
device="cpu",
)
# Insert a single token with a distinct fill value to all layers
def insert_token(token_idx):
for layer_idx in range(num_layers):
k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32)
v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32)
kv_cache.insert_kv(layer_idx, k, v)
# Check initial state
assert kv_cache.get_pos() == 0
assert kv_cache.k_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim)
assert kv_cache.v_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim)
# Insert 4 tokens (fills the initial seq_len=4)
for i in range(4):
insert_token(i)
# Test advance
kv_cache.advance(10)
assert kv_cache.get_pos() == 10
# Record the original state of the cache
original_cache = kv_cache.kv_cache.clone()
original_seq_len = original_cache.shape[4]
kv_cache.advance(5)
assert kv_cache.get_pos() == 15
# Insert the 5th token, which will trigger a resize
insert_token(4)
# Verify that the cache actually resized
new_seq_len = kv_cache.kv_cache.shape[4]
assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
# Test reset
kv_cache.reset()
assert kv_cache.get_pos() == 0
# Verify that the original 4 tokens are still intact after resize
for layer_idx in range(num_layers):
for token_idx in range(4):
# Check that resized cache matches expected values
expected_k = float(token_idx)
expected_v = float(token_idx * 100)
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :]
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :]
assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
# And that the original cache matches resized cache
original_k = original_cache[layer_idx, 0, :, :, token_idx, :]
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
# Test get_layer_cache returns correct views
k_layer0, v_layer0 = kv_cache.get_layer_cache(0)
assert k_layer0.shape == (batch_size, seq_len, num_heads, head_dim)
assert v_layer0.shape == (batch_size, seq_len, num_heads, head_dim)
def test_kv_cache_prefill():
"""Test KVCache.prefill() copies data correctly."""
batch_size = 1
num_heads = 4
head_dim = 8
num_layers = 2
# Create source cache and advance it
src_cache = KVCache(
batch_size=batch_size, num_heads=num_heads, seq_len=32,
head_dim=head_dim, num_layers=num_layers, device="cpu",
)
# Write some data to source cache
src_cache.k_cache[0, 0, :16, :, :] = 1.0
src_cache.v_cache[0, 0, :16, :, :] = 2.0
src_cache.advance(16)
# Create destination cache with larger seq_len
dst_cache = KVCache(
batch_size=batch_size, num_heads=num_heads, seq_len=64,
head_dim=head_dim, num_layers=num_layers, device="cpu",
)
# Prefill
dst_cache.prefill(src_cache)
# Check position was copied
assert dst_cache.get_pos() == 16
# Check data was copied
assert (dst_cache.k_cache[0, 0, :16, :, :] == 1.0).all()
assert (dst_cache.v_cache[0, 0, :16, :, :] == 2.0).all()
def test_multi_sample_first_token_diversity():

View File

@ -1,718 +0,0 @@
"""
Comparing the training of:
1. (very slow) Python reference implementation
2. Optimized Python implementation
3. HuggingFace tokenizers training implementation
4. Our own custom RustBPE training implementation
All of these should calculate the same merges and produce
the same vocabulary and tokenizations.
Finally, for inference we will use tiktoken for efficiency.
So we want to make sure we can export our rustbpe tokenizer
into tiktoken and use it for inference with identical results.
Run with:
python -m pytest tests/test_rustbpe.py -v -s
-v is verbose, -s is show prints
"""
import regex as re
from collections import Counter, defaultdict
import time
import warnings
import rustbpe
import tiktoken
import pytest
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
# -----------------------------------------------------------------------------
# Reference tokenizer, pretty much copy pasted and pruned a bit from minbpe
def get_stats(ids, counts=None):
"""
Given a list of integers, return a dictionary of counts of consecutive pairs
Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
Optionally allows to update an existing dictionary of counts
"""
counts = {} if counts is None else counts
for pair in zip(ids, ids[1:]): # iterate consecutive elements
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(ids, pair, idx):
"""
In the list of integers (ids), replace all consecutive occurrences
of pair with the new integer token idx
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
"""
newids = []
i = 0
while i < len(ids):
# if not at the very last position AND the pair matches, replace it
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
class RegexTokenizer:
def __init__(self, pattern=None):
"""
- pattern: optional string to override the default (GPT-4 split pattern)
- special_tokens: str -> int dictionary of special tokens
example: {'<|endoftext|>': 100257}
"""
self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
self.merges = {} # (int, int) -> int
self.compiled_pattern = re.compile(self.pattern)
self.special_tokens = {}
self.inverse_special_tokens = {}
self.vocab = self._build_vocab()
def _build_vocab(self):
# vocab is simply and deterministically derived from merges
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
for special, idx in self.special_tokens.items():
vocab[idx] = special.encode("utf-8")
return vocab
def train(self, text, vocab_size, verbose=False):
assert vocab_size >= 256
num_merges = vocab_size - 256
# keep track of whether at any point during training the merge is ambiguous (counts of pairs are not unique)
ambiguous = False
# split the text up into text chunks
text_chunks = re.findall(self.compiled_pattern, text)
# input text preprocessing
ids = [list(ch.encode("utf-8")) for ch in text_chunks]
# iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
for i in range(num_merges):
# count the number of times every consecutive pair appears
stats = {}
for chunk_ids in ids:
# passing in stats will update it in place, adding up counts
get_stats(chunk_ids, stats)
# find the pair with the highest count
pair = max(stats, key=stats.get)
# check if the merge is ambiguous - i.e. the max value is not unique
pair_count = stats[pair]
pairs_with_max_count = [pair for pair, count in stats.items() if count == pair_count]
if len(pairs_with_max_count) > 1:
# print the top 10 pairs with their counts
# print(f"{i} Merge is ambiguous! {pair} has {pair_count} occurrences")
# for print_pair, print_count in sorted(stats.items(), key=lambda x: x[1], reverse=True)[:10]:
# print(f"{print_pair}: {print_count}")
ambiguous = True
# mint a new token: assign it the next available id
idx = 256 + i
# replace all occurrences of pair in ids with idx
ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
# save the merge
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# prints
if verbose:
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
# save class variables
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()
return ambiguous
def _encode_chunk(self, text_bytes):
# return the token ids
# let's begin. first, convert all bytes to integers in range 0..255
ids = list(text_bytes)
while len(ids) >= 2:
# find the pair with the lowest merge index
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# subtle: if there are no more merges available, the key will
# result in an inf for every single pair, and the min will be
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
ids = merge(ids, pair, idx)
return ids
def encode_ordinary(self, text):
"""Encoding that ignores any special tokens."""
# split text into chunks of text by categories defined in regex pattern
text_chunks = re.findall(self.compiled_pattern, text)
# all chunks of text are encoded separately, then results are joined
ids = []
for chunk in text_chunks:
chunk_bytes = chunk.encode("utf-8") # raw bytes
chunk_ids = self._encode_chunk(chunk_bytes)
ids.extend(chunk_ids)
return ids
# -----------------------------------------------------------------------------
# Faster Python tokenizer, optimized version of the reference tokenizer
def fast_merge_inplace(ids, pair, idx):
"""
In the list of integers (ids), replace all consecutive occurrences
of pair with the new integer token idx in place
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
"""
# Find all positions where the pair occurs
i = 0
while i < len(ids) - 1:
if ids[i] == pair[0] and ids[i+1] == pair[1]:
ids[i] = idx
ids.pop(i+1)
else:
i += 1
return ids
class FastRegexTokenizer:
def __init__(self, pattern=None):
"""
- pattern: optional string to override the default (GPT-4 split pattern)
- special_tokens: str -> int dictionary of special tokens
example: {'<|endoftext|>': 100257}
"""
self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
self.compiled_pattern = re.compile(self.pattern)
self.special_tokens = {}
self.inverse_special_tokens = {}
self.merges = {}
self.vocab = self._build_vocab()
def _build_vocab(self):
# vocab is simply and deterministically derived from merges
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
for special, idx in self.special_tokens.items():
vocab[idx] = special.encode("utf-8")
return vocab
def train(self, text, vocab_size, verbose=False):
"""
A number of optimizations are introduced:
- delete function call overhead by inlining functions
- modifying list of ids in place with .pop() instead of creating a new list
- collapse identical chunks to just the unique ones
- update counts more cleverly - only around the affected chunks
"""
assert vocab_size >= 256
num_merges = vocab_size - 256
# split the text up into text chunks
text_chunks = re.findall(self.compiled_pattern, text)
# many, many chunks are identical, so we can "collapse" them to just the unique ones
counts = Counter(text_chunks)
unique_chunks = [ch for ch, count in counts.items()]
chunk_counts = [count for ch, count in counts.items()]
# input text preprocessing
ids = [list(ch.encode("utf-8")) for ch in unique_chunks]
# iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
# Initial count: build stats and position tracking
stats = defaultdict(int)
positions = defaultdict(set) # pair -> set of chunk indices that contain this pair
for chunk_idx, (chunk_ids, count) in enumerate(zip(ids, chunk_counts)):
for pair in zip(chunk_ids, chunk_ids[1:]):
stats[pair] += count
positions[pair].add(chunk_idx)
for i in range(num_merges):
if not stats:
break
# find the pair with the highest count
pair = max(stats, key=stats.get)
# mint a new token: assign it the next available id
idx = 256 + i
# Get chunks that contain this pair
affected_chunks = positions[pair]
# Track count changes for incremental update
count_changes = defaultdict(int)
# Replace all occurrences of pair in affected chunks only
for chunk_idx in affected_chunks:
chunk_ids = ids[chunk_idx]
chunk_count = chunk_counts[chunk_idx]
ix = 0
while ix < len(chunk_ids) - 1:
if chunk_ids[ix] == pair[0] and chunk_ids[ix+1] == pair[1]:
# Track what pairs are being removed/added
# Remove: (prev, A), (A, B), (B, next)
if ix > 0:
old_left = (chunk_ids[ix-1], chunk_ids[ix])
count_changes[old_left] -= chunk_count
# The merged pair disappears
count_changes[pair] -= chunk_count
if ix + 2 < len(chunk_ids):
old_right = (chunk_ids[ix+1], chunk_ids[ix+2])
count_changes[old_right] -= chunk_count
# Apply the merge
chunk_ids[ix] = idx
chunk_ids.pop(ix+1)
# Add: (prev, C), (C, next)
if ix > 0:
new_left = (chunk_ids[ix-1], chunk_ids[ix])
count_changes[new_left] += chunk_count
if ix + 1 < len(chunk_ids):
new_right = (chunk_ids[ix], chunk_ids[ix+1])
count_changes[new_right] += chunk_count
else:
ix += 1
# Apply incremental changes to stats and positions
for changed_pair, delta in count_changes.items():
if changed_pair == pair:
# The merged pair should disappear completely
continue
stats[changed_pair] += delta
# Update positions for changed pairs - only check affected chunks
for chunk_idx in affected_chunks:
chunk_ids = ids[chunk_idx]
contains_pair = any((chunk_ids[j], chunk_ids[j+1]) == changed_pair
for j in range(len(chunk_ids) - 1))
if contains_pair:
positions[changed_pair].add(chunk_idx)
else:
positions[changed_pair].discard(chunk_idx)
# Remove the merged pair completely
del stats[pair]
del positions[pair]
# save the merge
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# save class variables
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()
def register_special_tokens(self, special_tokens):
# special_tokens is a dictionary of str -> int
# example: {"<|endoftext|>": 100257}
self.special_tokens = special_tokens
self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
def decode(self, ids):
# given ids (list of integers), return Python string
part_bytes = []
for idx in ids:
if idx in self.vocab:
part_bytes.append(self.vocab[idx])
elif idx in self.inverse_special_tokens:
part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
else:
raise ValueError(f"invalid token id: {idx}")
text_bytes = b"".join(part_bytes)
text = text_bytes.decode("utf-8", errors="replace")
return text
def _encode_chunk(self, text_bytes):
# return the token ids
# let's begin. first, convert all bytes to integers in range 0..255
ids = list(text_bytes)
while len(ids) >= 2:
# find the pair with the lowest merge index
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# subtle: if there are no more merges available, the key will
# result in an inf for every single pair, and the min will be
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
ids = fast_merge_inplace(ids, pair, idx)
return ids
def encode_ordinary(self, text):
"""Encoding that ignores any special tokens."""
# split text into chunks of text by categories defined in regex pattern
text_chunks = re.findall(self.compiled_pattern, text)
# all chunks of text are encoded separately, then results are joined
ids = []
for chunk in text_chunks:
chunk_bytes = chunk.encode("utf-8") # raw bytes
chunk_ids = self._encode_chunk(chunk_bytes)
ids.extend(chunk_ids)
return ids
# -----------------------------------------------------------------------------
# HuggingFace tokenizer
from tokenizers import Tokenizer as HFTokenizer
from tokenizers import pre_tokenizers, decoders, Regex
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
class HuggingFaceTokenizer:
"""Light wrapper around HuggingFace Tokenizer for some utilities"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
@classmethod
def train_from_iterator(cls, text_iterator, vocab_size):
# train from an iterator of text
# Configure the HuggingFace Tokenizer
tokenizer = HFTokenizer(BPE(
byte_fallback=True, # needed!
unk_token=None,
fuse_unk=False,
))
# Normalizer: None
tokenizer.normalizer = None
# Pre-tokenizer: GPT-4 style
gpt4_split_regex = Regex(GPT4_SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
])
# Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
tokenizer.decoder = decoders.ByteLevel()
# Post-processor: None
tokenizer.post_processor = None
# Trainer: BPE
trainer = BpeTrainer(
vocab_size=vocab_size,
show_progress=True,
min_frequency=0, # no minimum frequency
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
special_tokens=[], # no special tokens
)
# Kick off the training
tokenizer.train_from_iterator(text_iterator, trainer)
return cls(tokenizer)
def encode_ordinary(self, text):
ids = self.tokenizer.encode(text, add_special_tokens=False).ids
return ids
# -----------------------------------------------------------------------------
# Test all of the above
@pytest.fixture(scope="module")
def enwik8_path():
"""Fixture to download and cache enwik8 dataset."""
import os
import zipfile
from nanochat.common import get_base_dir
base_dir = get_base_dir()
# download and unzip enwik8 to .cache directory
enwik8_url = "https://mattmahoney.net/dc/enwik8.zip"
enwik8_local_path = os.path.join(base_dir, "enwik8")
enwik8_local_path_zip = os.path.join(base_dir, "enwik8.zip")
if not os.path.exists(enwik8_local_path):
print(f"Downloading enwik8 to {enwik8_local_path_zip}")
import requests
response = requests.get(enwik8_url)
with open(enwik8_local_path_zip, "wb") as f:
f.write(response.content)
with zipfile.ZipFile(enwik8_local_path_zip, "r") as zip_ref:
zip_ref.extractall(base_dir)
print(f"Unzipped enwik8 to {enwik8_local_path}")
os.remove(enwik8_local_path_zip)
print(f"Removed {enwik8_local_path_zip}")
else:
print(f"Using existing enwik8 at {enwik8_local_path}")
return enwik8_local_path
@pytest.fixture(scope="module")
def enwik8_small(enwik8_path):
"""Fixture providing 100KB of enwik8 for quick tests."""
with open(enwik8_path, "r", encoding="utf-8") as f:
return f.read(100_000)
@pytest.fixture(scope="module")
def enwik8_large(enwik8_path):
"""Fixture providing 10MB of enwik8 for performance tests."""
with open(enwik8_path, "r", encoding="utf-8") as f:
return f.read(10**7)
def time_function(func, *args, **kwargs):
"""Time a function call and return the result and elapsed time"""
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
elapsed = end_time - start_time
return result, elapsed
def test_correctness(enwik8_small):
"""Test that all tokenizer implementations produce the same results."""
text = enwik8_small
encode_text = text
vocab_size = 256 + 20 # 20 merges
# Train slow reference
print("\nTraining slow reference...")
slow_reference_tokenizer = RegexTokenizer()
ambiguous_flag, slow_reference_train_time = time_function(slow_reference_tokenizer.train, text, vocab_size)
slow_reference_ids, slow_reference_encode_time = time_function(slow_reference_tokenizer.encode_ordinary, encode_text)
print(f"Slow reference train time: {slow_reference_train_time:.4f}s")
print(f"Slow reference encode time: {slow_reference_encode_time:.4f}s")
print(slow_reference_ids[:20])
if ambiguous_flag:
print("‼️ WARNING: merge order was detected to be ambiguous given current text and vocab size")
print("The implementation could be correct but we might see different results below")
else:
print("✅ Merge order is NOT ambiguous")
# Train fast reference
print("\nTraining fast reference...")
fast_reference_tokenizer = FastRegexTokenizer()
_, fast_reference_train_time = time_function(fast_reference_tokenizer.train, text, vocab_size)
fast_reference_ids, fast_reference_encode_time = time_function(fast_reference_tokenizer.encode_ordinary, encode_text)
print(f"Fast reference train time: {fast_reference_train_time:.4f}s")
print(f"Fast reference encode time: {fast_reference_encode_time:.4f}s")
print(fast_reference_ids[:20])
# Assert fast equals slow
assert fast_reference_ids == slow_reference_ids, "Fast reference should match slow reference"
print("✅ Fast == Slow")
# Train HuggingFace
print("\nTraining HuggingFace...")
hf_tokenizer, hf_train_time = time_function(HuggingFaceTokenizer.train_from_iterator, [text], vocab_size)
hf_ids, hf_encode_time = time_function(hf_tokenizer.encode_ordinary, encode_text)
print(f"HuggingFace train time: {hf_train_time:.4f}s")
print(f"HuggingFace encode time: {hf_encode_time:.4f}s")
print(hf_ids[:20])
# HuggingFace has a different byte order, so we need custom matching
def custom_match(ids1, ids2):
perm = {}
for x, y in zip(ids1, ids2):
if x < 256:
if x in perm:
if perm[x] != y:
return False
perm[x] = y
if x >= 256 and x != y:
return False
return True
assert custom_match(hf_ids, fast_reference_ids), "HuggingFace should match fast reference"
print("✅ HuggingFace == Fast")
# Finally use our own Rust implementation
print("\nTraining rustbpe...")
rustbpe_tokenizer = rustbpe.Tokenizer()
_, rustbpe_train_time = time_function(rustbpe_tokenizer.train_from_iterator, [text], vocab_size)
rustbpe_ids, rustbpe_encode_time = time_function(rustbpe_tokenizer.encode, encode_text)
print(f"RustBPE train time: {rustbpe_train_time:.4f}s")
print(f"RustBPE encode time: {rustbpe_encode_time:.4f}s")
print(rustbpe_ids[:20])
assert rustbpe_ids == fast_reference_ids, "RustBPE should match fast reference"
print("✅ RustBPE == Fast")
# Now export rustbpe to tiktoken for more efficient inference
print("\nTesting tiktoken export...")
pattern = rustbpe_tokenizer.get_pattern()
mergeable_ranks_list = rustbpe_tokenizer.get_mergeable_ranks()
mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
enc = tiktoken.Encoding(
name="rustbpe",
pat_str=pattern,
mergeable_ranks=mergeable_ranks,
special_tokens={},
)
tiktoken_ids, tiktoken_encode_time = time_function(enc.encode, encode_text)
print(f"Tiktoken encode time: {tiktoken_encode_time:.4f}s")
print(tiktoken_ids[:20])
assert tiktoken_ids == rustbpe_ids, "Tiktoken should match RustBPE"
print("✅ Tiktoken == RustBPE")
@pytest.mark.slow
def test_training_performance(enwik8_large):
"""Use a bigger dataset and compare the training speed of the optimized tokenizers (Python, Rust, HuggingFace)."""
text = enwik8_large
vocab_size = 2048
print(f"\nText length: {len(text)}")
# Commenting out because it's just way too slow to matter
# Train optimized python version
# print("Training optimized python version...")
# optimized_python_tokenizer = FastRegexTokenizer()
# _, optimized_python_train_time = time_function(optimized_python_tokenizer.train, text, vocab_size)
# print(f"Optimized python train time: {optimized_python_train_time:.4f}s")
# Train rustbpe
print("\nTraining rustbpe...")
rustbpe_tokenizer = rustbpe.Tokenizer()
_, rustbpe_train_time = time_function(rustbpe_tokenizer.train_from_iterator, [text], vocab_size)
print(f"RustBPE train time: {rustbpe_train_time:.4f}s")
assert rustbpe_train_time > 0, "Training should take some time"
# Train HuggingFace
print("\nTraining HuggingFace...")
hf_tokenizer, hf_train_time = time_function(HuggingFaceTokenizer.train_from_iterator, [text], vocab_size)
print(f"HuggingFace train time: {hf_train_time:.4f}s")
assert hf_train_time > 0, "Training should take some time"
# Print comparison
print(f"\n📊 Performance comparison:")
print(f" RustBPE: {rustbpe_train_time:.4f}s")
print(f" HuggingFace: {hf_train_time:.4f}s")
print(f" Speedup: {hf_train_time/rustbpe_train_time:.2f}x")
def test_interface(enwik8_small):
"""Test the RustBPETokenizer interface for training, encoding, decoding, and serialization."""
import tempfile
from nanochat.tokenizer import RustBPETokenizer
# Simple train test
vocab_size = 300
tok = RustBPETokenizer.train_from_iterator([enwik8_small], vocab_size)
assert tok.get_vocab_size() == vocab_size, f"Expected vocab size {vocab_size}, got {tok.get_vocab_size()}"
print(f"✅ Trained tokenizer with vocab size {vocab_size}")
# Encode/decode text
encode_text = "Hello world! How are you? 🙃"
ids = tok.encode(encode_text)
print(f"\nInput text: {encode_text}")
print(f"IDs: {ids}")
decoded = tok.decode(ids)
print(f"Decoded: {decoded}")
assert decoded == encode_text, f"Decoded text doesn't match: {decoded} != {encode_text}"
print("✅ Encode/decode test passed")
# Encode batch test
ids_new = tok.encode([encode_text, encode_text])
assert all(x == ids for x in ids_new), "Batch encoding should produce identical results"
print("✅ Encode batch OK")
# append/prepend functionality
ids_special = tok.encode(encode_text, prepend="<|bos|>", append="<|bos|>")
bos_token_id = tok.encode_special("<|bos|>")
assert ids_special == [bos_token_id] + ids + [bos_token_id], "Special tokens not correctly added"
print("✅ append/prepend OK")
# Save/load test through a temporary directory
with tempfile.TemporaryDirectory() as tmp_dir:
tok.save(tmp_dir)
tok_reloaded = RustBPETokenizer.from_directory(tmp_dir)
ids_reloaded = tok_reloaded.encode(encode_text)
assert ids_reloaded == ids, "Reloaded tokenizer should produce same results"
print("✅ Save/load through temporary directory OK")
def test_batch_encode_correctness(enwik8_small):
"""Quick correctness test for batch_encode()"""
text = enwik8_small
vocab_size = 512
tokenizer = rustbpe.Tokenizer()
tokenizer.train_from_iterator([text], vocab_size)
# Test with various batch sizes and edge cases
test_texts = [
"Hello world",
"The quick brown fox",
"jumps over the lazy dog",
"", # empty string
"a", # single char
]
# Compare batch vs individual encoding
individual = [tokenizer.encode(t) for t in test_texts]
batched = tokenizer.batch_encode(test_texts)
assert individual == batched, "Batch encoding should match individual encoding"
print("✅ batch_encode() correctness verified")
@pytest.mark.slow
def test_batch_encode_performance(enwik8_large):
"""
Benchmark batch_encode() vs sequential encode() loop.
Demonstrates parallelization speedup.
"""
# Setup
text = enwik8_large # 10MB dataset
vocab_size = 2048
# Train tokenizer
print("\nTraining tokenizer...")
tokenizer = rustbpe.Tokenizer()
tokenizer.train_from_iterator([text], vocab_size)
# Create test batch: split text into chunks
chunk_size = 50_000 # ~50KB per chunk
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
chunks = chunks[:20] # Use first 20 chunks (~1MB total)
print(f"\nBatch encoding benchmark:")
print(f" Number of texts: {len(chunks)}")
print(f" Avg text length: {sum(len(c) for c in chunks) / len(chunks):.0f} chars")
# Benchmark 1: Sequential encoding (baseline)
print("\n [1/3] Sequential encode() loop...")
sequential_results, sequential_time = time_function(
lambda: [tokenizer.encode(chunk) for chunk in chunks]
)
print(f" Time: {sequential_time:.4f}s")
# Benchmark 2: Parallel batch_encode()
print(" [2/3] Parallel batch_encode()...")
batch_results, batch_time = time_function(
tokenizer.batch_encode, chunks
)
print(f" Time: {batch_time:.4f}s")
# Verify correctness
print(" [3/3] Verifying correctness...")
assert len(batch_results) == len(sequential_results), "Result count mismatch"
for i, (seq, batch) in enumerate(zip(sequential_results, batch_results)):
assert seq == batch, f"Mismatch at index {i}"
print(" ✓ All results match")
# Report speedup
speedup = sequential_time / batch_time
print(f"\n Performance Results:")
print(f" Sequential: {sequential_time:.4f}s")
print(f" Batch: {batch_time:.4f}s")
print(f" Speedup: {speedup:.2f}x")
# Warn if speedup is low (can vary by machine/load)
if speedup < 1.5:
warnings.warn(f"batch_encode() speedup was only {speedup:.2f}x (expected >1.5x)")

1749
uv.lock

File diff suppressed because it is too large Load Diff