mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-12 11:58:23 +00:00
Merge branch 'master' into topk-validation-fix
# Please enter a commit message to explain why this merge is necessary, # especially if it merges an updated upstream into a topic branch. # # Lines starting with '#' will be ignored, and an empty message aborts # the commit.
This commit is contained in:
commit
8f6e616d12
6
.gitignore
vendored
6
.gitignore
vendored
|
|
@ -12,9 +12,3 @@ eval_bundle/
|
|||
.claude
|
||||
CLAUDE.md
|
||||
wandb/
|
||||
|
||||
# Local experimentation
|
||||
experiments/
|
||||
ignore/
|
||||
knowledge/
|
||||
ideas/
|
||||
|
|
|
|||
48
README.md
48
README.md
|
|
@ -4,28 +4,29 @@
|
|||
|
||||
> The best ChatGPT that $100 can buy.
|
||||
|
||||
This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs.
|
||||
|
||||
## Talk to it
|
||||
|
||||
To get a sense of the endpoint of this repo, you can currently find [nanochat d34](https://github.com/karpathy/nanochat/discussions/314) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d34" means that this model has 34 layers in the Transformer neural network. This model has 2.2 billion parameters, it was trained on 88 billion tokens by simply running the training script [run1000.sh](run1000.sh) with `--target_param_data_ratio=40` (2x longer than Chinchilla-optimal), and the total cost of training was ~$2,500 (about 100 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of modern Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
|
||||
This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](runs/speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs.
|
||||
|
||||
## Updates
|
||||
|
||||
- (Jan 7 2026) See new post: [nanochat Miniseries v1](https://github.com/karpathy/nanochat/discussions/420) and the associated script [miniseries.sh](miniseries.sh).
|
||||
- (Jan 16 2026) The repo is in active development, I am currently fleshing out the pretraining stage.
|
||||
- (Jan 7 2026) See new post: [nanochat Miniseries v1](https://github.com/karpathy/nanochat/discussions/420) and the associated script [miniseries.sh](runs/miniseries.sh).
|
||||
|
||||
## Talk to it
|
||||
|
||||
To get a sense of the endpoint of this repo, you can currently find [nanochat d34](https://github.com/karpathy/nanochat/discussions/314) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d34" means that this model has 34 layers in the Transformer neural network. This model has 2.2 billion parameters, it was trained on 88 billion tokens by simply running the training script [run1000.sh](runs/run1000.sh) with `--target_param_data_ratio=40` (2x longer than Chinchilla-optimal), and the total cost of training was ~$2,500 (about 100 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of modern Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
|
||||
|
||||
## Quick start
|
||||
|
||||
The fastest way to feel the magic is to run the speedrun script [speedrun.sh](speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
|
||||
The fastest way to feel the magic is to run the speedrun script [speedrun.sh](runs/speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
|
||||
|
||||
```bash
|
||||
bash speedrun.sh
|
||||
bash runs/speedrun.sh
|
||||
```
|
||||
|
||||
Alternatively, since the script runs for 4 hours, I like to launch it like this inside a new screen session `speedrun` (and also log output to `speedrun.log`):
|
||||
|
||||
```bash
|
||||
screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
screen -L -Logfile speedrun.log -S speedrun bash runs/speedrun.sh
|
||||
```
|
||||
|
||||
See the [screen cheatsheet](https://gist.github.com/jctosta/af918e1618682638aa82) if you are less familiar. You can watch it go inside the screen session, or detach with `Ctrl-a d` and `tail speedrun.log` to view progress. Now wait 4 hours. Once it's done, you can talk to your LLM via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it:
|
||||
|
|
@ -72,7 +73,7 @@ Total wall clock time: 3h51m
|
|||
|
||||
Unsurprisingly, $100 is not enough to train a highly performant ChatGPT clone. In fact, LLMs are famous for their multi-million dollar capex. For our purposes, I think there are two more scales of interest. First is the ~$300 tier d26 model (i.e. depth=26) that trains in ~12 hours, which slightly outperforms GPT-2 CORE score. Second is the $1000 tier (~41.6 hours), just because it's a nice round number. But both of these are not yet fully supported and therefore not attached here in the master branch yet.
|
||||
|
||||
That said, to give a sense, the example changes needed for the [speedrun.sh](speedrun.sh) file to train a GPT-2 grade model d26 only involve three changes:
|
||||
That said, to give a sense, the example changes needed for the [speedrun.sh](runs/speedrun.sh) file to train a GPT-2 grade model d26 only involve three changes:
|
||||
|
||||
```bash
|
||||
...
|
||||
|
|
@ -82,10 +83,10 @@ That said, to give a sense, the example changes needed for the [speedrun.sh](spe
|
|||
python -m nanochat.dataset -n 450 &
|
||||
...
|
||||
# use --depth to increase model size. to not oom, halve device batch size 32 -> 16:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --device_batch_size=16
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --device-batch-size=16
|
||||
...
|
||||
# make sure to use the same later during midtraining:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16
|
||||
```
|
||||
|
||||
That's it! The biggest thing to pay attention to is making sure you have enough data shards to train on (the code will loop and do more epochs over the same training set otherwise, decreasing learning speed a bit), and managing your memory/VRAM, primarily by decreasing the `device_batch_size` until things fit (the scripts automatically compensate by increasing the number of gradient accumulation loops, simply turning parallel compute to sequential compute).
|
||||
|
|
@ -99,7 +100,7 @@ And a bit more about computing environments that will run nanochat:
|
|||
|
||||
## Running on CPU / MPS
|
||||
|
||||
nanochat can be run on CPU or on MPS (if you're on Macbook), and will automatically try to detect what device is best to run on. You're not going to get too far without GPUs, but at least you'll be able to run the code paths and maybe train a tiny LLM with some patience. For an example of how to make all the run commands much smaller (feel free to tune!), you can refer to [dev/runcpu.sh](dev/runcpu.sh) file. You'll see that I'm essentially restricting all scripts to train smaller models, to run for shorter number of iterations, etc. This functionality is new, slightly gnarly (touched a lot of code), and was merged in this [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) on Oct 21, 2025.
|
||||
nanochat can be run on CPU or on MPS (if you're on Macbook) in principle, and will automatically try to detect what device is best to run on. The script [runcpu.sh](runs/runcpu.sh) shows a very simple example that will exercise the code paths but basically produce garbage results. Unless you know what you're doing, I basically don't recommend using this script right now and hope to tune it a bit more in the future.
|
||||
|
||||
## Customization
|
||||
|
||||
|
|
@ -109,15 +110,9 @@ Additionally, to add new abilities to nanochat, see [Guide: counting r in strawb
|
|||
|
||||
## Questions
|
||||
|
||||
nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so:
|
||||
I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
|
||||
|
||||
```bash
|
||||
files-to-prompt . -e py -e md -e html -e toml -e sh --cxml > packaged.txt
|
||||
```
|
||||
|
||||
This includes all py, html, toml, sh files and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files.
|
||||
|
||||
Alternatively, I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
|
||||
You can also come to the [#nanochat Discord channel](https://discord.com/channels/1020383067459821711/1427295580895314031) to ask questions, or use the Discussions.
|
||||
|
||||
## Tests
|
||||
|
||||
|
|
@ -137,8 +132,7 @@ python -m pytest tests/test_engine.py -v -s
|
|||
│ ├── gen_synthetic_data.py # Example synthetic data for identity
|
||||
│ ├── generate_logo.html
|
||||
│ ├── nanochat.png
|
||||
│ ├── repackage_data_reference.py # Pretraining data shard generation
|
||||
│ └── runcpu.sh # Small example of how to run on CPU/MPS
|
||||
│ └── repackage_data_reference.py # Pretraining data shard generation
|
||||
├── nanochat
|
||||
│ ├── __init__.py # empty
|
||||
│ ├── adamw.py # Distributed AdamW optimizer
|
||||
|
|
@ -157,7 +151,12 @@ python -m pytest tests/test_engine.py -v -s
|
|||
│ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4
|
||||
│ └── ui.html # HTML/CSS/JS for nanochat frontend
|
||||
├── pyproject.toml
|
||||
├── run1000.sh # Train the ~$800 nanochat d32
|
||||
├── runs
|
||||
│ ├── miniseries.sh # Miniseries training script
|
||||
│ ├── run1000.sh # Train the ~$800 nanochat d32
|
||||
│ ├── runcpu.sh # Small example of how to run on CPU/MPS
|
||||
│ ├── scaling_laws.sh # Scaling laws experiments
|
||||
│ └── speedrun.sh # Train the ~$100 nanochat d20
|
||||
├── scripts
|
||||
│ ├── base_eval.py # Base model: calculate CORE score
|
||||
│ ├── base_loss.py # Base model: calculate bits per byte, sample
|
||||
|
|
@ -170,7 +169,6 @@ python -m pytest tests/test_engine.py -v -s
|
|||
│ ├── mid_train.py # Chat model: midtraining
|
||||
│ ├── tok_eval.py # Tokenizer: evaluate compression rate
|
||||
│ └── tok_train.py # Tokenizer: train it
|
||||
├── speedrun.sh # Train the ~$100 nanochat d20
|
||||
├── tasks
|
||||
│ ├── arc.py # Multiple choice science questions
|
||||
│ ├── common.py # TaskMixture | TaskSequence
|
||||
|
|
|
|||
467
dev/LOG.md
467
dev/LOG.md
|
|
@ -4,6 +4,469 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
|||
|
||||
---
|
||||
|
||||
## 2026-01-17: Various experiments
|
||||
|
||||
Modded-nanogpt uses [Value Embeddings](https://arxiv.org/abs/2410.17897) (VEs) in a funny U-shaped structure, 3 of them in total and with gates. I tried a large number of tweaks on this today:
|
||||
|
||||
- VEs at every layer, at alternating layers, U shaped, front and back. Alternating layers worked best, i.e. we end up with *a lot* more VEs than modded-nanogpt, at every other layer. It works better.
|
||||
- Many parameters sharing ideas to reduce new parameter count, nothing here worked. All failed.
|
||||
- Many ideas to reduce parameter count, the LLM hates all of them: low rank decompositions, projections. All failed.
|
||||
- Gated yes or no and how much. Gate helps.
|
||||
|
||||
Long story short is that the models *love* Value Embeddings. It is a way to add a huge amount of capacity (parameters) to the model at almost zero cost of FLOPs, because these embeddings are simply added to the Values tensor. Any attempt to reduce the capacity of value embeddings (param sharing, low rank, projections) fail. The model wants many of them, and with all the capacity, and doing so wins across all x axes of steps, flops and wall clock. I re-ran the scaling laws and, because the models are now very parameter bloated, the optimal ratio has halved from 8 to 4! Way down lower than Chinchilla's 20 at this point.
|
||||
|
||||
Other experiments, looking at val/bpb as a function of all of steps, flops and wall clock time:
|
||||
|
||||
- Aspect ratio of 128 is worse than 64, I tried a sweep fixing FLOPs == 1e18 and 64 outperforms. The LLM prefers to be slightly thinner and longer.
|
||||
- Head dim definitely prefers to be 128 instead of 64, i.e. fewer bigger heads
|
||||
- Bunch of other random stuff like that.
|
||||
|
||||
Keeping all of this work on a private branch for now but hope to push shortly.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-17: Modded-nanogpt Ideas Sweep (Continued)
|
||||
|
||||
Continued testing ideas from modded-nanogpt.
|
||||
|
||||
| Idea | Result | Notes |
|
||||
|------|--------|-------|
|
||||
| Attention gates | No improvement | Per-head learnable gates on attention output. +1GB memory, decreased efficiency. |
|
||||
| Batch size schedule | Abandoned | 8→16→24 with LR scaling. Made training script too bloated/complex, not worth cognitive overhead. |
|
||||
| Value embeddings | Helps a lot | Experiments still ongoing, more on this later. |
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-16: Flash Attention 3 Fallback to SDPA
|
||||
|
||||
Added automatic fallback from Flash Attention 3 to PyTorch's `scaled_dot_product_attention` (SDPA) for users without Hopper GPUs. This enables nanochat to run on older CUDA GPUs, CPU, and MPS (Apple Silicon).
|
||||
|
||||
### Implementation
|
||||
|
||||
Created `nanochat/flash_attention.py` - a unified interface that:
|
||||
- Detects FA3 availability at import time (requires sm90+ / Hopper)
|
||||
- Exports a `flash_attn` object matching FA3's API exactly (`flash_attn.flash_attn_func`, `flash_attn.flash_attn_with_kvcache`)
|
||||
- Automatically routes to FA3 or SDPA based on hardware
|
||||
- Handles tensor layout differences: FA3 uses (B, T, H, D), SDPA uses (B, H, T, D)
|
||||
- Implements sliding window attention via explicit masks for SDPA
|
||||
- Manages KV cache manually for SDPA (FA3 does it in-place)
|
||||
|
||||
### Changes to Existing Files
|
||||
|
||||
Changes to existing code were intentionally kept extremely minimal.
|
||||
|
||||
**gpt.py**: Only the import line changed and a comment
|
||||
|
||||
**engine.py**: Zero changes needed
|
||||
|
||||
**base_train.py**: Added status print and warnings:
|
||||
- Prints whether FA3 or SDPA fallback is being used
|
||||
- Warns about efficiency loss without FA3
|
||||
- Warns about sliding window support if `--window-pattern` is not "L"
|
||||
|
||||
### Testing
|
||||
|
||||
Tests are split into two classes due to dtype/device constraints:
|
||||
|
||||
1. **TestFA3VsSDPA**: Comparison tests requiring Hopper GPU + bfloat16. Run both implementations on identical inputs and verify outputs match (max diff typically 0, at most ~0.004 for sliding window).
|
||||
|
||||
2. **TestSDPAOnly**: SDPA-only tests that run on any device with appropriate dtype. Verify forward pass, backward pass, and KV cache work correctly.
|
||||
|
||||
Added `_override_impl` mechanism for testing - can force 'fa3' or 'sdpa' to directly compare implementations.
|
||||
|
||||
### Notes
|
||||
|
||||
- SDPA fallback is significantly slower than FA3 especially in that it lacks the sliding window attention support
|
||||
- Recommend `--window-pattern L` (full context) when using SDPA fallback
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-16: Modded-nanogpt Ideas Sweep (Mostly Negative)
|
||||
|
||||
Tested several architectural ideas from modded-nanogpt to see if they transfer to nanochat. All of these did not help:
|
||||
|
||||
| Idea | Result | Notes |
|
||||
|------|--------|-------|
|
||||
| Half-truncated RoPE | No improvement | Only first half of head dims get RoPE (base 1024, linspace). Second half "stationary". |
|
||||
| Asymmetric softcap | Slightly worse | `23 * sigmoid((x+5)/7.5)` vs our symmetric `15 * tanh(x/15)`. May only help with FP8. |
|
||||
| Smear gate | Negligible | Blend each token with predecessor via learned gate. Tiny improvement not worth n_embd² params. |
|
||||
| Backout | No improvement | Save activations at ~60% through network, subtract scaled version at end. |
|
||||
| Skip connection | Slightly worse | Save at layer ~25%, add at layer ~50%. Also +2GB memory from storing activations. |
|
||||
|
||||
Value Embeddings do show promise. I need a more elaborate exploration of a few related ideas, which I leave for tomorrow.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-15: Olmo pretraining mix (Negative result)
|
||||
|
||||
I attempted to train on the Olmo 3 pretraining dataset [allenai/dolma3_mix-6T](https://huggingface.co/datasets/allenai/dolma3_mix-6T) instead of FineWeb-edu. I ran into a number of [errors and issues](https://huggingface.co/datasets/allenai/dolma3_mix-6T/discussions/2) trying to both download and process the dataset and then noticed some quality issues (e.g. some documents seem to be extremely short, like "5".). I managed to work around these with some sensible hacks (e.g. reject documents less than 100 characters in length) and tried to process the dataset exactly as FineWeb, re-trained the tokenizer and trained a d16 model. The CORE score decreased from 15.5 to 13.8, i.e. the result is quite a bit worse.
|
||||
|
||||
I am still looking to try the [DCLM dataset](https://arxiv.org/abs/2406.11794), which according to the paper should be better that FineWeb-edu. I do have some concerns that the same group both prepared the DCLM dataset *and* introduced the CORE score so I'm a bit hesitant in case there was some overfitting to CORE score adjacent data distribution.
|
||||
|
||||
Classifying as negative result and reverting back to FineWeb-edu for now.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-13: Varlen Attention (Negative Result)
|
||||
|
||||
Attempted to prevent attention from "leaking" across document boundaries using Flash Attention's `flash_attn_varlen_func`, similar to modded-nanogpt's approach.
|
||||
|
||||
### Background
|
||||
|
||||
With the BOS-aligned dataloader, multiple documents are packed into each row. Standard attention allows tokens to attend across document boundaries within a row. The hypothesis was that preventing this "leakage" via varlen attention might improve training.
|
||||
|
||||
### Approach: Compute cu_seqlens from inputs
|
||||
|
||||
- Find BOS positions: `(inputs.view(-1) == bos_token_id).nonzero()`
|
||||
- Gotcha 1: Variable-length `cu_seqlens` caused torch.compile recompilation (25s/iter!) - fixed by padding to fixed size
|
||||
- Gotcha 2: `nonzero()` inside compiled model hit recompile limit - fixed by moving computation outside compiled region
|
||||
|
||||
### Final Results (d16)
|
||||
|
||||
| Metric | Baseline | Varlen |
|
||||
|--------|----------|--------|
|
||||
| val_bpb | 0.85427 | 0.85407 |
|
||||
| MFU | ~same | ~same |
|
||||
| tok/sec | ~same | ~same |
|
||||
|
||||
Essentially identical. The 0.0002 bpb improvement is almost noise.
|
||||
|
||||
### Conclusion
|
||||
|
||||
Not worth the code complexity. The "leakage" across document boundaries within a row is not harmful - the model handles it fine. The BOS-aligned dataloader already provides the key benefit (every row starts with proper context). Not merging to master.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-13: BOS-Aligned Dataloader with Bin Packing
|
||||
|
||||
Redesigned the pretraining and midtraining dataloader to ensure every sequence starts with a BOS token, and explored bin-packing algorithms to minimize wasted tokens.
|
||||
|
||||
### Problem Statement
|
||||
|
||||
The original dataloader streams tokens into a flat buffer and reshapes into batches. This means some rows start mid-document (no BOS), which could confuse the model during training. We want every row to start with BOS and contain well-formed documents.
|
||||
|
||||
### Approach 1: Greedy-Crop BOS (Simple)
|
||||
|
||||
Each row is built independently:
|
||||
- Start with a document (which has BOS prepended)
|
||||
- Pack more documents until row is full
|
||||
- If a document doesn't fit, **crop it** to fill remaining space (discard the rest)
|
||||
- 100% utilization (no padding), but wastes cropped tokens
|
||||
|
||||
### Waste Analysis
|
||||
|
||||
Measured token waste empirically on real data (T=2048):
|
||||
- **39.4% of tokens are cropped** (discarded when docs don't fit)
|
||||
- **22.9% is the theoretical minimum** (tokens in docs longer than T+1 that can never fit)
|
||||
- The extra ~16.5% comes from "unlucky" cropping when a long doc starts near the end of a row
|
||||
|
||||
### Bin Packing Algorithms Explored
|
||||
|
||||
| Algorithm | Util% | Crop% | Pad% | Notes |
|
||||
|-----------|-------|-------|------|-------|
|
||||
| Greedy-Crop (baseline) | 100% | 39.4% | 0% | Simple, no wasted compute |
|
||||
| Greedy-Pad | 78% | 23.0% | 22% | Pads instead of crops - wastes compute |
|
||||
| First-Fit Decreasing (FFD) | 99.7% | 23.0% | 0.3% | Near-optimal packing, minimal padding |
|
||||
| **BestFit-Crop** | 100% | 34.6% | 0% | Smart cropping, no padding |
|
||||
|
||||
### BestFit-Crop Algorithm
|
||||
|
||||
A middle ground that maintains 100% utilization while reducing cropping:
|
||||
|
||||
1. Buffer N documents
|
||||
2. For each row, greedily pick the **largest doc that fits entirely**
|
||||
3. Repeat until nothing fits
|
||||
4. When nothing fits, crop a doc to fill remaining space exactly
|
||||
|
||||
This avoids "unlucky" crops by searching the buffer for better-fitting documents.
|
||||
|
||||
**Results (T=2048):**
|
||||
- Crop waste reduced from 39.4% → 34.6% (~12% relative improvement)
|
||||
- Still achieves 100% utilization (no padding, every token trains)
|
||||
- Slightly more rows than baseline (uses more documents per batch)
|
||||
|
||||
### Decision: Keep Two Implementations
|
||||
|
||||
1. Keep the original implementation which is very simple, efficient and has 100% token utilization in the batch (no padding with ignore tokens), but creates slightly more confusing token streams for the LLM because documents during training can start abruptly from the middle with no context. Note that this never happens at test time, where BOS is always present.
|
||||
|
||||
2. **`_bos_bestfit` (BestFit-Crop, new default)**: Slightly more complex but still keeps 100% token utilization in the batch (no padding), but at the cost of discarding documents when they don't fit. In practice, about 34% of tokens are discarded with this approach. This is ok because for most models we care about we have plenty of data without having to go to multiple epochs. One more subtle effect is that it does skew the data distribution a tiny bit because, reliably and necessarily, tokens at the tails of long documents will be discarded. However, this doesn't seem to impact actual downstream performance.
|
||||
|
||||
### Midtraining
|
||||
|
||||
The midtraining dataloader was also updated. Because conversations are on average a lot shorter than pretraining documents, only about 3.3% of tokens get cropped.
|
||||
|
||||
### NOTE: loss scale
|
||||
|
||||
Do note that switching to the BOS dataloader changes the validation loss and makes all previous experiments not comparable in absolute value of the loss, because we have a lot fewer "confusing" tokens in the train/val batches. All tokens can look back and find the BOS token and have the full context of that document to make predictions. Therefore, the loss appears lower but this is "fake" to some extent, and the expectation is that the vast majority of relative comparisons done so far would agree with those before and after this change.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-13: Number Token Split Pattern
|
||||
|
||||
Validated the `\p{N}{1,2}` pattern in `SPLIT_PATTERN` (tokenizer.py line 30), which I only guessed earlier and had a TODO for to validate. GPT-4 uses `\p{N}{1,3}` to group number sequences of up to 3 digits into tokens, but we suspected smaller vocab sizes benefit from grouping fewer digits per token.
|
||||
|
||||
**Results (d12, vocab=32K):**
|
||||
| Pattern | val_bpb |
|
||||
|---------|---------|
|
||||
| `\p{N}{1,1}` | 0.969 |
|
||||
| `\p{N}{1,2}` | **0.965** |
|
||||
| `\p{N}{1,3}` | 0.972 |
|
||||
|
||||
**Conclusion:** `{1,2}` is optimal for vocab size 32K. Grouping 3 digits wastes tokens on rare 3-digit combinations; grouping 1 digit is too fine-grained and bloats token sequences. Keeping `{1,2}` as default.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-13: FP8 Training for lm_head
|
||||
|
||||
Attempted to use FP8 (8-bit floating point) for the lm_head layer to speed up the large vocab projection matmul. H100 GPUs have FP8 tensor cores that can theoretically provide ~2x speedup over BF16.
|
||||
|
||||
### Implementation Approaches Tried
|
||||
|
||||
**1. Dynamic Scaling (failed)**
|
||||
- Compute `x.abs().max()` and `w.abs().max()` each forward to determine scales
|
||||
- Problem: `.item()` calls cause graph breaks with torch.compile
|
||||
- Tried `@torch._dynamo.allow_in_graph` pattern (like torchao.float8) - worked but no speedup
|
||||
- Tried `torch.library.custom_op` with float scales - caused NaN gradients after first optimizer step
|
||||
- Root cause: interaction between custom ops, dynamic scale computation, and torch.compile is fragile
|
||||
|
||||
**2. Static Scaling (partial success)**
|
||||
- Pre-set scales at init time like modded-nanogpt: `x_scale=10/448, w_scale=0.1/448`
|
||||
- `grad_scale` computed dynamically from batch size (safe since it's just `1/(B*T)/57344` due to the gradient expression of cross entropy). modded-nanogpt has a bug here probably because they set `grad_scale = 0.75/448`, but grads are in E5M2 so this should probably be `1/57344`, 1 being the amax of any individual element of cross entropy loss, and no normalization by B,T because they use sum reduction not mean reduction.
|
||||
- Uses `torch.library.custom_op` with `@torch.compile` on inner kernels
|
||||
- This works correctly - no NaNs, proper gradients
|
||||
|
||||
### Results (d12)
|
||||
|
||||
| Metric | BF16 Baseline | FP8 lm_head |
|
||||
|--------|---------------|-------------|
|
||||
| GPU Memory | 34 GB | 36 GB |
|
||||
| tok/sec | baseline | ~1% faster |
|
||||
|
||||
### The Memory Mystery
|
||||
|
||||
FP8 *should* save memory since we store `x_f8` (1 byte) instead of `x` (2 bytes) for backward. But we see 2GB *increase*. Suspected causes:
|
||||
- `torch.compile` on inner kernels creating extra buffers/specializations
|
||||
- `torch._scaled_mm` internal workspace allocations
|
||||
- Custom op registration machinery overhead
|
||||
|
||||
Tried saving original weight `w` (just a reference to parameter) instead of `w_f8` in backward, then re-quantizing on the spot during backward - didn't help. Still saw bump.
|
||||
|
||||
### Microbenchmark vs Reality
|
||||
|
||||
Raw microbenchmark showed promise:
|
||||
- BF16 matmul: 16.95 ms
|
||||
- FP8 matmul (static scales): 10.31 ms (1.64x faster)
|
||||
- FP8 with dynamic scaling: 12.25 ms (1.38x faster)
|
||||
|
||||
But in full training, the ~1% tok/sec improvement doesn't justify the 2GB memory increase and the added code complexity and the need to tune scale factors for both x and w.
|
||||
|
||||
### Code Artifacts
|
||||
|
||||
See the branch `fp8_attempt_fail` for:
|
||||
|
||||
- `nanochat/fp8_static.py` - Static scaling implementation (working)
|
||||
- `nanochat/fp8_dynamic.py` - Dynamic scaling implementation (torchao-style, working but slow)
|
||||
- `gpt.py` imports `fp8_static.LinearFP8` and simply swaps it for `lm_head` in `gpt.py`.
|
||||
|
||||
### Open Questions
|
||||
|
||||
- Why does the custom op approach use more memory than vanilla BF16?
|
||||
- Why is the bump in tok_per_sec so low? We should see ~1.6X speedup in both the forward pass and also (twice) in backward pass for the gradients. Granted, Ahmdal's law is part of the solution because our vocab_size is only 32K so the final layer isn't a huge part of the profile but the expected speedup is still not fully realized.
|
||||
|
||||
**Conclusion:** Negative result for now. The implementation works correctly but provides marginal speedup with *increased* memory usage. I'm not understanding the torch.compile interaction here. The complexity of FP8 custom ops isn't justified for lm_head alone. TODO to study in more detail the way this is implemented in other libraries, e.g. torchao.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-12: Multi-Token Prediction (MTP)
|
||||
|
||||
Ported multi-token prediction from modded-nanogpt. Instead of predicting just the next token, predict the next n tokens at each position with weighted loss.
|
||||
|
||||
### Implementation
|
||||
|
||||
- Instead of calling the loss `n_predict` times, uses a fancy batched computation using `unfold` + `gather` + cross-entropy decomposition (`CE = logsumexp - logits[target]`)
|
||||
- Schedule anneals from 3-token to 1-token prediction:
|
||||
- 0-33%: `[1.0, 0.5, 0.25→0]` (3rd token fades)
|
||||
- 33-67%: `[1.0, 0.5→0]` (2nd token fades)
|
||||
- 67-100%: `[1.0]` (standard next-token)
|
||||
- Weights normalized to sum to 1
|
||||
|
||||
### Results (d12)
|
||||
|
||||
| Metric | Baseline | MTP |
|
||||
|--------|----------|-----|
|
||||
| GPU Memory | 34 GB | 47 GB |
|
||||
| MFU | 41% | 40% |
|
||||
| val/bpb (per step) | baseline | same/slightly worse |
|
||||
| val/bpb (wall clock) | baseline | noticeably worse |
|
||||
|
||||
**Conclusion:** Negative result for nanochat. The extra memory and compute overhead from predicting multiple tokens doesn't pay off, in fact the results get worse. The auxiliary loss signal may help in other settings (larger models, different architectures?), but for our setup it's pure overhead at the moment.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-11: Sliding Window Attention
|
||||
|
||||
Added configurable sliding window attention, inspired by GPT-3's alternating short/long pattern.
|
||||
|
||||
**Pattern string configuration:**
|
||||
- New `--window_pattern` CLI arg and `GPTConfig.window_pattern` field
|
||||
- Pattern is tiled across layers (e.g., `SSSL` for 20 layers → `SSSLSSSLSSSLSSSLSSSL`)
|
||||
- Final layer always forced to L (full context) regardless of pattern
|
||||
- Short window = `sequence_len // 2`
|
||||
- Long window = `sequence_len` (full context)
|
||||
- All previous models so far have been simply `L` and checkpoint loading is modified accordingly to fill in this param for old models, see `_patch_missing_config_keys`
|
||||
|
||||
Quick experiments showed `SSSL` (every 4th layer is long) works well - provides a good balance between compute savings and model quality. This is now the default.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-11: Flash Attention 3 Integration
|
||||
|
||||
Replaced PyTorch's `scaled_dot_product_attention` (FA2) with Flash Attention 3 for training and inference.
|
||||
|
||||
### Changes Made
|
||||
|
||||
**1. FA3 via `kernels` package**
|
||||
- Official FA3 is "beta" and requires building from source (painful)
|
||||
- Using `kernels` package from HuggingFace Hub: `get_kernel('varunneal/flash-attention-3')`
|
||||
- Loads pre-built wheels, works out of the box on H100
|
||||
|
||||
**2. Simplified attention code**
|
||||
- FA3 uses `(B, T, H, D)` layout matching our projection output directly - no transpose needed
|
||||
- Training: `flash_attn.flash_attn_func(q, k, v, causal=True)`
|
||||
- Inference: `flash_attn.flash_attn_with_kvcache()` handles all cache cases in one call
|
||||
- Removed 3 separate FA2 code paths (training, single-token, chunk inference)
|
||||
- GQA handled automatically when n_kv_heads < n_heads
|
||||
|
||||
**3. Rewrote KVCache for FA3**
|
||||
- Old format: `(num_layers, 2, B, H, T, D)` combined tensor
|
||||
- New format: separate `k_cache` and `v_cache` of shape `(num_layers, B, T, H, D)`
|
||||
- FA3 updates cache in-place during `flash_attn_with_kvcache`
|
||||
- Position tracked via `cache_seqlens` tensor (int32, per batch element)
|
||||
- Simpler API: `get_layer_cache()`, `advance()`, `reset()`, `prefill()`
|
||||
|
||||
### Results
|
||||
|
||||
- **~9% improvement in tok/sec** during training out of the box
|
||||
- Benchmarks showed FA3 is 2x faster than FA2 at realistic training sizes (batch=32, seq=2048)
|
||||
- FA3 supports sliding window via `window_size=(left, 0)`, which is huge and expected to give further improvements. This is ready to tune but keeping full context for now.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-11: Per-Layer Residual Scalars (x0 & resid lambdas)
|
||||
|
||||
Cherry-picked an idea from modded-nanogpt around learnable per-layer residual connections.
|
||||
|
||||
### Changes Made
|
||||
|
||||
**1. x0_lambdas (x0 residual connections)**
|
||||
- Save initial normalized embedding as `x0` after `norm(wte(idx))`
|
||||
- At each layer, blend x0 back in: `x = resid_lambdas[i] * x + x0_lambdas[i] * x0`
|
||||
- Zero-initialized, so disabled at start; model learns which layers benefit from the shortcut
|
||||
- Provides direct path from embedding to deep layers, helps preserve token information
|
||||
|
||||
**2. resid_lambdas (residual stream scaling)**
|
||||
- Per-layer multiplicative scaling of the residual stream
|
||||
- Initialized to 1.0 (neutral, standard transformer behavior)
|
||||
- Allows model to learn to amplify/dampen residual at each layer
|
||||
|
||||
**3. DistAdamW small parameter handling**
|
||||
- Added support for parameters with < 1024 elements (like the scalar lambdas)
|
||||
- Small params use `all_reduce` instead of `reduce_scatter`/`all_gather`
|
||||
- Fixes crash when param shape isn't divisible by world_size
|
||||
|
||||
### Key Finding: Different LR Sensitivity
|
||||
|
||||
The two scalar types need very different learning rates:
|
||||
- **x0_lambdas (additive)**: Can use normal LR (~0.5). Adding a fraction of x0 is forgiving.
|
||||
- **resid_lambdas (multiplicative)**: Needs ~100x smaller LR (~0.005). Multiplying the residual compounds through layers.
|
||||
|
||||
Implementation: `resid_params` gets `scalar_lr * 0.01`, `x0_params` gets full `scalar_lr`.
|
||||
|
||||
### Experiment Results
|
||||
|
||||
Swept `--scalar_lr` (controlling x0_lambdas) at multiple depths:
|
||||
|
||||
| Depth | Baseline (disabled) | Best scalar_lr | Best val_bpb | Δ bpb |
|
||||
|-------|---------------------|----------------|--------------|-------|
|
||||
| d8 | 1.0885 | 0.20 | 1.0782 | -0.0103 |
|
||||
| d12 | 0.9770 | 0.60 | 0.9693 | -0.0077 |
|
||||
| d16 | 0.9059 | 0.20 | 0.9002 | -0.0057 |
|
||||
| d20 | 0.8565 | 0.10 | 0.8526 | -0.0039 |
|
||||
|
||||
**Observations:**
|
||||
- Consistent improvement across all model sizes
|
||||
- Optimal LR varies by depth; default of 0.5 is reasonable, but 0.6 is better for d12
|
||||
- Adding resid_lambdas (with 0.01x LR) gives small additional improvement over x0 alone
|
||||
|
||||
### Meta Device Footgun
|
||||
|
||||
Important lesson: `__init__` runs in meta device context, so any tensor values set there are fake. Must initialize actual values in `init_weights()`. Added docstring warning to `__init__`.
|
||||
|
||||
### Summary
|
||||
|
||||
Added `--scalar_lr` (default 0.5) controlling learnable per-layer scalars. The formula `x = resid_lambdas[i] * x + x0_lambdas[i] * x0` gives the model control over residual scaling and direct shortcuts to the initial embedding. Solid improvement with essentially no compute overhead.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-10: Muon Optimizer Upgrades & Cautious Weight Decay
|
||||
|
||||
Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon implementation. Decided against using NorMuon directly due to hard-coded architecture assumptions (expects 32 params split 10 attn + 22 mlp), parameter labeling requirements, and complexity.
|
||||
|
||||
### Changes Made
|
||||
|
||||
**1. Polar Express Orthogonalization**
|
||||
- Replaced Newton-Schulz iteration with "Polar Express Sign Method" from [arxiv.org/pdf/2505.16932](https://arxiv.org/pdf/2505.16932)
|
||||
- Uses 5 different coefficient tuples (one per iteration) instead of fixed coefficients
|
||||
- Both methods kept in code for easy comparison (`zeropower_via_polar_express` vs `zeropower_via_newtonschulz5`)
|
||||
- **Result:** No dramatic/noticeable difference in training, but keeping the new Polar Express as default.
|
||||
|
||||
**2. Variance Reduction (NorMuon-style)**
|
||||
- Added low-rank variance estimator similar to Adafactor ([arxiv.org/pdf/2510.05491](https://arxiv.org/pdf/2510.05491))
|
||||
- Maintains `second_momentum_buffer` with shape `[rows, 1]` or `[1, cols]` (whichever is smaller)
|
||||
- Normalizes updates based on running per-row/col variance estimate (beta2=0.95)
|
||||
- Memory overhead: ~1/max(rows, cols) per param, negligible
|
||||
- **Result:** Led to a very small improvement, kept and enabled by default.
|
||||
|
||||
**3. Cautious Weight Decay**
|
||||
- Only decays weights where `update * weight >= 0` (same sign) from [arxiv.org/abs/2411.16085](https://arxiv.org/abs/2411.16085)
|
||||
- Standard WD always pulls toward zero; cautious WD skips decay when gradient is pushing weight away from zero
|
||||
- **Implementation note:** Had to inline the logic rather than use a separate `@torch.compile` function. Passing changing float values (like `weight_decay` during scheduling) as function arguments triggers recompilation. Reading from `group["weight_decay"]` inside the step avoids this.
|
||||
- **Result:** Solid improvements, especially the cautious version was better than standard wd.
|
||||
- Now defaults to ON for Muon via the `weight_decay` param. AdamW still has no weight decay and is hardcoded to 0 weight decay, might try to re-tune this later.
|
||||
|
||||
**4. Weight decay schedule**
|
||||
- Added a linear schedule to weight decay that is default on from 1.0 to 0.0 (i.e. start with max weight decay in the beginning of training, them ramp to 0 by the end). Worked better than a static setting in experiments. (modded-nanogpt has the same schedule but it is imlpemented in a more confusing way by multiplying twice by the learning rate, which is already wired up to a decay schedule).
|
||||
|
||||
### Weight Decay Scaling Experiments
|
||||
|
||||
Swept weight decay values at d8, d12, d16, d20 to find optimal values and scaling law.
|
||||
|
||||
**Optimal Values Found:**
|
||||
| Depth | Width (channels) | Optimal WD |
|
||||
|-------|------------------|------------|
|
||||
| d8 | 512 | ~0.40 |
|
||||
| d12 | 768 | ~0.22 |
|
||||
| d16 | 1024 | ~0.10 |
|
||||
| d20 | 1280 | ~0.08 |
|
||||
|
||||
**Scaling Law:**
|
||||
- Fit power law: `WD = k / channels^α` in log-log space
|
||||
- Found α ≈ 1.97 (approximately 2), meaning WD ∝ 1/width²
|
||||
|
||||
**Practical Formula:**
|
||||
```
|
||||
WD_target = WD_reference × (d_reference / d_target)²
|
||||
```
|
||||
Example: If d12 optimal is 0.22, then d20 optimal ≈ 0.22 × (12/20)² ≈ 0.08
|
||||
|
||||
**Reference:** Moonlight paper uses fixed WD=0.1 for their 15B MoE model. Our experiments indicated a scaling law where the optimal WD changed with depth, so we go along with the empirical scaling law.
|
||||
|
||||
### Summary
|
||||
|
||||
Muon was changed to use Polar Express, added Adafactor-style variance reduction, and cautious weight decay with schedule that ramps linearly to zero. All of these changes follow modded-nanogpt repo, but all of them were also validated piece by piece to yield improvements in nanochat with the exception of the Polar Express change which was in the noise. This is default on and configurable with `--weight_decay`, using simply 0.2 and ∝ 1/width² scaling. The kwarg `--weight_decay` is therefore changing as of this change. It used to configure AdamW via standard weight decay and now it becomes exclusively used in Muon (AdamW is hardcoded to 0.0), and it is scaled based on depth.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-08: exp_grad_clip - Gradient Clipping
|
||||
|
||||
**Hypothesis:** Gradient clipping may be unnecessary overhead. Tested L2 norm clipping at various thresholds (0.25, 0.5, 1.0, 2.0) and elementwise clipping.
|
||||
|
|
@ -18,6 +481,4 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
|||
|
||||
**Observartion:** modded-nanogpt does not appear to clip either right now.
|
||||
|
||||
**Recommendation:** Disable by default (`--grad_clip=0.0`). The code naturally produces well-behaved gradients.
|
||||
|
||||
---
|
||||
**Summary:** Deleted all grad-clip code paths. The code naturally produces well-behaved gradients. This improves a bit of MFU because we don't have to calculate and sync grad norms.
|
||||
|
|
|
|||
|
|
@ -1,77 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
|
||||
# Run as:
|
||||
# bash dev/cpu_demo_run.sh
|
||||
|
||||
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
|
||||
# Think of this run as educational/fun demo, not something you should expect to work well.
|
||||
# This is also why I hide this script away in dev/
|
||||
|
||||
# all the setup stuff
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra cpu
|
||||
source .venv/bin/activate
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
|
||||
# wipe the report
|
||||
python -m nanochat.report reset
|
||||
|
||||
# train tokenizer on ~1B characters
|
||||
python -m nanochat.dataset -n 4
|
||||
python -m scripts.tok_train --max_chars=1000000000
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# train a very small 4 layer model on the CPU
|
||||
# each optimization step processes a single sequence of 1024 tokens
|
||||
# we only run 50 steps of optimization (bump this to get better results)
|
||||
python -m scripts.base_train \
|
||||
--depth=4 \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--total_batch_size=1024 \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=4096 \
|
||||
--core_metric_every=50 \
|
||||
--core_metric_max_per_task=12 \
|
||||
--sample_every=50 \
|
||||
--num_iterations=50
|
||||
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
|
||||
python -m scripts.base_eval --max-per-task=16
|
||||
|
||||
# midtraining
|
||||
python -m scripts.mid_train \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=4096 \
|
||||
--total_batch_size=1024 \
|
||||
--num_iterations=100
|
||||
# eval results will be terrible, this is just to execute the code paths.
|
||||
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
|
||||
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
|
||||
|
||||
# SFT
|
||||
python -m scripts.chat_sft \
|
||||
--device_batch_size=1 \
|
||||
--target_examples_per_step=4 \
|
||||
--num_iterations=100 \
|
||||
--eval_steps=4 \
|
||||
--eval_metrics_max_problems=16
|
||||
|
||||
# Chat CLI
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
||||
# Chat Web
|
||||
# python -m scripts.chat_web
|
||||
|
||||
python -m nanochat.report generate
|
||||
|
|
@ -1,11 +1,42 @@
|
|||
"""
|
||||
Borrowed from modded-nanogpt. By Keller, @vagrawal, et al.
|
||||
Not a general optimizer! But works for our specific use.
|
||||
Distributed AdamW optimizer with a fused step function.
|
||||
A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt.
|
||||
"""
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def adamw_step_fused(
|
||||
p: Tensor,
|
||||
grad: Tensor,
|
||||
exp_avg: Tensor,
|
||||
exp_avg_sq: Tensor,
|
||||
step_t: Tensor,
|
||||
lr_t: Tensor,
|
||||
beta1_t: Tensor,
|
||||
beta2_t: Tensor,
|
||||
eps_t: Tensor,
|
||||
wd_t: Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
|
||||
"""
|
||||
# Weight decay (decoupled, applied before the update)
|
||||
p.mul_(1 - lr_t * wd_t)
|
||||
# Update running averages (lerp_ is cleaner and fuses well)
|
||||
exp_avg.lerp_(grad, 1 - beta1_t)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
||||
# Bias corrections
|
||||
bias1 = 1 - beta1_t ** step_t
|
||||
bias2 = 1 - beta2_t ** step_t
|
||||
# Compute update and apply
|
||||
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
||||
step_size = lr_t / bias1
|
||||
p.add_(exp_avg / denom, alpha=-step_size)
|
||||
|
||||
|
||||
class DistAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
|
|
@ -14,25 +45,51 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
"""
|
||||
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
# Validate
|
||||
if rank == 0:
|
||||
for group in param_groups:
|
||||
assert isinstance(group, dict), "expecting param_groups to be a list of dicts"
|
||||
assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors"
|
||||
for p in group['params']:
|
||||
sliced = p.numel() >= 1024
|
||||
print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}")
|
||||
if sliced: # large parameter tensors will be operated on in slices
|
||||
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
|
||||
super().__init__(param_groups, defaults)
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.compile
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
reduce_scatter_futures: list[torch.Future] = []
|
||||
all_reduce_futures: list[torch.Future] = []
|
||||
reduce_futures: list[torch.Future] = []
|
||||
gather_futures: list[torch.Future] = []
|
||||
grad_slices = []
|
||||
is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter)
|
||||
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
for base_i in range(len(params)):
|
||||
assert params[base_i].shape[0] % world_size == 0, f"First dim of parameter shape {params[base_i].shape} must be divisible by world size {world_size}"
|
||||
grad = params[base_i].grad
|
||||
rank_size = grad.shape[0] // world_size
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad_slice)
|
||||
for p in params:
|
||||
grad = p.grad
|
||||
# Small params: use all_reduce (no scatter/gather needed)
|
||||
if p.numel() < 1024:
|
||||
is_small.append(True)
|
||||
reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad)
|
||||
else:
|
||||
is_small.append(False)
|
||||
rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad_slice)
|
||||
|
||||
idx = 0
|
||||
for group in self.param_groups:
|
||||
|
|
@ -40,38 +97,47 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
eps = group['eps']
|
||||
wd = group['weight_decay']
|
||||
params = group['params']
|
||||
for base in range(len(params)):
|
||||
reduce_scatter_futures[idx].wait()
|
||||
p = params[base]
|
||||
rank_size = p.shape[0] // world_size
|
||||
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
||||
for p in params:
|
||||
reduce_futures[idx].wait()
|
||||
g_slice = grad_slices[idx]
|
||||
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
|
||||
state = self.state[p]
|
||||
g_slice = grad_slices[idx]
|
||||
|
||||
# For small params, operate on full param; for large, operate on slice
|
||||
if is_small[idx]:
|
||||
p_slice = p
|
||||
else:
|
||||
rank_size = p.shape[0] // world_size
|
||||
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
||||
|
||||
# State init
|
||||
if not state:
|
||||
state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_slice)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
||||
exp_avg = state['exp_avg']
|
||||
exp_avg_sq = state['exp_avg_sq']
|
||||
state['step'] += 1
|
||||
t = state['step']
|
||||
# weight decay
|
||||
if wd != 0:
|
||||
eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
|
||||
p_slice.mul_(1 - eff_weight_decay)
|
||||
# update running averages
|
||||
exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
|
||||
# bias corrections
|
||||
bias1 = 1 - beta1 ** t
|
||||
bias2 = 1 - beta2 ** t
|
||||
# compute step
|
||||
denom = exp_avg_sq.sqrt().add_(eps)
|
||||
step_size = lr * (torch.sqrt(bias2) / bias1)
|
||||
update = exp_avg.div(denom).mul_(step_size)
|
||||
p_slice.add_(other=update, alpha=-1.0)
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
eff_wd = wd * getattr(p, "wd_mul", 1.0)
|
||||
self._step_t.fill_(state['step'])
|
||||
self._lr_t.fill_(lr)
|
||||
self._beta1_t.fill_(beta1)
|
||||
self._beta2_t.fill_(beta2)
|
||||
self._eps_t.fill_(eps)
|
||||
self._wd_t.fill_(eff_wd)
|
||||
|
||||
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
||||
adamw_step_fused(
|
||||
p_slice, g_slice, exp_avg, exp_avg_sq,
|
||||
self._step_t, self._lr_t, self._beta1_t, self._beta2_t, self._eps_t, self._wd_t,
|
||||
)
|
||||
|
||||
# Only large params need all_gather
|
||||
if not is_small[idx]:
|
||||
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
|
||||
idx += 1
|
||||
all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
|
||||
torch.futures.collect_all(all_reduce_futures).wait()
|
||||
|
||||
if gather_futures:
|
||||
torch.futures.collect_all(gather_futures).wait()
|
||||
|
|
|
|||
|
|
@ -20,6 +20,25 @@ def log0(message):
|
|||
if int(os.environ.get('RANK', 0)) == 0:
|
||||
logger.info(message)
|
||||
|
||||
def _patch_missing_config_keys(model_config_kwargs):
|
||||
"""Add default values for new config keys missing in old checkpoints."""
|
||||
# Old models were trained with full context (no sliding window)
|
||||
if "window_pattern" not in model_config_kwargs:
|
||||
model_config_kwargs["window_pattern"] = "L"
|
||||
log0(f"Patching missing window_pattern in model config to 'L'")
|
||||
|
||||
def _patch_missing_keys(model_data, model_config):
|
||||
"""Add default values for new parameters that may be missing in old checkpoints."""
|
||||
n_layer = model_config.n_layer
|
||||
# resid_lambdas defaults to 1.0 (identity scaling)
|
||||
if "resid_lambdas" not in model_data:
|
||||
model_data["resid_lambdas"] = torch.ones(n_layer)
|
||||
log0(f"Patching missing resid_lambdas in model data to 1.0")
|
||||
# x0_lambdas defaults to 0.0 (disabled)
|
||||
if "x0_lambdas" not in model_data:
|
||||
model_data["x0_lambdas"] = torch.zeros(n_layer)
|
||||
log0(f"Patching missing x0_lambdas in model data to 0.0")
|
||||
|
||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
||||
if rank == 0:
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
|
@ -74,8 +93,10 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
||||
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
||||
model_config_kwargs = meta_data["model_config"]
|
||||
_patch_missing_config_keys(model_config_kwargs)
|
||||
log0(f"Building model with config: {model_config_kwargs}")
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
_patch_missing_keys(model_data, model_config)
|
||||
with torch.device("meta"):
|
||||
model = GPT(model_config)
|
||||
# Load the model state
|
||||
|
|
@ -90,7 +111,7 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
# Load the Tokenizer
|
||||
tokenizer = get_tokenizer()
|
||||
# Sanity check: compatibility between model and tokenizer
|
||||
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"]
|
||||
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"], f"Tokenizer vocab size {tokenizer.get_vocab_size()} does not match model config vocab size {model_config_kwargs['vocab_size']}"
|
||||
return model, tokenizer, meta_data
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -200,3 +200,77 @@ class DummyWandb:
|
|||
pass
|
||||
def finish(self):
|
||||
pass
|
||||
|
||||
# hardcoded BF16 peak flops for various GPUs
|
||||
# inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py
|
||||
# and PR: https://github.com/karpathy/nanochat/pull/147
|
||||
def get_peak_flops(device_name: str) -> float:
|
||||
name = device_name.lower()
|
||||
|
||||
# --- NVIDIA Blackwell ---
|
||||
if "gb200" in name or "grace blackwell" in name:
|
||||
return 2.5e15
|
||||
if "b200" in name:
|
||||
return 2.25e15
|
||||
if "b100" in name:
|
||||
return 1.8e15
|
||||
|
||||
# --- NVIDIA Hopper (H100/H200/H800) ---
|
||||
if "h200" in name:
|
||||
if "nvl" in name or "pcie" in name:
|
||||
return 836e12
|
||||
return 989e12 # H200 SXM
|
||||
if "h100" in name:
|
||||
if "nvl" in name:
|
||||
return 835e12
|
||||
if "pcie" in name:
|
||||
return 756e12
|
||||
return 989e12 # H100 SXM
|
||||
if "h800" in name:
|
||||
if "nvl" in name:
|
||||
return 989e12
|
||||
return 756e12 # H800 PCIe
|
||||
|
||||
# --- NVIDIA Ampere data center ---
|
||||
if "a100" in name or "a800" in name:
|
||||
return 312e12
|
||||
if "a40" in name:
|
||||
return 149.7e12
|
||||
if "a30" in name:
|
||||
return 165e12
|
||||
|
||||
# --- NVIDIA Ada data center ---
|
||||
if "l40s" in name or "l40-s" in name or "l40 s" in name:
|
||||
return 362e12
|
||||
if "l4" in name:
|
||||
return 121e12
|
||||
|
||||
# --- AMD CDNA accelerators ---
|
||||
if "mi355" in name:
|
||||
return 2.5e15
|
||||
if "mi325" in name or "mi300x" in name:
|
||||
return 1.3074e15
|
||||
if "mi300a" in name:
|
||||
return 980.6e12
|
||||
if "mi250x" in name:
|
||||
return 383e12
|
||||
if "mi250" in name:
|
||||
return 362.1e12
|
||||
|
||||
# --- Intel ---
|
||||
if "data center gpu max 1550" in name:
|
||||
# Ponte Vecchio (PVC) - dynamic based on compute units
|
||||
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
|
||||
return 512 * max_comp_units * 1300 * 10**6
|
||||
|
||||
# --- Consumer RTX (for hobbyists) ---
|
||||
if "5090" in name:
|
||||
return 209.5e12
|
||||
if "4090" in name:
|
||||
return 165.2e12
|
||||
if "3090" in name:
|
||||
return 71e12
|
||||
|
||||
# Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess
|
||||
logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%")
|
||||
return float('inf')
|
||||
|
|
|
|||
|
|
@ -1,94 +1,199 @@
|
|||
from collections import deque
|
||||
"""
|
||||
Distributed dataloaders for pretraining.
|
||||
|
||||
Two implementations are provided:
|
||||
|
||||
1. Original (tokenizing_distributed_data_loader):
|
||||
- Streams tokens into a flat buffer, reshapes to (B, T)
|
||||
- Rows may start mid-document (no guaranteed BOS at position 0)
|
||||
- 100% token utilization, simple and efficient
|
||||
|
||||
2. BOS-aligned bestfit (tokenizing_distributed_data_loader_bos_bestfit):
|
||||
- Every row starts with BOS token
|
||||
- Documents packed using best-fit algorithm to minimize cropping
|
||||
- When no document fits remaining space, crops a document to fill exactly
|
||||
- 100% utilization (no padding), ~35% tokens cropped at T=2048
|
||||
|
||||
The tradeoff: BOS-aligned loses ~35% of tokens to cropping, but ensures that
|
||||
there are fewer "confusing" tokens in the train/val batches as every token can
|
||||
now attend back to the BOS token and sees the full context of the document.
|
||||
(2) is the new default if you have enough data.
|
||||
Fallback to (1) if you have very limited data AND long documents.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from nanochat.common import get_dist_info
|
||||
from nanochat.dataset import list_parquet_files
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
|
||||
def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
|
||||
def _document_batches(split, resume_state_dict, tokenizer_batch_size):
|
||||
"""
|
||||
Infinite iterator over document batches (list of text strings) from parquet files.
|
||||
|
||||
Handles DDP sharding and approximate resume. Each yield is (text_batch, (pq_idx, rg_idx, epoch))
|
||||
where text_batch is a list of document strings, indices track position for resumption,
|
||||
and epoch counts how many times we've cycled through the dataset (starts at 1).
|
||||
"""
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
|
||||
parquet_paths = list_parquet_files()
|
||||
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
|
||||
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||
|
||||
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
||||
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
||||
resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1
|
||||
first_pass = True
|
||||
pq_idx = resume_pq_idx
|
||||
epoch = resume_epoch
|
||||
|
||||
while True: # iterate infinitely (multi-epoch)
|
||||
pq_idx = resume_pq_idx if first_pass else 0
|
||||
while pq_idx < len(parquet_paths):
|
||||
filepath = parquet_paths[pq_idx]
|
||||
pf = pq.ParquetFile(filepath)
|
||||
# Start from resume point if resuming on same file, otherwise from DDP rank
|
||||
if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
|
||||
base_idx = resume_rg_idx // ddp_world_size
|
||||
base_idx += 1 # advance by 1 so we don't repeat data after resuming
|
||||
rg_idx = base_idx * ddp_world_size + ddp_rank
|
||||
if rg_idx >= pf.num_row_groups:
|
||||
pq_idx += 1
|
||||
continue
|
||||
resume_rg_idx = None # only do this once
|
||||
else:
|
||||
rg_idx = ddp_rank
|
||||
while rg_idx < pf.num_row_groups:
|
||||
rg = pf.read_row_group(rg_idx)
|
||||
batch = rg.column('text').to_pylist()
|
||||
for i in range(0, len(batch), tokenizer_batch_size):
|
||||
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx, epoch)
|
||||
rg_idx += ddp_world_size
|
||||
pq_idx += 1
|
||||
first_pass = False
|
||||
epoch += 1
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader_with_state(tokenizer, B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
|
||||
"""
|
||||
Stream pretraining text from parquet files, tokenize, yield training batches.
|
||||
|
||||
This implementation became a bit more complex because we wish to support approximate resume training.
|
||||
Instead of turning this into a Class, we opt to return the state_dict with every batch,
|
||||
and then the caller can pass in a state_dict to resume training from a desired point.
|
||||
Note that this resumption is atm only *approximate* for simplicity.
|
||||
We won't repeat the same documents but we might skip a few.
|
||||
The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume.
|
||||
This is the original dataloader that streams tokens into a flat buffer and reshapes.
|
||||
Rows may start mid-document (no guaranteed BOS at position 0).
|
||||
|
||||
Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm.
|
||||
Supports approximate resume via state_dict.
|
||||
"""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
|
||||
# infinite iterator over document batches (list of text strings)
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
def document_batches():
|
||||
parquet_paths = list_parquet_files()
|
||||
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
|
||||
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
||||
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
||||
first_pass = True
|
||||
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
|
||||
while True: # iterate infinitely (multi-epoch)
|
||||
pq_idx = resume_pq_idx if first_pass else 0
|
||||
while pq_idx < len(parquet_paths): # iterate over all parquet files
|
||||
filepath = parquet_paths[pq_idx]
|
||||
pf = pq.ParquetFile(filepath)
|
||||
# Start from resume point if resuming on same file, otherwise from DDP rank
|
||||
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
|
||||
if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
|
||||
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
|
||||
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
|
||||
rg_idx = base_idx * ddp_world_size + ddp_rank
|
||||
if rg_idx >= pf.num_row_groups:
|
||||
pq_idx += 1
|
||||
continue
|
||||
resume_rg_idx = None # set to None as we only want to do this a single time
|
||||
else:
|
||||
rg_idx = ddp_rank
|
||||
while rg_idx < pf.num_row_groups:
|
||||
rg = pf.read_row_group(rg_idx)
|
||||
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
|
||||
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
|
||||
for i in range(0, len(batch), tokenizer_batch_size):
|
||||
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
|
||||
rg_idx += ddp_world_size # advance to the next row group (in DDP)
|
||||
pq_idx += 1 # advance to the next parquet file
|
||||
first_pass = False
|
||||
batches = document_batches()
|
||||
|
||||
# Now emit batches of tokens.
|
||||
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
|
||||
# get the tokenizer and the bos token
|
||||
tokenizer = get_tokenizer()
|
||||
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
|
||||
needed_tokens = B * T + 1 # +1 for target at last position
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
# scratch buffer holds the tokens for one iteration
|
||||
token_buffer = deque() # we stream tokens on the right and pop from the left
|
||||
token_buffer = []
|
||||
pq_idx, rg_idx, epoch = 0, 0, 1
|
||||
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding.
|
||||
|
||||
# Accumulate enough tokens
|
||||
while len(token_buffer) < needed_tokens:
|
||||
doc_batch, (pq_idx, rg_idx) = next(batches)
|
||||
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
|
||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
||||
for tokens in token_lists:
|
||||
token_buffer.extend(tokens)
|
||||
# Move tokens from the deque into the scratch buffer
|
||||
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
||||
# CUDA supports memory pinning for asynchronous transfers between CPU and GPU
|
||||
use_cuda_optimizations = device == "cuda"
|
||||
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
|
||||
# Create the inputs/targets as 1D tensors
|
||||
inputs_cpu = scratch[:-1]
|
||||
targets_cpu = scratch[1:]
|
||||
# Reshape to 2D and move to GPU async
|
||||
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
|
||||
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
|
||||
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training
|
||||
yield inputs, targets, state_dict
|
||||
tokens = token_buffer[:needed_tokens] # Read B*T+1 tokens (+1 is only for the target for the last token)
|
||||
token_buffer = token_buffer[B*T:] # Advance by B*T tokens, so we move exactly one window of B*T tokens over
|
||||
|
||||
# Package tokens into inputs and targets, yield
|
||||
use_cuda = device == "cuda"
|
||||
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda)
|
||||
inputs = scratch[:-1].view(B, T).to(device=device, non_blocking=use_cuda)
|
||||
targets = scratch[1:].view(B, T).to(device=device, non_blocking=use_cuda)
|
||||
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader(*args, **kwargs):
|
||||
# helper function that only emits the inputs/targets and not the state_dict
|
||||
"""Helper that omits state_dict from yields."""
|
||||
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
|
||||
yield inputs, targets
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
||||
tokenizer, B, T, split,
|
||||
tokenizer_threads=4, tokenizer_batch_size=128,
|
||||
device="cuda", resume_state_dict=None,
|
||||
buffer_size=1000
|
||||
):
|
||||
"""
|
||||
BOS-aligned dataloader with Best-Fit Cropping.
|
||||
|
||||
Reduces token waste compared to simple greedy cropping by searching a buffer
|
||||
for documents that fit well, while maintaining 100% utilization (no padding).
|
||||
|
||||
Algorithm for each row:
|
||||
1. From buffered docs, pick the LARGEST doc that fits entirely
|
||||
2. Repeat until no doc fits
|
||||
3. When nothing fits, crop a doc to fill remaining space exactly
|
||||
|
||||
Key properties:
|
||||
- Every row starts with BOS
|
||||
- 100% utilization (no padding, every token is trained on)
|
||||
- Approximately 35% of all tokens are discarded due to cropping
|
||||
"""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
|
||||
row_capacity = T + 1
|
||||
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
doc_buffer = []
|
||||
pq_idx, rg_idx, epoch = 0, 0, 1
|
||||
|
||||
def refill_buffer():
|
||||
nonlocal pq_idx, rg_idx, epoch
|
||||
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
|
||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
||||
for tokens in token_lists:
|
||||
doc_buffer.append(tokens)
|
||||
|
||||
while True:
|
||||
rows = []
|
||||
for _ in range(B):
|
||||
row = []
|
||||
while len(row) < row_capacity:
|
||||
# Ensure buffer has documents
|
||||
while len(doc_buffer) < buffer_size:
|
||||
refill_buffer()
|
||||
|
||||
remaining = row_capacity - len(row)
|
||||
|
||||
# Find largest doc that fits entirely
|
||||
best_idx = -1
|
||||
best_len = 0
|
||||
for i, doc in enumerate(doc_buffer):
|
||||
doc_len = len(doc)
|
||||
if doc_len <= remaining and doc_len > best_len:
|
||||
best_idx = i
|
||||
best_len = doc_len
|
||||
|
||||
if best_idx >= 0:
|
||||
doc = doc_buffer.pop(best_idx)
|
||||
row.extend(doc)
|
||||
else:
|
||||
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
|
||||
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
|
||||
doc = doc_buffer.pop(shortest_idx)
|
||||
row.extend(doc[:remaining])
|
||||
|
||||
rows.append(row[:row_capacity])
|
||||
|
||||
use_cuda = device == "cuda"
|
||||
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
|
||||
inputs = batch_tensor[:, :-1].to(device=device, non_blocking=use_cuda)
|
||||
targets = batch_tensor[:, 1:].to(device=device, non_blocking=use_cuda)
|
||||
|
||||
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
|
||||
"""Helper that omits state_dict from yields."""
|
||||
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs):
|
||||
yield inputs, targets
|
||||
|
|
|
|||
|
|
@ -82,83 +82,54 @@ def use_calculator(expr):
|
|||
# -----------------------------------------------------------------------------
|
||||
class KVCache:
|
||||
"""
|
||||
Works hand-in-hand with the GPT model to maintain the KV cache.
|
||||
Note that the .pos advances automatically after the last layer of the Transformer inserts.
|
||||
KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
|
||||
|
||||
Key differences from FA2-style cache:
|
||||
- Tensors are (B, T, H, D) not (B, H, T, D)
|
||||
- FA3 updates the cache in-place during flash_attn_with_kvcache
|
||||
- Position tracked per batch element via cache_seqlens tensor
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
|
||||
# Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
|
||||
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
|
||||
self.kv_cache = None
|
||||
self.pos = 0 # current position in time in the cache
|
||||
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
|
||||
self.batch_size = batch_size
|
||||
self.max_seq_len = seq_len
|
||||
self.n_layers = num_layers
|
||||
self.n_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
# Pre-allocate cache tensors: (n_layers, B, T, H, D)
|
||||
self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
||||
self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
||||
# Current sequence length per batch element (FA3 needs int32)
|
||||
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
|
||||
def reset(self):
|
||||
self.pos = 0
|
||||
"""Reset cache to empty state."""
|
||||
self.cache_seqlens.zero_()
|
||||
|
||||
def get_pos(self):
|
||||
return self.pos
|
||||
"""Get current position (assumes all batch elements at same position)."""
|
||||
return self.cache_seqlens[0].item()
|
||||
|
||||
def get_layer_cache(self, layer_idx):
|
||||
"""Return (k_cache, v_cache) views for a specific layer."""
|
||||
return self.k_cache[layer_idx], self.v_cache[layer_idx]
|
||||
|
||||
def advance(self, num_tokens):
|
||||
"""Advance the cache position by num_tokens."""
|
||||
self.cache_seqlens += num_tokens
|
||||
|
||||
def prefill(self, other):
|
||||
"""
|
||||
Prefill given another KV cache. Optionally expand along batch dim.
|
||||
This is used when we do batch 1 prefill and then want to generate
|
||||
multiple samples in parallel from there.
|
||||
Copy cached KV from another cache into this one.
|
||||
Used when we do batch=1 prefill and then want to generate multiple samples in parallel.
|
||||
"""
|
||||
# 1) validate the shapes
|
||||
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
|
||||
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
|
||||
|
||||
# Extract dimensions explicitly
|
||||
self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape
|
||||
other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape
|
||||
|
||||
# Validate dimensions
|
||||
assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}"
|
||||
assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}"
|
||||
assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}"
|
||||
assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}"
|
||||
|
||||
# Batch size can be expanded (other can be 1, self can be larger)
|
||||
assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)"
|
||||
|
||||
# Sequence length: self must be longer than other
|
||||
assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}"
|
||||
|
||||
# 2) initialize the cache
|
||||
dtype, device = other.kv_cache.dtype, other.kv_cache.device
|
||||
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
|
||||
# 3) copy the data over
|
||||
self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
|
||||
# 4) update the pos
|
||||
self.pos = other.pos
|
||||
|
||||
def insert_kv(self, layer_idx, k, v):
|
||||
# Lazy initialize the cache here because we need to know the dtype/device
|
||||
if self.kv_cache is None:
|
||||
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
|
||||
# Insert new keys/values to the cache and return the full cache so far
|
||||
B, H, T_add, D = k.size()
|
||||
t0, t1 = self.pos, self.pos + T_add
|
||||
# Dynamically grow the cache if needed
|
||||
if t1 > self.kv_cache.size(4):
|
||||
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
|
||||
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
|
||||
additional_shape = list(self.kv_cache.shape)
|
||||
additional_shape[4] = t_needed - self.kv_cache.size(4)
|
||||
additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
|
||||
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
|
||||
self.kv_shape = self.kv_cache.shape
|
||||
# Insert k, v into the cache
|
||||
self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k
|
||||
self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v
|
||||
# Return the full cached keys/values up to current position (as a view)
|
||||
key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :]
|
||||
value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :]
|
||||
# Increment pos after the last layer of the Transformer processes
|
||||
if layer_idx == self.kv_cache.size(0) - 1:
|
||||
self.pos = t1
|
||||
return key_view, value_view
|
||||
|
||||
assert self.get_pos() == 0, "Cannot prefill a non-empty KV cache"
|
||||
assert self.n_layers == other.n_layers and self.n_heads == other.n_heads and self.head_dim == other.head_dim
|
||||
assert self.max_seq_len >= other.max_seq_len
|
||||
other_pos = other.get_pos()
|
||||
self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
|
||||
self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
|
||||
self.cache_seqlens.fill_(other_pos)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@torch.inference_mode()
|
||||
|
|
@ -201,6 +172,13 @@ class Engine:
|
|||
"""Same as generate, but does single prefill and then clones the KV cache."""
|
||||
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
|
||||
device = self.model.get_device()
|
||||
# NOTE: setting the dtype here and in this way is an ugly hack.
|
||||
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
|
||||
# We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.
|
||||
# As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
|
||||
# I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
|
||||
# In particular, the KVCache should allocate its tensors lazily
|
||||
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
|
||||
rng = torch.Generator(device=device)
|
||||
rng.manual_seed(seed)
|
||||
|
||||
|
|
@ -219,6 +197,8 @@ class Engine:
|
|||
kv_cache_prefill = KVCache(
|
||||
batch_size=1,
|
||||
seq_len=len(tokens),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
**kv_model_kwargs,
|
||||
)
|
||||
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
||||
|
|
@ -230,6 +210,8 @@ class Engine:
|
|||
kv_cache_decode = KVCache(
|
||||
batch_size=num_samples,
|
||||
seq_len=kv_length_hint,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
**kv_model_kwargs,
|
||||
)
|
||||
kv_cache_decode.prefill(kv_cache_prefill)
|
||||
|
|
@ -324,8 +306,8 @@ if __name__ == "__main__":
|
|||
"""
|
||||
import time
|
||||
# init compute
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type()
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# load the model and tokenizer
|
||||
|
|
|
|||
178
nanochat/flash_attention.py
Normal file
178
nanochat/flash_attention.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
"""
|
||||
Unified Flash Attention interface with automatic FA3/SDPA switching.
|
||||
|
||||
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
|
||||
to PyTorch SDPA on non-Hopper GPUs, MPS, and CPU.
|
||||
|
||||
Usage (drop-in replacement for FA3):
|
||||
from nanochat.flash_attention import flash_attn
|
||||
|
||||
# Training (no KV cache)
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
||||
|
||||
# Inference (with KV cache)
|
||||
y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...)
|
||||
"""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Detection: Try to load FA3 on Hopper+ GPUs
|
||||
# =============================================================================
|
||||
def _load_flash_attention_3():
|
||||
"""Try to load Flash Attention 3 (requires Hopper+ GPU)."""
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
try:
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
if major < 9: # Hopper is sm90
|
||||
return None
|
||||
import os
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
||||
from kernels import get_kernel
|
||||
return get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
_fa3 = _load_flash_attention_3()
|
||||
HAS_FA3 = _fa3 is not None
|
||||
|
||||
# Override for testing: set to 'fa3', 'sdpa', or None (auto)
|
||||
_override_impl = None
|
||||
|
||||
|
||||
def _use_fa3():
|
||||
"""Determine whether to use FA3 based on availability and override."""
|
||||
if _override_impl == 'fa3':
|
||||
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
|
||||
return True
|
||||
if _override_impl == 'sdpa':
|
||||
return False
|
||||
return HAS_FA3 # auto
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SDPA helpers
|
||||
# =============================================================================
|
||||
def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
||||
"""
|
||||
SDPA attention with sliding window support.
|
||||
q, k, v are (B, H, T, D) format.
|
||||
"""
|
||||
Tq = q.size(2)
|
||||
Tk = k.size(2)
|
||||
window = window_size[0]
|
||||
|
||||
# Full context, same length
|
||||
if (window < 0 or window >= Tq) and Tq == Tk:
|
||||
return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
||||
|
||||
# Single token generation
|
||||
if Tq == 1:
|
||||
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
||||
|
||||
# Need explicit mask
|
||||
device = q.device
|
||||
if Tq == Tk:
|
||||
# Causal + sliding window
|
||||
mask = torch.tril(torch.ones(Tq, Tk, device=device, dtype=torch.bool))
|
||||
if window > 0 and window < Tq:
|
||||
row_idx = torch.arange(Tq, device=device).unsqueeze(1)
|
||||
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
|
||||
mask = mask & ((row_idx - col_idx) <= window)
|
||||
else:
|
||||
# Chunk inference: attend to prefix + causal within chunk
|
||||
prefix_len = Tk - Tq
|
||||
mask = torch.zeros(Tq, Tk, device=device, dtype=torch.bool)
|
||||
mask[:, :prefix_len] = True
|
||||
mask[:, prefix_len:] = torch.tril(torch.ones(Tq, Tq, device=device, dtype=torch.bool))
|
||||
|
||||
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Public API: Same interface as FA3
|
||||
# =============================================================================
|
||||
def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
||||
"""
|
||||
Flash Attention for training (no KV cache).
|
||||
|
||||
Args:
|
||||
q, k, v: Tensors of shape (B, T, H, D)
|
||||
causal: Whether to use causal masking
|
||||
window_size: (left, right) sliding window. -1 means unlimited.
|
||||
|
||||
Returns:
|
||||
Output tensor of shape (B, T, H, D)
|
||||
"""
|
||||
if _use_fa3():
|
||||
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
||||
|
||||
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
enable_gqa = q.size(1) != k.size(1)
|
||||
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
|
||||
return y.transpose(1, 2) # back to (B, T, H, D)
|
||||
|
||||
|
||||
def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
|
||||
causal=False, window_size=(-1, -1)):
|
||||
"""
|
||||
Flash Attention with KV cache for inference.
|
||||
|
||||
FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same.
|
||||
|
||||
Args:
|
||||
q: Queries, shape (B, T_new, H, D)
|
||||
k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D)
|
||||
k, v: New keys/values to insert, shape (B, T_new, H_kv, D)
|
||||
cache_seqlens: Current position in cache, shape (B,) int32
|
||||
causal: Whether to use causal masking
|
||||
window_size: (left, right) sliding window. -1 means unlimited.
|
||||
|
||||
Returns:
|
||||
Output tensor of shape (B, T_new, H, D)
|
||||
"""
|
||||
if _use_fa3():
|
||||
return _fa3.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
|
||||
causal=causal, window_size=window_size
|
||||
)
|
||||
|
||||
# SDPA fallback: manually manage KV cache
|
||||
B, T_new, H, D = q.shape
|
||||
pos = cache_seqlens[0].item() # assume uniform position across batch
|
||||
|
||||
# Insert new k, v into cache (in-place, matching FA3 behavior)
|
||||
if k is not None and v is not None:
|
||||
k_cache[:, pos:pos+T_new, :, :] = k
|
||||
v_cache[:, pos:pos+T_new, :, :] = v
|
||||
|
||||
# Get full cache up to current position + new tokens
|
||||
end_pos = pos + T_new
|
||||
k_full = k_cache[:, :end_pos, :, :]
|
||||
v_full = v_cache[:, :end_pos, :, :]
|
||||
|
||||
# Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
|
||||
q_sdpa = q.transpose(1, 2)
|
||||
k_sdpa = k_full.transpose(1, 2)
|
||||
v_sdpa = v_full.transpose(1, 2)
|
||||
|
||||
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
|
||||
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
|
||||
|
||||
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Export: flash_attn module interface (drop-in replacement for FA3)
|
||||
# =============================================================================
|
||||
from types import SimpleNamespace
|
||||
flash_attn = SimpleNamespace(
|
||||
flash_attn_func=flash_attn_func,
|
||||
flash_attn_with_kvcache=flash_attn_with_kvcache,
|
||||
)
|
||||
192
nanochat/gpt.py
192
nanochat/gpt.py
|
|
@ -9,9 +9,9 @@ Notable features:
|
|||
- no learnable params in rmsnorm
|
||||
- no bias in linear layers
|
||||
- Group-Query Attention (GQA) support for more efficient inference
|
||||
- Flash Attention 3 integration
|
||||
"""
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
|
@ -23,14 +23,21 @@ from nanochat.common import get_dist_info, print0
|
|||
from nanochat.muon import Muon, DistMuon
|
||||
from nanochat.adamw import DistAdamW
|
||||
|
||||
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
||||
from nanochat.flash_attention import flash_attn
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
sequence_len: int = 1024
|
||||
vocab_size: int = 50304
|
||||
sequence_len: int = 2048
|
||||
vocab_size: int = 32768
|
||||
n_layer: int = 12
|
||||
n_head: int = 6 # number of query heads
|
||||
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||
n_embd: int = 768
|
||||
# Sliding window attention pattern string, tiled across layers. Final layer always L.
|
||||
# Characters: L=long (full context), S=short (half context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
window_pattern: str = "SSSL"
|
||||
|
||||
|
||||
def norm(x):
|
||||
|
|
@ -38,6 +45,10 @@ def norm(x):
|
|||
return F.rms_norm(x, (x.size(-1),))
|
||||
|
||||
|
||||
def has_ve(layer_idx, n_layer):
|
||||
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
|
||||
return layer_idx % 2 == (n_layer - 1) % 2
|
||||
|
||||
def apply_rotary_emb(x, cos, sin):
|
||||
assert x.ndim == 4 # multihead attention
|
||||
d = x.shape[3] // 2
|
||||
|
|
@ -60,49 +71,50 @@ class CausalSelfAttention(nn.Module):
|
|||
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||
self.ve_gate_channels = 32
|
||||
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
||||
|
||||
def forward(self, x, cos_sin, kv_cache):
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
B, T, C = x.size()
|
||||
|
||||
# Project the input to get queries, keys, and values
|
||||
# Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
|
||||
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
||||
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||
|
||||
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
|
||||
if ve is not None:
|
||||
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
|
||||
gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 2)
|
||||
v = v + gate.unsqueeze(-1) * ve
|
||||
|
||||
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
||||
cos, sin = cos_sin
|
||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
|
||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
||||
q, k = norm(q), norm(k) # QK norm
|
||||
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
|
||||
|
||||
# Apply KV cache: insert current k,v into cache, get the full view so far
|
||||
if kv_cache is not None:
|
||||
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
|
||||
Tq = q.size(2) # number of queries in this forward pass
|
||||
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
|
||||
|
||||
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
|
||||
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
|
||||
if kv_cache is None or Tq == Tk:
|
||||
# During training (no KV cache), attend as usual with causal attention
|
||||
# And even if there is KV cache, we can still use this simple version when Tq == Tk
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
||||
elif Tq == 1:
|
||||
# During inference but with a single query in this forward pass:
|
||||
# The query has to attend to all the keys/values in the cache
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
||||
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
|
||||
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
|
||||
if kv_cache is None:
|
||||
# Training: causal attention with optional sliding window
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
||||
else:
|
||||
# During inference AND we have a chunk of queries in this forward pass:
|
||||
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
||||
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
|
||||
prefix_len = Tk - Tq
|
||||
attn_mask[:, :prefix_len] = True
|
||||
# Then, causal attention within this chunk
|
||||
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
|
||||
# Inference: use flash_attn_with_kvcache which handles cache management
|
||||
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
|
||||
y = flash_attn.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache,
|
||||
k=k, v=v,
|
||||
cache_seqlens=kv_cache.cache_seqlens,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
)
|
||||
# Advance position after last layer processes
|
||||
if self.layer_idx == kv_cache.n_layers - 1:
|
||||
kv_cache.advance(T)
|
||||
|
||||
# Re-assemble the heads side by side and project back to residual stream
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
||||
# Re-assemble the heads and project back to residual stream
|
||||
y = y.contiguous().view(B, T, -1)
|
||||
y = self.c_proj(y)
|
||||
return y
|
||||
|
||||
|
|
@ -126,26 +138,44 @@ class Block(nn.Module):
|
|||
self.attn = CausalSelfAttention(config, layer_idx)
|
||||
self.mlp = MLP(config)
|
||||
|
||||
def forward(self, x, cos_sin, kv_cache):
|
||||
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
|
||||
x = x + self.mlp(norm(x))
|
||||
return x
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
def __init__(self, config, pad_vocab_size_to=64):
|
||||
"""
|
||||
NOTE a major footgun: this __init__ function runs in meta device context (!!)
|
||||
Therefore, any calculations inside here are shapes and dtypes only, no actual data.
|
||||
=> We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# For DDP, we want vocab_size divisible by world_size. Also, there are potential performance benefits, see:
|
||||
# Compute per-layer window sizes for sliding window attention
|
||||
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
|
||||
self.window_sizes = self._compute_window_sizes(config)
|
||||
# Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
|
||||
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
|
||||
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
|
||||
if padded_vocab_size != config.vocab_size:
|
||||
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} to be divisible by {pad_vocab_size_to}")
|
||||
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
|
||||
self.transformer = nn.ModuleDict({
|
||||
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||
})
|
||||
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
|
||||
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
||||
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
||||
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
||||
# Separate parameters so they can have different optimizer treatment
|
||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
# Value embeddings (ResFormer-style): alternating layers, last layer always included
|
||||
head_dim = config.n_embd // config.n_head
|
||||
kv_dim = config.n_kv_head * head_dim
|
||||
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
||||
|
|
@ -156,6 +186,7 @@ class GPT(nn.Module):
|
|||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def init_weights(self):
|
||||
"""
|
||||
Initialize the full model in this one function for maximum clarity.
|
||||
|
|
@ -186,14 +217,29 @@ class GPT(nn.Module):
|
|||
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
|
||||
# Per-layer scalars
|
||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
|
||||
|
||||
# Value embeddings (init like c_v: uniform with same std)
|
||||
for ve in self.value_embeds.values():
|
||||
torch.nn.init.uniform_(ve.weight, -s, s)
|
||||
|
||||
# Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral)
|
||||
for block in self.transformer.h:
|
||||
if block.attn.ve_gate is not None:
|
||||
torch.nn.init.zeros_(block.attn.ve_gate.weight)
|
||||
|
||||
# Rotary embeddings
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.cos, self.sin = cos, sin
|
||||
|
||||
# Cast token embeddings to bf16: optimizer can tolerate it and it saves memory
|
||||
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
for ve in self.value_embeds.values():
|
||||
ve.to(dtype=torch.bfloat16)
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
|
|
@ -212,6 +258,35 @@ class GPT(nn.Module):
|
|||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||
return cos, sin
|
||||
|
||||
def _compute_window_sizes(self, config):
|
||||
"""
|
||||
Compute per-layer window sizes for sliding window attention.
|
||||
|
||||
Returns list of (left, right) tuples for FA3's window_size parameter:
|
||||
- left: how many tokens before current position to attend to (-1 = unlimited)
|
||||
- right: how many tokens after current position to attend to (0 for causal)
|
||||
|
||||
Pattern string is tiled across layers. Final layer always gets L (full context).
|
||||
Characters: L=long (full context), S=short (half context)
|
||||
"""
|
||||
pattern = config.window_pattern.upper()
|
||||
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
|
||||
# Map characters to window sizes
|
||||
long_window = config.sequence_len
|
||||
short_window = long_window // 2
|
||||
char_to_window = {
|
||||
"L": (long_window, 0),
|
||||
"S": (short_window, 0),
|
||||
}
|
||||
# Tile pattern across layers
|
||||
window_sizes = []
|
||||
for layer_idx in range(config.n_layer):
|
||||
char = pattern[layer_idx % len(pattern)]
|
||||
window_sizes.append(char_to_window[char])
|
||||
# Final layer always gets full context
|
||||
window_sizes[-1] = (long_window, 0)
|
||||
return window_sizes
|
||||
|
||||
def get_device(self):
|
||||
return self.transformer.wte.weight.device
|
||||
|
||||
|
|
@ -220,16 +295,26 @@ class GPT(nn.Module):
|
|||
Return the estimated FLOPs per token for the model (forward + backward).
|
||||
Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
|
||||
Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
|
||||
On top of that, the term 12 * l * h * q * t accounts for key @ query matmul flops inside attention.
|
||||
On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention.
|
||||
With sliding windows, effective_seq_len varies per layer (capped by window size).
|
||||
Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
|
||||
This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
|
||||
- Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
|
||||
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
|
||||
"""
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
nparams_embedding = self.transformer.wte.weight.numel()
|
||||
l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
||||
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
attn_flops = 0
|
||||
for window_size in self.window_sizes:
|
||||
window = window_size[0] # (left, right) tuple, we use left
|
||||
effective_seq = t if window < 0 else min(window, t)
|
||||
attn_flops += 12 * h * q * effective_seq
|
||||
num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
|
||||
return num_flops_per_token
|
||||
|
||||
def num_scaling_params(self):
|
||||
|
|
@ -244,27 +329,33 @@ class GPT(nn.Module):
|
|||
nparams = sum(p.numel() for p in self.parameters())
|
||||
return nparams
|
||||
|
||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95)):
|
||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
|
||||
# Separate out all parameters into groups
|
||||
matrix_params = list(self.transformer.h.parameters())
|
||||
value_embeds_params = list(self.value_embeds.parameters())
|
||||
embedding_params = list(self.transformer.wte.parameters())
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
|
||||
# Create the AdamW optimizer for the embedding and lm_head
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
||||
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
adam_groups = [
|
||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||
dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
||||
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
|
||||
dict(params=x0_params, lr=scalar_lr),
|
||||
]
|
||||
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=weight_decay)
|
||||
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
|
||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
||||
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
|
||||
# Create the Muon optimizer for the linear layers
|
||||
muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
|
||||
muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay)
|
||||
MuonFactory = DistMuon if ddp else Muon
|
||||
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
|
||||
# Combine them the two optimizers into one list
|
||||
|
|
@ -288,8 +379,11 @@ class GPT(nn.Module):
|
|||
# Forward the trunk of the Transformer
|
||||
x = self.transformer.wte(idx)
|
||||
x = norm(x)
|
||||
for block in self.transformer.h:
|
||||
x = block(x, cos_sin, kv_cache)
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
x = norm(x)
|
||||
|
||||
# Forward the lm_head (compute logits)
|
||||
|
|
|
|||
401
nanochat/muon.py
401
nanochat/muon.py
|
|
@ -1,39 +1,96 @@
|
|||
"""
|
||||
Muon optimizer from Keller et al.
|
||||
Also a lot of borrowing of ideas from modded-nanogpt.
|
||||
Muon optimizer adapted and simplified from modded-nanogpt.
|
||||
https://github.com/KellerJordan/modded-nanogpt
|
||||
|
||||
Background:
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
|
||||
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
|
||||
Polar Express Sign Method for orthogonalization.
|
||||
https://arxiv.org/pdf/2505.16932
|
||||
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
||||
|
||||
Some of the changes in nanochat implementation:
|
||||
- Uses a simpler, more general approach to parameter grouping and stacking
|
||||
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
|
||||
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.distributed as dist
|
||||
|
||||
@torch.compile
|
||||
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
||||
"""
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
"""
|
||||
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
X = G.bfloat16()
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
||||
# From https://arxiv.org/pdf/2505.16932
|
||||
polar_express_coeffs = [
|
||||
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
||||
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
||||
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
||||
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
||||
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
||||
]
|
||||
|
||||
# Ensure spectral norm is at most 1
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
||||
# Perform the NS iterations
|
||||
for _ in range(steps):
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def muon_step_fused(
|
||||
stacked_grads: Tensor,
|
||||
stacked_params: Tensor,
|
||||
momentum_buffer: Tensor,
|
||||
second_momentum_buffer: Tensor,
|
||||
momentum_t: Tensor,
|
||||
lr_t: Tensor,
|
||||
wd_t: Tensor,
|
||||
beta2_t: Tensor,
|
||||
ns_steps: int,
|
||||
red_dim: int,
|
||||
) -> None:
|
||||
"""
|
||||
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
|
||||
"""
|
||||
|
||||
# Nesterov momentum
|
||||
momentum = momentum_t.to(stacked_grads.dtype)
|
||||
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
||||
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
||||
|
||||
# Polar express
|
||||
X = g.bfloat16()
|
||||
if g.size(-2) > g.size(-1):
|
||||
X = X.mT
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
A = X @ X.mT
|
||||
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
|
||||
if G.size(-2) > G.size(-1):
|
||||
if g.size(-2) > g.size(-1):
|
||||
X = X.mT
|
||||
return X
|
||||
g = X
|
||||
|
||||
# Variance reduction
|
||||
beta2 = beta2_t.to(g.dtype)
|
||||
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
||||
red_dim_size = g.size(red_dim)
|
||||
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
||||
v_norm = v_norm_sq.sqrt()
|
||||
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
||||
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
||||
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
||||
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
||||
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
||||
g = g * final_scale.to(g.dtype)
|
||||
|
||||
# Cautious weight decay + parameter update
|
||||
lr = lr_t.to(g.dtype)
|
||||
wd = wd_t.to(g.dtype)
|
||||
mask = (g * stacked_params) >= 0
|
||||
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
||||
|
||||
class Muon(torch.optim.Optimizer):
|
||||
"""
|
||||
|
|
@ -54,74 +111,112 @@ class Muon(torch.optim.Optimizer):
|
|||
Arguments:
|
||||
lr: The learning rate used by the internal SGD.
|
||||
momentum: The momentum used by the internal SGD.
|
||||
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
|
||||
ns_steps: The number of Newton-Schulz iteration steps to use.
|
||||
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
|
||||
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
||||
"""
|
||||
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
|
||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
|
||||
params: list[Tensor] = [*params]
|
||||
def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
params = list(params) # ensure we have a list, not an e.g. (exhaustible) iterator
|
||||
# Group by shape so we can stack tensors
|
||||
shapes = sorted({p.shape for p in params})
|
||||
param_groups = []
|
||||
for size in {p.numel() for p in params}:
|
||||
group = dict(params=[p for p in params if p.numel() == size])
|
||||
param_groups.append(group)
|
||||
for shape in shapes:
|
||||
group_params = [p for p in params if p.shape == shape]
|
||||
param_groups.append(dict(params=group_params))
|
||||
super().__init__(param_groups, defaults)
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
for p in params:
|
||||
g = p.grad
|
||||
assert g is not None
|
||||
state = self.state[p]
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros_like(g)
|
||||
buf: Tensor = state["momentum_buffer"]
|
||||
buf.lerp_(g, 1 - group["momentum"])
|
||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||||
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
|
||||
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
|
||||
if not params:
|
||||
continue
|
||||
|
||||
# Get or create group-level buffers (stored in first param's state for convenience)
|
||||
state = self.state[params[0]]
|
||||
num_params = len(params) # e.g.: 12 (for a d12 model)
|
||||
# e.g.: shape = (768, 3072), device = cuda:0, dtype = torch.float32, for one of the MLP projections
|
||||
shape, device, dtype = params[0].shape, params[0].device, params[0].dtype
|
||||
|
||||
# Momentum for every individual parameter
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"] # e.g.: (12, 768, 3072)
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
if shape[-2] >= shape[-1]:
|
||||
state["second_momentum_buffer"] = torch.zeros(num_params, shape[-2], 1, dtype=dtype, device=device)
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros(num_params, 1, shape[-1], dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"] # (12, 1, 3072)
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2 # e.g.: -2
|
||||
|
||||
# Stack grads and params
|
||||
stacked_grads = torch.stack([p.grad for p in params]) # (12, 768, 3072)
|
||||
stacked_params = torch.stack(params) # (12, 768, 3072)
|
||||
|
||||
# Fill all the 0-D tensors with current values
|
||||
self._momentum_t.fill_(group["momentum"])
|
||||
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
stacked_grads,
|
||||
stacked_params,
|
||||
momentum_buffer,
|
||||
second_momentum_buffer,
|
||||
self._momentum_t,
|
||||
self._lr_t,
|
||||
self._wd_t,
|
||||
self._beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy back to original params: [(768, 3072), (768, 3072), ...] <- (12, 768, 3072)
|
||||
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
||||
|
||||
|
||||
class DistMuon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Newton–Schulz,
|
||||
finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
|
||||
- reduce_scatter(AVG) for gradient averaging
|
||||
- all_gather to replicate updated weights
|
||||
|
||||
Notes:
|
||||
* Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D
|
||||
params like embeddings or scalars.
|
||||
* Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen
|
||||
by block-cyclic assignment below). If you checkpoint optimizer state on a single rank,
|
||||
consolidate states beforehand.
|
||||
|
||||
Args:
|
||||
params: iterable of Tensors
|
||||
lr: learning rate
|
||||
momentum: momentum coefficient in [0,1)
|
||||
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
|
||||
ns_steps: number of Newton–Schulz iterations for the orthogonalization
|
||||
Distributed version of the Muon optimizer.
|
||||
"""
|
||||
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
|
||||
nesterov: bool = True, ns_steps: int = 5):
|
||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
|
||||
params = list(params)
|
||||
ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
params = list(params)
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
# Group all parameters by their shape
|
||||
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
|
||||
shapes = sorted({p.shape for p in params}) # sort for deterministic ordering across ranks
|
||||
param_groups = []
|
||||
for shape in shapes:
|
||||
group_params = [p for p in params if p.shape == shape]
|
||||
device, dtype = group_params[0].device, group_params[0].dtype
|
||||
assert all(p.device == device for p in group_params)
|
||||
assert all(p.dtype == dtype for p in group_params)
|
||||
# Compute chunk size for this group (how many params each rank owns)
|
||||
chunk_size = (len(group_params) + world_size - 1) // world_size
|
||||
if rank == 0:
|
||||
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
|
||||
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
|
||||
print(f"Muon: {len(group_params)} params of shape {shape}, chunk_size={chunk_size}")
|
||||
param_groups.append(dict(params=group_params, chunk_size=chunk_size))
|
||||
super().__init__(param_groups, defaults)
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
|
|
@ -131,57 +226,127 @@ class DistMuon(torch.optim.Optimizer):
|
|||
# Ensure all grads exist
|
||||
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
|
||||
|
||||
# Kick off all the reduce scatter operations to average up the gradients across all ranks
|
||||
all_reduce_futures = []
|
||||
# First pass: stack grads and kick off reduce_scatter for each group
|
||||
group_infos = []
|
||||
for group in self.param_groups:
|
||||
params = group["params"]
|
||||
zero_buffer = group["zero_buffer"]
|
||||
# Go through params in groups of world_size.
|
||||
for base_i in range(0, len(params), world_size):
|
||||
# The compute owner of each param is rank i % world_size
|
||||
owner_idx = base_i + rank
|
||||
# each rank stacks up its chunk of world_size params into a list
|
||||
rs_input = [p.grad for p in params[base_i:base_i + world_size]]
|
||||
# pad rs_input with the zero buffer to complete the group
|
||||
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
|
||||
# the output buffer gets strided across the group based on the rank
|
||||
rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
|
||||
# reduce scatter the gradients within this group of world_size params
|
||||
work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
all_reduce_futures.append(work)
|
||||
params: list[Tensor] = group["params"]
|
||||
chunk_size = group["chunk_size"]
|
||||
padded_num_params = chunk_size * world_size
|
||||
shape = params[0].shape
|
||||
device, dtype = params[0].device, params[0].dtype
|
||||
|
||||
# Now each rank computes the update and gathers
|
||||
future_idx = 0
|
||||
# Stack all gradients into a single tensor (single kernel via torch.stack)
|
||||
grad_stack = torch.stack([p.grad for p in params])
|
||||
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
|
||||
stacked_grads[:len(params)].copy_(grad_stack)
|
||||
# Zero-pad if we have fewer params than padded size
|
||||
if len(params) < padded_num_params:
|
||||
stacked_grads[len(params):].zero_()
|
||||
|
||||
# Output buffer for this rank's chunk
|
||||
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
|
||||
# Async reduce_scatter on the stacked tensor
|
||||
reduce_future = dist.reduce_scatter_tensor(
|
||||
grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True
|
||||
).get_future()
|
||||
|
||||
group_infos.append(dict(
|
||||
grad_chunk=grad_chunk,
|
||||
reduce_future=reduce_future,
|
||||
stacked_grads=stacked_grads, # reuse for all_gather output
|
||||
))
|
||||
|
||||
# Second pass: wait for reduce, compute batched updates, kick off all_gather
|
||||
all_gather_futures = []
|
||||
for group in self.param_groups:
|
||||
params = group["params"]
|
||||
zero_buffer = group["zero_buffer"]
|
||||
# Go through params in groups of world_size.
|
||||
for base_i in range(0, len(params), world_size):
|
||||
# The compute owner of each param is rank i % world_size
|
||||
owner_idx = base_i + rank # calculate the index of the param that this rank owns
|
||||
# Wait for the reduce scatter to complete
|
||||
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
|
||||
future_idx += 1
|
||||
# Owner computes the Muon update, result is in its param
|
||||
if owner_idx < len(params):
|
||||
p = params[owner_idx]
|
||||
g = p.grad # now averaged across ranks
|
||||
state = self.state[p]
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros_like(g)
|
||||
buf: Tensor = state["momentum_buffer"]
|
||||
buf.lerp_(g, 1.0 - group["momentum"])
|
||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||||
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
|
||||
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
|
||||
p.add_(g, alpha=-group["lr"] * scale)
|
||||
# Replicate updated parameters to all ranks
|
||||
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
|
||||
ag_output = params[base_i:base_i + world_size]
|
||||
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
|
||||
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
|
||||
all_gather_futures.append(work)
|
||||
for group, info in zip(self.param_groups, group_infos):
|
||||
info["reduce_future"].wait()
|
||||
|
||||
# Wait for all work to finish
|
||||
torch.futures.collect_all(all_gather_futures).wait()
|
||||
params = group["params"]
|
||||
chunk_size = group["chunk_size"]
|
||||
shape = params[0].shape
|
||||
device, dtype = params[0].device, params[0].dtype
|
||||
grad_chunk = info["grad_chunk"]
|
||||
|
||||
# How many params does this rank actually own?
|
||||
start_idx = rank * chunk_size
|
||||
num_owned = min(chunk_size, max(0, len(params) - start_idx))
|
||||
|
||||
# Get or create group-level state (stored keyed by first param)
|
||||
state = self.state[params[0]]
|
||||
|
||||
# Momentum buffer
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"]
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
if shape[-2] >= shape[-1]:
|
||||
state["second_momentum_buffer"] = torch.zeros(chunk_size, shape[-2], 1, dtype=dtype, device=device)
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros(chunk_size, 1, shape[-1], dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"]
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
||||
# Build updated_params tensor for all_gather
|
||||
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
|
||||
if num_owned > 0:
|
||||
# Stack owned params (single kernel via torch.stack)
|
||||
owned_params = [params[start_idx + i] for i in range(num_owned)]
|
||||
stacked_owned_params = torch.stack(owned_params)
|
||||
|
||||
# Get owned slices of buffers and grads
|
||||
owned_grads = grad_chunk[:num_owned]
|
||||
owned_momentum = momentum_buffer[:num_owned]
|
||||
owned_second_momentum = second_momentum_buffer[:num_owned]
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
self._momentum_t.fill_(group["momentum"])
|
||||
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
owned_grads,
|
||||
stacked_owned_params,
|
||||
owned_momentum,
|
||||
owned_second_momentum,
|
||||
self._momentum_t,
|
||||
self._lr_t,
|
||||
self._wd_t,
|
||||
self._beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy updated params to output buffer
|
||||
updated_params[:num_owned].copy_(stacked_owned_params)
|
||||
|
||||
# Zero-pad the rest (for ranks that own fewer params)
|
||||
if num_owned < chunk_size:
|
||||
updated_params[num_owned:].zero_()
|
||||
|
||||
# Reuse stacked_grads buffer for all_gather output
|
||||
stacked_params = info["stacked_grads"]
|
||||
|
||||
# Async all_gather to replicate updated params to all ranks
|
||||
gather_future = dist.all_gather_into_tensor(
|
||||
stacked_params, updated_params, async_op=True
|
||||
).get_future()
|
||||
|
||||
all_gather_futures.append(dict(
|
||||
gather_future=gather_future,
|
||||
stacked_params=stacked_params,
|
||||
params=params,
|
||||
))
|
||||
|
||||
# Final pass: wait for all_gather and copy back to params
|
||||
for info in all_gather_futures:
|
||||
info["gather_future"].wait()
|
||||
stacked_params = info["stacked_params"]
|
||||
params = info["params"]
|
||||
# Batched copy back (single kernel instead of N individual copies)
|
||||
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ SPECIAL_TOKENS = [
|
|||
|
||||
# NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
|
||||
# I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
|
||||
# I haven't validated that this is actually a good idea, TODO.
|
||||
# I verified that 2 is the sweet spot for vocab size of 32K. 1 is a bit worse, 3 was worse still.
|
||||
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -103,9 +103,10 @@ class HuggingFaceTokenizer:
|
|||
def id_to_token(self, id):
|
||||
return self.tokenizer.id_to_token(id)
|
||||
|
||||
def _encode_one(self, text, prepend=None, append=None):
|
||||
def _encode_one(self, text, prepend=None, append=None, num_threads=None):
|
||||
# encode a single string
|
||||
# prepend/append can be either a string of a special token or a token id directly.
|
||||
# num_threads is ignored (only used by the nanochat Tokenizer for parallel encoding)
|
||||
assert isinstance(text, str)
|
||||
ids = []
|
||||
if prepend is not None:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ dependencies = [
|
|||
"datasets>=4.0.0",
|
||||
"fastapi>=0.117.1",
|
||||
"ipykernel>=7.1.0",
|
||||
"kernels>=0.11.7",
|
||||
"matplotlib>=3.10.8",
|
||||
"psutil>=7.1.0",
|
||||
"python-dotenv>=1.2.1",
|
||||
|
|
@ -22,6 +23,7 @@ dependencies = [
|
|||
"transformers>=4.57.3",
|
||||
"uvicorn>=0.36.0",
|
||||
"wandb>=0.21.3",
|
||||
"zstandard>=0.25.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
|
|
|||
|
|
@ -1,29 +1,40 @@
|
|||
#!/bin/bash
|
||||
|
||||
# See speedrun.sh for more comments
|
||||
# Usage: ./miniseries.sh [series_name]
|
||||
# Example: ./miniseries.sh jan11
|
||||
# Default series name is today's date (e.g., jan11)
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
|
||||
# uv
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra gpu
|
||||
source .venv/bin/activate
|
||||
# Setup (skip with SKIP_SETUP=1)
|
||||
if [ -z "$SKIP_SETUP" ]; then
|
||||
# uv
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra gpu
|
||||
source .venv/bin/activate
|
||||
|
||||
# Tokenizer
|
||||
python -m nanochat.dataset -n 240
|
||||
python -m scripts.tok_train --max_chars=2000000000 --vocab_size=32768
|
||||
# Tokenizer, download 1000 shards for pretraining
|
||||
# (probably this can be reduced but it's tricky to determine the exact right number, TODO).
|
||||
python -m nanochat.dataset -n 1000
|
||||
python -m scripts.tok_train --max-chars=2000000000 --vocab-size=32768
|
||||
else
|
||||
source .venv/bin/activate
|
||||
fi
|
||||
|
||||
# Series name: from arg, env var, or default to today's date (e.g., jan11)
|
||||
SERIES_NAME="${1:-${SERIES_NAME:-$(date +%b%d | tr '[:upper:]' '[:lower:]')}}"
|
||||
# Depths to train (the "miniseries")
|
||||
DEPTHS=(10 11 12 13 14 15 16 17 18 19 20)
|
||||
# Hardware
|
||||
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
||||
# Logging
|
||||
WANDB_RUN="${WANDB_RUN:-jan7_miniseries}"
|
||||
WANDB_RUN="${WANDB_RUN:-${SERIES_NAME}_miniseries}"
|
||||
|
||||
RESULTS_DIR="$NANOCHAT_BASE_DIR/jan7_miniseries_results"
|
||||
RESULTS_DIR="$NANOCHAT_BASE_DIR/${SERIES_NAME}_miniseries_results"
|
||||
mkdir -p "$RESULTS_DIR"
|
||||
RESULTS_FILE="$RESULTS_DIR/results.csv"
|
||||
|
||||
|
|
@ -37,26 +48,25 @@ log() {
|
|||
}
|
||||
|
||||
log "=============================================="
|
||||
log "Jan 7 Miniseries Training"
|
||||
log "${SERIES_NAME} Miniseries Training"
|
||||
log "=============================================="
|
||||
|
||||
for d in "${DEPTHS[@]}"; do
|
||||
log "Training d=$d..."
|
||||
|
||||
TAG="jan7_miniseries_d${d}"
|
||||
TAG="${SERIES_NAME}_miniseries_d${d}"
|
||||
START_TIME=$(date +%s)
|
||||
|
||||
# Train the model with natural horizon (target_param_data_ratio default)
|
||||
# No --target_flops, let it use the default ratio from base_train
|
||||
# No --target-flops, let it use the default ratio from base_train
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
|
||||
--depth=$d \
|
||||
--target_param_data_ratio=8 \
|
||||
--run="${WANDB_RUN}_d${d}" \
|
||||
--model_tag="${TAG}" \
|
||||
--core_metric_every=999999 \
|
||||
--core_metric_max_per_task=-1 \
|
||||
--sample_every=-1 \
|
||||
--save_every=-1 \
|
||||
--model-tag="${TAG}" \
|
||||
--core-metric-every=999999 \
|
||||
--core-metric-max-per-task=-1 \
|
||||
--sample-every=-1 \
|
||||
--save-every=-1 \
|
||||
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
|
||||
|
||||
END_TIME=$(date +%s)
|
||||
|
|
@ -84,7 +94,7 @@ for d in "${DEPTHS[@]}"; do
|
|||
done
|
||||
|
||||
log "=============================================="
|
||||
log "Jan 7 Miniseries Complete!"
|
||||
log "${SERIES_NAME} Miniseries Complete!"
|
||||
log "=============================================="
|
||||
log "Results saved to: $RESULTS_FILE"
|
||||
echo ""
|
||||
|
|
@ -20,18 +20,18 @@ curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-publ
|
|||
|
||||
# train tokenizer on ~4B characters and kick off download of the rest for pretraining
|
||||
python -m nanochat.dataset -n 16
|
||||
# start downloading the rest of the shards for a total of 800 (see below why 800)
|
||||
python -m nanochat.dataset -n 800 &
|
||||
# start downloading the rest of the shards for a total of 1200 (see below why 1200)
|
||||
python -m nanochat.dataset -n 1200 &
|
||||
# todo: download the rest of it
|
||||
python -m scripts.tok_train --max_chars=4000000000 --vocab_size=65536
|
||||
python -m scripts.tok_train --max-chars=4000000000 --vocab-size=65536
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# Documenting my process for determining the hyperparameters for this run1000.sh script:
|
||||
# We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute
|
||||
# 1) I guessed the model size for this to be about depth=32
|
||||
# 2) Determine the device_batch_size that fits:
|
||||
# Running the base_train.py script with --depth=32, I saw that --device_batch_size=16
|
||||
# runs out of memory, but --device_batch_size=8 fits. Inspecting `nvidia-smi` during training,
|
||||
# Running the base_train.py script with --depth=32, I saw that --device-batch-size=16
|
||||
# runs out of memory, but --device-batch-size=8 fits. Inspecting `nvidia-smi` during training,
|
||||
# I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%.
|
||||
# So the training script was running ok and showed:
|
||||
# Vocab size: 65,536
|
||||
|
|
@ -62,7 +62,9 @@ python -m scripts.tok_eval
|
|||
# The tok_eval.py script reports about ~4.8 chars/token on average for the default tokenizer settings.
|
||||
# So ~38B tokens # ~4.8 chars/token = ~185B chars.
|
||||
# Each data shard is ~250M chars, so we need ~185B / 250M ~= 740 shards.
|
||||
# For safety, I bumped that up to 800 shards, and that's why up above I used -n 800 when pre-downloading dataset shards.
|
||||
# For safety, I bumped that up to 800 shards.
|
||||
# The new DataLoader wastes about 35% of tokens to cropping, so 800 / (1 - 0.35) ~= 1200 shards are needed.
|
||||
# => why up above I used -n 1200 when pre-downloading dataset shards.
|
||||
# If we didn't have enough data, the training script would loop around and do multiple epochs over the same data,
|
||||
# which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd
|
||||
# start to overfit hard.
|
||||
|
|
@ -71,13 +73,13 @@ python -m scripts.tok_eval
|
|||
# Number of processes/GPUs to use
|
||||
NPROC_PER_NODE=8
|
||||
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target_param_data_ratio=20 --device_batch_size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target-param-data-ratio=20 --device-batch-size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
|
||||
|
||||
# midtrain
|
||||
# NOTE: ensure that we use the same device_batch_size here as the base training script.
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device-batch-size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
|
||||
|
||||
# sft
|
||||
70
runs/runcpu.sh
Executable file
70
runs/runcpu.sh
Executable file
|
|
@ -0,0 +1,70 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
|
||||
# This script was last updated/tuned on Jan 17, 2026.
|
||||
|
||||
# Run as:
|
||||
# bash dev/cpu_demo_run.sh
|
||||
|
||||
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
|
||||
# Think of this run as educational/fun demo, not something you should expect to work well.
|
||||
# (This is why I hide this script away in dev/)
|
||||
# You may also want to run this script manually and one by one, copy pasting commands into your terminal.
|
||||
|
||||
# all the setup stuff
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra cpu
|
||||
source .venv/bin/activate
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
|
||||
# train tokenizer on ~2B characters (~34 seconds on my MacBook Pro M3 Max)
|
||||
python -m nanochat.dataset -n 8
|
||||
python -m scripts.tok_train --max-chars=2000000000
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# train a small 4 layer model
|
||||
# I tuned this run to complete in about 30 minutes on my MacBook Pro M3 Max.
|
||||
# To get better results, try increasing num_iterations, or get other ideas from your favorite LLM.
|
||||
python -m scripts.base_train \
|
||||
--depth=6 \
|
||||
--head-dim=64 \
|
||||
--window-pattern=L \
|
||||
--max-seq-len=512 \
|
||||
--device-batch-size=32 \
|
||||
--total-batch-size=16384 \
|
||||
--eval-every=100 \
|
||||
--eval-tokens=524288 \
|
||||
--core-metric-every=-1 \
|
||||
--sample-every=100 \
|
||||
--num-iterations=5000 \
|
||||
--run=$WANDB_RUN
|
||||
python -m scripts.base_loss --device-batch-size=1 --split-tokens=16384
|
||||
python -m scripts.base_eval --max-per-task=16
|
||||
|
||||
# midtraining (~10 minutes on my MacBook Pro M3 Max)
|
||||
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
python -m scripts.mid_train \
|
||||
--max-seq-len=512 \
|
||||
--device-batch-size=32 \
|
||||
--total-batch-size=16384 \
|
||||
--eval-every=200 \
|
||||
--eval-tokens=524288 \
|
||||
--num-iterations=1500 \
|
||||
--run=$WANDB_RUN
|
||||
|
||||
# (it's ~ok to skip SFT)
|
||||
|
||||
# Chat with the model over CLI
|
||||
# The model should be able to say that it is Paris.
|
||||
# It might even know that the color of the sky is blue.
|
||||
# Sometimes the model likes it if you first say Hi before you ask it questions.
|
||||
# python -m scripts.chat_cli -i mid -p "What is the capital of France?"
|
||||
|
||||
# Chat with the model over a pretty WebUI ChatGPT style
|
||||
# python -m scripts.chat_web -i mid
|
||||
|
|
@ -1,20 +1,23 @@
|
|||
#!/bin/bash
|
||||
|
||||
LABEL="jan16"
|
||||
|
||||
FLOPS_BUDGETS=(
|
||||
1e18
|
||||
3e18
|
||||
6e18
|
||||
)
|
||||
DEPTHS=(8 10 12 14 16 18 20)
|
||||
DEPTHS=(6 7 8 9 10 11 12 13 14)
|
||||
|
||||
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
||||
WANDB_RUN="${WANDB_RUN:-scaling}"
|
||||
WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}"
|
||||
EVAL_TOKENS=$((100 * 524288)) # ~100M tokens for final eval (default is ~10M)
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat}"
|
||||
source .venv/bin/activate
|
||||
|
||||
RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results"
|
||||
RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results_${LABEL}"
|
||||
mkdir -p "$RESULTS_DIR"
|
||||
RESULTS_FILE="$RESULTS_DIR/results.csv"
|
||||
|
||||
|
|
@ -64,15 +67,15 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
|
|||
# CORE eval happens once at the end (999999 ensures only final step)
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
|
||||
--depth=$d \
|
||||
--target_flops=$flops \
|
||||
--target_param_data_ratio=-1 \
|
||||
--target-flops=$flops \
|
||||
--target-param-data-ratio=-1 \
|
||||
--run="${WANDB_RUN}_${TAG}" \
|
||||
--model_tag="${TAG}" \
|
||||
--eval_tokens=$EVAL_TOKENS \
|
||||
--core_metric_every=999999 \
|
||||
--core_metric_max_per_task=-1 \
|
||||
--sample_every=-1 \
|
||||
--save_every=-1 \
|
||||
--model-tag="${TAG}" \
|
||||
--eval-tokens=$EVAL_TOKENS \
|
||||
--core-metric-every=999999 \
|
||||
--core-metric-max-per-task=-1 \
|
||||
--sample-every=-1 \
|
||||
--save-every=-1 \
|
||||
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
|
||||
|
||||
END_TIME=$(date +%s)
|
||||
|
|
@ -55,11 +55,11 @@ python -m nanochat.report reset
|
|||
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
|
||||
python -m nanochat.dataset -n 8
|
||||
# Immediately also kick off downloading more shards in the background while tokenizer trains
|
||||
# See comment below for why 240 is the right number here
|
||||
python -m nanochat.dataset -n 240 &
|
||||
# See comment below for why 370 is the right number here
|
||||
python -m nanochat.dataset -n 370 &
|
||||
DATASET_DOWNLOAD_PID=$!
|
||||
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
|
||||
python -m scripts.tok_train --max_chars=2000000000 --vocab_size=65536
|
||||
python -m scripts.tok_train --max-chars=2000000000 --vocab-size=65536
|
||||
# evaluate the tokenizer (report compression ratio etc.)
|
||||
python -m scripts.tok_eval
|
||||
|
||||
|
|
@ -70,7 +70,9 @@ python -m scripts.tok_eval
|
|||
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
|
||||
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
|
||||
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
|
||||
# Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk.
|
||||
# Round up to 240 for safety. Also, the new DataLoader wastes about 35% of tokens to cropping
|
||||
# so 240 / (1 - 0.35) = 370 shards are needed.
|
||||
# At ~100MB/shard, this downloads ~37GB of data to disk.
|
||||
# (The total number of shards available in the entire dataset is 1822.)
|
||||
echo "Waiting for dataset download to complete..."
|
||||
wait $DATASET_DOWNLOAD_PID
|
||||
|
|
@ -79,7 +81,7 @@ wait $DATASET_DOWNLOAD_PID
|
|||
NPROC_PER_NODE=8
|
||||
|
||||
# pretrain the d20 model
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target_param_data_ratio=20 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target-param-data-ratio=20 --run=$WANDB_RUN
|
||||
# evaluate the model on a larger chunk of train/val data and draw some samples
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
|
||||
# evaluate the model on CORE tasks
|
||||
|
|
@ -5,49 +5,108 @@ Loads a checkpoint, and:
|
|||
|
||||
Example run as:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||
|
||||
To evaluate a HuggingFace model:
|
||||
python -m scripts.base_loss --hf-path openai-community/gpt2
|
||||
"""
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
import torch
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
|
||||
from nanochat.tokenizer import get_token_bytes, HuggingFaceTokenizer
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# HuggingFace loading utilities, making the APIs match up to those of nanochat
|
||||
|
||||
class ModelWrapper:
|
||||
"""Lightweight wrapper for a HuggingFace model"""
|
||||
def __init__(self, model, max_seq_len=None):
|
||||
self.model = model
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
def __call__(self, input_ids, targets=None, loss_reduction='mean'):
|
||||
logits = self.model(input_ids).logits
|
||||
if targets is None:
|
||||
return logits
|
||||
else:
|
||||
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
||||
return loss
|
||||
|
||||
def get_device(self):
|
||||
return next(self.model.parameters()).device
|
||||
|
||||
def load_hf_model(hf_path: str, device):
|
||||
print0(f"Loading model from: {hf_path}")
|
||||
from transformers import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained(hf_path)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None
|
||||
model = ModelWrapper(model, max_seq_len=max_seq_len)
|
||||
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
|
||||
return model, tokenizer
|
||||
|
||||
def get_hf_token_bytes(tokenizer, device="cpu"):
|
||||
"""Compute token_bytes tensor for a HuggingFace tokenizer."""
|
||||
vocab_size = tokenizer.tokenizer.get_vocab_size()
|
||||
token_bytes = torch.zeros(vocab_size, dtype=torch.int64, device=device)
|
||||
for token_id in range(vocab_size):
|
||||
token_str = tokenizer.tokenizer.decode([token_id])
|
||||
token_bytes[token_id] = len(token_str.encode('utf-8')) # Count UTF-8 bytes
|
||||
return token_bytes
|
||||
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model")
|
||||
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--split_tokens", type=int, default=20*524288, help="number of tokens to evaluate per split")
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="model tag for checkpoint directory")
|
||||
parser.add_argument("--model_step", type=int, default=None, help="model step to load")
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--split-tokens", type=int, default=40*524288, help="number of tokens to evaluate per split")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag for checkpoint directory")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--hf-path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the base model and the tokenizer
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
||||
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
||||
print0(f"Device: {device} | DDP rank: {ddp_rank} | DDP local rank: {ddp_local_rank} | DDP world size: {ddp_world_size}")
|
||||
|
||||
if args.hf_path is not None:
|
||||
# Load HuggingFace model
|
||||
model, tokenizer = load_hf_model(args.hf_path, device)
|
||||
sequence_len = model.max_seq_len if model.max_seq_len else 1024
|
||||
token_bytes = get_hf_token_bytes(tokenizer, device=device)
|
||||
model_name = args.hf_path
|
||||
else:
|
||||
# Load local nanochat model
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
||||
sequence_len = meta["model_config"]["sequence_len"]
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
model_name = f"base_model (step {meta['step']})"
|
||||
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
print0(f"Evaluating model: {model_name}")
|
||||
|
||||
# Evaluate the loss on each split
|
||||
tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size
|
||||
assert args.split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
|
||||
steps = args.split_tokens // tokens_per_step
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
bpb_results = {}
|
||||
for split_name in ["train", "val"]:
|
||||
loader = tokenizing_distributed_data_loader(args.device_batch_size, sequence_len, split_name, device=device)
|
||||
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
|
||||
with autocast_ctx:
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
print0(f"{split_name} bpb: {bpb:.4f}")
|
||||
bpb_results[split_name] = bpb
|
||||
print0(f"Model: {model_name}, {split_name} bpb: {bpb:.6f}")
|
||||
|
||||
# Master process also samples from the model
|
||||
# Master process also samples from the model for some basic knowledge-eliciting prompts (only for nanochat models)
|
||||
samples = []
|
||||
if ddp_rank == 0:
|
||||
if ddp_rank == 0 and args.hf_path is None:
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The chemical symbol of gold is",
|
||||
|
|
@ -63,17 +122,33 @@ if ddp_rank == 0:
|
|||
with autocast_ctx:
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
sample_str = tokenizer.decode(sample[0])
|
||||
print0("-" * 80)
|
||||
print0(sample_str)
|
||||
samples.append(sample_str)
|
||||
|
||||
# Draw some unconditioned samples from the model (only for nanochat models)
|
||||
unconditioned_samples = []
|
||||
if ddp_rank == 0 and args.hf_path is None:
|
||||
engine = Engine(model, tokenizer)
|
||||
tokens = tokenizer("", prepend="<|bos|>")
|
||||
with autocast_ctx:
|
||||
samples, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
|
||||
for sample in samples:
|
||||
sample_str = tokenizer.decode(sample)
|
||||
print0("-" * 80)
|
||||
print0(sample_str)
|
||||
unconditioned_samples.append(sample_str)
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Base model loss", data=[
|
||||
{
|
||||
"model": model_name,
|
||||
"train bpb": bpb_results["train"],
|
||||
"val bpb": bpb_results["val"],
|
||||
},
|
||||
{f"sample {i}": sample for i, sample in enumerate(samples)},
|
||||
{f"unconditioned sample {i}": sample for i, sample in enumerate(unconditioned_samples)},
|
||||
])
|
||||
|
||||
# Cleanup
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ or distributed as:
|
|||
torchrun --nproc_per_node=8 -m scripts.base_train.py
|
||||
|
||||
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
||||
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
|
||||
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -21,12 +21,13 @@ import wandb
|
|||
import torch
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.flash_attention import HAS_FA3
|
||||
from scripts.base_eval import evaluate_model
|
||||
print_banner()
|
||||
|
||||
|
|
@ -36,38 +37,40 @@ parser = argparse.ArgumentParser(description="Pretrain base model")
|
|||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
# Model architecture
|
||||
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
|
||||
parser.add_argument("--aspect_ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
|
||||
parser.add_argument("--head_dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
|
||||
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num_iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target_flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
parser.add_argument("--target_param_data_ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
parser.add_argument("--target-param-data-ratio", type=int, default=4, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
# Optimization
|
||||
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens")
|
||||
parser.add_argument("--embedding_lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
|
||||
parser.add_argument("--warmup_ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
||||
parser.add_argument("--warmdown_ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final_lr_frac", type=float, default=0.0, help="final LR as fraction of initial LR")
|
||||
parser.add_argument("--resume_from_step", type=int, default=-1, help="resume training from this step (-1 = disable)")
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
||||
parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
|
||||
parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
|
||||
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
|
||||
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval_every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--core_metric_every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
|
||||
parser.add_argument("--core_metric_max_per_task", type=int, default=500, help="examples per task for CORE metric")
|
||||
parser.add_argument("--sample_every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
|
||||
parser.add_argument("--save_every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
|
||||
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
|
||||
parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
|
||||
parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
|
||||
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
|
||||
# Output
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="override model tag for checkpoint directory name")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy() # for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -79,11 +82,29 @@ master_process = ddp_rank == 0 # this process will do logging, checkpointing etc
|
|||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
if device_type == "cuda":
|
||||
gpu_device_name = torch.cuda.get_device_name(0)
|
||||
gpu_peak_flops = get_peak_flops(gpu_device_name)
|
||||
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
|
||||
else:
|
||||
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
|
||||
|
||||
# Flash Attention status
|
||||
if HAS_FA3:
|
||||
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
|
||||
else:
|
||||
print0("!" * 80)
|
||||
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
|
||||
print0("WARNING: Training will be less efficient without FA3")
|
||||
if args.window_pattern != "L":
|
||||
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
|
||||
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
|
||||
print0("!" * 80)
|
||||
|
||||
# Tokenizer will be useful for evaluation, also we need the vocab size
|
||||
tokenizer = get_tokenizer()
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
|
|
@ -91,21 +112,19 @@ vocab_size = tokenizer.get_vocab_size()
|
|||
print0(f"Vocab size: {vocab_size:,}")
|
||||
|
||||
# Model kwargs are derived from the desired depth of the model
|
||||
# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division
|
||||
# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
|
||||
# (For very small depths, this gives a slight "unfair" advantage to models with odd depths)
|
||||
num_layers = args.depth
|
||||
model_dim = args.depth * args.aspect_ratio
|
||||
def find_num_heads(model_dim, target_head_dim):
|
||||
# Find num_heads that divides model_dim evenly, with head_dim closest to target.
|
||||
ideal = max(1, round(model_dim / target_head_dim))
|
||||
for offset in range(model_dim):
|
||||
for candidate in [ideal + offset, ideal - offset]:
|
||||
if candidate > 0 and model_dim % candidate == 0:
|
||||
return candidate
|
||||
return 1
|
||||
num_heads = find_num_heads(model_dim, args.head_dim)
|
||||
base_dim = args.depth * args.aspect_ratio
|
||||
model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
|
||||
num_heads = model_dim // args.head_dim
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
head_dim = model_dim // num_heads
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim}")
|
||||
print0(f"model_dim: {model_dim} (base: {base_dim}, nudge: {model_dim - base_dim:+d})")
|
||||
print0(f"num_heads: {num_heads}")
|
||||
print0(f"head_dim: {head_dim}")
|
||||
print0(f"num_kv_heads: {num_kv_heads}")
|
||||
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
|
|
@ -129,11 +148,16 @@ if batch_ratio != 1.0:
|
|||
batch_lr_scale = batch_ratio ** 0.5
|
||||
print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})")
|
||||
|
||||
# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio)
|
||||
weight_decay_scaled = args.weight_decay * (12 / args.depth)**2
|
||||
if args.depth != 12:
|
||||
print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Model
|
||||
|
||||
# Create a new model with random weights
|
||||
model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
||||
model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, window_pattern=args.window_pattern)
|
||||
with torch.device("meta"):
|
||||
# All tensors are created as meta tensors (they have shape/dtype but no data)
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
|
|
@ -188,8 +212,9 @@ optimizers = model.setup_optimizers(
|
|||
unembedding_lr=args.unembedding_lr * batch_lr_scale,
|
||||
embedding_lr=args.embedding_lr * batch_lr_scale,
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
weight_decay=args.weight_decay,
|
||||
weight_decay=weight_decay_scaled,
|
||||
adam_betas=adam_betas,
|
||||
scalar_lr=args.scalar_lr * batch_lr_scale,
|
||||
)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
|
||||
|
|
@ -200,10 +225,9 @@ if resuming:
|
|||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the DataLoaders for train/val
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||
train_loader = tokenizing_distributed_data_loader_with_state(args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(args.device_batch_size, args.max_seq_len, split="val", device=device)
|
||||
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
|
||||
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -227,6 +251,10 @@ def get_muon_momentum(it):
|
|||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||||
return momentum
|
||||
|
||||
# Weight decay scheduler for Muon optimizer (linear to zero over the course of training)
|
||||
def get_weight_decay(it):
|
||||
return weight_decay_scaled * (1 - it / num_iterations)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Loop state (variables updated by the training loop)
|
||||
|
||||
|
|
@ -257,7 +285,7 @@ while True:
|
|||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
with autocast_ctx:
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
min_val_bpb = val_bpb
|
||||
wandb_run.log({
|
||||
|
|
@ -351,25 +379,27 @@ while True:
|
|||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
muon_momentum = get_muon_momentum(step)
|
||||
muon_weight_decay = get_weight_decay(step)
|
||||
for group in muon_optimizer.param_groups:
|
||||
group["momentum"] = muon_momentum
|
||||
group["weight_decay"] = muon_weight_decay
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# logging
|
||||
# logging (CPU action only)
|
||||
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(args.total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
# Calculate ETA based on average time per step (excluding first 10 steps)
|
||||
|
|
@ -381,7 +411,8 @@ while True:
|
|||
eta_str = f" | eta: {eta_seconds/60:.1f}m"
|
||||
else:
|
||||
eta_str = ""
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m{eta_str}")
|
||||
epoch = dataloader_state_dict["epoch"]
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
|
||||
if step % 100 == 0:
|
||||
log_data = {
|
||||
"step": step,
|
||||
|
|
@ -392,6 +423,7 @@ while True:
|
|||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
"train/epoch": epoch,
|
||||
}
|
||||
wandb_run.log(log_data)
|
||||
|
||||
|
|
@ -402,7 +434,7 @@ while True:
|
|||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
if val_bpb is not None:
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.6f}")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ All the generic code lives here, and all the evaluation-specific
|
|||
code lives in nanochat directory and is imported from here.
|
||||
|
||||
Example runs:
|
||||
python -m scripts.chat_eval -a ARC-Easy
|
||||
torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
|
||||
python -m scripts.chat_eval -i mid -a ARC-Easy
|
||||
torchrun --nproc_per_node=8 -m scripts.chat_eval -- -i mid -a ARC-Easy
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ simpler and more similar to just REINFORCE:
|
|||
|
||||
1) Delete trust region, so there is no KL regularization to a reference model
|
||||
2) We are on policy, so there's no need for PPO ratio+clip.
|
||||
3) We use GAPO style normalization that is token-level, not sequence-level.
|
||||
3) We use DAPO style normalization that is token-level, not sequence-level.
|
||||
4) Instead of z-score normalization (r - mu)/sigma, only use (r - mu) as the advantage.
|
||||
|
||||
1 GPU:
|
||||
|
|
@ -35,32 +35,32 @@ parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
|
|||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from")
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs over GSM8K")
|
||||
parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs over GSM8K")
|
||||
# Batch sizes / sampling
|
||||
parser.add_argument("--device_batch_size", type=int, default=8, help="max batch size per forward pass")
|
||||
parser.add_argument("--examples_per_step", type=int, default=16, help="total examples per optimization step across all ranks")
|
||||
parser.add_argument("--num_samples", type=int, default=16, help="number of samples per example/question")
|
||||
parser.add_argument("--device-batch-size", type=int, default=8, help="max batch size per forward pass")
|
||||
parser.add_argument("--examples-per-step", type=int, default=16, help="total examples per optimization step across all ranks")
|
||||
parser.add_argument("--num-samples", type=int, default=16, help="number of samples per example/question")
|
||||
# Generation
|
||||
parser.add_argument("--max_new_tokens", type=int, default=256, help="max tokens to generate per sample")
|
||||
parser.add_argument("--max-new-tokens", type=int, default=256, help="max tokens to generate per sample")
|
||||
parser.add_argument("--temperature", type=float, default=1.0, help="sampling temperature")
|
||||
parser.add_argument("--top_k", type=int, default=50, help="top-k sampling (0 = disabled)")
|
||||
parser.add_argument("--top-k", type=int, default=50, help="top-k sampling (0 = disabled)")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init_lr_frac", type=float, default=0.05, help="initial LR as fraction of base LR")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init-lr-frac", type=float, default=0.05, help="initial LR as fraction of base LR")
|
||||
# Evaluation / checkpointing
|
||||
parser.add_argument("--eval_every", type=int, default=60, help="evaluate pass@k every N steps")
|
||||
parser.add_argument("--eval_examples", type=int, default=400, help="number of examples for pass@k evaluation")
|
||||
parser.add_argument("--save_every", type=int, default=60, help="save checkpoint every N steps")
|
||||
parser.add_argument("--eval-every", type=int, default=60, help="evaluate pass@k every N steps")
|
||||
parser.add_argument("--eval-examples", type=int, default=400, help="number of examples for pass@k evaluation")
|
||||
parser.add_argument("--save-every", type=int, default=60, help="save checkpoint every N steps")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -37,29 +37,29 @@ parser = argparse.ArgumentParser(description="Supervised finetuning for chat")
|
|||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--source", type=str, default="mid", help="base|mid - which checkpoint to load from")
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs")
|
||||
parser.add_argument("--num_iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)")
|
||||
parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs")
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)")
|
||||
# Batch sizes
|
||||
parser.add_argument("--device_batch_size", type=int, default=4, help="per-device batch size")
|
||||
parser.add_argument("--target_examples_per_step", type=int, default=32, help="target examples per optimization step")
|
||||
parser.add_argument("--device-batch-size", type=int, default=4, help="per-device batch size")
|
||||
parser.add_argument("--target-examples-per-step", type=int, default=32, help="target examples per optimization step")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init_lr_frac", type=float, default=0.02, help="initial LR as fraction of base LR")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init-lr-frac", type=float, default=0.02, help="initial LR as fraction of base LR")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval_every", type=int, default=100, help="evaluate val loss every N steps")
|
||||
parser.add_argument("--eval_steps", type=int, default=100, help="number of batches for val loss evaluation")
|
||||
parser.add_argument("--eval_metrics_every", type=int, default=200, help="evaluate accuracy metrics every N steps")
|
||||
parser.add_argument("--eval_metrics_max_problems", type=int, default=1024, help="max problems per metric evaluation")
|
||||
parser.add_argument("--eval-every", type=int, default=100, help="evaluate val loss every N steps")
|
||||
parser.add_argument("--eval-steps", type=int, default=100, help="number of batches for val loss evaluation")
|
||||
parser.add_argument("--eval-metrics-every", type=int, default=200, help="evaluate accuracy metrics every N steps")
|
||||
parser.add_argument("--eval-metrics-max-problems", type=int, default=1024, help="max problems per metric evaluation")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -6,11 +6,10 @@ python -m scripts.mid_train
|
|||
|
||||
Or torchrun for training:
|
||||
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from collections import deque
|
||||
import os
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
|
|
@ -37,28 +36,28 @@ parser = argparse.ArgumentParser(description="Midtrain the model")
|
|||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num_iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
|
||||
# Batch sizes
|
||||
parser.add_argument("--max_seq_len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens")
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init_lr_frac", type=float, default=1.0, help="initial LR as fraction of base LR")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval_every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
# Output
|
||||
parser.add_argument("--dry_run", action="store_true", help="log to wandb but skip checkpoints/report")
|
||||
parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -80,7 +79,7 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mi
|
|||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
|
||||
pretrain_batch_size = meta.get("device_batch_size", None)
|
||||
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
|
|
@ -125,49 +124,100 @@ val_dataset = TaskMixture([
|
|||
# these two global variables and update them from within the data generator.
|
||||
last_step = False # we will toggle this to True when we reach the end of the training dataset
|
||||
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
|
||||
def mid_data_generator(split):
|
||||
global last_step, approx_progress
|
||||
current_epoch = 1 # track epoch for logging
|
||||
def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
||||
"""
|
||||
BOS-aligned dataloader for midtraining with bestfit-crop packing.
|
||||
|
||||
Each row in the batch starts with BOS (beginning of a conversation).
|
||||
Conversations are packed using best-fit algorithm to minimize cropping.
|
||||
This matches the BOS-aligned approach used in pretraining.
|
||||
"""
|
||||
global last_step, approx_progress, current_epoch
|
||||
assert split in {"train", "val"}, "split must be 'train' or 'val'"
|
||||
dataset = train_dataset if split == "train" else val_dataset
|
||||
dataset_size = len(dataset)
|
||||
assert dataset_size > 0
|
||||
needed_tokens = args.device_batch_size * args.max_seq_len + 1 # to form one training batch of inputs,targets
|
||||
token_buffer = deque()
|
||||
# CUDA supports memory pinning for faster transfers between CPU and GPU:
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
|
||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||
it = 0 # iteration counter
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding
|
||||
while len(token_buffer) < needed_tokens:
|
||||
row_capacity = args.max_seq_len + 1 # +1 for target at last position
|
||||
|
||||
# Conversation buffer: list of token lists
|
||||
conv_buffer = []
|
||||
cursor = ddp_rank # Each rank processes different conversations (for fetching)
|
||||
consumed = ddp_rank # Track actual consumption separately from buffering
|
||||
epoch = 1
|
||||
it = 0 # iteration counter
|
||||
|
||||
def refill_buffer():
|
||||
nonlocal cursor, epoch
|
||||
while len(conv_buffer) < buffer_size:
|
||||
conversation = dataset[cursor]
|
||||
ids, _ = tokenizer.render_conversation(conversation)
|
||||
token_buffer.extend(ids)
|
||||
conv_buffer.append(ids)
|
||||
cursor += ddp_world_size
|
||||
if cursor >= dataset_size:
|
||||
cursor -= dataset_size # wrap around for another epoch
|
||||
if split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
cursor = cursor % dataset_size
|
||||
epoch += 1
|
||||
# Note: last_step is now triggered based on consumption, not fetching
|
||||
|
||||
while True:
|
||||
rows = []
|
||||
for _ in range(args.device_batch_size):
|
||||
row = []
|
||||
while len(row) < row_capacity:
|
||||
# Ensure buffer has conversations
|
||||
while len(conv_buffer) < buffer_size:
|
||||
refill_buffer()
|
||||
|
||||
remaining = row_capacity - len(row)
|
||||
|
||||
# Find largest conversation that fits entirely
|
||||
best_idx = -1
|
||||
best_len = 0
|
||||
for i, conv in enumerate(conv_buffer):
|
||||
conv_len = len(conv)
|
||||
if conv_len <= remaining and conv_len > best_len:
|
||||
best_idx = i
|
||||
best_len = conv_len
|
||||
|
||||
if best_idx >= 0:
|
||||
# Found a conversation that fits - use it entirely
|
||||
conv = conv_buffer.pop(best_idx)
|
||||
row.extend(conv)
|
||||
consumed += ddp_world_size # Track actual consumption
|
||||
else:
|
||||
# No conversation fits - crop first conversation to fill remaining
|
||||
conv = conv_buffer.pop(0)
|
||||
row.extend(conv[:remaining])
|
||||
consumed += ddp_world_size # Track actual consumption
|
||||
|
||||
rows.append(row[:row_capacity])
|
||||
|
||||
# Stopping condition to respect num_iterations, if given
|
||||
it += 1
|
||||
if 0 < args.num_iterations <= it and split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Build up inputs/targets and yield
|
||||
for i in range(needed_tokens):
|
||||
scratch[i] = token_buffer.popleft()
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
inputs = inputs_cpu.view(args.device_batch_size, args.max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(args.device_batch_size, args.max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
last_step = True
|
||||
|
||||
# Update progress tracking (based on consumed, not cursor, to account for buffering)
|
||||
if split == "train":
|
||||
current_epoch = epoch
|
||||
if args.num_iterations > 0:
|
||||
approx_progress = it / args.num_iterations # calculate progress from the max number of iterations
|
||||
approx_progress = it / args.num_iterations
|
||||
else:
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
approx_progress = consumed / dataset_size
|
||||
# Trigger last_step when we've consumed enough (instead of when cursor wraps)
|
||||
if consumed >= dataset_size:
|
||||
last_step = True
|
||||
|
||||
# Build tensors
|
||||
use_cuda = device_type == "cuda"
|
||||
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
|
||||
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
|
||||
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
|
||||
|
||||
yield inputs, targets
|
||||
|
||||
train_loader = mid_data_generator("train")
|
||||
build_val_loader = lambda: mid_data_generator("val")
|
||||
train_loader = mid_data_generator_bos_bestfit("train")
|
||||
build_val_loader = lambda: mid_data_generator_bos_bestfit("val")
|
||||
progress = 0 # will go from 0 to 1 over the course of the epoch
|
||||
|
||||
# Learning rate scheduler
|
||||
|
|
@ -199,7 +249,7 @@ while True:
|
|||
last_step = bool(last_step_tensor.item())
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
|
||||
if last_step or (args.eval_every > 0 and step % args.eval_every == 0):
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
|
|
@ -285,7 +335,7 @@ while True:
|
|||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
|
||||
if step % 10 == 0:
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
|
|
@ -296,6 +346,7 @@ while True:
|
|||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
"train/epoch": current_epoch,
|
||||
})
|
||||
|
||||
# print a few more stats
|
||||
|
|
|
|||
|
|
@ -14,9 +14,9 @@ from nanochat.dataset import parquets_iter_batched
|
|||
# Parse command line arguments
|
||||
|
||||
parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
|
||||
parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
|
||||
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
|
||||
parser.add_argument('--vocab_size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)')
|
||||
parser.add_argument('--max-chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
|
||||
parser.add_argument('--doc-cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
|
||||
parser.add_argument('--vocab-size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)')
|
||||
args = parser.parse_args()
|
||||
print(f"max_chars: {args.max_chars:,}")
|
||||
print(f"doc_cap: {args.doc_cap:,}")
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class CustomJSON(Task):
|
|||
print("-" * 80)
|
||||
print(f"Warning: File {filepath} does not exist")
|
||||
print("HINT (Oct 21 2025)")
|
||||
print("If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations")
|
||||
print("If you recently did a git pull and suddenly see this, it might be due to the new addition of identity conversations")
|
||||
print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139")
|
||||
print("Quick fix: simply run the following command to download the file and you're done:")
|
||||
print(f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl")
|
||||
|
|
|
|||
338
tests/test_attention_fallback.py
Normal file
338
tests/test_attention_fallback.py
Normal file
|
|
@ -0,0 +1,338 @@
|
|||
"""
|
||||
Test Flash Attention unified interface - verify FA3 and SDPA produce identical results.
|
||||
|
||||
Run: python -m pytest tests/test_attention_fallback.py -v -s
|
||||
|
||||
Note on test structure:
|
||||
Tests are split into two classes due to dtype/device constraints:
|
||||
|
||||
1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs
|
||||
and verify they produce identical results. These require a Hopper GPU (FA3 only
|
||||
works on sm90+) and use bfloat16 (FA3 doesn't support float32).
|
||||
|
||||
2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run
|
||||
on any device (CUDA, CPU, MPS) with the appropriate dtype for that device.
|
||||
"""
|
||||
import torch
|
||||
import pytest
|
||||
import nanochat.flash_attention as fa_module
|
||||
from nanochat.flash_attention import flash_attn, HAS_FA3
|
||||
from nanochat.engine import KVCache
|
||||
|
||||
|
||||
def set_impl(impl):
|
||||
"""Set the implementation override ('fa3', 'sdpa', or None for auto)."""
|
||||
fa_module._override_impl = impl
|
||||
|
||||
|
||||
def run_both_impls(fn):
|
||||
"""Run a function with both FA3 and SDPA, return both outputs."""
|
||||
set_impl('fa3')
|
||||
out_fa3 = fn()
|
||||
set_impl('sdpa')
|
||||
out_sdpa = fn()
|
||||
set_impl(None) # reset
|
||||
return out_fa3, out_sdpa
|
||||
|
||||
|
||||
def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2):
|
||||
"""Assert two tensors are close, with helpful error message."""
|
||||
max_diff = (t1 - t2).abs().max().item()
|
||||
mean_diff = (t1 - t2).abs().mean().item()
|
||||
assert torch.allclose(t1, t2, atol=atol, rtol=rtol), \
|
||||
f"{name}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}"
|
||||
return max_diff, mean_diff
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FA3 vs SDPA comparison tests (require Hopper GPU)
|
||||
# =============================================================================
|
||||
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required to compare implementations")
|
||||
class TestFA3VsSDPA:
|
||||
"""Compare FA3 and SDPA produce identical results. Requires Hopper GPU."""
|
||||
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.bfloat16
|
||||
|
||||
def test_basic_causal(self):
|
||||
"""Basic causal attention."""
|
||||
B, T, H, D = 2, 64, 4, 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "basic_causal")
|
||||
print(f"basic_causal: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_full_context(self):
|
||||
"""Full context (window_size=-1)."""
|
||||
B, T, H, D = 2, 128, 4, 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "full_context")
|
||||
print(f"full_context: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_sliding_window(self):
|
||||
"""Sliding window attention."""
|
||||
B, T, H, D = 2, 128, 4, 32
|
||||
window = 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(window, 0))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "sliding_window")
|
||||
print(f"sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_gqa(self):
|
||||
"""Group Query Attention (fewer KV heads than Q heads)."""
|
||||
B, T, D = 2, 64, 32
|
||||
n_heads = 8
|
||||
n_kv_heads = 2
|
||||
|
||||
q = torch.randn(B, T, n_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "gqa")
|
||||
print(f"gqa: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_larger_model(self):
|
||||
"""Larger dimensions closer to real model."""
|
||||
B, T, H, D = 4, 256, 12, 64
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "larger_model")
|
||||
print(f"larger_model: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_kvcache_prefill(self):
|
||||
"""Test prefill (inserting multiple tokens into empty cache)."""
|
||||
B, T_max, H, D = 2, 64, 4, 32
|
||||
T_prefill = 16
|
||||
|
||||
q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
cache_seqlens = torch.zeros(B, dtype=torch.int32, device=self.DEVICE)
|
||||
return flash_attn.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k=k, v=v,
|
||||
cache_seqlens=cache_seqlens,
|
||||
causal=True, window_size=(T_max, 0)
|
||||
)
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "prefill")
|
||||
print(f"prefill: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_kvcache_single_token(self):
|
||||
"""Test single token generation (cache already has content)."""
|
||||
B, T_max, H, D = 2, 64, 4, 32
|
||||
T_prefill = 16
|
||||
|
||||
k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_cache[:, :T_prefill, :, :] = k_init
|
||||
v_cache[:, :T_prefill, :, :] = v_init
|
||||
cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE)
|
||||
return flash_attn.flash_attn_with_kvcache(
|
||||
q_single, k_cache, v_cache, k=k_single, v=v_single,
|
||||
cache_seqlens=cache_seqlens,
|
||||
causal=True, window_size=(T_max, 0)
|
||||
)
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token")
|
||||
print(f"single_token: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_backward_gradients_match(self):
|
||||
"""Verify gradients are similar between FA3 and SDPA."""
|
||||
B, T, H, D = 2, 32, 4, 16
|
||||
|
||||
q_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
q = q_data.clone().requires_grad_(True)
|
||||
k = k_data.clone().requires_grad_(True)
|
||||
v = v_data.clone().requires_grad_(True)
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
return y.detach(), q.grad.detach(), k.grad.detach(), v.grad.detach()
|
||||
|
||||
set_impl('fa3')
|
||||
y_fa3, q_grad_fa3, k_grad_fa3, v_grad_fa3 = run()
|
||||
set_impl('sdpa')
|
||||
y_sdpa, q_grad_sdpa, k_grad_sdpa, v_grad_sdpa = run()
|
||||
set_impl(None)
|
||||
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "backward_output")
|
||||
print(f"backward_output: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
max_diff, mean_diff = assert_close(q_grad_fa3, q_grad_sdpa, "q_grad", atol=0.05, rtol=0.05)
|
||||
print(f"q_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
max_diff, mean_diff = assert_close(k_grad_fa3, k_grad_sdpa, "k_grad", atol=0.05, rtol=0.05)
|
||||
print(f"k_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
max_diff, mean_diff = assert_close(v_grad_fa3, v_grad_sdpa, "v_grad", atol=0.05, rtol=0.05)
|
||||
print(f"v_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SDPA-only tests (run on any device)
|
||||
# =============================================================================
|
||||
class TestSDPAOnly:
|
||||
"""Test SDPA fallback works correctly. Runs on any device."""
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
||||
|
||||
def test_basic_forward(self):
|
||||
"""Test SDPA forward pass produces valid output."""
|
||||
set_impl('sdpa')
|
||||
B, T, H, D = 2, 64, 4, 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
|
||||
assert y.shape == (B, T, H, D)
|
||||
assert not torch.isnan(y).any(), "Output contains NaN"
|
||||
set_impl(None)
|
||||
|
||||
def test_backward(self):
|
||||
"""Test gradients flow through SDPA."""
|
||||
set_impl('sdpa')
|
||||
B, T, H, D = 2, 32, 4, 16
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
|
||||
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
|
||||
assert q.grad is not None, "No gradient for q"
|
||||
assert k.grad is not None, "No gradient for k"
|
||||
assert v.grad is not None, "No gradient for v"
|
||||
assert not torch.isnan(q.grad).any(), "NaN in q gradient"
|
||||
set_impl(None)
|
||||
|
||||
def test_kvcache(self):
|
||||
"""Test SDPA with KV cache."""
|
||||
set_impl('sdpa')
|
||||
B, T_max, H, D = 2, 64, 4, 32
|
||||
n_layers = 1
|
||||
|
||||
cache = KVCache(
|
||||
batch_size=B, num_heads=H, seq_len=T_max, head_dim=D,
|
||||
num_layers=n_layers, device=self.DEVICE, dtype=self.DTYPE
|
||||
)
|
||||
k_cache, v_cache = cache.get_layer_cache(0)
|
||||
|
||||
# Prefill
|
||||
T_prefill = 16
|
||||
q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
y = flash_attn.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k=k, v=v,
|
||||
cache_seqlens=cache.cache_seqlens,
|
||||
causal=True, window_size=(T_max, 0)
|
||||
)
|
||||
cache.advance(T_prefill)
|
||||
|
||||
assert y.shape == (B, T_prefill, H, D)
|
||||
assert cache.get_pos() == T_prefill
|
||||
|
||||
# Generate single token
|
||||
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
y_single = flash_attn.flash_attn_with_kvcache(
|
||||
q_single, k_cache, v_cache, k=k_single, v=v_single,
|
||||
cache_seqlens=cache.cache_seqlens,
|
||||
causal=True, window_size=(T_max, 0)
|
||||
)
|
||||
cache.advance(1)
|
||||
|
||||
assert y_single.shape == (B, 1, H, D)
|
||||
assert cache.get_pos() == T_prefill + 1
|
||||
set_impl(None)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Override mechanism tests
|
||||
# =============================================================================
|
||||
class TestOverrideMechanism:
|
||||
"""Test that the override mechanism works correctly."""
|
||||
|
||||
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required")
|
||||
def test_override_fa3(self):
|
||||
"""Test that override='fa3' uses FA3."""
|
||||
set_impl('fa3')
|
||||
assert fa_module._use_fa3() == True
|
||||
set_impl(None)
|
||||
|
||||
def test_override_sdpa(self):
|
||||
"""Test that override='sdpa' uses SDPA."""
|
||||
set_impl('sdpa')
|
||||
assert fa_module._use_fa3() == False
|
||||
set_impl(None)
|
||||
|
||||
def test_override_auto(self):
|
||||
"""Test that override=None uses auto-detection."""
|
||||
set_impl(None)
|
||||
assert fa_module._use_fa3() == HAS_FA3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"PyTorch version: {torch.__version__}")
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
print(f"CUDA device: {torch.cuda.get_device_name()}")
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
print(f"Compute capability: {major}.{minor}")
|
||||
print(f"HAS_FA3: {HAS_FA3}")
|
||||
print()
|
||||
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
|
@ -39,13 +39,9 @@ class MockModel:
|
|||
def forward(self, ids, kv_cache=None):
|
||||
"""Return uniform logits so sampling is spread across vocab."""
|
||||
B, T = ids.shape
|
||||
# Simulate what a real transformer does: insert k,v into the cache for each layer
|
||||
# With FA3, flash_attn_with_kvcache updates cache in-place and we advance position
|
||||
if kv_cache is not None:
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
for layer_idx in range(self.config.n_layer):
|
||||
k = torch.zeros(B, self.config.n_kv_head, T, head_dim)
|
||||
v = torch.zeros(B, self.config.n_kv_head, T, head_dim)
|
||||
kv_cache.insert_kv(layer_idx, k, v)
|
||||
kv_cache.advance(T)
|
||||
# Uniform logits -> equal probability for all tokens
|
||||
logits = torch.zeros(B, T, self.vocab_size)
|
||||
return logits
|
||||
|
|
@ -85,16 +81,11 @@ class ByteTokenizer:
|
|||
byte_tokens = [t for t in tokens if t < 256]
|
||||
return bytes(byte_tokens).decode("utf-8", errors="replace")
|
||||
|
||||
def test_kv_cache_resize():
|
||||
"""
|
||||
The KV cache was not resized correctly, more information here:
|
||||
https://github.com/karpathy/nanochat/pull/186
|
||||
This test reproduces the issue and will be merged alongside the fix.
|
||||
"""
|
||||
|
||||
def test_kv_cache_basic():
|
||||
"""Test basic KVCache functionality for FA3."""
|
||||
batch_size = 2
|
||||
num_heads = 3
|
||||
seq_len = 4
|
||||
seq_len = 64
|
||||
head_dim = 5
|
||||
num_layers = 6
|
||||
|
||||
|
|
@ -103,45 +94,65 @@ def test_kv_cache_resize():
|
|||
num_heads=num_heads,
|
||||
seq_len=seq_len,
|
||||
head_dim=head_dim,
|
||||
num_layers=num_layers
|
||||
num_layers=num_layers,
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
# Insert a single token with a distinct fill value to all layers
|
||||
def insert_token(token_idx):
|
||||
for layer_idx in range(num_layers):
|
||||
k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32)
|
||||
v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32)
|
||||
kv_cache.insert_kv(layer_idx, k, v)
|
||||
# Check initial state
|
||||
assert kv_cache.get_pos() == 0
|
||||
assert kv_cache.k_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim)
|
||||
assert kv_cache.v_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim)
|
||||
|
||||
# Insert 4 tokens (fills the initial seq_len=4)
|
||||
for i in range(4):
|
||||
insert_token(i)
|
||||
# Test advance
|
||||
kv_cache.advance(10)
|
||||
assert kv_cache.get_pos() == 10
|
||||
|
||||
# Record the original state of the cache
|
||||
original_cache = kv_cache.kv_cache.clone()
|
||||
original_seq_len = original_cache.shape[4]
|
||||
kv_cache.advance(5)
|
||||
assert kv_cache.get_pos() == 15
|
||||
|
||||
# Insert the 5th token, which will trigger a resize
|
||||
insert_token(4)
|
||||
# Verify that the cache actually resized
|
||||
new_seq_len = kv_cache.kv_cache.shape[4]
|
||||
assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
|
||||
# Test reset
|
||||
kv_cache.reset()
|
||||
assert kv_cache.get_pos() == 0
|
||||
|
||||
# Verify that the original 4 tokens are still intact after resize
|
||||
for layer_idx in range(num_layers):
|
||||
for token_idx in range(4):
|
||||
# Check that resized cache matches expected values
|
||||
expected_k = float(token_idx)
|
||||
expected_v = float(token_idx * 100)
|
||||
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :]
|
||||
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :]
|
||||
assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
|
||||
assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
|
||||
# And that the original cache matches resized cache
|
||||
original_k = original_cache[layer_idx, 0, :, :, token_idx, :]
|
||||
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
|
||||
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
|
||||
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
|
||||
# Test get_layer_cache returns correct views
|
||||
k_layer0, v_layer0 = kv_cache.get_layer_cache(0)
|
||||
assert k_layer0.shape == (batch_size, seq_len, num_heads, head_dim)
|
||||
assert v_layer0.shape == (batch_size, seq_len, num_heads, head_dim)
|
||||
|
||||
|
||||
def test_kv_cache_prefill():
|
||||
"""Test KVCache.prefill() copies data correctly."""
|
||||
batch_size = 1
|
||||
num_heads = 4
|
||||
head_dim = 8
|
||||
num_layers = 2
|
||||
|
||||
# Create source cache and advance it
|
||||
src_cache = KVCache(
|
||||
batch_size=batch_size, num_heads=num_heads, seq_len=32,
|
||||
head_dim=head_dim, num_layers=num_layers, device="cpu", dtype=torch.float32,
|
||||
)
|
||||
# Write some data to source cache
|
||||
src_cache.k_cache[0, 0, :16, :, :] = 1.0
|
||||
src_cache.v_cache[0, 0, :16, :, :] = 2.0
|
||||
src_cache.advance(16)
|
||||
|
||||
# Create destination cache with larger seq_len
|
||||
dst_cache = KVCache(
|
||||
batch_size=batch_size, num_heads=num_heads, seq_len=64,
|
||||
head_dim=head_dim, num_layers=num_layers, device="cpu", dtype=torch.float32,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
dst_cache.prefill(src_cache)
|
||||
|
||||
# Check position was copied
|
||||
assert dst_cache.get_pos() == 16
|
||||
|
||||
# Check data was copied
|
||||
assert (dst_cache.k_cache[0, 0, :16, :, :] == 1.0).all()
|
||||
assert (dst_cache.v_cache[0, 0, :16, :, :] == 2.0).all()
|
||||
|
||||
|
||||
def test_multi_sample_first_token_diversity():
|
||||
|
|
@ -185,3 +196,72 @@ def test_multi_sample_first_token_diversity():
|
|||
f"With uniform logits, this is statistically impossible (~10^-36 probability) "
|
||||
f"unless tokens are being broadcast instead of independently sampled."
|
||||
)
|
||||
|
||||
|
||||
def test_seed_reproducibility():
|
||||
"""Same seed must produce identical output."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
|
||||
|
||||
for seed in [1, 42, 123, 999]:
|
||||
r1, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
|
||||
r2, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
|
||||
r3, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed)
|
||||
assert r1 == r2 == r3, "Same seed must produce identical output for the same prompt."
|
||||
|
||||
|
||||
def test_temperature_zero_determinism():
|
||||
"""Temperature=0 is deterministic regardless of seed."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111]
|
||||
|
||||
r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=1)
|
||||
r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=42)
|
||||
r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=123)
|
||||
assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed."
|
||||
|
||||
|
||||
def test_max_tokens_respected():
|
||||
"""Generation stops at max_tokens limit."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111]
|
||||
|
||||
for max_tokens in [1, 4, 16, 64]:
|
||||
results, _ = engine.generate_batch(prompt, max_tokens=max_tokens)
|
||||
num_generated_tokens = len(results[0]) - len(prompt)
|
||||
assert num_generated_tokens <= max_tokens, f"Generated {num_generated_tokens} tokens, expected max_tokens={max_tokens} or less."
|
||||
|
||||
|
||||
def test_num_samples_count():
|
||||
"""num_samples=N produces exactly N sequences."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111]
|
||||
|
||||
for num_samples in [1, 4, 16, 64]:
|
||||
results, _ = engine.generate_batch(prompt, num_samples=num_samples, max_tokens=3)
|
||||
assert len(results) == num_samples, f"Expected {num_samples} sequences from {num_samples} samples, got {len(results)}"
|
||||
|
||||
|
||||
def test_different_seeds_introduce_variation_when_temperature_nonzero():
|
||||
"""With temperature > 0, different seeds should introduce sampling variation."""
|
||||
model = MockModel()
|
||||
engine = Engine(model, ByteTokenizer())
|
||||
prompt = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
|
||||
|
||||
outputs = set()
|
||||
|
||||
for seed in [1, 42, 123, 999, 1000, 1001, 1002, 1003, 1004, 1005]:
|
||||
results, _ = engine.generate_batch(
|
||||
prompt,
|
||||
temperature=1.0,
|
||||
max_tokens=5,
|
||||
seed=seed,
|
||||
)
|
||||
outputs.add(tuple(results[0]))
|
||||
|
||||
# Sanity check: sampling actually introduces variation
|
||||
assert len(outputs) > 1, "All seeds produced the same output which is statistically highly improbable."
|
||||
|
|
|
|||
109
uv.lock
109
uv.lock
|
|
@ -1089,6 +1089,21 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/e7/e7/80988e32bf6f73919a113473a604f5a8f09094de312b9d52b79c2df7612b/jupyter_core-5.9.1-py3-none-any.whl", hash = "sha256:ebf87fdc6073d142e114c72c9e29a9d7ca03fad818c5d300ce2adc1fb0743407", size = 29032, upload-time = "2025-10-16T19:19:16.783Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "kernels"
|
||||
version = "0.11.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "huggingface-hub" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d6/c8/2d4fea16366d34069af6d4c4f61218f55e5d0daea5d4c24d58849e9fd626/kernels-0.11.7.tar.gz", hash = "sha256:99c3aa518965518902f4dc26053d6051f06abc904ae33d9486c28674a2ea0fa5", size = 50282, upload-time = "2026-01-08T15:41:57.383Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/49/e62183353374ec71306ef354781233ac8d12fdfd1cf3d47c875055a99603/kernels-0.11.7-py3-none-any.whl", hash = "sha256:1421791b1e501fcb0a7f0a4d763c5385591756d9d6ed12ed8baa1e0d71bcd21a", size = 46501, upload-time = "2026-01-08T15:41:55.784Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "kiwisolver"
|
||||
version = "1.4.9"
|
||||
|
|
@ -1478,6 +1493,7 @@ dependencies = [
|
|||
{ name = "datasets" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "ipykernel" },
|
||||
{ name = "kernels" },
|
||||
{ name = "matplotlib" },
|
||||
{ name = "psutil" },
|
||||
{ name = "python-dotenv" },
|
||||
|
|
@ -1497,6 +1513,7 @@ dependencies = [
|
|||
{ name = "transformers" },
|
||||
{ name = "uvicorn" },
|
||||
{ name = "wandb" },
|
||||
{ name = "zstandard" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
|
|
@ -1518,6 +1535,7 @@ requires-dist = [
|
|||
{ name = "datasets", specifier = ">=4.0.0" },
|
||||
{ name = "fastapi", specifier = ">=0.117.1" },
|
||||
{ name = "ipykernel", specifier = ">=7.1.0" },
|
||||
{ name = "kernels", specifier = ">=0.11.7" },
|
||||
{ name = "matplotlib", specifier = ">=3.10.8" },
|
||||
{ name = "psutil", specifier = ">=7.1.0" },
|
||||
{ name = "python-dotenv", specifier = ">=1.2.1" },
|
||||
|
|
@ -1534,6 +1552,7 @@ requires-dist = [
|
|||
{ name = "transformers", specifier = ">=4.57.3" },
|
||||
{ name = "uvicorn", specifier = ">=0.36.0" },
|
||||
{ name = "wandb", specifier = ">=0.21.3" },
|
||||
{ name = "zstandard", specifier = ">=0.25.0" },
|
||||
]
|
||||
provides-extras = ["cpu", "gpu"]
|
||||
|
||||
|
|
@ -3602,3 +3621,93 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstandard"
|
||||
version = "0.25.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fd/aa/3e0508d5a5dd96529cdc5a97011299056e14c6505b678fd58938792794b1/zstandard-0.25.0.tar.gz", hash = "sha256:7713e1179d162cf5c7906da876ec2ccb9c3a9dcbdffef0cc7f70c3667a205f0b", size = 711513, upload-time = "2025-09-14T22:15:54.002Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/56/7a/28efd1d371f1acd037ac64ed1c5e2b41514a6cc937dd6ab6a13ab9f0702f/zstandard-0.25.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e59fdc271772f6686e01e1b3b74537259800f57e24280be3f29c8a0deb1904dd", size = 795256, upload-time = "2025-09-14T22:15:56.415Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/96/34/ef34ef77f1ee38fc8e4f9775217a613b452916e633c4f1d98f31db52c4a5/zstandard-0.25.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4d441506e9b372386a5271c64125f72d5df6d2a8e8a2a45a0ae09b03cb781ef7", size = 640565, upload-time = "2025-09-14T22:15:58.177Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/1b/4fdb2c12eb58f31f28c4d28e8dc36611dd7205df8452e63f52fb6261d13e/zstandard-0.25.0-cp310-cp310-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:ab85470ab54c2cb96e176f40342d9ed41e58ca5733be6a893b730e7af9c40550", size = 5345306, upload-time = "2025-09-14T22:16:00.165Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/73/28/a44bdece01bca027b079f0e00be3b6bd89a4df180071da59a3dd7381665b/zstandard-0.25.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e05ab82ea7753354bb054b92e2f288afb750e6b439ff6ca78af52939ebbc476d", size = 5055561, upload-time = "2025-09-14T22:16:02.22Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/74/68341185a4f32b274e0fc3410d5ad0750497e1acc20bd0f5b5f64ce17785/zstandard-0.25.0-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:78228d8a6a1c177a96b94f7e2e8d012c55f9c760761980da16ae7546a15a8e9b", size = 5402214, upload-time = "2025-09-14T22:16:04.109Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8b/67/f92e64e748fd6aaffe01e2b75a083c0c4fd27abe1c8747fee4555fcee7dd/zstandard-0.25.0-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:2b6bd67528ee8b5c5f10255735abc21aa106931f0dbaf297c7be0c886353c3d0", size = 5449703, upload-time = "2025-09-14T22:16:06.312Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/e5/6d36f92a197c3c17729a2125e29c169f460538a7d939a27eaaa6dcfcba8e/zstandard-0.25.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4b6d83057e713ff235a12e73916b6d356e3084fd3d14ced499d84240f3eecee0", size = 5556583, upload-time = "2025-09-14T22:16:08.457Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d7/83/41939e60d8d7ebfe2b747be022d0806953799140a702b90ffe214d557638/zstandard-0.25.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9174f4ed06f790a6869b41cba05b43eeb9a35f8993c4422ab853b705e8112bbd", size = 5045332, upload-time = "2025-09-14T22:16:10.444Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/87/d3ee185e3d1aa0133399893697ae91f221fda79deb61adbe998a7235c43f/zstandard-0.25.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:25f8f3cd45087d089aef5ba3848cd9efe3ad41163d3400862fb42f81a3a46701", size = 5572283, upload-time = "2025-09-14T22:16:12.128Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0a/1d/58635ae6104df96671076ac7d4ae7816838ce7debd94aecf83e30b7121b0/zstandard-0.25.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:3756b3e9da9b83da1796f8809dd57cb024f838b9eeafde28f3cb472012797ac1", size = 4959754, upload-time = "2025-09-14T22:16:14.225Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/75/d6/57e9cb0a9983e9a229dd8fd2e6e96593ef2aa82a3907188436f22b111ccd/zstandard-0.25.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:81dad8d145d8fd981b2962b686b2241d3a1ea07733e76a2f15435dfb7fb60150", size = 5266477, upload-time = "2025-09-14T22:16:16.343Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/a9/ee891e5edf33a6ebce0a028726f0bbd8567effe20fe3d5808c42323e8542/zstandard-0.25.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:a5a419712cf88862a45a23def0ae063686db3d324cec7edbe40509d1a79a0aab", size = 5440914, upload-time = "2025-09-14T22:16:18.453Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/08/a8522c28c08031a9521f27abc6f78dbdee7312a7463dd2cfc658b813323b/zstandard-0.25.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:e7360eae90809efd19b886e59a09dad07da4ca9ba096752e61a2e03c8aca188e", size = 5819847, upload-time = "2025-09-14T22:16:20.559Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6f/11/4c91411805c3f7b6f31c60e78ce347ca48f6f16d552fc659af6ec3b73202/zstandard-0.25.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:75ffc32a569fb049499e63ce68c743155477610532da1eb38e7f24bf7cd29e74", size = 5363131, upload-time = "2025-09-14T22:16:22.206Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ef/d6/8c4bd38a3b24c4c7676a7a3d8de85d6ee7a983602a734b9f9cdefb04a5d6/zstandard-0.25.0-cp310-cp310-win32.whl", hash = "sha256:106281ae350e494f4ac8a80470e66d1fe27e497052c8d9c3b95dc4cf1ade81aa", size = 436469, upload-time = "2025-09-14T22:16:25.002Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/93/90/96d50ad417a8ace5f841b3228e93d1bb13e6ad356737f42e2dde30d8bd68/zstandard-0.25.0-cp310-cp310-win_amd64.whl", hash = "sha256:ea9d54cc3d8064260114a0bbf3479fc4a98b21dffc89b3459edd506b69262f6e", size = 506100, upload-time = "2025-09-14T22:16:23.569Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/83/c3ca27c363d104980f1c9cee1101cc8ba724ac8c28a033ede6aab89585b1/zstandard-0.25.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:933b65d7680ea337180733cf9e87293cc5500cc0eb3fc8769f4d3c88d724ec5c", size = 795254, upload-time = "2025-09-14T22:16:26.137Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/4d/e66465c5411a7cf4866aeadc7d108081d8ceba9bc7abe6b14aa21c671ec3/zstandard-0.25.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a3f79487c687b1fc69f19e487cd949bf3aae653d181dfb5fde3bf6d18894706f", size = 640559, upload-time = "2025-09-14T22:16:27.973Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/12/56/354fe655905f290d3b147b33fe946b0f27e791e4b50a5f004c802cb3eb7b/zstandard-0.25.0-cp311-cp311-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:0bbc9a0c65ce0eea3c34a691e3c4b6889f5f3909ba4822ab385fab9057099431", size = 5348020, upload-time = "2025-09-14T22:16:29.523Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/13/2b7ed68bd85e69a2069bcc72141d378f22cae5a0f3b353a2c8f50ef30c1b/zstandard-0.25.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:01582723b3ccd6939ab7b3a78622c573799d5d8737b534b86d0e06ac18dbde4a", size = 5058126, upload-time = "2025-09-14T22:16:31.811Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/dd/fdaf0674f4b10d92cb120ccff58bbb6626bf8368f00ebfd2a41ba4a0dc99/zstandard-0.25.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:5f1ad7bf88535edcf30038f6919abe087f606f62c00a87d7e33e7fc57cb69fcc", size = 5405390, upload-time = "2025-09-14T22:16:33.486Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0f/67/354d1555575bc2490435f90d67ca4dd65238ff2f119f30f72d5cde09c2ad/zstandard-0.25.0-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:06acb75eebeedb77b69048031282737717a63e71e4ae3f77cc0c3b9508320df6", size = 5452914, upload-time = "2025-09-14T22:16:35.277Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/1f/e9cfd801a3f9190bf3e759c422bbfd2247db9d7f3d54a56ecde70137791a/zstandard-0.25.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9300d02ea7c6506f00e627e287e0492a5eb0371ec1670ae852fefffa6164b072", size = 5559635, upload-time = "2025-09-14T22:16:37.141Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/21/88/5ba550f797ca953a52d708c8e4f380959e7e3280af029e38fbf47b55916e/zstandard-0.25.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bfd06b1c5584b657a2892a6014c2f4c20e0db0208c159148fa78c65f7e0b0277", size = 5048277, upload-time = "2025-09-14T22:16:38.807Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/46/c0/ca3e533b4fa03112facbe7fbe7779cb1ebec215688e5df576fe5429172e0/zstandard-0.25.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f373da2c1757bb7f1acaf09369cdc1d51d84131e50d5fa9863982fd626466313", size = 5574377, upload-time = "2025-09-14T22:16:40.523Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/12/9b/3fb626390113f272abd0799fd677ea33d5fc3ec185e62e6be534493c4b60/zstandard-0.25.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6c0e5a65158a7946e7a7affa6418878ef97ab66636f13353b8502d7ea03c8097", size = 4961493, upload-time = "2025-09-14T22:16:43.3Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/d3/23094a6b6a4b1343b27ae68249daa17ae0651fcfec9ed4de09d14b940285/zstandard-0.25.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:c8e167d5adf59476fa3e37bee730890e389410c354771a62e3c076c86f9f7778", size = 5269018, upload-time = "2025-09-14T22:16:45.292Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/a7/bb5a0c1c0f3f4b5e9d5b55198e39de91e04ba7c205cc46fcb0f95f0383c1/zstandard-0.25.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:98750a309eb2f020da61e727de7d7ba3c57c97cf6213f6f6277bb7fb42a8e065", size = 5443672, upload-time = "2025-09-14T22:16:47.076Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/22/503347aa08d073993f25109c36c8d9f029c7d5949198050962cb568dfa5e/zstandard-0.25.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:22a086cff1b6ceca18a8dd6096ec631e430e93a8e70a9ca5efa7561a00f826fa", size = 5822753, upload-time = "2025-09-14T22:16:49.316Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e2/be/94267dc6ee64f0f8ba2b2ae7c7a2df934a816baaa7291db9e1aa77394c3c/zstandard-0.25.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:72d35d7aa0bba323965da807a462b0966c91608ef3a48ba761678cb20ce5d8b7", size = 5366047, upload-time = "2025-09-14T22:16:51.328Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/a3/732893eab0a3a7aecff8b99052fecf9f605cf0fb5fb6d0290e36beee47a4/zstandard-0.25.0-cp311-cp311-win32.whl", hash = "sha256:f5aeea11ded7320a84dcdd62a3d95b5186834224a9e55b92ccae35d21a8b63d4", size = 436484, upload-time = "2025-09-14T22:16:55.005Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/a3/c6155f5c1cce691cb80dfd38627046e50af3ee9ddc5d0b45b9b063bfb8c9/zstandard-0.25.0-cp311-cp311-win_amd64.whl", hash = "sha256:daab68faadb847063d0c56f361a289c4f268706b598afbf9ad113cbe5c38b6b2", size = 506183, upload-time = "2025-09-14T22:16:52.753Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/3e/8945ab86a0820cc0e0cdbf38086a92868a9172020fdab8a03ac19662b0e5/zstandard-0.25.0-cp311-cp311-win_arm64.whl", hash = "sha256:22a06c5df3751bb7dc67406f5374734ccee8ed37fc5981bf1ad7041831fa1137", size = 462533, upload-time = "2025-09-14T22:16:53.878Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/fc/f26eb6ef91ae723a03e16eddb198abcfce2bc5a42e224d44cc8b6765e57e/zstandard-0.25.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7b3c3a3ab9daa3eed242d6ecceead93aebbb8f5f84318d82cee643e019c4b73b", size = 795738, upload-time = "2025-09-14T22:16:56.237Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/aa/1c/d920d64b22f8dd028a8b90e2d756e431a5d86194caa78e3819c7bf53b4b3/zstandard-0.25.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:913cbd31a400febff93b564a23e17c3ed2d56c064006f54efec210d586171c00", size = 640436, upload-time = "2025-09-14T22:16:57.774Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/53/6c/288c3f0bd9fcfe9ca41e2c2fbfd17b2097f6af57b62a81161941f09afa76/zstandard-0.25.0-cp312-cp312-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:011d388c76b11a0c165374ce660ce2c8efa8e5d87f34996aa80f9c0816698b64", size = 5343019, upload-time = "2025-09-14T22:16:59.302Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/15/efef5a2f204a64bdb5571e6161d49f7ef0fffdbca953a615efbec045f60f/zstandard-0.25.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6dffecc361d079bb48d7caef5d673c88c8988d3d33fb74ab95b7ee6da42652ea", size = 5063012, upload-time = "2025-09-14T22:17:01.156Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/37/a6ce629ffdb43959e92e87ebdaeebb5ac81c944b6a75c9c47e300f85abdf/zstandard-0.25.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:7149623bba7fdf7e7f24312953bcf73cae103db8cae49f8154dd1eadc8a29ecb", size = 5394148, upload-time = "2025-09-14T22:17:03.091Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/79/2bf870b3abeb5c070fe2d670a5a8d1057a8270f125ef7676d29ea900f496/zstandard-0.25.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:6a573a35693e03cf1d67799fd01b50ff578515a8aeadd4595d2a7fa9f3ec002a", size = 5451652, upload-time = "2025-09-14T22:17:04.979Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/53/60/7be26e610767316c028a2cbedb9a3beabdbe33e2182c373f71a1c0b88f36/zstandard-0.25.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5a56ba0db2d244117ed744dfa8f6f5b366e14148e00de44723413b2f3938a902", size = 5546993, upload-time = "2025-09-14T22:17:06.781Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/85/c7/3483ad9ff0662623f3648479b0380d2de5510abf00990468c286c6b04017/zstandard-0.25.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:10ef2a79ab8e2974e2075fb984e5b9806c64134810fac21576f0668e7ea19f8f", size = 5046806, upload-time = "2025-09-14T22:17:08.415Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/08/b3/206883dd25b8d1591a1caa44b54c2aad84badccf2f1de9e2d60a446f9a25/zstandard-0.25.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aaf21ba8fb76d102b696781bddaa0954b782536446083ae3fdaa6f16b25a1c4b", size = 5576659, upload-time = "2025-09-14T22:17:10.164Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/31/76c0779101453e6c117b0ff22565865c54f48f8bd807df2b00c2c404b8e0/zstandard-0.25.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1869da9571d5e94a85a5e8d57e4e8807b175c9e4a6294e3b66fa4efb074d90f6", size = 4953933, upload-time = "2025-09-14T22:17:11.857Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/18/e1/97680c664a1bf9a247a280a053d98e251424af51f1b196c6d52f117c9720/zstandard-0.25.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:809c5bcb2c67cd0ed81e9229d227d4ca28f82d0f778fc5fea624a9def3963f91", size = 5268008, upload-time = "2025-09-14T22:17:13.627Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/73/316e4010de585ac798e154e88fd81bb16afc5c5cb1a72eeb16dd37e8024a/zstandard-0.25.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:f27662e4f7dbf9f9c12391cb37b4c4c3cb90ffbd3b1fb9284dadbbb8935fa708", size = 5433517, upload-time = "2025-09-14T22:17:16.103Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/60/dd0f8cfa8129c5a0ce3ea6b7f70be5b33d2618013a161e1ff26c2b39787c/zstandard-0.25.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:99c0c846e6e61718715a3c9437ccc625de26593fea60189567f0118dc9db7512", size = 5814292, upload-time = "2025-09-14T22:17:17.827Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/5f/75aafd4b9d11b5407b641b8e41a57864097663699f23e9ad4dbb91dc6bfe/zstandard-0.25.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:474d2596a2dbc241a556e965fb76002c1ce655445e4e3bf38e5477d413165ffa", size = 5360237, upload-time = "2025-09-14T22:17:19.954Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/8d/0309daffea4fcac7981021dbf21cdb2e3427a9e76bafbcdbdf5392ff99a4/zstandard-0.25.0-cp312-cp312-win32.whl", hash = "sha256:23ebc8f17a03133b4426bcc04aabd68f8236eb78c3760f12783385171b0fd8bd", size = 436922, upload-time = "2025-09-14T22:17:24.398Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/3b/fa54d9015f945330510cb5d0b0501e8253c127cca7ebe8ba46a965df18c5/zstandard-0.25.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffef5a74088f1e09947aecf91011136665152e0b4b359c42be3373897fb39b01", size = 506276, upload-time = "2025-09-14T22:17:21.429Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ea/6b/8b51697e5319b1f9ac71087b0af9a40d8a6288ff8025c36486e0c12abcc4/zstandard-0.25.0-cp312-cp312-win_arm64.whl", hash = "sha256:181eb40e0b6a29b3cd2849f825e0fa34397f649170673d385f3598ae17cca2e9", size = 462679, upload-time = "2025-09-14T22:17:23.147Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/35/0b/8df9c4ad06af91d39e94fa96cc010a24ac4ef1378d3efab9223cc8593d40/zstandard-0.25.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ec996f12524f88e151c339688c3897194821d7f03081ab35d31d1e12ec975e94", size = 795735, upload-time = "2025-09-14T22:17:26.042Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/06/9ae96a3e5dcfd119377ba33d4c42a7d89da1efabd5cb3e366b156c45ff4d/zstandard-0.25.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a1a4ae2dec3993a32247995bdfe367fc3266da832d82f8438c8570f989753de1", size = 640440, upload-time = "2025-09-14T22:17:27.366Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/14/933d27204c2bd404229c69f445862454dcc101cd69ef8c6068f15aaec12c/zstandard-0.25.0-cp313-cp313-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:e96594a5537722fdfb79951672a2a63aec5ebfb823e7560586f7484819f2a08f", size = 5343070, upload-time = "2025-09-14T22:17:28.896Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6d/db/ddb11011826ed7db9d0e485d13df79b58586bfdec56e5c84a928a9a78c1c/zstandard-0.25.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bfc4e20784722098822e3eee42b8e576b379ed72cca4a7cb856ae733e62192ea", size = 5063001, upload-time = "2025-09-14T22:17:31.044Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/db/00/87466ea3f99599d02a5238498b87bf84a6348290c19571051839ca943777/zstandard-0.25.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:457ed498fc58cdc12fc48f7950e02740d4f7ae9493dd4ab2168a47c93c31298e", size = 5394120, upload-time = "2025-09-14T22:17:32.711Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2b/95/fc5531d9c618a679a20ff6c29e2b3ef1d1f4ad66c5e161ae6ff847d102a9/zstandard-0.25.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:fd7a5004eb1980d3cefe26b2685bcb0b17989901a70a1040d1ac86f1d898c551", size = 5451230, upload-time = "2025-09-14T22:17:34.41Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/63/4b/e3678b4e776db00f9f7b2fe58e547e8928ef32727d7a1ff01dea010f3f13/zstandard-0.25.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8e735494da3db08694d26480f1493ad2cf86e99bdd53e8e9771b2752a5c0246a", size = 5547173, upload-time = "2025-09-14T22:17:36.084Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4e/d5/ba05ed95c6b8ec30bd468dfeab20589f2cf709b5c940483e31d991f2ca58/zstandard-0.25.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3a39c94ad7866160a4a46d772e43311a743c316942037671beb264e395bdd611", size = 5046736, upload-time = "2025-09-14T22:17:37.891Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/50/d5/870aa06b3a76c73eced65c044b92286a3c4e00554005ff51962deef28e28/zstandard-0.25.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:172de1f06947577d3a3005416977cce6168f2261284c02080e7ad0185faeced3", size = 5576368, upload-time = "2025-09-14T22:17:40.206Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5d/35/398dc2ffc89d304d59bc12f0fdd931b4ce455bddf7038a0a67733a25f550/zstandard-0.25.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:3c83b0188c852a47cd13ef3bf9209fb0a77fa5374958b8c53aaa699398c6bd7b", size = 4954022, upload-time = "2025-09-14T22:17:41.879Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9a/5c/36ba1e5507d56d2213202ec2b05e8541734af5f2ce378c5d1ceaf4d88dc4/zstandard-0.25.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:1673b7199bbe763365b81a4f3252b8e80f44c9e323fc42940dc8843bfeaf9851", size = 5267889, upload-time = "2025-09-14T22:17:43.577Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/70/e8/2ec6b6fb7358b2ec0113ae202647ca7c0e9d15b61c005ae5225ad0995df5/zstandard-0.25.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:0be7622c37c183406f3dbf0cba104118eb16a4ea7359eeb5752f0794882fc250", size = 5433952, upload-time = "2025-09-14T22:17:45.271Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/01/b5f4d4dbc59ef193e870495c6f1275f5b2928e01ff5a81fecb22a06e22fb/zstandard-0.25.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:5f5e4c2a23ca271c218ac025bd7d635597048b366d6f31f420aaeb715239fc98", size = 5814054, upload-time = "2025-09-14T22:17:47.08Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/e5/fbd822d5c6f427cf158316d012c5a12f233473c2f9c5fe5ab1ae5d21f3d8/zstandard-0.25.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4f187a0bb61b35119d1926aee039524d1f93aaf38a9916b8c4b78ac8514a0aaf", size = 5360113, upload-time = "2025-09-14T22:17:48.893Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/e0/69a553d2047f9a2c7347caa225bb3a63b6d7704ad74610cb7823baa08ed7/zstandard-0.25.0-cp313-cp313-win32.whl", hash = "sha256:7030defa83eef3e51ff26f0b7bfb229f0204b66fe18e04359ce3474ac33cbc09", size = 436936, upload-time = "2025-09-14T22:17:52.658Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/82/b9c06c870f3bd8767c201f1edbdf9e8dc34be5b0fbc5682c4f80fe948475/zstandard-0.25.0-cp313-cp313-win_amd64.whl", hash = "sha256:1f830a0dac88719af0ae43b8b2d6aef487d437036468ef3c2ea59c51f9d55fd5", size = 506232, upload-time = "2025-09-14T22:17:50.402Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/57/60c3c01243bb81d381c9916e2a6d9e149ab8627c0c7d7abb2d73384b3c0c/zstandard-0.25.0-cp313-cp313-win_arm64.whl", hash = "sha256:85304a43f4d513f5464ceb938aa02c1e78c2943b29f44a750b48b25ac999a049", size = 462671, upload-time = "2025-09-14T22:17:51.533Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/5c/f8923b595b55fe49e30612987ad8bf053aef555c14f05bb659dd5dbe3e8a/zstandard-0.25.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:e29f0cf06974c899b2c188ef7f783607dbef36da4c242eb6c82dcd8b512855e3", size = 795887, upload-time = "2025-09-14T22:17:54.198Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8d/09/d0a2a14fc3439c5f874042dca72a79c70a532090b7ba0003be73fee37ae2/zstandard-0.25.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:05df5136bc5a011f33cd25bc9f506e7426c0c9b3f9954f056831ce68f3b6689f", size = 640658, upload-time = "2025-09-14T22:17:55.423Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5d/7c/8b6b71b1ddd517f68ffb55e10834388d4f793c49c6b83effaaa05785b0b4/zstandard-0.25.0-cp314-cp314-manylinux2010_i686.manylinux_2_12_i686.manylinux_2_28_i686.whl", hash = "sha256:f604efd28f239cc21b3adb53eb061e2a205dc164be408e553b41ba2ffe0ca15c", size = 5379849, upload-time = "2025-09-14T22:17:57.372Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a4/86/a48e56320d0a17189ab7a42645387334fba2200e904ee47fc5a26c1fd8ca/zstandard-0.25.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:223415140608d0f0da010499eaa8ccdb9af210a543fac54bce15babbcfc78439", size = 5058095, upload-time = "2025-09-14T22:17:59.498Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/ad/eb659984ee2c0a779f9d06dbfe45e2dc39d99ff40a319895df2d3d9a48e5/zstandard-0.25.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2e54296a283f3ab5a26fc9b8b5d4978ea0532f37b231644f367aa588930aa043", size = 5551751, upload-time = "2025-09-14T22:18:01.618Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/61/b3/b637faea43677eb7bd42ab204dfb7053bd5c4582bfe6b1baefa80ac0c47b/zstandard-0.25.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ca54090275939dc8ec5dea2d2afb400e0f83444b2fc24e07df7fdef677110859", size = 6364818, upload-time = "2025-09-14T22:18:03.769Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/dc/cc50210e11e465c975462439a492516a73300ab8caa8f5e0902544fd748b/zstandard-0.25.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e09bb6252b6476d8d56100e8147b803befa9a12cea144bbe629dd508800d1ad0", size = 5560402, upload-time = "2025-09-14T22:18:05.954Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/ae/56523ae9c142f0c08efd5e868a6da613ae76614eca1305259c3bf6a0ed43/zstandard-0.25.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:a9ec8c642d1ec73287ae3e726792dd86c96f5681eb8df274a757bf62b750eae7", size = 4955108, upload-time = "2025-09-14T22:18:07.68Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/cf/c899f2d6df0840d5e384cf4c4121458c72802e8bda19691f3b16619f51e9/zstandard-0.25.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a4089a10e598eae6393756b036e0f419e8c1d60f44a831520f9af41c14216cf2", size = 5269248, upload-time = "2025-09-14T22:18:09.753Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1b/c0/59e912a531d91e1c192d3085fc0f6fb2852753c301a812d856d857ea03c6/zstandard-0.25.0-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:f67e8f1a324a900e75b5e28ffb152bcac9fbed1cc7b43f99cd90f395c4375344", size = 5430330, upload-time = "2025-09-14T22:18:11.966Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/1d/7e31db1240de2df22a58e2ea9a93fc6e38cc29353e660c0272b6735d6669/zstandard-0.25.0-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:9654dbc012d8b06fc3d19cc825af3f7bf8ae242226df5f83936cb39f5fdc846c", size = 5811123, upload-time = "2025-09-14T22:18:13.907Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/49/fac46df5ad353d50535e118d6983069df68ca5908d4d65b8c466150a4ff1/zstandard-0.25.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:4203ce3b31aec23012d3a4cf4a2ed64d12fea5269c49aed5e4c3611b938e4088", size = 5359591, upload-time = "2025-09-14T22:18:16.465Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/38/f249a2050ad1eea0bb364046153942e34abba95dd5520af199aed86fbb49/zstandard-0.25.0-cp314-cp314-win32.whl", hash = "sha256:da469dc041701583e34de852d8634703550348d5822e66a0c827d39b05365b12", size = 444513, upload-time = "2025-09-14T22:18:20.61Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3a/43/241f9615bcf8ba8903b3f0432da069e857fc4fd1783bd26183db53c4804b/zstandard-0.25.0-cp314-cp314-win_amd64.whl", hash = "sha256:c19bcdd826e95671065f8692b5a4aa95c52dc7a02a4c5a0cac46deb879a017a2", size = 516118, upload-time = "2025-09-14T22:18:17.849Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f0/ef/da163ce2450ed4febf6467d77ccb4cd52c4c30ab45624bad26ca0a27260c/zstandard-0.25.0-cp314-cp314-win_arm64.whl", hash = "sha256:d7541afd73985c630bafcd6338d2518ae96060075f9463d7dc14cfb33514383d", size = 476940, upload-time = "2025-09-14T22:18:19.088Z" },
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user