mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-24 04:14:27 +00:00
Merge branch 'master' into master_goderr
This commit is contained in:
commit
65865df300
11
.gitignore
vendored
11
.gitignore
vendored
|
|
@ -1,7 +1,14 @@
|
|||
.venv/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
rustbpe/target/
|
||||
dev-ignore/
|
||||
report.md
|
||||
eval_bundle/
|
||||
eval_bundle/
|
||||
|
||||
# Secrets
|
||||
.env
|
||||
|
||||
# Local setup
|
||||
.claude
|
||||
CLAUDE.md
|
||||
wandb/
|
||||
|
|
|
|||
18
README.md
18
README.md
|
|
@ -10,6 +10,10 @@ This repo is a full-stack implementation of an LLM like ChatGPT in a single, cle
|
|||
|
||||
To get a sense of the endpoint of this repo, you can currently find [nanochat d34](https://github.com/karpathy/nanochat/discussions/314) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d34" means that this model has 34 layers in the Transformer neural network. This model has 2.2 billion parameters, it was trained on 88 billion tokens by simply running the training script [run1000.sh](run1000.sh) with `--target_param_data_ratio=40` (2x longer than Chinchilla-optimal), and the total cost of training was ~$2,500 (about 100 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of modern Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
|
||||
|
||||
## Updates
|
||||
|
||||
- (Jan 7 2026) See new post: [nanochat Miniseries v1](https://github.com/karpathy/nanochat/discussions/420) and the associated script [miniseries.sh](miniseries.sh).
|
||||
|
||||
## Quick start
|
||||
|
||||
The fastest way to feel the magic is to run the speedrun script [speedrun.sh](speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
|
||||
|
|
@ -108,10 +112,10 @@ Additionally, to add new abilities to nanochat, see [Guide: counting r in strawb
|
|||
nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so:
|
||||
|
||||
```bash
|
||||
files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml > packaged.txt
|
||||
files-to-prompt . -e py -e md -e html -e toml -e sh --cxml > packaged.txt
|
||||
```
|
||||
|
||||
This includes all py, rs, html, toml, sh files, excludes the `rustbpe/target` folder, and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files.
|
||||
This includes all py, html, toml, sh files and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files.
|
||||
|
||||
Alternatively, I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
|
||||
|
||||
|
|
@ -120,7 +124,7 @@ Alternatively, I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanoch
|
|||
I haven't invested too much here but some tests exist, especially for the tokenizer. Run e.g. as:
|
||||
|
||||
```bash
|
||||
python -m pytest tests/test_rustbpe.py -v -s
|
||||
python -m pytest tests/test_engine.py -v -s
|
||||
```
|
||||
|
||||
## File structure
|
||||
|
|
@ -140,7 +144,6 @@ python -m pytest tests/test_rustbpe.py -v -s
|
|||
│ ├── adamw.py # Distributed AdamW optimizer
|
||||
│ ├── checkpoint_manager.py # Save/Load model checkpoints
|
||||
│ ├── common.py # Misc small utilities, quality of life
|
||||
│ ├── configurator.py # A superior alternative to argparse
|
||||
│ ├── core_eval.py # Evaluates base model CORE score (DCLM paper)
|
||||
│ ├── dataloader.py # Tokenizing Distributed Data Loader
|
||||
│ ├── dataset.py # Download/read utils for pretraining data
|
||||
|
|
@ -155,12 +158,6 @@ python -m pytest tests/test_rustbpe.py -v -s
|
|||
│ └── ui.html # HTML/CSS/JS for nanochat frontend
|
||||
├── pyproject.toml
|
||||
├── run1000.sh # Train the ~$800 nanochat d32
|
||||
├── rustbpe # Custom Rust BPE tokenizer trainer
|
||||
│ ├── Cargo.lock
|
||||
│ ├── Cargo.toml
|
||||
│ ├── README.md # see for why this even exists
|
||||
│ └── src
|
||||
│ └── lib.rs
|
||||
├── scripts
|
||||
│ ├── base_eval.py # Base model: calculate CORE score
|
||||
│ ├── base_loss.py # Base model: calculate bits per byte, sample
|
||||
|
|
@ -185,7 +182,6 @@ python -m pytest tests/test_rustbpe.py -v -s
|
|||
│ └── spellingbee.py # Task teaching model to spell/count letters
|
||||
├── tests
|
||||
│ └── test_engine.py
|
||||
│ └── test_rustbpe.py
|
||||
└── uv.lock
|
||||
```
|
||||
|
||||
|
|
|
|||
381
dev/LOG.md
Normal file
381
dev/LOG.md
Normal file
|
|
@ -0,0 +1,381 @@
|
|||
# Experiment Log
|
||||
|
||||
A running summary documenting some experiments and findings. Started ~Jan 7 2026.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-13: Varlen Attention (Negative Result)
|
||||
|
||||
Attempted to prevent attention from "leaking" across document boundaries using Flash Attention's `flash_attn_varlen_func`, similar to modded-nanogpt's approach.
|
||||
|
||||
### Background
|
||||
|
||||
With the BOS-aligned dataloader, multiple documents are packed into each row. Standard attention allows tokens to attend across document boundaries within a row. The hypothesis was that preventing this "leakage" via varlen attention might improve training.
|
||||
|
||||
### Approach: Compute cu_seqlens from inputs
|
||||
|
||||
- Find BOS positions: `(inputs.view(-1) == bos_token_id).nonzero()`
|
||||
- Gotcha 1: Variable-length `cu_seqlens` caused torch.compile recompilation (25s/iter!) - fixed by padding to fixed size
|
||||
- Gotcha 2: `nonzero()` inside compiled model hit recompile limit - fixed by moving computation outside compiled region
|
||||
|
||||
### Final Results (d16)
|
||||
|
||||
| Metric | Baseline | Varlen |
|
||||
|--------|----------|--------|
|
||||
| val_bpb | 0.85427 | 0.85407 |
|
||||
| MFU | ~same | ~same |
|
||||
| tok/sec | ~same | ~same |
|
||||
|
||||
Essentially identical. The 0.0002 bpb improvement is almost noise.
|
||||
|
||||
### Conclusion
|
||||
|
||||
Not worth the code complexity. The "leakage" across document boundaries within a row is not harmful - the model handles it fine. The BOS-aligned dataloader already provides the key benefit (every row starts with proper context). Not merging to master.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-13: BOS-Aligned Dataloader with Bin Packing
|
||||
|
||||
Redesigned the pretraining and midtraining dataloader to ensure every sequence starts with a BOS token, and explored bin-packing algorithms to minimize wasted tokens.
|
||||
|
||||
### Problem Statement
|
||||
|
||||
The original dataloader streams tokens into a flat buffer and reshapes into batches. This means some rows start mid-document (no BOS), which could confuse the model during training. We want every row to start with BOS and contain well-formed documents.
|
||||
|
||||
### Approach 1: Greedy-Crop BOS (Simple)
|
||||
|
||||
Each row is built independently:
|
||||
- Start with a document (which has BOS prepended)
|
||||
- Pack more documents until row is full
|
||||
- If a document doesn't fit, **crop it** to fill remaining space (discard the rest)
|
||||
- 100% utilization (no padding), but wastes cropped tokens
|
||||
|
||||
### Waste Analysis
|
||||
|
||||
Measured token waste empirically on real data (T=2048):
|
||||
- **39.4% of tokens are cropped** (discarded when docs don't fit)
|
||||
- **22.9% is the theoretical minimum** (tokens in docs longer than T+1 that can never fit)
|
||||
- The extra ~16.5% comes from "unlucky" cropping when a long doc starts near the end of a row
|
||||
|
||||
### Bin Packing Algorithms Explored
|
||||
|
||||
| Algorithm | Util% | Crop% | Pad% | Notes |
|
||||
|-----------|-------|-------|------|-------|
|
||||
| Greedy-Crop (baseline) | 100% | 39.4% | 0% | Simple, no wasted compute |
|
||||
| Greedy-Pad | 78% | 23.0% | 22% | Pads instead of crops - wastes compute |
|
||||
| First-Fit Decreasing (FFD) | 99.7% | 23.0% | 0.3% | Near-optimal packing, minimal padding |
|
||||
| **BestFit-Crop** | 100% | 34.6% | 0% | Smart cropping, no padding |
|
||||
|
||||
### BestFit-Crop Algorithm
|
||||
|
||||
A middle ground that maintains 100% utilization while reducing cropping:
|
||||
|
||||
1. Buffer N documents
|
||||
2. For each row, greedily pick the **largest doc that fits entirely**
|
||||
3. Repeat until nothing fits
|
||||
4. When nothing fits, crop a doc to fill remaining space exactly
|
||||
|
||||
This avoids "unlucky" crops by searching the buffer for better-fitting documents.
|
||||
|
||||
**Results (T=2048):**
|
||||
- Crop waste reduced from 39.4% → 34.6% (~12% relative improvement)
|
||||
- Still achieves 100% utilization (no padding, every token trains)
|
||||
- Slightly more rows than baseline (uses more documents per batch)
|
||||
|
||||
### Decision: Keep Two Implementations
|
||||
|
||||
1. Keep the original implementation which is very simple, efficient and has 100% token utilization in the batch (no padding with ignore tokens), but creates slightly more confusing token streams for the LLM because documents during training can start abruptly from the middle with no context. Note that this never happens at test time, where BOS is always present.
|
||||
|
||||
2. **`_bos_bestfit` (BestFit-Crop, new default)**: Slightly more complex but still keeps 100% token utilization in the batch (no padding), but at the cost of discarding documents when they don't fit. In practice, about 34% of tokens are discarded with this approach. This is ok because for most models we care about we have plenty of data without having to go to multiple epochs. One more subtle effect is that it does skew the data distribution a tiny bit because, reliably and necessarily, tokens at the tails of long documents will be discarded. However, this doesn't seem to impact actual downstream performance.
|
||||
|
||||
### Midtraining
|
||||
|
||||
The midtraining dataloader was also updated. Because conversations are on average a lot shorter than pretraining documents, only about 3.3% of tokens get cropped.
|
||||
|
||||
### NOTE: loss scale
|
||||
|
||||
Do note that switching to the BOS dataloader changes the validation loss and makes all previous experiments not comparable in absolute value of the loss, because we have a lot fewer "confusing" tokens in the train/val batches. All tokens can look back and find the BOS token and have the full context of that document to make predictions. Therefore, the loss appears lower but this is "fake" to some extent, and the expectation is that the vast majority of relative comparisons done so far would agree with those before and after this change.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-13: Number Token Split Pattern
|
||||
|
||||
Validated the `\p{N}{1,2}` pattern in `SPLIT_PATTERN` (tokenizer.py line 30), which I only guessed earlier and had a TODO for to validate. GPT-4 uses `\p{N}{1,3}` to group number sequences of up to 3 digits into tokens, but we suspected smaller vocab sizes benefit from grouping fewer digits per token.
|
||||
|
||||
**Results (d12, vocab=32K):**
|
||||
| Pattern | val_bpb |
|
||||
|---------|---------|
|
||||
| `\p{N}{1,1}` | 0.969 |
|
||||
| `\p{N}{1,2}` | **0.965** |
|
||||
| `\p{N}{1,3}` | 0.972 |
|
||||
|
||||
**Conclusion:** `{1,2}` is optimal for vocab size 32K. Grouping 3 digits wastes tokens on rare 3-digit combinations; grouping 1 digit is too fine-grained and bloats token sequences. Keeping `{1,2}` as default.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-13: FP8 Training for lm_head
|
||||
|
||||
Attempted to use FP8 (8-bit floating point) for the lm_head layer to speed up the large vocab projection matmul. H100 GPUs have FP8 tensor cores that can theoretically provide ~2x speedup over BF16.
|
||||
|
||||
### Implementation Approaches Tried
|
||||
|
||||
**1. Dynamic Scaling (failed)**
|
||||
- Compute `x.abs().max()` and `w.abs().max()` each forward to determine scales
|
||||
- Problem: `.item()` calls cause graph breaks with torch.compile
|
||||
- Tried `@torch._dynamo.allow_in_graph` pattern (like torchao.float8) - worked but no speedup
|
||||
- Tried `torch.library.custom_op` with float scales - caused NaN gradients after first optimizer step
|
||||
- Root cause: interaction between custom ops, dynamic scale computation, and torch.compile is fragile
|
||||
|
||||
**2. Static Scaling (partial success)**
|
||||
- Pre-set scales at init time like modded-nanogpt: `x_scale=10/448, w_scale=0.1/448`
|
||||
- `grad_scale` computed dynamically from batch size (safe since it's just `1/(B*T)/57344` due to the gradient expression of cross entropy). modded-nanogpt has a bug here probably because they set `grad_scale = 0.75/448`, but grads are in E5M2 so this should probably be `1/57344`, 1 being the amax of any individual element of cross entropy loss, and no normalization by B,T because they use sum reduction not mean reduction.
|
||||
- Uses `torch.library.custom_op` with `@torch.compile` on inner kernels
|
||||
- This works correctly - no NaNs, proper gradients
|
||||
|
||||
### Results (d12)
|
||||
|
||||
| Metric | BF16 Baseline | FP8 lm_head |
|
||||
|--------|---------------|-------------|
|
||||
| GPU Memory | 34 GB | 36 GB |
|
||||
| tok/sec | baseline | ~1% faster |
|
||||
|
||||
### The Memory Mystery
|
||||
|
||||
FP8 *should* save memory since we store `x_f8` (1 byte) instead of `x` (2 bytes) for backward. But we see 2GB *increase*. Suspected causes:
|
||||
- `torch.compile` on inner kernels creating extra buffers/specializations
|
||||
- `torch._scaled_mm` internal workspace allocations
|
||||
- Custom op registration machinery overhead
|
||||
|
||||
Tried saving original weight `w` (just a reference to parameter) instead of `w_f8` in backward, then re-quantizing on the spot during backward - didn't help. Still saw bump.
|
||||
|
||||
### Microbenchmark vs Reality
|
||||
|
||||
Raw microbenchmark showed promise:
|
||||
- BF16 matmul: 16.95 ms
|
||||
- FP8 matmul (static scales): 10.31 ms (1.64x faster)
|
||||
- FP8 with dynamic scaling: 12.25 ms (1.38x faster)
|
||||
|
||||
But in full training, the ~1% tok/sec improvement doesn't justify the 2GB memory increase and the added code complexity and the need to tune scale factors for both x and w.
|
||||
|
||||
### Code Artifacts
|
||||
|
||||
See the branch `fp8_attempt_fail` for:
|
||||
|
||||
- `nanochat/fp8_static.py` - Static scaling implementation (working)
|
||||
- `nanochat/fp8_dynamic.py` - Dynamic scaling implementation (torchao-style, working but slow)
|
||||
- `gpt.py` imports `fp8_static.LinearFP8` and simply swaps it for `lm_head` in `gpt.py`.
|
||||
|
||||
### Open Questions
|
||||
|
||||
- Why does the custom op approach use more memory than vanilla BF16?
|
||||
- Why is the bump in tok_per_sec so low? We should see ~1.6X speedup in both the forward pass and also (twice) in backward pass for the gradients. Granted, Ahmdal's law is part of the solution because our vocab_size is only 32K so the final layer isn't a huge part of the profile but the expected speedup is still not fully realized.
|
||||
|
||||
**Conclusion:** Negative result for now. The implementation works correctly but provides marginal speedup with *increased* memory usage. I'm not understanding the torch.compile interaction here. The complexity of FP8 custom ops isn't justified for lm_head alone. TODO to study in more detail the way this is implemented in other libraries, e.g. torchao.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-12: Multi-Token Prediction (MTP)
|
||||
|
||||
Ported multi-token prediction from modded-nanogpt. Instead of predicting just the next token, predict the next n tokens at each position with weighted loss.
|
||||
|
||||
### Implementation
|
||||
|
||||
- Instead of calling the loss `n_predict` times, uses a fancy batched computation using `unfold` + `gather` + cross-entropy decomposition (`CE = logsumexp - logits[target]`)
|
||||
- Schedule anneals from 3-token to 1-token prediction:
|
||||
- 0-33%: `[1.0, 0.5, 0.25→0]` (3rd token fades)
|
||||
- 33-67%: `[1.0, 0.5→0]` (2nd token fades)
|
||||
- 67-100%: `[1.0]` (standard next-token)
|
||||
- Weights normalized to sum to 1
|
||||
|
||||
### Results (d12)
|
||||
|
||||
| Metric | Baseline | MTP |
|
||||
|--------|----------|-----|
|
||||
| GPU Memory | 34 GB | 47 GB |
|
||||
| MFU | 41% | 40% |
|
||||
| val/bpb (per step) | baseline | same/slightly worse |
|
||||
| val/bpb (wall clock) | baseline | noticeably worse |
|
||||
|
||||
**Conclusion:** Negative result for nanochat. The extra memory and compute overhead from predicting multiple tokens doesn't pay off, in fact the results get worse. The auxiliary loss signal may help in other settings (larger models, different architectures?), but for our setup it's pure overhead at the moment.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-11: Sliding Window Attention
|
||||
|
||||
Added configurable sliding window attention, inspired by GPT-3's alternating short/long pattern.
|
||||
|
||||
**Pattern string configuration:**
|
||||
- New `--window_pattern` CLI arg and `GPTConfig.window_pattern` field
|
||||
- Pattern is tiled across layers (e.g., `SSSL` for 20 layers → `SSSLSSSLSSSLSSSLSSSL`)
|
||||
- Final layer always forced to L (full context) regardless of pattern
|
||||
- Short window = `sequence_len // 2`
|
||||
- Long window = `sequence_len` (full context)
|
||||
- All previous models so far have been simply `L` and checkpoint loading is modified accordingly to fill in this param for old models, see `_patch_missing_config_keys`
|
||||
|
||||
Quick experiments showed `SSSL` (every 4th layer is long) works well - provides a good balance between compute savings and model quality. This is now the default.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-11: Flash Attention 3 Integration
|
||||
|
||||
Replaced PyTorch's `scaled_dot_product_attention` (FA2) with Flash Attention 3 for training and inference.
|
||||
|
||||
### Changes Made
|
||||
|
||||
**1. FA3 via `kernels` package**
|
||||
- Official FA3 is "beta" and requires building from source (painful)
|
||||
- Using `kernels` package from HuggingFace Hub: `get_kernel('varunneal/flash-attention-3')`
|
||||
- Loads pre-built wheels, works out of the box on H100
|
||||
|
||||
**2. Simplified attention code**
|
||||
- FA3 uses `(B, T, H, D)` layout matching our projection output directly - no transpose needed
|
||||
- Training: `flash_attn.flash_attn_func(q, k, v, causal=True)`
|
||||
- Inference: `flash_attn.flash_attn_with_kvcache()` handles all cache cases in one call
|
||||
- Removed 3 separate FA2 code paths (training, single-token, chunk inference)
|
||||
- GQA handled automatically when n_kv_heads < n_heads
|
||||
|
||||
**3. Rewrote KVCache for FA3**
|
||||
- Old format: `(num_layers, 2, B, H, T, D)` combined tensor
|
||||
- New format: separate `k_cache` and `v_cache` of shape `(num_layers, B, T, H, D)`
|
||||
- FA3 updates cache in-place during `flash_attn_with_kvcache`
|
||||
- Position tracked via `cache_seqlens` tensor (int32, per batch element)
|
||||
- Simpler API: `get_layer_cache()`, `advance()`, `reset()`, `prefill()`
|
||||
|
||||
### Results
|
||||
|
||||
- **~9% improvement in tok/sec** during training out of the box
|
||||
- Benchmarks showed FA3 is 2x faster than FA2 at realistic training sizes (batch=32, seq=2048)
|
||||
- FA3 supports sliding window via `window_size=(left, 0)`, which is huge and expected to give further improvements. This is ready to tune but keeping full context for now.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-11: Per-Layer Residual Scalars (x0 & resid lambdas)
|
||||
|
||||
Cherry-picked an idea from modded-nanogpt around learnable per-layer residual connections.
|
||||
|
||||
### Changes Made
|
||||
|
||||
**1. x0_lambdas (x0 residual connections)**
|
||||
- Save initial normalized embedding as `x0` after `norm(wte(idx))`
|
||||
- At each layer, blend x0 back in: `x = resid_lambdas[i] * x + x0_lambdas[i] * x0`
|
||||
- Zero-initialized, so disabled at start; model learns which layers benefit from the shortcut
|
||||
- Provides direct path from embedding to deep layers, helps preserve token information
|
||||
|
||||
**2. resid_lambdas (residual stream scaling)**
|
||||
- Per-layer multiplicative scaling of the residual stream
|
||||
- Initialized to 1.0 (neutral, standard transformer behavior)
|
||||
- Allows model to learn to amplify/dampen residual at each layer
|
||||
|
||||
**3. DistAdamW small parameter handling**
|
||||
- Added support for parameters with < 1024 elements (like the scalar lambdas)
|
||||
- Small params use `all_reduce` instead of `reduce_scatter`/`all_gather`
|
||||
- Fixes crash when param shape isn't divisible by world_size
|
||||
|
||||
### Key Finding: Different LR Sensitivity
|
||||
|
||||
The two scalar types need very different learning rates:
|
||||
- **x0_lambdas (additive)**: Can use normal LR (~0.5). Adding a fraction of x0 is forgiving.
|
||||
- **resid_lambdas (multiplicative)**: Needs ~100x smaller LR (~0.005). Multiplying the residual compounds through layers.
|
||||
|
||||
Implementation: `resid_params` gets `scalar_lr * 0.01`, `x0_params` gets full `scalar_lr`.
|
||||
|
||||
### Experiment Results
|
||||
|
||||
Swept `--scalar_lr` (controlling x0_lambdas) at multiple depths:
|
||||
|
||||
| Depth | Baseline (disabled) | Best scalar_lr | Best val_bpb | Δ bpb |
|
||||
|-------|---------------------|----------------|--------------|-------|
|
||||
| d8 | 1.0885 | 0.20 | 1.0782 | -0.0103 |
|
||||
| d12 | 0.9770 | 0.60 | 0.9693 | -0.0077 |
|
||||
| d16 | 0.9059 | 0.20 | 0.9002 | -0.0057 |
|
||||
| d20 | 0.8565 | 0.10 | 0.8526 | -0.0039 |
|
||||
|
||||
**Observations:**
|
||||
- Consistent improvement across all model sizes
|
||||
- Optimal LR varies by depth; default of 0.5 is reasonable, but 0.6 is better for d12
|
||||
- Adding resid_lambdas (with 0.01x LR) gives small additional improvement over x0 alone
|
||||
|
||||
### Meta Device Footgun
|
||||
|
||||
Important lesson: `__init__` runs in meta device context, so any tensor values set there are fake. Must initialize actual values in `init_weights()`. Added docstring warning to `__init__`.
|
||||
|
||||
### Summary
|
||||
|
||||
Added `--scalar_lr` (default 0.5) controlling learnable per-layer scalars. The formula `x = resid_lambdas[i] * x + x0_lambdas[i] * x0` gives the model control over residual scaling and direct shortcuts to the initial embedding. Solid improvement with essentially no compute overhead.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-10: Muon Optimizer Upgrades & Cautious Weight Decay
|
||||
|
||||
Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon implementation. Decided against using NorMuon directly due to hard-coded architecture assumptions (expects 32 params split 10 attn + 22 mlp), parameter labeling requirements, and complexity.
|
||||
|
||||
### Changes Made
|
||||
|
||||
**1. Polar Express Orthogonalization**
|
||||
- Replaced Newton-Schulz iteration with "Polar Express Sign Method" from [arxiv.org/pdf/2505.16932](https://arxiv.org/pdf/2505.16932)
|
||||
- Uses 5 different coefficient tuples (one per iteration) instead of fixed coefficients
|
||||
- Both methods kept in code for easy comparison (`zeropower_via_polar_express` vs `zeropower_via_newtonschulz5`)
|
||||
- **Result:** No dramatic/noticeable difference in training, but keeping the new Polar Express as default.
|
||||
|
||||
**2. Variance Reduction (NorMuon-style)**
|
||||
- Added low-rank variance estimator similar to Adafactor ([arxiv.org/pdf/2510.05491](https://arxiv.org/pdf/2510.05491))
|
||||
- Maintains `second_momentum_buffer` with shape `[rows, 1]` or `[1, cols]` (whichever is smaller)
|
||||
- Normalizes updates based on running per-row/col variance estimate (beta2=0.95)
|
||||
- Memory overhead: ~1/max(rows, cols) per param, negligible
|
||||
- **Result:** Led to a very small improvement, kept and enabled by default.
|
||||
|
||||
**3. Cautious Weight Decay**
|
||||
- Only decays weights where `update * weight >= 0` (same sign) from [arxiv.org/abs/2411.16085](https://arxiv.org/abs/2411.16085)
|
||||
- Standard WD always pulls toward zero; cautious WD skips decay when gradient is pushing weight away from zero
|
||||
- **Implementation note:** Had to inline the logic rather than use a separate `@torch.compile` function. Passing changing float values (like `weight_decay` during scheduling) as function arguments triggers recompilation. Reading from `group["weight_decay"]` inside the step avoids this.
|
||||
- **Result:** Solid improvements, especially the cautious version was better than standard wd.
|
||||
- Now defaults to ON for Muon via the `weight_decay` param. AdamW still has no weight decay and is hardcoded to 0 weight decay, might try to re-tune this later.
|
||||
|
||||
**4. Weight decay schedule**
|
||||
- Added a linear schedule to weight decay that is default on from 1.0 to 0.0 (i.e. start with max weight decay in the beginning of training, them ramp to 0 by the end). Worked better than a static setting in experiments. (modded-nanogpt has the same schedule but it is imlpemented in a more confusing way by multiplying twice by the learning rate, which is already wired up to a decay schedule).
|
||||
|
||||
### Weight Decay Scaling Experiments
|
||||
|
||||
Swept weight decay values at d8, d12, d16, d20 to find optimal values and scaling law.
|
||||
|
||||
**Optimal Values Found:**
|
||||
| Depth | Width (channels) | Optimal WD |
|
||||
|-------|------------------|------------|
|
||||
| d8 | 512 | ~0.40 |
|
||||
| d12 | 768 | ~0.22 |
|
||||
| d16 | 1024 | ~0.10 |
|
||||
| d20 | 1280 | ~0.08 |
|
||||
|
||||
**Scaling Law:**
|
||||
- Fit power law: `WD = k / channels^α` in log-log space
|
||||
- Found α ≈ 1.97 (approximately 2), meaning WD ∝ 1/width²
|
||||
|
||||
**Practical Formula:**
|
||||
```
|
||||
WD_target = WD_reference × (d_reference / d_target)²
|
||||
```
|
||||
Example: If d12 optimal is 0.22, then d20 optimal ≈ 0.22 × (12/20)² ≈ 0.08
|
||||
|
||||
**Reference:** Moonlight paper uses fixed WD=0.1 for their 15B MoE model. Our experiments indicated a scaling law where the optimal WD changed with depth, so we go along with the empirical scaling law.
|
||||
|
||||
### Summary
|
||||
|
||||
Muon was changed to use Polar Express, added Adafactor-style variance reduction, and cautious weight decay with schedule that ramps linearly to zero. All of these changes follow modded-nanogpt repo, but all of them were also validated piece by piece to yield improvements in nanochat with the exception of the Polar Express change which was in the noise. This is default on and configurable with `--weight_decay`, using simply 0.2 and ∝ 1/width² scaling. The kwarg `--weight_decay` is therefore changing as of this change. It used to configure AdamW via standard weight decay and now it becomes exclusively used in Muon (AdamW is hardcoded to 0.0), and it is scaled based on depth.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-08: exp_grad_clip - Gradient Clipping
|
||||
|
||||
**Hypothesis:** Gradient clipping may be unnecessary overhead. Tested L2 norm clipping at various thresholds (0.25, 0.5, 1.0, 2.0) and elementwise clipping.
|
||||
|
||||
**Results:**
|
||||
- No benefit at any scale tested (d12, d20)
|
||||
- All variants within noise (~0.9827 val_bpb)
|
||||
- Grad norm never exceeds 1.0 naturally, so clipping is always inactive
|
||||
- Clipping adds ~2% time overhead from the all-reduce
|
||||
|
||||
**Bug Found:** Original implementation clipped local gradients before sync. Since this codebase doesn't use DDP (gradient sync is in the optimizers), each rank was clipping based on its own local norm. Fixed on the branch with proper distributed all-reduce.
|
||||
|
||||
**Observartion:** modded-nanogpt does not appear to clip either right now.
|
||||
|
||||
**Summary:** Deleted all grad-clip code paths. The code naturally produces well-behaved gradients. This improves a bit of MFU because we don't have to calculate and sync grad norms.
|
||||
2190
dev/estimate_gpt3_core.ipynb
Normal file
2190
dev/estimate_gpt3_core.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
|
|
@ -24,8 +24,7 @@ prompt:
|
|||
manually generate any kind of entropy you can think of and include it in your prompts
|
||||
to maintain healthy and good diversity in the data.
|
||||
|
||||
NOTE: You need OpenRouter API key in a file called "openroutertoken.txt" in the root directory of the repo.
|
||||
(obviously you can tune this arbitrarily to your liking)
|
||||
NOTE: You need OPENROUTER_API_KEY set in .env or as an environment variable.
|
||||
NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139
|
||||
"""
|
||||
import requests
|
||||
|
|
@ -34,10 +33,12 @@ import os
|
|||
import copy
|
||||
import random
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from nanochat.common import get_base_dir
|
||||
|
||||
api_key = open("openroutertoken.txt", "r", encoding="utf-8").read().strip()
|
||||
load_dotenv()
|
||||
api_key = os.environ["OPENROUTER_API_KEY"]
|
||||
|
||||
url = "https://openrouter.ai/api/v1/chat/completions"
|
||||
headers = {
|
||||
|
|
|
|||
|
|
@ -19,16 +19,13 @@ source .venv/bin/activate
|
|||
if [ -z "$WANDB_RUN" ]; then
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
|
||||
# wipe the report
|
||||
python -m nanochat.report reset
|
||||
|
||||
# train tokenizer on ~1B characters
|
||||
python -m nanochat.dataset -n 4
|
||||
python -m scripts.tok_train --max_chars=1000000000
|
||||
python -m scripts.tok_train --max-chars=1000000000
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# train a very small 4 layer model on the CPU
|
||||
|
|
@ -36,37 +33,37 @@ python -m scripts.tok_eval
|
|||
# we only run 50 steps of optimization (bump this to get better results)
|
||||
python -m scripts.base_train \
|
||||
--depth=4 \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--total_batch_size=1024 \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=4096 \
|
||||
--core_metric_every=50 \
|
||||
--core_metric_max_per_task=12 \
|
||||
--sample_every=50 \
|
||||
--num_iterations=50
|
||||
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
|
||||
--max-seq-len=1024 \
|
||||
--device-batch-size=1 \
|
||||
--total-batch-size=1024 \
|
||||
--eval-every=50 \
|
||||
--eval-tokens=4096 \
|
||||
--core-metric-every=50 \
|
||||
--core-metric-max-per-task=12 \
|
||||
--sample-every=50 \
|
||||
--num-iterations=50
|
||||
python -m scripts.base_loss --device-batch-size=1 --split-tokens=4096
|
||||
python -m scripts.base_eval --max-per-task=16
|
||||
|
||||
# midtraining
|
||||
python -m scripts.mid_train \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=4096 \
|
||||
--total_batch_size=1024 \
|
||||
--num_iterations=100
|
||||
--max-seq-len=1024 \
|
||||
--device-batch-size=1 \
|
||||
--eval-every=50 \
|
||||
--eval-tokens=4096 \
|
||||
--total-batch-size=1024 \
|
||||
--num-iterations=100
|
||||
# eval results will be terrible, this is just to execute the code paths.
|
||||
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
|
||||
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
|
||||
|
||||
# SFT
|
||||
python -m scripts.chat_sft \
|
||||
--device_batch_size=1 \
|
||||
--target_examples_per_step=4 \
|
||||
--num_iterations=100 \
|
||||
--eval_steps=4 \
|
||||
--eval_metrics_max_problems=16
|
||||
--device-batch-size=1 \
|
||||
--target-examples-per-step=4 \
|
||||
--num-iterations=100 \
|
||||
--eval-steps=4 \
|
||||
--eval-metrics-max-problems=16
|
||||
|
||||
# Chat CLI
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
|
|
|||
227
dev/scaling_analysis.ipynb
Normal file
227
dev/scaling_analysis.ipynb
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Scaling Laws Analysis\n",
|
||||
"\n",
|
||||
"Analyze results from `scaling_laws.sh` to find the optimal param:data ratio for nanochat."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"# Load results\n",
|
||||
"base_dir = os.environ.get('NANOCHAT_BASE_DIR', os.path.expanduser('~/.cache/nanochat'))\n",
|
||||
"results_path = os.path.join(base_dir, 'scaling_laws_results', 'results.csv')\n",
|
||||
"\n",
|
||||
"df = pd.read_csv(results_path)\n",
|
||||
"flops_budgets = sorted(df['flops_budget'].unique())\n",
|
||||
"print(f\"Loaded {len(df)} runs across {len(flops_budgets)} FLOPs budgets\")\n",
|
||||
"print(f\"Columns: {list(df.columns)}\")\n",
|
||||
"df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## IsoFLOP Curves (à la Chinchilla)\n",
|
||||
"\n",
|
||||
"For each compute budget, plot loss vs model size. Looking for the U-shape valley that reveals the optimal model size for each FLOPs budget."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
|
||||
"\n",
|
||||
"# Plot 1: IsoFLOP curves - Val BPB vs Parameters (the Chinchilla plot!)\n",
|
||||
"ax = axes[0]\n",
|
||||
"colors = plt.cm.viridis(np.linspace(0, 0.9, len(flops_budgets)))\n",
|
||||
"optimal_by_bpb = []\n",
|
||||
"\n",
|
||||
"for flops, color in zip(flops_budgets, colors):\n",
|
||||
" subset = df[df['flops_budget'] == flops].sort_values('num_scaling_params')\n",
|
||||
" ax.plot(subset['num_scaling_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n",
|
||||
"\n",
|
||||
" # Fit quadratic in log-space: val_bpb = a*(log N)^2 + b*(log N) + c\n",
|
||||
" log_params = np.log10(subset['num_scaling_params'])\n",
|
||||
" coeffs = np.polyfit(log_params, subset['val_bpb'], 2)\n",
|
||||
" a, b, c = coeffs\n",
|
||||
"\n",
|
||||
" # Plot fitted curve (dashed)\n",
|
||||
" log_fit_x = np.linspace(log_params.min() - 0.1, log_params.max() + 0.1, 100)\n",
|
||||
" fit_y = a * log_fit_x**2 + b * log_fit_x + c\n",
|
||||
" ax.plot(10**log_fit_x, fit_y, '--', color=color, linewidth=2)\n",
|
||||
"\n",
|
||||
" # Find minimum of quadratic: d/dx(ax^2 + bx + c) = 0 => x = -b/(2a)\n",
|
||||
" if a > 0: # parabola opens upward (has a minimum)\n",
|
||||
" log_opt = -b / (2 * a)\n",
|
||||
" opt_params = 10**log_opt\n",
|
||||
" opt_bpb = a * log_opt**2 + b * log_opt + c\n",
|
||||
" # Mark the fitted optimal\n",
|
||||
" ax.scatter([opt_params], [opt_bpb], s=150, color=color,\n",
|
||||
" zorder=5, edgecolors='black', linewidths=2, marker='*')\n",
|
||||
" # Interpolate tokens and ratio from actual data (don't use C≈6ND approximation)\n",
|
||||
" opt_tokens = np.interp(np.log10(opt_params), log_params, subset['tokens_trained'])\n",
|
||||
" opt_ratio = np.interp(np.log10(opt_params), log_params, subset['param_data_ratio'])\n",
|
||||
" optimal_by_bpb.append({'flops': flops, 'params': opt_params, 'tokens': opt_tokens, 'ratio': opt_ratio, 'bpb': opt_bpb})\n",
|
||||
" else:\n",
|
||||
" # Fallback to raw minimum if quadratic doesn't have minimum\n",
|
||||
" best_idx = subset['val_bpb'].idxmin()\n",
|
||||
" best = subset.loc[best_idx]\n",
|
||||
" ax.scatter([best['num_scaling_params']], [best['val_bpb']], s=150, color=color,\n",
|
||||
" zorder=5, edgecolors='black', linewidths=2)\n",
|
||||
" optimal_by_bpb.append({'flops': flops, 'params': best['num_scaling_params'],\n",
|
||||
" 'tokens': best['tokens_trained'], 'ratio': best['param_data_ratio'], 'bpb': best['val_bpb']})\n",
|
||||
"\n",
|
||||
"ax.set_xscale('log')\n",
|
||||
"ax.set_xlabel('Parameters')\n",
|
||||
"ax.set_ylabel('Validation Loss (bpb)')\n",
|
||||
"ax.set_title('IsoFLOP Curves')\n",
|
||||
"ax.legend(title='FLOPs', loc='upper right')\n",
|
||||
"ax.grid(True, alpha=0.3)\n",
|
||||
"\n",
|
||||
"opt_df = pd.DataFrame(optimal_by_bpb)\n",
|
||||
"\n",
|
||||
"# Plot 2: Optimal model size vs compute (power law)\n",
|
||||
"ax = axes[1]\n",
|
||||
"ax.loglog(opt_df['flops'], opt_df['params'], 'o', markersize=10, color='#2ecc71')\n",
|
||||
"ax.set_xlabel('FLOPs')\n",
|
||||
"ax.set_ylabel('Optimal Parameters')\n",
|
||||
"ax.set_title('Optimal Model Size')\n",
|
||||
"ax.grid(True, alpha=0.3)\n",
|
||||
"\n",
|
||||
"# Fit and show power law\n",
|
||||
"if len(opt_df) >= 2:\n",
|
||||
" log_f = np.log10(opt_df['flops'])\n",
|
||||
" log_p = np.log10(opt_df['params'])\n",
|
||||
" slope, intercept = np.polyfit(log_f, log_p, 1)\n",
|
||||
" fit_f = np.logspace(log_f.min() - 0.5, log_f.max() + 0.5, 100)\n",
|
||||
" fit_p = 10**(intercept + slope * np.log10(fit_f))\n",
|
||||
" ax.plot(fit_f, fit_p, 'r--', alpha=0.7, label=f'N ∝ C^{slope:.2f}')\n",
|
||||
" ax.legend()\n",
|
||||
"\n",
|
||||
"# Plot 3: Optimal tokens vs compute (power law)\n",
|
||||
"ax = axes[2]\n",
|
||||
"ax.loglog(opt_df['flops'], opt_df['tokens'], 'o', markersize=10, color='#e74c3c')\n",
|
||||
"ax.set_xlabel('FLOPs')\n",
|
||||
"ax.set_ylabel('Optimal Tokens')\n",
|
||||
"ax.set_title('Optimal Training Tokens')\n",
|
||||
"ax.grid(True, alpha=0.3)\n",
|
||||
"\n",
|
||||
"# Fit and show power law\n",
|
||||
"if len(opt_df) >= 2:\n",
|
||||
" log_f = np.log10(opt_df['flops'])\n",
|
||||
" log_t = np.log10(opt_df['tokens'])\n",
|
||||
" slope, intercept = np.polyfit(log_f, log_t, 1)\n",
|
||||
" fit_f = np.logspace(log_f.min() - 0.5, log_f.max() + 0.5, 100)\n",
|
||||
" fit_t = 10**(intercept + slope * np.log10(fit_f))\n",
|
||||
" ax.plot(fit_f, fit_t, 'r--', alpha=0.7, label=f'D ∝ C^{slope:.2f}')\n",
|
||||
" ax.legend()\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()\n",
|
||||
"\n",
|
||||
"# Print the optimal points (from quadratic fits)\n",
|
||||
"print(\"\\nOptimal configurations (from quadratic fits):\")\n",
|
||||
"print(f\"{'FLOPs':<12} {'Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n",
|
||||
"print(\"-\" * 65)\n",
|
||||
"for _, row in opt_df.iterrows():\n",
|
||||
" print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Val BPB vs Depth and Ratio"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
||||
"\n",
|
||||
"# Plot 1: Val BPB vs Depth\n",
|
||||
"ax = axes[0]\n",
|
||||
"for flops in flops_budgets:\n",
|
||||
" subset = df[df['flops_budget'] == flops].sort_values('depth')\n",
|
||||
" ax.plot(subset['depth'], subset['val_bpb'], 'o-', label=f'{flops:.0e}')\n",
|
||||
" # Mark the best (lowest)\n",
|
||||
" best_idx = subset['val_bpb'].idxmin()\n",
|
||||
" best = subset.loc[best_idx]\n",
|
||||
" ax.scatter([best['depth']], [best['val_bpb']], s=100, zorder=5, edgecolors='black', linewidths=2)\n",
|
||||
"\n",
|
||||
"ax.set_xlabel('Depth')\n",
|
||||
"ax.set_ylabel('Val BPB (lower is better)')\n",
|
||||
"ax.set_title('Validation BPB vs Model Depth')\n",
|
||||
"ax.legend(title='FLOPs')\n",
|
||||
"ax.grid(True, alpha=0.3)\n",
|
||||
"\n",
|
||||
"# Plot 2: Val BPB vs Param:Data Ratio\n",
|
||||
"ax = axes[1]\n",
|
||||
"for flops in flops_budgets:\n",
|
||||
" subset = df[df['flops_budget'] == flops].sort_values('param_data_ratio')\n",
|
||||
" ax.plot(subset['param_data_ratio'], subset['val_bpb'], 'o-', label=f'{flops:.0e}')\n",
|
||||
" best_idx = subset['val_bpb'].idxmin()\n",
|
||||
" best = subset.loc[best_idx]\n",
|
||||
" ax.scatter([best['param_data_ratio']], [best['val_bpb']], s=100, zorder=5, edgecolors='black', linewidths=2)\n",
|
||||
"\n",
|
||||
"ax.axvline(x=20, color='red', linestyle='--', alpha=0.5, label='Chinchilla (20)')\n",
|
||||
"ax.set_xlabel('Param:Data Ratio (tokens/param)')\n",
|
||||
"ax.set_ylabel('Val BPB (lower is better)')\n",
|
||||
"ax.set_title('Val BPB vs Param:Data Ratio')\n",
|
||||
"ax.legend(title='FLOPs')\n",
|
||||
"ax.grid(True, alpha=0.3)\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
103
miniseries.sh
Normal file
103
miniseries.sh
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
#!/bin/bash
|
||||
|
||||
# See speedrun.sh for more comments
|
||||
# Usage: ./miniseries.sh [series_name]
|
||||
# Example: ./miniseries.sh jan11
|
||||
# Default series name is today's date (e.g., jan11)
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
|
||||
# Setup (skip with SKIP_SETUP=1)
|
||||
if [ -z "$SKIP_SETUP" ]; then
|
||||
# uv
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra gpu
|
||||
source .venv/bin/activate
|
||||
|
||||
# Tokenizer, download 1000 shards for pretraining
|
||||
# (probably this can be reduced but it's tricky to determine the exact right number, TODO).
|
||||
python -m nanochat.dataset -n 1000
|
||||
python -m scripts.tok_train --max-chars=2000000000 --vocab-size=32768
|
||||
else
|
||||
source .venv/bin/activate
|
||||
fi
|
||||
|
||||
# Series name: from arg, env var, or default to today's date (e.g., jan11)
|
||||
SERIES_NAME="${1:-${SERIES_NAME:-$(date +%b%d | tr '[:upper:]' '[:lower:]')}}"
|
||||
# Depths to train (the "miniseries")
|
||||
DEPTHS=(10 11 12 13 14 15 16 17 18 19 20)
|
||||
# Hardware
|
||||
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
||||
# Logging
|
||||
WANDB_RUN="${WANDB_RUN:-${SERIES_NAME}_miniseries}"
|
||||
|
||||
RESULTS_DIR="$NANOCHAT_BASE_DIR/${SERIES_NAME}_miniseries_results"
|
||||
mkdir -p "$RESULTS_DIR"
|
||||
RESULTS_FILE="$RESULTS_DIR/results.csv"
|
||||
|
||||
# Write CSV header only if file doesn't exist
|
||||
if [ ! -f "$RESULTS_FILE" ]; then
|
||||
echo "depth,model_dim,num_params,num_scaling_params,num_iterations,tokens_trained,param_data_ratio,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
|
||||
fi
|
||||
|
||||
log() {
|
||||
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
|
||||
}
|
||||
|
||||
log "=============================================="
|
||||
log "${SERIES_NAME} Miniseries Training"
|
||||
log "=============================================="
|
||||
|
||||
for d in "${DEPTHS[@]}"; do
|
||||
log "Training d=$d..."
|
||||
|
||||
TAG="${SERIES_NAME}_miniseries_d${d}"
|
||||
START_TIME=$(date +%s)
|
||||
|
||||
# Train the model with natural horizon (target_param_data_ratio default)
|
||||
# No --target-flops, let it use the default ratio from base_train
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
|
||||
--depth=$d \
|
||||
--target-param-data-ratio=8 \
|
||||
--run="${WANDB_RUN}_d${d}" \
|
||||
--model-tag="${TAG}" \
|
||||
--core-metric-every=999999 \
|
||||
--core-metric-max-per-task=-1 \
|
||||
--sample-every=-1 \
|
||||
--save-every=-1 \
|
||||
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
|
||||
|
||||
END_TIME=$(date +%s)
|
||||
TRAIN_TIME=$((END_TIME - START_TIME))
|
||||
|
||||
# Extract stats from log
|
||||
LOG_FILE="$RESULTS_DIR/${TAG}_train.log"
|
||||
NUM_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | head -1 | tr -d ',')
|
||||
NUM_SCALING_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP 'scaling: [\d,]+' | grep -oP '[\d,]+' | tr -d ',')
|
||||
NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',')
|
||||
TOKENS_TRAINED=$((NUM_ITERS * 524288))
|
||||
PARAM_DATA_RATIO=$(python -c "print(f'{$TOKENS_TRAINED / $NUM_SCALING_PARAMS:.2f}')")
|
||||
MODEL_DIM=$((d * 64))
|
||||
VAL_BPB=$(grep "Validation bpb:" "$LOG_FILE" | tail -1 | grep -oP '[\d.]+$')
|
||||
CORE_SCORE=$(grep "CORE metric:" "$LOG_FILE" | tail -1 | awk '{print $NF}')
|
||||
|
||||
if [ -z "$CORE_SCORE" ]; then
|
||||
CORE_SCORE="0.0"
|
||||
fi
|
||||
|
||||
log " d=$d: params=$NUM_PARAMS, scaling=$NUM_SCALING_PARAMS, ratio=$PARAM_DATA_RATIO, bpb=$VAL_BPB, CORE=$CORE_SCORE, time=${TRAIN_TIME}s"
|
||||
|
||||
# Append to CSV
|
||||
echo "$d,$MODEL_DIM,$NUM_PARAMS,$NUM_SCALING_PARAMS,$NUM_ITERS,$TOKENS_TRAINED,$PARAM_DATA_RATIO,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
|
||||
done
|
||||
|
||||
log "=============================================="
|
||||
log "${SERIES_NAME} Miniseries Complete!"
|
||||
log "=============================================="
|
||||
log "Results saved to: $RESULTS_FILE"
|
||||
echo ""
|
||||
echo "Results:"
|
||||
column -t -s',' "$RESULTS_FILE"
|
||||
|
|
@ -16,23 +16,31 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super().__init__(param_groups, defaults)
|
||||
|
||||
@torch.compile
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
reduce_scatter_futures: list[torch.Future] = []
|
||||
all_reduce_futures: list[torch.Future] = []
|
||||
reduce_futures: list[torch.Future] = []
|
||||
gather_futures: list[torch.Future] = []
|
||||
grad_slices = []
|
||||
is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter)
|
||||
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
for base_i in range(len(params)):
|
||||
assert params[base_i].shape[0] % world_size == 0, f"First dim of parameter shape {params[base_i].shape} must be divisible by world size {world_size}"
|
||||
grad = params[base_i].grad
|
||||
rank_size = grad.shape[0] // world_size
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad_slice)
|
||||
for p in params:
|
||||
grad = p.grad
|
||||
# Small params: use all_reduce (no scatter/gather needed)
|
||||
if p.numel() < 1024:
|
||||
is_small.append(True)
|
||||
reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad)
|
||||
else:
|
||||
is_small.append(False)
|
||||
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
|
||||
rank_size = grad.shape[0] // world_size
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad_slice)
|
||||
|
||||
idx = 0
|
||||
for group in self.param_groups:
|
||||
|
|
@ -40,14 +48,19 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
eps = group['eps']
|
||||
wd = group['weight_decay']
|
||||
params = group['params']
|
||||
for base in range(len(params)):
|
||||
reduce_scatter_futures[idx].wait()
|
||||
p = params[base]
|
||||
rank_size = p.shape[0] // world_size
|
||||
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
||||
for p in params:
|
||||
reduce_futures[idx].wait()
|
||||
g_slice = grad_slices[idx]
|
||||
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
|
||||
state = self.state[p]
|
||||
g_slice = grad_slices[idx]
|
||||
|
||||
# For small params, operate on full param; for large, operate on slice
|
||||
if is_small[idx]:
|
||||
p_slice = p
|
||||
else:
|
||||
rank_size = p.shape[0] // world_size
|
||||
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
||||
|
||||
# State init
|
||||
if not state:
|
||||
state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
|
||||
|
|
@ -68,10 +81,15 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
bias1 = 1 - beta1 ** t
|
||||
bias2 = 1 - beta2 ** t
|
||||
# compute step
|
||||
denom = exp_avg_sq.sqrt().add_(eps)
|
||||
step_size = lr * (torch.sqrt(bias2) / bias1)
|
||||
denom = (exp_avg_sq / bias2).sqrt().add_(eps)
|
||||
step_size = lr / bias1
|
||||
update = exp_avg.div(denom).mul_(step_size)
|
||||
p_slice.add_(other=update, alpha=-1.0)
|
||||
|
||||
# Only large params need all_gather
|
||||
if not is_small[idx]:
|
||||
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
|
||||
idx += 1
|
||||
all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
|
||||
torch.futures.collect_all(all_reduce_futures).wait()
|
||||
|
||||
if gather_futures:
|
||||
torch.futures.collect_all(gather_futures).wait()
|
||||
|
|
|
|||
|
|
@ -20,6 +20,25 @@ def log0(message):
|
|||
if int(os.environ.get('RANK', 0)) == 0:
|
||||
logger.info(message)
|
||||
|
||||
def _patch_missing_config_keys(model_config_kwargs):
|
||||
"""Add default values for new config keys missing in old checkpoints."""
|
||||
# Old models were trained with full context (no sliding window)
|
||||
if "window_pattern" not in model_config_kwargs:
|
||||
model_config_kwargs["window_pattern"] = "L"
|
||||
log0(f"Patching missing window_pattern in model config to 'L'")
|
||||
|
||||
def _patch_missing_keys(model_data, model_config):
|
||||
"""Add default values for new parameters that may be missing in old checkpoints."""
|
||||
n_layer = model_config.n_layer
|
||||
# resid_lambdas defaults to 1.0 (identity scaling)
|
||||
if "resid_lambdas" not in model_data:
|
||||
model_data["resid_lambdas"] = torch.ones(n_layer)
|
||||
log0(f"Patching missing resid_lambdas in model data to 1.0")
|
||||
# x0_lambdas defaults to 0.0 (disabled)
|
||||
if "x0_lambdas" not in model_data:
|
||||
model_data["x0_lambdas"] = torch.zeros(n_layer)
|
||||
log0(f"Patching missing x0_lambdas in model data to 0.0")
|
||||
|
||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
||||
if rank == 0:
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
|
@ -74,8 +93,10 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
||||
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
||||
model_config_kwargs = meta_data["model_config"]
|
||||
_patch_missing_config_keys(model_config_kwargs)
|
||||
log0(f"Building model with config: {model_config_kwargs}")
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
_patch_missing_keys(model_data, model_config)
|
||||
with torch.device("meta"):
|
||||
model = GPT(model_config)
|
||||
# Load the model state
|
||||
|
|
@ -90,7 +111,7 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
# Load the Tokenizer
|
||||
tokenizer = get_tokenizer()
|
||||
# Sanity check: compatibility between model and tokenizer
|
||||
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"]
|
||||
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"], f"Tokenizer vocab size {tokenizer.get_vocab_size()} does not match model config vocab size {model_config_kwargs['vocab_size']}"
|
||||
return model, tokenizer, meta_data
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,56 +0,0 @@
|
|||
"""
|
||||
Poor Man's Configurator. Probably a terrible idea. Example usage:
|
||||
$ python train.py config/override_file.py --batch_size=32
|
||||
this will first run config/override_file.py, then override batch_size to 32
|
||||
|
||||
The code in this file will be run as follows from e.g. train.py:
|
||||
>>> exec(open('configurator.py').read())
|
||||
|
||||
So it's not a Python module, it's just shuttling this code away from train.py
|
||||
The code in this script then overrides the globals()
|
||||
|
||||
I know people are not going to love this, I just really dislike configuration
|
||||
complexity and having to prepend config. to every single variable. If someone
|
||||
comes up with a better simple Python solution I am all ears.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from ast import literal_eval
|
||||
|
||||
def print0(s="",**kwargs):
|
||||
ddp_rank = int(os.environ.get('RANK', 0))
|
||||
if ddp_rank == 0:
|
||||
print(s, **kwargs)
|
||||
|
||||
for arg in sys.argv[1:]:
|
||||
if '=' not in arg:
|
||||
# assume it's the name of a config file
|
||||
assert not arg.startswith('--')
|
||||
config_file = arg
|
||||
print0(f"Overriding config with {config_file}:")
|
||||
with open(config_file) as f:
|
||||
print0(f.read())
|
||||
exec(open(config_file).read())
|
||||
else:
|
||||
# assume it's a --key=value argument
|
||||
assert arg.startswith('--')
|
||||
key, val = arg.split('=')
|
||||
key = key[2:]
|
||||
if key in globals():
|
||||
try:
|
||||
# attempt to eval it it (e.g. if bool, number, or etc)
|
||||
attempt = literal_eval(val)
|
||||
except (SyntaxError, ValueError):
|
||||
# if that goes wrong, just use the string
|
||||
attempt = val
|
||||
# ensure the types match ok
|
||||
if globals()[key] is not None:
|
||||
attempt_type = type(attempt)
|
||||
default_type = type(globals()[key])
|
||||
assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}"
|
||||
# cross fingers
|
||||
print0(f"Overriding: {key} = {attempt}")
|
||||
globals()[key] = attempt
|
||||
else:
|
||||
raise ValueError(f"Unknown config key: {key}")
|
||||
|
|
@ -1,94 +1,198 @@
|
|||
from collections import deque
|
||||
"""
|
||||
Distributed dataloaders for pretraining.
|
||||
|
||||
Two implementations are provided:
|
||||
|
||||
1. Original (tokenizing_distributed_data_loader):
|
||||
- Streams tokens into a flat buffer, reshapes to (B, T)
|
||||
- Rows may start mid-document (no guaranteed BOS at position 0)
|
||||
- 100% token utilization, simple and efficient
|
||||
|
||||
2. BOS-aligned bestfit (tokenizing_distributed_data_loader_bos_bestfit):
|
||||
- Every row starts with BOS token
|
||||
- Documents packed using best-fit algorithm to minimize cropping
|
||||
- When no document fits remaining space, crops a document to fill exactly
|
||||
- 100% utilization (no padding), ~35% tokens cropped at T=2048
|
||||
|
||||
The tradeoff: BOS-aligned loses ~35% of tokens to cropping, but ensures that
|
||||
there are fewer "confusing" tokens in the train/val batches as every token can
|
||||
now attend back to the BOS token and sees the full context of the document.
|
||||
(2) is the new default if you have enough data.
|
||||
Fallback to (1) if you have very limited data AND long documents.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from nanochat.common import get_dist_info
|
||||
from nanochat.dataset import list_parquet_files
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
|
||||
def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
|
||||
def _document_batches(split, resume_state_dict, tokenizer_batch_size):
|
||||
"""
|
||||
Infinite iterator over document batches (list of text strings) from parquet files.
|
||||
|
||||
Handles DDP sharding and approximate resume. Each yield is (text_batch, (pq_idx, rg_idx, epoch))
|
||||
where text_batch is a list of document strings, indices track position for resumption,
|
||||
and epoch counts how many times we've cycled through the dataset (starts at 1).
|
||||
"""
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
|
||||
parquet_paths = list_parquet_files()
|
||||
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
|
||||
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||
|
||||
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
||||
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
||||
resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1
|
||||
first_pass = True
|
||||
pq_idx = resume_pq_idx
|
||||
epoch = resume_epoch
|
||||
|
||||
while True: # iterate infinitely (multi-epoch)
|
||||
pq_idx = resume_pq_idx if first_pass else 0
|
||||
while pq_idx < len(parquet_paths):
|
||||
filepath = parquet_paths[pq_idx]
|
||||
pf = pq.ParquetFile(filepath)
|
||||
# Start from resume point if resuming on same file, otherwise from DDP rank
|
||||
if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
|
||||
base_idx = resume_rg_idx // ddp_world_size
|
||||
base_idx += 1 # advance by 1 so we don't repeat data after resuming
|
||||
rg_idx = base_idx * ddp_world_size + ddp_rank
|
||||
if rg_idx >= pf.num_row_groups:
|
||||
pq_idx += 1
|
||||
continue
|
||||
resume_rg_idx = None # only do this once
|
||||
else:
|
||||
rg_idx = ddp_rank
|
||||
while rg_idx < pf.num_row_groups:
|
||||
rg = pf.read_row_group(rg_idx)
|
||||
batch = rg.column('text').to_pylist()
|
||||
for i in range(0, len(batch), tokenizer_batch_size):
|
||||
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx, epoch)
|
||||
rg_idx += ddp_world_size
|
||||
pq_idx += 1
|
||||
first_pass = False
|
||||
epoch += 1
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader_with_state(tokenizer, B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
|
||||
"""
|
||||
Stream pretraining text from parquet files, tokenize, yield training batches.
|
||||
|
||||
This implementation became a bit more complex because we wish to support approximate resume training.
|
||||
Instead of turning this into a Class, we opt to return the state_dict with every batch,
|
||||
and then the caller can pass in a state_dict to resume training from a desired point.
|
||||
Note that this resumption is atm only *approximate* for simplicity.
|
||||
We won't repeat the same documents but we might skip a few.
|
||||
The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume.
|
||||
This is the original dataloader that streams tokens into a flat buffer and reshapes.
|
||||
Rows may start mid-document (no guaranteed BOS at position 0).
|
||||
|
||||
Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm.
|
||||
Supports approximate resume via state_dict.
|
||||
"""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
|
||||
# infinite iterator over document batches (list of text strings)
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
def document_batches():
|
||||
parquet_paths = list_parquet_files()
|
||||
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
|
||||
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
||||
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
||||
first_pass = True
|
||||
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
|
||||
while True: # iterate infinitely (multi-epoch)
|
||||
pq_idx = resume_pq_idx if first_pass else 0
|
||||
while pq_idx < len(parquet_paths): # iterate over all parquet files
|
||||
filepath = parquet_paths[pq_idx]
|
||||
pf = pq.ParquetFile(filepath)
|
||||
# Start from resume point if resuming on same file, otherwise from DDP rank
|
||||
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
|
||||
if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
|
||||
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
|
||||
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
|
||||
rg_idx = base_idx * ddp_world_size + ddp_rank
|
||||
if rg_idx >= pf.num_row_groups:
|
||||
pq_idx += 1
|
||||
continue
|
||||
resume_rg_idx = None # set to None as we only want to do this a single time
|
||||
else:
|
||||
rg_idx = ddp_rank
|
||||
while rg_idx < pf.num_row_groups:
|
||||
rg = pf.read_row_group(rg_idx)
|
||||
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
|
||||
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
|
||||
for i in range(0, len(batch), tokenizer_batch_size):
|
||||
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
|
||||
rg_idx += ddp_world_size # advance to the next row group (in DDP)
|
||||
pq_idx += 1 # advance to the next parquet file
|
||||
first_pass = False
|
||||
batches = document_batches()
|
||||
|
||||
# Now emit batches of tokens.
|
||||
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
|
||||
# get the tokenizer and the bos token
|
||||
tokenizer = get_tokenizer()
|
||||
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
|
||||
needed_tokens = B * T + 1 # +1 for target at last position
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
# scratch buffer holds the tokens for one iteration
|
||||
token_buffer = deque() # we stream tokens on the right and pop from the left
|
||||
token_buffer = []
|
||||
pq_idx, rg_idx, epoch = 0, 0, 1
|
||||
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding.
|
||||
|
||||
# Accumulate enough tokens
|
||||
while len(token_buffer) < needed_tokens:
|
||||
doc_batch, (pq_idx, rg_idx) = next(batches)
|
||||
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
|
||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
||||
for tokens in token_lists:
|
||||
token_buffer.extend(tokens)
|
||||
# Move tokens from the deque into the scratch buffer
|
||||
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
||||
# CUDA supports memory pinning for asynchronous transfers between CPU and GPU
|
||||
use_cuda_optimizations = device == "cuda"
|
||||
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
|
||||
# Create the inputs/targets as 1D tensors
|
||||
inputs_cpu = scratch[:-1]
|
||||
targets_cpu = scratch[1:]
|
||||
# Reshape to 2D and move to GPU async
|
||||
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
|
||||
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
|
||||
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training
|
||||
yield inputs, targets, state_dict
|
||||
tokens = token_buffer[:needed_tokens] # Read B*T+1 tokens (+1 is only for the target for the last token)
|
||||
token_buffer = token_buffer[B*T:] # Advance by B*T tokens, so we move exactly one window of B*T tokens over
|
||||
|
||||
# Package tokens into inputs and targets, yield
|
||||
use_cuda = device == "cuda"
|
||||
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda)
|
||||
inputs = scratch[:-1].view(B, T).to(device=device, non_blocking=use_cuda)
|
||||
targets = scratch[1:].view(B, T).to(device=device, non_blocking=use_cuda)
|
||||
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader(*args, **kwargs):
|
||||
# helper function that only emits the inputs/targets and not the state_dict
|
||||
"""Helper that omits state_dict from yields."""
|
||||
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
|
||||
yield inputs, targets
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
||||
tokenizer, B, T, split,
|
||||
tokenizer_threads=4, tokenizer_batch_size=128,
|
||||
device="cuda", resume_state_dict=None,
|
||||
buffer_size=1000
|
||||
):
|
||||
"""
|
||||
BOS-aligned dataloader with Best-Fit Cropping.
|
||||
|
||||
Reduces token waste compared to simple greedy cropping by searching a buffer
|
||||
for documents that fit well, while maintaining 100% utilization (no padding).
|
||||
|
||||
Algorithm for each row:
|
||||
1. From buffered docs, pick the LARGEST doc that fits entirely
|
||||
2. Repeat until no doc fits
|
||||
3. When nothing fits, crop a doc to fill remaining space exactly
|
||||
|
||||
Key properties:
|
||||
- Every row starts with BOS
|
||||
- 100% utilization (no padding, every token is trained on)
|
||||
- Approximately 35% of all tokens are discarded due to cropping
|
||||
"""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
|
||||
row_capacity = T + 1
|
||||
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
doc_buffer = []
|
||||
pq_idx, rg_idx, epoch = 0, 0, 1
|
||||
|
||||
def refill_buffer():
|
||||
nonlocal pq_idx, rg_idx, epoch
|
||||
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
|
||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
||||
for tokens in token_lists:
|
||||
doc_buffer.append(tokens)
|
||||
|
||||
while True:
|
||||
rows = []
|
||||
for _ in range(B):
|
||||
row = []
|
||||
while len(row) < row_capacity:
|
||||
# Ensure buffer has documents
|
||||
while len(doc_buffer) < buffer_size:
|
||||
refill_buffer()
|
||||
|
||||
remaining = row_capacity - len(row)
|
||||
|
||||
# Find largest doc that fits entirely
|
||||
best_idx = -1
|
||||
best_len = 0
|
||||
for i, doc in enumerate(doc_buffer):
|
||||
doc_len = len(doc)
|
||||
if doc_len <= remaining and doc_len > best_len:
|
||||
best_idx = i
|
||||
best_len = doc_len
|
||||
|
||||
if best_idx >= 0:
|
||||
doc = doc_buffer.pop(best_idx)
|
||||
row.extend(doc)
|
||||
else:
|
||||
# No doc fits - crop first doc to fill remaining
|
||||
doc = doc_buffer.pop(0)
|
||||
row.extend(doc[:remaining])
|
||||
|
||||
rows.append(row[:row_capacity])
|
||||
|
||||
use_cuda = device == "cuda"
|
||||
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
|
||||
inputs = batch_tensor[:, :-1].to(device=device, non_blocking=use_cuda)
|
||||
targets = batch_tensor[:, 1:].to(device=device, non_blocking=use_cuda)
|
||||
|
||||
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
|
||||
"""Helper that omits state_dict from yields."""
|
||||
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs):
|
||||
yield inputs, targets
|
||||
|
|
|
|||
|
|
@ -82,83 +82,54 @@ def use_calculator(expr):
|
|||
# -----------------------------------------------------------------------------
|
||||
class KVCache:
|
||||
"""
|
||||
Works hand-in-hand with the GPT model to maintain the KV cache.
|
||||
Note that the .pos advances automatically after the last layer of the Transformer inserts.
|
||||
KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
|
||||
|
||||
Key differences from FA2-style cache:
|
||||
- Tensors are (B, T, H, D) not (B, H, T, D)
|
||||
- FA3 updates the cache in-place during flash_attn_with_kvcache
|
||||
- Position tracked per batch element via cache_seqlens tensor
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
|
||||
# Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
|
||||
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
|
||||
self.kv_cache = None
|
||||
self.pos = 0 # current position in time in the cache
|
||||
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype=torch.bfloat16):
|
||||
self.batch_size = batch_size
|
||||
self.max_seq_len = seq_len
|
||||
self.n_layers = num_layers
|
||||
self.n_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
# Pre-allocate cache tensors: (n_layers, B, T, H, D)
|
||||
self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
||||
self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
||||
# Current sequence length per batch element (FA3 needs int32)
|
||||
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
|
||||
def reset(self):
|
||||
self.pos = 0
|
||||
"""Reset cache to empty state."""
|
||||
self.cache_seqlens.zero_()
|
||||
|
||||
def get_pos(self):
|
||||
return self.pos
|
||||
"""Get current position (assumes all batch elements at same position)."""
|
||||
return self.cache_seqlens[0].item()
|
||||
|
||||
def get_layer_cache(self, layer_idx):
|
||||
"""Return (k_cache, v_cache) views for a specific layer."""
|
||||
return self.k_cache[layer_idx], self.v_cache[layer_idx]
|
||||
|
||||
def advance(self, num_tokens):
|
||||
"""Advance the cache position by num_tokens."""
|
||||
self.cache_seqlens += num_tokens
|
||||
|
||||
def prefill(self, other):
|
||||
"""
|
||||
Prefill given another KV cache. Optionally expand along batch dim.
|
||||
This is used when we do batch 1 prefill and then want to generate
|
||||
multiple samples in parallel from there.
|
||||
Copy cached KV from another cache into this one.
|
||||
Used when we do batch=1 prefill and then want to generate multiple samples in parallel.
|
||||
"""
|
||||
# 1) validate the shapes
|
||||
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
|
||||
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
|
||||
|
||||
# Extract dimensions explicitly
|
||||
self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape
|
||||
other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape
|
||||
|
||||
# Validate dimensions
|
||||
assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}"
|
||||
assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}"
|
||||
assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}"
|
||||
assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}"
|
||||
|
||||
# Batch size can be expanded (other can be 1, self can be larger)
|
||||
assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)"
|
||||
|
||||
# Sequence length: self must be longer than other
|
||||
assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}"
|
||||
|
||||
# 2) initialize the cache
|
||||
dtype, device = other.kv_cache.dtype, other.kv_cache.device
|
||||
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
|
||||
# 3) copy the data over
|
||||
self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
|
||||
# 4) update the pos
|
||||
self.pos = other.pos
|
||||
|
||||
def insert_kv(self, layer_idx, k, v):
|
||||
# Lazy initialize the cache here because we need to know the dtype/device
|
||||
if self.kv_cache is None:
|
||||
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
|
||||
# Insert new keys/values to the cache and return the full cache so far
|
||||
B, H, T_add, D = k.size()
|
||||
t0, t1 = self.pos, self.pos + T_add
|
||||
# Dynamically grow the cache if needed
|
||||
if t1 > self.kv_cache.size(4):
|
||||
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
|
||||
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
|
||||
additional_shape = list(self.kv_cache.shape)
|
||||
additional_shape[4] = t_needed - self.kv_cache.size(4)
|
||||
additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
|
||||
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
|
||||
self.kv_shape = self.kv_cache.shape
|
||||
# Insert k, v into the cache
|
||||
self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k
|
||||
self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v
|
||||
# Return the full cached keys/values up to current position (as a view)
|
||||
key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :]
|
||||
value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :]
|
||||
# Increment pos after the last layer of the Transformer processes
|
||||
if layer_idx == self.kv_cache.size(0) - 1:
|
||||
self.pos = t1
|
||||
return key_view, value_view
|
||||
|
||||
assert self.get_pos() == 0, "Cannot prefill a non-empty KV cache"
|
||||
assert self.n_layers == other.n_layers and self.n_heads == other.n_heads and self.head_dim == other.head_dim
|
||||
assert self.max_seq_len >= other.max_seq_len
|
||||
other_pos = other.get_pos()
|
||||
self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
|
||||
self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
|
||||
self.cache_seqlens.fill_(other_pos)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@torch.inference_mode()
|
||||
|
|
@ -167,7 +138,7 @@ def sample_next_token(logits, rng, temperature=1.0, top_k=None):
|
|||
assert temperature >= 0.0, "temperature must be non-negative"
|
||||
if temperature == 0.0:
|
||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||
if top_k is not None:
|
||||
if top_k is not None and top_k > 0:
|
||||
k = min(top_k, logits.size(-1))
|
||||
vals, idx = torch.topk(logits, k, dim=-1)
|
||||
vals = vals / temperature
|
||||
|
|
@ -219,6 +190,7 @@ class Engine:
|
|||
kv_cache_prefill = KVCache(
|
||||
batch_size=1,
|
||||
seq_len=len(tokens),
|
||||
device=device,
|
||||
**kv_model_kwargs,
|
||||
)
|
||||
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
||||
|
|
@ -230,6 +202,7 @@ class Engine:
|
|||
kv_cache_decode = KVCache(
|
||||
batch_size=num_samples,
|
||||
seq_len=kv_length_hint,
|
||||
device=device,
|
||||
**kv_model_kwargs,
|
||||
)
|
||||
kv_cache_decode.prefill(kv_cache_prefill)
|
||||
|
|
|
|||
236
nanochat/gpt.py
236
nanochat/gpt.py
|
|
@ -9,9 +9,9 @@ Notable features:
|
|||
- no learnable params in rmsnorm
|
||||
- no bias in linear layers
|
||||
- Group-Query Attention (GQA) support for more efficient inference
|
||||
- Flash Attention 3 integration
|
||||
"""
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
|
@ -23,6 +23,14 @@ from nanochat.common import get_dist_info, print0
|
|||
from nanochat.muon import Muon, DistMuon
|
||||
from nanochat.adamw import DistAdamW
|
||||
|
||||
# Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar)
|
||||
import os
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
||||
# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain.
|
||||
# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal)
|
||||
from kernels import get_kernel
|
||||
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
sequence_len: int = 1024
|
||||
|
|
@ -31,6 +39,10 @@ class GPTConfig:
|
|||
n_head: int = 6 # number of query heads
|
||||
n_kv_head: int = 6 # number of key/value heads (GQA)
|
||||
n_embd: int = 768
|
||||
# Sliding window attention pattern string, tiled across layers. Final layer always L.
|
||||
# Characters: L=long (full context), S=short (half context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
window_pattern: str = "L"
|
||||
|
||||
|
||||
def norm(x):
|
||||
|
|
@ -61,48 +73,42 @@ class CausalSelfAttention(nn.Module):
|
|||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||
|
||||
def forward(self, x, cos_sin, kv_cache):
|
||||
def forward(self, x, cos_sin, window_size, kv_cache):
|
||||
B, T, C = x.size()
|
||||
|
||||
# Project the input to get queries, keys, and values
|
||||
# Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
|
||||
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
||||
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||
|
||||
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
||||
cos, sin = cos_sin
|
||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
|
||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
||||
q, k = norm(q), norm(k) # QK norm
|
||||
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
|
||||
|
||||
# Apply KV cache: insert current k,v into cache, get the full view so far
|
||||
if kv_cache is not None:
|
||||
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
|
||||
Tq = q.size(2) # number of queries in this forward pass
|
||||
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
|
||||
|
||||
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
|
||||
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
|
||||
if kv_cache is None or Tq == Tk:
|
||||
# During training (no KV cache), attend as usual with causal attention
|
||||
# And even if there is KV cache, we can still use this simple version when Tq == Tk
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
||||
elif Tq == 1:
|
||||
# During inference but with a single query in this forward pass:
|
||||
# The query has to attend to all the keys/values in the cache
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
||||
# Attention with Flash Attention 3
|
||||
# FA3 handles GQA automatically when n_kv_heads < n_heads
|
||||
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
|
||||
if kv_cache is None:
|
||||
# Training: causal attention with optional sliding window
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
||||
else:
|
||||
# During inference AND we have a chunk of queries in this forward pass:
|
||||
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
||||
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
|
||||
prefix_len = Tk - Tq
|
||||
attn_mask[:, :prefix_len] = True
|
||||
# Then, causal attention within this chunk
|
||||
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
|
||||
# Inference: use flash_attn_with_kvcache which handles cache management
|
||||
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
|
||||
y = flash_attn.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache,
|
||||
k=k, v=v,
|
||||
cache_seqlens=kv_cache.cache_seqlens,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
)
|
||||
# Advance position after last layer processes
|
||||
if self.layer_idx == kv_cache.n_layers - 1:
|
||||
kv_cache.advance(T)
|
||||
|
||||
# Re-assemble the heads side by side and project back to residual stream
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
||||
# Re-assemble the heads and project back to residual stream
|
||||
y = y.contiguous().view(B, T, -1)
|
||||
y = self.c_proj(y)
|
||||
return y
|
||||
|
||||
|
|
@ -126,29 +132,43 @@ class Block(nn.Module):
|
|||
self.attn = CausalSelfAttention(config, layer_idx)
|
||||
self.mlp = MLP(config)
|
||||
|
||||
def forward(self, x, cos_sin, kv_cache):
|
||||
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
||||
def forward(self, x, cos_sin, window_size, kv_cache):
|
||||
x = x + self.attn(norm(x), cos_sin, window_size, kv_cache)
|
||||
x = x + self.mlp(norm(x))
|
||||
return x
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
def __init__(self, config, pad_vocab_size_to=64):
|
||||
"""
|
||||
NOTE a major footgun: this __init__ function runs in meta device context (!!)
|
||||
Therefore, any calculations inside here are shapes and dtypes only, no actual data.
|
||||
=> We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# For DDP, we want vocab_size divisible by world_size. Also, there are potential performance benefits, see:
|
||||
# Compute per-layer window sizes for sliding window attention
|
||||
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
|
||||
self.window_sizes = self._compute_window_sizes(config)
|
||||
# Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
|
||||
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
|
||||
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
|
||||
if padded_vocab_size != config.vocab_size:
|
||||
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} to be divisible by {pad_vocab_size_to}")
|
||||
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
|
||||
self.transformer = nn.ModuleDict({
|
||||
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||
})
|
||||
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's fake
|
||||
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
||||
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
||||
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
||||
# Separate parameters so they can have different optimizer treatment
|
||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them, but assert fail if we ever reach that amount.
|
||||
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
||||
# In the future we can dynamically grow the cache, for now it's fine.
|
||||
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
|
||||
head_dim = config.n_embd // config.n_head
|
||||
|
|
@ -157,35 +177,51 @@ class GPT(nn.Module):
|
|||
self.register_buffer("sin", sin, persistent=False)
|
||||
|
||||
def init_weights(self):
|
||||
self.apply(self._init_weights)
|
||||
# zero out classifier weights
|
||||
torch.nn.init.zeros_(self.lm_head.weight)
|
||||
# zero out c_proj weights in all blocks
|
||||
"""
|
||||
Initialize the full model in this one function for maximum clarity.
|
||||
|
||||
wte (embedding): normal, std=1.0
|
||||
lm_head: normal, std=0.001
|
||||
for each block:
|
||||
attn.c_q: uniform, std=1/sqrt(n_embd)
|
||||
attn.c_k: uniform, std=1/sqrt(n_embd)
|
||||
attn.c_v: uniform, std=1/sqrt(n_embd)
|
||||
attn.c_proj: zeros
|
||||
mlp.c_fc: uniform, std=1/sqrt(n_embd)
|
||||
mlp.c_proj: zeros
|
||||
"""
|
||||
|
||||
# Embedding and unembedding
|
||||
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
|
||||
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
||||
|
||||
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
|
||||
n_embd = self.config.n_embd
|
||||
s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
|
||||
for block in self.transformer.h:
|
||||
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
|
||||
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
||||
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
torch.nn.init.zeros_(block.attn.c_proj.weight)
|
||||
# init the rotary embeddings
|
||||
|
||||
# Per-layer scalars
|
||||
with torch.no_grad():
|
||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
|
||||
|
||||
# Rotary embeddings
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.cos, self.sin = cos, sin
|
||||
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
||||
|
||||
# Cast token embeddings to bf16: optimizer can tolerate it and it saves memory
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
# https://arxiv.org/pdf/2310.17813
|
||||
fan_out = module.weight.size(0)
|
||||
fan_in = module.weight.size(1)
|
||||
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
|
||||
|
||||
# TODO: bump base theta more, e.g. 100K is more common more recently
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
# autodetect the device from model embeddings
|
||||
if device is None:
|
||||
device = self.transformer.wte.weight.device
|
||||
|
|
@ -201,38 +237,100 @@ class GPT(nn.Module):
|
|||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||
return cos, sin
|
||||
|
||||
def _compute_window_sizes(self, config):
|
||||
"""
|
||||
Compute per-layer window sizes for sliding window attention.
|
||||
|
||||
Returns list of (left, right) tuples for FA3's window_size parameter:
|
||||
- left: how many tokens before current position to attend to (-1 = unlimited)
|
||||
- right: how many tokens after current position to attend to (0 for causal)
|
||||
|
||||
Pattern string is tiled across layers. Final layer always gets L (full context).
|
||||
Characters: L=long (full context), S=short (half context)
|
||||
"""
|
||||
pattern = config.window_pattern.upper()
|
||||
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
|
||||
# Map characters to window sizes
|
||||
long_window = config.sequence_len
|
||||
short_window = long_window // 2
|
||||
char_to_window = {
|
||||
"L": (long_window, 0),
|
||||
"S": (short_window, 0),
|
||||
}
|
||||
# Tile pattern across layers
|
||||
window_sizes = []
|
||||
for layer_idx in range(config.n_layer):
|
||||
char = pattern[layer_idx % len(pattern)]
|
||||
window_sizes.append(char_to_window[char])
|
||||
# Final layer always gets full context
|
||||
window_sizes[-1] = (long_window, 0)
|
||||
return window_sizes
|
||||
|
||||
def get_device(self):
|
||||
return self.transformer.wte.weight.device
|
||||
|
||||
def estimate_flops(self):
|
||||
""" Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
|
||||
"""
|
||||
Return the estimated FLOPs per token for the model (forward + backward).
|
||||
Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
|
||||
Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
|
||||
On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention.
|
||||
With sliding windows, effective_seq_len varies per layer (capped by window size).
|
||||
Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
|
||||
This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
|
||||
- Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
|
||||
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
|
||||
"""
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
nparams_embedding = self.transformer.wte.weight.numel()
|
||||
l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
||||
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||
nparams_exclude = self.transformer.wte.weight.numel() + self.resid_lambdas.numel() + self.x0_lambdas.numel()
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
attn_flops = 0
|
||||
for window_size in self.window_sizes:
|
||||
window = window_size[0] # (left, right) tuple, we use left
|
||||
effective_seq = t if window < 0 else min(window, t)
|
||||
attn_flops += 12 * h * q * effective_seq
|
||||
num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
|
||||
return num_flops_per_token
|
||||
|
||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
|
||||
def num_scaling_params(self):
|
||||
"""
|
||||
Return all of the parameters, same as Chinchilla paper.
|
||||
Kaplan et al. did not include embedding parameters and said that this led to cleaner scaling laws.
|
||||
But Kaplan et al. also had a bug in their results (as pointed out by Chinchilla).
|
||||
My own experiments in nanochat confirm the Chinchilla approach gives the much cleaner scaling law.
|
||||
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper <- good).
|
||||
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper <- bad)
|
||||
"""
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
return nparams
|
||||
|
||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
|
||||
# Separate out all parameters into 5 groups (matrix, embedding, lm_head, resid_lambdas, x0_lambdas)
|
||||
matrix_params = list(self.transformer.h.parameters())
|
||||
embedding_params = list(self.transformer.wte.parameters())
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
|
||||
# Create the AdamW optimizer for the embedding and lm_head
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(resid_params) + len(x0_params)
|
||||
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
adam_groups = [
|
||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
|
||||
dict(params=x0_params, lr=scalar_lr),
|
||||
]
|
||||
adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
|
||||
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
|
||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
||||
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
|
||||
# Create the Muon optimizer for the linear layers
|
||||
muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
|
||||
muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay)
|
||||
MuonFactory = DistMuon if ddp else Muon
|
||||
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
|
||||
# Combine them the two optimizers into one list
|
||||
|
|
@ -256,8 +354,10 @@ class GPT(nn.Module):
|
|||
# Forward the trunk of the Transformer
|
||||
x = self.transformer.wte(idx)
|
||||
x = norm(x)
|
||||
for block in self.transformer.h:
|
||||
x = block(x, cos_sin, kv_cache)
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
x = block(x, cos_sin, self.window_sizes[i], kv_cache)
|
||||
x = norm(x)
|
||||
|
||||
# Forward the lm_head (compute logits)
|
||||
|
|
|
|||
401
nanochat/muon.py
401
nanochat/muon.py
|
|
@ -1,39 +1,96 @@
|
|||
"""
|
||||
Muon optimizer from Keller et al.
|
||||
Also a lot of borrowing of ideas from modded-nanogpt.
|
||||
Muon optimizer adapted and simplified from modded-nanogpt.
|
||||
https://github.com/KellerJordan/modded-nanogpt
|
||||
|
||||
Background:
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
|
||||
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
|
||||
Polar Express Sign Method for orthogonalization.
|
||||
https://arxiv.org/pdf/2505.16932
|
||||
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
||||
|
||||
Some of the changes in nanochat implementation:
|
||||
- Uses a simpler, more general approach to parameter grouping and stacking
|
||||
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
|
||||
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.distributed as dist
|
||||
|
||||
@torch.compile
|
||||
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
||||
"""
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
"""
|
||||
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
X = G.bfloat16()
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
||||
# From https://arxiv.org/pdf/2505.16932
|
||||
polar_express_coeffs = [
|
||||
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
||||
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
||||
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
||||
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
||||
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
||||
]
|
||||
|
||||
# Ensure spectral norm is at most 1
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
||||
# Perform the NS iterations
|
||||
for _ in range(steps):
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def muon_step_fused(
|
||||
stacked_grads: Tensor,
|
||||
stacked_params: Tensor,
|
||||
momentum_buffer: Tensor,
|
||||
second_momentum_buffer: Tensor,
|
||||
momentum_t: Tensor,
|
||||
lr_t: Tensor,
|
||||
wd_t: Tensor,
|
||||
beta2_t: Tensor,
|
||||
ns_steps: int,
|
||||
red_dim: int,
|
||||
) -> None:
|
||||
"""
|
||||
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
|
||||
"""
|
||||
|
||||
# Nesterov momentum
|
||||
momentum = momentum_t.to(stacked_grads.dtype)
|
||||
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
||||
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
||||
|
||||
# Polar express
|
||||
X = g.bfloat16()
|
||||
if g.size(-2) > g.size(-1):
|
||||
X = X.mT
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
A = X @ X.mT
|
||||
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
|
||||
if G.size(-2) > G.size(-1):
|
||||
if g.size(-2) > g.size(-1):
|
||||
X = X.mT
|
||||
return X
|
||||
g = X
|
||||
|
||||
# Variance reduction
|
||||
beta2 = beta2_t.to(g.dtype)
|
||||
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
||||
red_dim_size = g.size(red_dim)
|
||||
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
||||
v_norm = v_norm_sq.sqrt()
|
||||
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
||||
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
||||
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
||||
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
||||
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
||||
g = g * final_scale.to(g.dtype)
|
||||
|
||||
# Cautious weight decay + parameter update
|
||||
lr = lr_t.to(g.dtype)
|
||||
wd = wd_t.to(g.dtype)
|
||||
mask = (g * stacked_params) >= 0
|
||||
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
||||
|
||||
class Muon(torch.optim.Optimizer):
|
||||
"""
|
||||
|
|
@ -54,74 +111,112 @@ class Muon(torch.optim.Optimizer):
|
|||
Arguments:
|
||||
lr: The learning rate used by the internal SGD.
|
||||
momentum: The momentum used by the internal SGD.
|
||||
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
|
||||
ns_steps: The number of Newton-Schulz iteration steps to use.
|
||||
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
|
||||
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
||||
"""
|
||||
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
|
||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
|
||||
params: list[Tensor] = [*params]
|
||||
def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
params = list(params) # ensure we have a list, not an e.g. (exhaustible) iterator
|
||||
# Group by shape so we can stack tensors
|
||||
shapes = sorted({p.shape for p in params})
|
||||
param_groups = []
|
||||
for size in {p.numel() for p in params}:
|
||||
group = dict(params=[p for p in params if p.numel() == size])
|
||||
param_groups.append(group)
|
||||
for shape in shapes:
|
||||
group_params = [p for p in params if p.shape == shape]
|
||||
param_groups.append(dict(params=group_params))
|
||||
super().__init__(param_groups, defaults)
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
for p in params:
|
||||
g = p.grad
|
||||
assert g is not None
|
||||
state = self.state[p]
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros_like(g)
|
||||
buf: Tensor = state["momentum_buffer"]
|
||||
buf.lerp_(g, 1 - group["momentum"])
|
||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||||
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
|
||||
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
|
||||
if not params:
|
||||
continue
|
||||
|
||||
# Get or create group-level buffers (stored in first param's state for convenience)
|
||||
state = self.state[params[0]]
|
||||
num_params = len(params) # e.g.: 12 (for a d12 model)
|
||||
# e.g.: shape = (768, 3072), device = cuda:0, dtype = torch.float32, for one of the MLP projections
|
||||
shape, device, dtype = params[0].shape, params[0].device, params[0].dtype
|
||||
|
||||
# Momentum for every individual parameter
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"] # e.g.: (12, 768, 3072)
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
if shape[-2] >= shape[-1]:
|
||||
state["second_momentum_buffer"] = torch.zeros(num_params, shape[-2], 1, dtype=dtype, device=device)
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros(num_params, 1, shape[-1], dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"] # (12, 1, 3072)
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2 # e.g.: -2
|
||||
|
||||
# Stack grads and params
|
||||
stacked_grads = torch.stack([p.grad for p in params]) # (12, 768, 3072)
|
||||
stacked_params = torch.stack(params) # (12, 768, 3072)
|
||||
|
||||
# Fill all the 0-D tensors with current values
|
||||
self._momentum_t.fill_(group["momentum"])
|
||||
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
stacked_grads,
|
||||
stacked_params,
|
||||
momentum_buffer,
|
||||
second_momentum_buffer,
|
||||
self._momentum_t,
|
||||
self._lr_t,
|
||||
self._wd_t,
|
||||
self._beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy back to original params: [(768, 3072), (768, 3072), ...] <- (12, 768, 3072)
|
||||
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
||||
|
||||
|
||||
class DistMuon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Newton–Schulz,
|
||||
finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
|
||||
- reduce_scatter(AVG) for gradient averaging
|
||||
- all_gather to replicate updated weights
|
||||
|
||||
Notes:
|
||||
* Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D
|
||||
params like embeddings or scalars.
|
||||
* Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen
|
||||
by block-cyclic assignment below). If you checkpoint optimizer state on a single rank,
|
||||
consolidate states beforehand.
|
||||
|
||||
Args:
|
||||
params: iterable of Tensors
|
||||
lr: learning rate
|
||||
momentum: momentum coefficient in [0,1)
|
||||
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
|
||||
ns_steps: number of Newton–Schulz iterations for the orthogonalization
|
||||
Distributed version of the Muon optimizer.
|
||||
"""
|
||||
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
|
||||
nesterov: bool = True, ns_steps: int = 5):
|
||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
|
||||
params = list(params)
|
||||
ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
params = list(params)
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
# Group all parameters by their shape
|
||||
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
|
||||
shapes = sorted({p.shape for p in params}) # sort for deterministic ordering across ranks
|
||||
param_groups = []
|
||||
for shape in shapes:
|
||||
group_params = [p for p in params if p.shape == shape]
|
||||
device, dtype = group_params[0].device, group_params[0].dtype
|
||||
assert all(p.device == device for p in group_params)
|
||||
assert all(p.dtype == dtype for p in group_params)
|
||||
# Compute chunk size for this group (how many params each rank owns)
|
||||
chunk_size = (len(group_params) + world_size - 1) // world_size
|
||||
if rank == 0:
|
||||
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
|
||||
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
|
||||
print(f"Muon: {len(group_params)} params of shape {shape}, chunk_size={chunk_size}")
|
||||
param_groups.append(dict(params=group_params, chunk_size=chunk_size))
|
||||
super().__init__(param_groups, defaults)
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
|
|
@ -131,57 +226,127 @@ class DistMuon(torch.optim.Optimizer):
|
|||
# Ensure all grads exist
|
||||
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
|
||||
|
||||
# Kick off all the reduce scatter operations to average up the gradients across all ranks
|
||||
all_reduce_futures = []
|
||||
# First pass: stack grads and kick off reduce_scatter for each group
|
||||
group_infos = []
|
||||
for group in self.param_groups:
|
||||
params = group["params"]
|
||||
zero_buffer = group["zero_buffer"]
|
||||
# Go through params in groups of world_size.
|
||||
for base_i in range(0, len(params), world_size):
|
||||
# The compute owner of each param is rank i % world_size
|
||||
owner_idx = base_i + rank
|
||||
# each rank stacks up its chunk of world_size params into a list
|
||||
rs_input = [p.grad for p in params[base_i:base_i + world_size]]
|
||||
# pad rs_input with the zero buffer to complete the group
|
||||
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
|
||||
# the output buffer gets strided across the group based on the rank
|
||||
rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
|
||||
# reduce scatter the gradients within this group of world_size params
|
||||
work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
all_reduce_futures.append(work)
|
||||
params: list[Tensor] = group["params"]
|
||||
chunk_size = group["chunk_size"]
|
||||
padded_num_params = chunk_size * world_size
|
||||
shape = params[0].shape
|
||||
device, dtype = params[0].device, params[0].dtype
|
||||
|
||||
# Now each rank computes the update and gathers
|
||||
future_idx = 0
|
||||
# Stack all gradients into a single tensor (single kernel via torch.stack)
|
||||
grad_stack = torch.stack([p.grad for p in params])
|
||||
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
|
||||
stacked_grads[:len(params)].copy_(grad_stack)
|
||||
# Zero-pad if we have fewer params than padded size
|
||||
if len(params) < padded_num_params:
|
||||
stacked_grads[len(params):].zero_()
|
||||
|
||||
# Output buffer for this rank's chunk
|
||||
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
|
||||
# Async reduce_scatter on the stacked tensor
|
||||
reduce_future = dist.reduce_scatter_tensor(
|
||||
grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True
|
||||
).get_future()
|
||||
|
||||
group_infos.append(dict(
|
||||
grad_chunk=grad_chunk,
|
||||
reduce_future=reduce_future,
|
||||
stacked_grads=stacked_grads, # reuse for all_gather output
|
||||
))
|
||||
|
||||
# Second pass: wait for reduce, compute batched updates, kick off all_gather
|
||||
all_gather_futures = []
|
||||
for group in self.param_groups:
|
||||
params = group["params"]
|
||||
zero_buffer = group["zero_buffer"]
|
||||
# Go through params in groups of world_size.
|
||||
for base_i in range(0, len(params), world_size):
|
||||
# The compute owner of each param is rank i % world_size
|
||||
owner_idx = base_i + rank # calculate the index of the param that this rank owns
|
||||
# Wait for the reduce scatter to complete
|
||||
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
|
||||
future_idx += 1
|
||||
# Owner computes the Muon update, result is in its param
|
||||
if owner_idx < len(params):
|
||||
p = params[owner_idx]
|
||||
g = p.grad # now averaged across ranks
|
||||
state = self.state[p]
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros_like(g)
|
||||
buf: Tensor = state["momentum_buffer"]
|
||||
buf.lerp_(g, 1.0 - group["momentum"])
|
||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||||
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
|
||||
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
|
||||
p.add_(g, alpha=-group["lr"] * scale)
|
||||
# Replicate updated parameters to all ranks
|
||||
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
|
||||
ag_output = params[base_i:base_i + world_size]
|
||||
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
|
||||
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
|
||||
all_gather_futures.append(work)
|
||||
for group, info in zip(self.param_groups, group_infos):
|
||||
info["reduce_future"].wait()
|
||||
|
||||
# Wait for all work to finish
|
||||
torch.futures.collect_all(all_gather_futures).wait()
|
||||
params = group["params"]
|
||||
chunk_size = group["chunk_size"]
|
||||
shape = params[0].shape
|
||||
device, dtype = params[0].device, params[0].dtype
|
||||
grad_chunk = info["grad_chunk"]
|
||||
|
||||
# How many params does this rank actually own?
|
||||
start_idx = rank * chunk_size
|
||||
num_owned = min(chunk_size, max(0, len(params) - start_idx))
|
||||
|
||||
# Get or create group-level state (stored keyed by first param)
|
||||
state = self.state[params[0]]
|
||||
|
||||
# Momentum buffer
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"]
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
if shape[-2] >= shape[-1]:
|
||||
state["second_momentum_buffer"] = torch.zeros(chunk_size, shape[-2], 1, dtype=dtype, device=device)
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros(chunk_size, 1, shape[-1], dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"]
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
||||
# Build updated_params tensor for all_gather
|
||||
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
|
||||
if num_owned > 0:
|
||||
# Stack owned params (single kernel via torch.stack)
|
||||
owned_params = [params[start_idx + i] for i in range(num_owned)]
|
||||
stacked_owned_params = torch.stack(owned_params)
|
||||
|
||||
# Get owned slices of buffers and grads
|
||||
owned_grads = grad_chunk[:num_owned]
|
||||
owned_momentum = momentum_buffer[:num_owned]
|
||||
owned_second_momentum = second_momentum_buffer[:num_owned]
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
self._momentum_t.fill_(group["momentum"])
|
||||
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
owned_grads,
|
||||
stacked_owned_params,
|
||||
owned_momentum,
|
||||
owned_second_momentum,
|
||||
self._momentum_t,
|
||||
self._lr_t,
|
||||
self._wd_t,
|
||||
self._beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy updated params to output buffer
|
||||
updated_params[:num_owned].copy_(stacked_owned_params)
|
||||
|
||||
# Zero-pad the rest (for ranks that own fewer params)
|
||||
if num_owned < chunk_size:
|
||||
updated_params[num_owned:].zero_()
|
||||
|
||||
# Reuse stacked_grads buffer for all_gather output
|
||||
stacked_params = info["stacked_grads"]
|
||||
|
||||
# Async all_gather to replicate updated params to all ranks
|
||||
gather_future = dist.all_gather_into_tensor(
|
||||
stacked_params, updated_params, async_op=True
|
||||
).get_future()
|
||||
|
||||
all_gather_futures.append(dict(
|
||||
gather_future=gather_future,
|
||||
stacked_params=stacked_params,
|
||||
params=params,
|
||||
))
|
||||
|
||||
# Final pass: wait for all_gather and copy back to params
|
||||
for info in all_gather_futures:
|
||||
info["gather_future"].wait()
|
||||
stacked_params = info["stacked_params"]
|
||||
params = info["params"]
|
||||
# Batched copy back (single kernel instead of N individual copies)
|
||||
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))
|
||||
|
|
|
|||
|
|
@ -16,8 +16,11 @@ def run_command(cmd):
|
|||
"""Run a shell command and return output, or None if it fails."""
|
||||
try:
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5)
|
||||
if result.returncode == 0:
|
||||
# Return stdout if we got output (even if some files in xargs failed)
|
||||
if result.stdout.strip():
|
||||
return result.stdout.strip()
|
||||
if result.returncode == 0:
|
||||
return ""
|
||||
return None
|
||||
except:
|
||||
return None
|
||||
|
|
@ -160,12 +163,23 @@ Generated: {timestamp}
|
|||
|
||||
"""
|
||||
|
||||
# bloat metrics: package all of the source code and assess its weight
|
||||
packaged = run_command('files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml')
|
||||
num_chars = len(packaged)
|
||||
num_lines = len(packaged.split('\n'))
|
||||
num_files = len([x for x in packaged.split('\n') if x.startswith('<source>')])
|
||||
num_tokens = num_chars // 4 # assume approximately 4 chars per token
|
||||
# bloat metrics: count lines/chars in git-tracked source files only
|
||||
extensions = ['py', 'md', 'rs', 'html', 'toml', 'sh']
|
||||
git_patterns = ' '.join(f"'*.{ext}'" for ext in extensions)
|
||||
files_output = run_command(f"git ls-files -- {git_patterns}")
|
||||
file_list = [f for f in (files_output or '').split('\n') if f]
|
||||
num_files = len(file_list)
|
||||
num_lines = 0
|
||||
num_chars = 0
|
||||
if num_files > 0:
|
||||
wc_output = run_command(f"git ls-files -- {git_patterns} | xargs wc -lc 2>/dev/null")
|
||||
if wc_output:
|
||||
total_line = wc_output.strip().split('\n')[-1]
|
||||
parts = total_line.split()
|
||||
if len(parts) >= 2:
|
||||
num_lines = int(parts[0])
|
||||
num_chars = int(parts[1])
|
||||
num_tokens = num_chars // 4 # assume approximately 4 chars per token
|
||||
|
||||
# count dependencies via uv.lock
|
||||
uv_lock_lines = 0
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ SPECIAL_TOKENS = [
|
|||
|
||||
# NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
|
||||
# I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
|
||||
# I haven't validated that this is actually a good idea, TODO.
|
||||
# I verified that 2 is the sweet spot for vocab size of 32K. 1 is a bit worse, 3 was worse still.
|
||||
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -103,9 +103,10 @@ class HuggingFaceTokenizer:
|
|||
def id_to_token(self, id):
|
||||
return self.tokenizer.id_to_token(id)
|
||||
|
||||
def _encode_one(self, text, prepend=None, append=None):
|
||||
def _encode_one(self, text, prepend=None, append=None, num_threads=None):
|
||||
# encode a single string
|
||||
# prepend/append can be either a string of a special token or a token id directly.
|
||||
# num_threads is ignored (only used by the nanochat Tokenizer for parallel encoding)
|
||||
assert isinstance(text, str)
|
||||
ids = []
|
||||
if prepend is not None:
|
||||
|
|
@ -122,7 +123,14 @@ class HuggingFaceTokenizer:
|
|||
return self.tokenizer.token_to_id(text)
|
||||
|
||||
def get_bos_token_id(self):
|
||||
# Different HuggingFace models use different BOS tokens and there is little consistency
|
||||
# 1) attempt to find a <|bos|> token
|
||||
bos = self.encode_special("<|bos|>")
|
||||
# 2) if that fails, attempt to find a <|endoftext|> token (e.g. GPT-2 models)
|
||||
if bos is None:
|
||||
bos = self.encode_special("<|endoftext|>")
|
||||
# 3) if these fail, it's better to crash than to silently return None
|
||||
assert bos is not None, "Failed to find BOS token in tokenizer"
|
||||
return bos
|
||||
|
||||
def encode(self, text, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -17,6 +17,11 @@
|
|||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
html, body{
|
||||
height: 100%;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
|
||||
background-color: #ffffff;
|
||||
|
|
@ -113,7 +118,6 @@
|
|||
.message.assistant .message-content {
|
||||
background: transparent;
|
||||
border: none;
|
||||
padding: 0.25rem 0;
|
||||
cursor: pointer;
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.5rem;
|
||||
|
|
|
|||
|
|
@ -7,30 +7,26 @@ requires-python = ">=3.10"
|
|||
dependencies = [
|
||||
"datasets>=4.0.0",
|
||||
"fastapi>=0.117.1",
|
||||
"files-to-prompt>=0.6",
|
||||
"ipykernel>=7.1.0",
|
||||
"kernels>=0.11.7",
|
||||
"matplotlib>=3.10.8",
|
||||
"psutil>=7.1.0",
|
||||
"python-dotenv>=1.2.1",
|
||||
"regex>=2025.9.1",
|
||||
"rustbpe>=0.1.0",
|
||||
"scipy>=1.15.3",
|
||||
"setuptools>=80.9.0",
|
||||
"tabulate>=0.9.0",
|
||||
"tiktoken>=0.11.0",
|
||||
"tokenizers>=0.22.0",
|
||||
"torch>=2.8.0",
|
||||
"torch>=2.9.0",
|
||||
"transformers>=4.57.3",
|
||||
"uvicorn>=0.36.0",
|
||||
"wandb>=0.21.3",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["maturin>=1.7,<2.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
[tool.maturin]
|
||||
module-name = "rustbpe"
|
||||
bindings = "pyo3"
|
||||
python-source = "."
|
||||
manifest-path = "rustbpe/Cargo.toml"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"maturin>=1.9.4",
|
||||
"pytest>=8.0.0",
|
||||
]
|
||||
|
||||
|
|
@ -45,33 +41,33 @@ python_functions = ["test_*"]
|
|||
|
||||
# target torch to cuda 12.8 or CPU
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", extra = "cpu" },
|
||||
{ index = "pytorch-cu128", extra = "gpu" },
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", extra = "cpu" },
|
||||
{ index = "pytorch-cu128", extra = "gpu" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[project.optional-dependencies]
|
||||
cpu = [
|
||||
"torch>=2.8.0",
|
||||
]
|
||||
gpu = [
|
||||
"torch>=2.8.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "cpu" },
|
||||
{ extra = "gpu" },
|
||||
],
|
||||
]
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
|
||||
[project.optional-dependencies]
|
||||
cpu = [
|
||||
"torch>=2.9.1",
|
||||
]
|
||||
gpu = [
|
||||
"torch>=2.9.1",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "cpu" },
|
||||
{ extra = "gpu" },
|
||||
],
|
||||
]
|
||||
|
|
|
|||
21
run1000.sh
21
run1000.sh
|
|
@ -16,25 +16,22 @@ if [ -z "$WANDB_RUN" ]; then
|
|||
WANDB_RUN=dummy
|
||||
fi
|
||||
python -m nanochat.report reset
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
|
||||
# train tokenizer on ~4B characters and kick off download of the rest for pretraining
|
||||
python -m nanochat.dataset -n 16
|
||||
# start downloading the rest of the shards for a total of 800 (see below why 800)
|
||||
python -m nanochat.dataset -n 800 &
|
||||
# start downloading the rest of the shards for a total of 1200 (see below why 1200)
|
||||
python -m nanochat.dataset -n 1200 &
|
||||
# todo: download the rest of it
|
||||
python -m scripts.tok_train --max_chars=4000000000
|
||||
python -m scripts.tok_train --max-chars=4000000000 --vocab-size=65536
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# Documenting my process for determining the hyperparameters for this run1000.sh script:
|
||||
# We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute
|
||||
# 1) I guessed the model size for this to be about depth=32
|
||||
# 2) Determine the device_batch_size that fits:
|
||||
# Running the base_train.py script with --depth=32, I saw that --device_batch_size=16
|
||||
# runs out of memory, but --device_batch_size=8 fits. Inspecting `nvidia-smi` during training,
|
||||
# Running the base_train.py script with --depth=32, I saw that --device-batch-size=16
|
||||
# runs out of memory, but --device-batch-size=8 fits. Inspecting `nvidia-smi` during training,
|
||||
# I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%.
|
||||
# So the training script was running ok and showed:
|
||||
# Vocab size: 65,536
|
||||
|
|
@ -65,7 +62,9 @@ python -m scripts.tok_eval
|
|||
# The tok_eval.py script reports about ~4.8 chars/token on average for the default tokenizer settings.
|
||||
# So ~38B tokens # ~4.8 chars/token = ~185B chars.
|
||||
# Each data shard is ~250M chars, so we need ~185B / 250M ~= 740 shards.
|
||||
# For safety, I bumped that up to 800 shards, and that's why up above I used -n 800 when pre-downloading dataset shards.
|
||||
# For safety, I bumped that up to 800 shards.
|
||||
# The new DataLoader wastes about 35% of tokens to cropping, so 800 / (1 - 0.35) ~= 1200 shards are needed.
|
||||
# => why up above I used -n 1200 when pre-downloading dataset shards.
|
||||
# If we didn't have enough data, the training script would loop around and do multiple epochs over the same data,
|
||||
# which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd
|
||||
# start to overfit hard.
|
||||
|
|
@ -74,13 +73,13 @@ python -m scripts.tok_eval
|
|||
# Number of processes/GPUs to use
|
||||
NPROC_PER_NODE=8
|
||||
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --device_batch_size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --target-param-data-ratio=20 --device-batch-size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
|
||||
|
||||
# midtrain
|
||||
# NOTE: ensure that we use the same device_batch_size here as the base training script.
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device-batch-size=8 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
|
||||
|
||||
# sft
|
||||
|
|
|
|||
458
rustbpe/Cargo.lock
generated
458
rustbpe/Cargo.lock
generated
|
|
@ -1,458 +0,0 @@
|
|||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 4
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.8.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"getrandom",
|
||||
"once_cell",
|
||||
"version_check",
|
||||
"zerocopy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "arc-swap"
|
||||
version = "1.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
||||
|
||||
[[package]]
|
||||
name = "bit-set"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3"
|
||||
dependencies = [
|
||||
"bit-vec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bit-vec"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7"
|
||||
|
||||
[[package]]
|
||||
name = "castaway"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a"
|
||||
dependencies = [
|
||||
"rustversion",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9"
|
||||
|
||||
[[package]]
|
||||
name = "compact_str"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a"
|
||||
dependencies = [
|
||||
"castaway",
|
||||
"cfg-if",
|
||||
"itoa",
|
||||
"rustversion",
|
||||
"ryu",
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
|
||||
dependencies = [
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.9.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
|
||||
|
||||
[[package]]
|
||||
name = "dary_heap"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728"
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
|
||||
|
||||
[[package]]
|
||||
name = "fancy-regex"
|
||||
version = "0.16.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bf04c5ec15464ace8355a7b440a33aece288993475556d461154d7a62ad9947c"
|
||||
dependencies = [
|
||||
"bit-set",
|
||||
"regex-automata",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"r-efi",
|
||||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.15.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f2481980430f9f78649238835720ddccc57e52df14ffce1c6f37391d61b563e9"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indoc"
|
||||
version = "2.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd"
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.175"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.7.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0"
|
||||
|
||||
[[package]]
|
||||
name = "memoffset"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.21.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
||||
|
||||
[[package]]
|
||||
name = "portable-atomic"
|
||||
version = "1.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.101"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3"
|
||||
version = "0.23.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"indoc",
|
||||
"libc",
|
||||
"memoffset",
|
||||
"once_cell",
|
||||
"portable-atomic",
|
||||
"pyo3-build-config",
|
||||
"pyo3-ffi",
|
||||
"pyo3-macros",
|
||||
"unindent",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-build-config"
|
||||
version = "0.23.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"target-lexicon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-ffi"
|
||||
version = "0.23.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"pyo3-build-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-log"
|
||||
version = "0.12.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"log",
|
||||
"pyo3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-macros"
|
||||
version = "0.23.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"pyo3-macros-backend",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-macros-backend"
|
||||
version = "0.23.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"pyo3-build-config",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.40"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "r-efi"
|
||||
version = "5.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
|
||||
dependencies = [
|
||||
"either",
|
||||
"rayon-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-core"
|
||||
version = "1.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
|
||||
dependencies = [
|
||||
"crossbeam-deque",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.4.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6b9458fa0bfeeac22b5ca447c63aaf45f28439a709ccd244698632f9aa6394d6"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001"
|
||||
|
||||
[[package]]
|
||||
name = "rustbpe"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"compact_str",
|
||||
"dary_heap",
|
||||
"fancy-regex",
|
||||
"indexmap",
|
||||
"log",
|
||||
"pyo3",
|
||||
"pyo3-log",
|
||||
"rayon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
|
||||
|
||||
[[package]]
|
||||
name = "static_assertions"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.106"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "target-lexicon"
|
||||
version = "0.12.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
|
||||
|
||||
[[package]]
|
||||
name = "unindent"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.14.4+wasi-0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "88a5f4a424faf49c3c2c344f166f0662341d470ea185e939657aaff130f0ec4a"
|
||||
dependencies = [
|
||||
"wit-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen"
|
||||
version = "0.45.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c573471f125075647d03df72e026074b7203790d41351cd6edc96f46bcccd36"
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy"
|
||||
version = "0.8.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f"
|
||||
dependencies = [
|
||||
"zerocopy-derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy-derive"
|
||||
version = "0.8.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
[package]
|
||||
name = "rustbpe"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
dary_heap = "0.3"
|
||||
indexmap = "2.2"
|
||||
fancy-regex = "0.16.1"
|
||||
log = "0.4.28"
|
||||
pyo3 = { version = "0.23.3", features = ["extension-module"] }
|
||||
pyo3-log = "0.12.4"
|
||||
ahash = "0.8.12"
|
||||
rayon = "1.11.0"
|
||||
compact_str = "0.9.0"
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# rustbpe
|
||||
|
||||
> The missing tiktoken training code
|
||||
|
||||
A very lightweight Rust library for training a GPT tokenizer. The issue is that the inference library [tiktoken](https://github.com/openai/tiktoken) is great, but only does inference. Separately, the huggingface [tokenizers](https://github.com/huggingface/tokenizers) library does training, but it is rather bloated and really hard to navigate because it has to support all the different historical baggage of how people dealt with tokenizers over the years. More recently, I also wrote the [minbpe](https://github.com/karpathy/minbpe) library which does both training and inference, but only in inefficient Python. Basically what I really want is a non-fancy, super simple, but still relatively efficient training code for GPT tokenizer (more efficient than minbpe, much cleaner/simpler than tokenizers), and then export the trained vocab for inference with tiktoken. Does that make sense? So here we are. There are more opportunities for optimization here, I just stopped a bit early because unlike minbpe before it, rustbpe is now simple and fast enough, and not a significant bottleneck for nanochat.
|
||||
|
|
@ -1,491 +0,0 @@
|
|||
use std::cmp::Ordering;
|
||||
use std::collections::HashMap as StdHashMap;
|
||||
|
||||
use dary_heap::OctonaryHeap;
|
||||
use fancy_regex::Regex;
|
||||
use pyo3::prelude::*;
|
||||
|
||||
use ahash::{AHashMap, AHashSet};
|
||||
use compact_str::CompactString;
|
||||
use rayon::prelude::*;
|
||||
|
||||
// Default GPT-4 style regex pattern for splitting text
|
||||
const GPT4_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+";
|
||||
|
||||
type Pair = (u32, u32);
|
||||
|
||||
/// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation
|
||||
#[pyclass]
|
||||
pub struct Tokenizer {
|
||||
/// Maps pairs of token IDs to their merged token ID
|
||||
pub merges: StdHashMap<Pair, u32>,
|
||||
/// The regex pattern used for text splitting
|
||||
pub pattern: String,
|
||||
/// Compiled regex for efficiency
|
||||
compiled_pattern: Regex,
|
||||
}
|
||||
|
||||
// ------------------------ internal helpers ------------------------
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Word {
|
||||
ids: Vec<u32>,
|
||||
}
|
||||
|
||||
impl Word {
|
||||
#[inline]
|
||||
fn new(ids: Vec<u32>) -> Self {
|
||||
Self { ids }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn pairs<'a>(&'a self) -> impl Iterator<Item = Pair> + 'a {
|
||||
self.ids.windows(2).map(|w| (w[0], w[1]))
|
||||
}
|
||||
|
||||
/// Merge all non-overlapping occurrences of pair -> new_id.
|
||||
/// Returns a small Vec of local pair-count deltas for THIS word only:
|
||||
/// -1 for removed pairs, +1 for newly created pairs.
|
||||
///
|
||||
/// NOTE: this version deliberately avoids a HashMap in the hot loop.
|
||||
fn merge_pair(&mut self, pair: Pair, new_id: u32) -> Vec<(Pair, i32)> {
|
||||
let (a, b) = pair;
|
||||
let n = self.ids.len();
|
||||
if n < 2 {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut out: Vec<u32> = Vec::with_capacity(n);
|
||||
let mut deltas: Vec<(Pair, i32)> = Vec::with_capacity(6);
|
||||
|
||||
let mut i = 0;
|
||||
while i < n {
|
||||
if i + 1 < n && self.ids[i] == a && self.ids[i + 1] == b {
|
||||
let left = out.last().copied();
|
||||
let right = if i + 2 < n { Some(self.ids[i + 2]) } else { None };
|
||||
|
||||
// remove old pairs
|
||||
if let Some(x) = left {
|
||||
deltas.push(((x, a), -1));
|
||||
deltas.push(((x, new_id), 1));
|
||||
}
|
||||
deltas.push(((a, b), -1));
|
||||
if let Some(y) = right {
|
||||
deltas.push(((b, y), -1));
|
||||
deltas.push(((new_id, y), 1));
|
||||
}
|
||||
|
||||
// write merged token
|
||||
out.push(new_id);
|
||||
i += 2; // skip 'a' and 'b'
|
||||
} else {
|
||||
out.push(self.ids[i]);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
self.ids = out;
|
||||
deltas
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq)]
|
||||
struct MergeJob {
|
||||
pair: Pair,
|
||||
count: u64,
|
||||
/// set of word indices where this pair may occur and needs processing
|
||||
pos: AHashSet<usize>,
|
||||
}
|
||||
|
||||
impl PartialEq for MergeJob {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.count == other.count && self.pair == other.pair
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for MergeJob {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for MergeJob {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
// Max-heap by count; tie-break to ascending pair order (deterministic)
|
||||
if self.count != other.count {
|
||||
self.count.cmp(&other.count)
|
||||
} else {
|
||||
// ascending order on the pair when counts tie
|
||||
other.pair.cmp(&self.pair)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn count_pairs_parallel(
|
||||
words: &[Word],
|
||||
counts: &[i32],
|
||||
) -> (AHashMap<Pair, i32>, AHashMap<Pair, AHashSet<usize>>) {
|
||||
words
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.map(|(i, w)| {
|
||||
let mut local_pc: AHashMap<Pair, i32> = AHashMap::new();
|
||||
let mut local_wtu: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();
|
||||
if w.ids.len() >= 2 && counts[i] != 0 {
|
||||
for (a, b) in w.pairs() {
|
||||
*local_pc.entry((a, b)).or_default() += counts[i];
|
||||
local_wtu.entry((a, b)).or_default().insert(i);
|
||||
}
|
||||
}
|
||||
(local_pc, local_wtu)
|
||||
})
|
||||
.reduce(
|
||||
|| (AHashMap::new(), AHashMap::new()),
|
||||
|(mut acc_pc, mut acc_wtu), (pc, wtu)| {
|
||||
for (k, v) in pc {
|
||||
*acc_pc.entry(k).or_default() += v;
|
||||
}
|
||||
for (k, s) in wtu {
|
||||
acc_wtu.entry(k).or_default().extend(s);
|
||||
}
|
||||
(acc_pc, acc_wtu)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// ------------------------ END helpers ------------------------
|
||||
|
||||
impl Tokenizer {
|
||||
|
||||
/// Core incremental BPE training given unique words and their counts.
|
||||
/// `words`: one entry per unique chunk (Vec<u32> of token-ids/bytes).
|
||||
/// `counts`: same length as `words`, count per chunk.
|
||||
fn train_core_incremental(&mut self, mut words: Vec<Word>, counts: Vec<i32>, vocab_size: u32) {
|
||||
assert!(vocab_size >= 256, "vocab_size must be at least 256");
|
||||
let num_merges = vocab_size - 256;
|
||||
log::info!("Starting BPE training: {} merges to compute", num_merges);
|
||||
self.merges.clear();
|
||||
|
||||
// ---- Initial pair_counts and where_to_update (parallel) ----
|
||||
log::info!("Computing initial pair counts from {} unique sequences", words.len());
|
||||
let (mut pair_counts, mut where_to_update) = count_pairs_parallel(&words, &counts);
|
||||
|
||||
// ---- Build heap ----
|
||||
log::info!("Building heap with {} unique pairs", pair_counts.len());
|
||||
let mut heap = OctonaryHeap::with_capacity(pair_counts.len());
|
||||
for (pair, pos) in where_to_update.drain() {
|
||||
let c = *pair_counts.get(&pair).unwrap_or(&0);
|
||||
if c > 0 {
|
||||
heap.push(MergeJob {
|
||||
pair,
|
||||
count: c as u64,
|
||||
pos,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Merge loop ----
|
||||
log::info!("Starting merge loop");
|
||||
let mut merges_done = 0u32;
|
||||
let mut last_log_percent = 0u32;
|
||||
|
||||
while merges_done < num_merges {
|
||||
let Some(mut top) = heap.pop() else { break; };
|
||||
|
||||
// Lazy refresh
|
||||
let current = *pair_counts.get(&top.pair).unwrap_or(&0);
|
||||
if top.count != current as u64 {
|
||||
top.count = current as u64;
|
||||
if top.count > 0 {
|
||||
heap.push(top);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if top.count == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
// Record merge
|
||||
let new_id = 256 + merges_done;
|
||||
self.merges.insert(top.pair, new_id);
|
||||
|
||||
// Merge this pair in all words where it occurs
|
||||
let mut local_pos_updates: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();
|
||||
for &word_idx in &top.pos {
|
||||
// Apply merge to this word and collect pair-count deltas
|
||||
let changes = words[word_idx].merge_pair(top.pair, new_id);
|
||||
// Update global pair counts based on this word's count
|
||||
for (pair, delta) in changes {
|
||||
let delta_total = delta * counts[word_idx];
|
||||
if delta_total != 0 {
|
||||
*pair_counts.entry(pair).or_default() += delta_total;
|
||||
if delta > 0 {
|
||||
local_pos_updates.entry(pair).or_default().insert(word_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add the updated pair counts back to the heap
|
||||
for (pair, pos) in local_pos_updates {
|
||||
let cnt = *pair_counts.get(&pair).unwrap_or(&0);
|
||||
if cnt > 0 {
|
||||
heap.push(MergeJob {
|
||||
pair,
|
||||
count: cnt as u64,
|
||||
pos,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
merges_done += 1;
|
||||
|
||||
// Log progress every 1%
|
||||
let current_percent = (merges_done * 100) / num_merges;
|
||||
if current_percent > last_log_percent {
|
||||
log::info!(
|
||||
"Progress: {}% ({}/{} merges) - Last merge: {:?} -> {} (frequency: {})",
|
||||
current_percent, merges_done, num_merges, top.pair, new_id, top.count
|
||||
);
|
||||
last_log_percent = current_percent;
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("Finished training: {} merges completed", merges_done);
|
||||
}
|
||||
}
|
||||
|
||||
/// Public methods for the Tokenizer class that will be exposed to Python.
|
||||
#[pymethods]
|
||||
impl Tokenizer {
|
||||
/// Create a new Tokenizer
|
||||
#[new]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
merges: StdHashMap::new(),
|
||||
pattern: String::new(),
|
||||
compiled_pattern: Regex::new("").expect("Empty regex should be valid"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Train from a streaming iterator (parallel ingestion).
|
||||
/// We refill a Rust Vec<String> buffer under the GIL, then release the GIL
|
||||
/// to do the heavy splitting and counting **in parallel** with rayon.
|
||||
#[pyo3(signature = (iterator, vocab_size, buffer_size=8192, pattern=None))]
|
||||
#[pyo3(text_signature = "(self, iterator, vocab_size, buffer_size=8192, pattern=None)")]
|
||||
pub fn train_from_iterator(
|
||||
&mut self,
|
||||
py: pyo3::Python<'_>,
|
||||
iterator: &pyo3::Bound<'_, pyo3::PyAny>,
|
||||
vocab_size: u32,
|
||||
buffer_size: usize,
|
||||
pattern: Option<String>,
|
||||
) -> PyResult<()> {
|
||||
// Use provided pattern or default to GPT-4 pattern
|
||||
let pattern_str = pattern.unwrap_or_else(|| GPT4_PATTERN.to_string());
|
||||
|
||||
// Update the stored pattern and compile it
|
||||
self.pattern = pattern_str.clone();
|
||||
self.compiled_pattern = Regex::new(&pattern_str)
|
||||
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?;
|
||||
|
||||
// Prepare a true Python iterator object
|
||||
let py_iter: pyo3::Py<pyo3::PyAny> = unsafe {
|
||||
pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
|
||||
};
|
||||
|
||||
// Global chunk counts
|
||||
let mut counts: AHashMap<CompactString, i32> = AHashMap::new();
|
||||
|
||||
// Temporary buffer we refill under the GIL
|
||||
let mut buf: Vec<String> = Vec::with_capacity(buffer_size);
|
||||
|
||||
log::info!("Processing sequences from iterator (buffer_size: {})", buffer_size);
|
||||
let mut total_sequences = 0u64;
|
||||
|
||||
// Helper: refill `buf` with up to `buffer_size` strings from the Python iterator.
|
||||
// Returns Ok(true) if the iterator is exhausted, Ok(false) otherwise.
|
||||
let refill = |buf: &mut Vec<String>| -> PyResult<bool> {
|
||||
pyo3::Python::with_gil(|py| {
|
||||
buf.clear();
|
||||
let it = py_iter.bind(py);
|
||||
loop {
|
||||
if buf.len() >= buffer_size {
|
||||
return Ok(false);
|
||||
}
|
||||
// next(it)
|
||||
let next_obj = unsafe {
|
||||
pyo3::Bound::from_owned_ptr_or_opt(py, pyo3::ffi::PyIter_Next(it.as_ptr()))
|
||||
};
|
||||
match next_obj {
|
||||
Some(obj) => {
|
||||
let s: String = obj.extract()?;
|
||||
buf.push(s);
|
||||
}
|
||||
None => {
|
||||
if pyo3::PyErr::occurred(py) {
|
||||
return Err(pyo3::PyErr::fetch(py));
|
||||
} else {
|
||||
return Ok(true); // exhausted
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
// Stream ingestion loop: refill under GIL, process without GIL (parallel)
|
||||
loop {
|
||||
let exhausted = refill(&mut buf)?;
|
||||
if buf.is_empty() && exhausted {
|
||||
break;
|
||||
}
|
||||
|
||||
total_sequences += buf.len() as u64;
|
||||
|
||||
let pattern = self.compiled_pattern.clone();
|
||||
let local: AHashMap<CompactString, i32> = py.allow_threads(|| {
|
||||
buf.par_iter()
|
||||
.map(|s| {
|
||||
let mut m: AHashMap<CompactString, i32> = AHashMap::new();
|
||||
for mat in pattern.find_iter(s) {
|
||||
let piece = mat.expect("regex match failed").as_str();
|
||||
*m.entry(CompactString::from(piece)).or_default() += 1;
|
||||
}
|
||||
m
|
||||
})
|
||||
.reduce(
|
||||
|| AHashMap::new(),
|
||||
|mut a, b| {
|
||||
for (k, v) in b {
|
||||
*a.entry(k).or_default() += v;
|
||||
}
|
||||
a
|
||||
},
|
||||
)
|
||||
});
|
||||
|
||||
// Merge local into global (single-threaded)
|
||||
for (k, v) in local {
|
||||
*counts.entry(k).or_default() += v;
|
||||
}
|
||||
|
||||
if exhausted {
|
||||
break;
|
||||
}
|
||||
}
|
||||
log::info!("Processed {} sequences total, {} unique", total_sequences, counts.len());
|
||||
|
||||
// Materialize words & counts
|
||||
let mut words = Vec::with_capacity(counts.len());
|
||||
let mut cvec = Vec::with_capacity(counts.len());
|
||||
for (chunk, c) in counts.into_iter() {
|
||||
words.push(Word::new(chunk.as_bytes().iter().map(|&b| b as u32).collect()));
|
||||
cvec.push(c);
|
||||
}
|
||||
|
||||
self.train_core_incremental(words, cvec, vocab_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Return the regex pattern
|
||||
pub fn get_pattern(&self) -> String {
|
||||
self.pattern.clone()
|
||||
}
|
||||
|
||||
/// Return the mergeable ranks (token bytes -> token id / rank)
|
||||
pub fn get_mergeable_ranks(&self) -> Vec<(Vec<u8>, u32)> {
|
||||
let mut mergeable_ranks = Vec::new();
|
||||
|
||||
// Build vocabulary incrementally from low to high token IDs
|
||||
let mut token_bytes: Vec<Vec<u8>> = (0..256_u32).map(|i| vec![i as u8]).collect();
|
||||
|
||||
for (i, bytes) in token_bytes.iter().enumerate() {
|
||||
mergeable_ranks.push((bytes.clone(), i as u32));
|
||||
}
|
||||
|
||||
// Sort merges by token id (so we can reconstruct bytes progressively)
|
||||
let mut sorted_merges: Vec<_> = self.merges.iter().collect();
|
||||
sorted_merges.sort_by_key(|&(_, &token_id)| token_id);
|
||||
|
||||
for (&pair, &merged_id) in sorted_merges {
|
||||
let (left, right) = pair;
|
||||
let mut merged_bytes = token_bytes[left as usize].clone();
|
||||
merged_bytes.extend(&token_bytes[right as usize]);
|
||||
|
||||
if token_bytes.len() <= merged_id as usize {
|
||||
token_bytes.resize(merged_id as usize + 1, Vec::new());
|
||||
}
|
||||
token_bytes[merged_id as usize] = merged_bytes.clone();
|
||||
|
||||
mergeable_ranks.push((merged_bytes, merged_id));
|
||||
}
|
||||
|
||||
mergeable_ranks
|
||||
}
|
||||
|
||||
/// Encode a string into token IDs
|
||||
pub fn encode(&self, text: &str) -> Vec<u32> {
|
||||
let mut all_ids = Vec::new();
|
||||
|
||||
// Split text using the regex pattern
|
||||
for m in self.compiled_pattern.find_iter(text) {
|
||||
let chunk = m.expect("regex match failed").as_str();
|
||||
|
||||
// Convert chunk to bytes then to u32 IDs
|
||||
let mut ids: Vec<u32> = chunk.bytes().map(|b| b as u32).collect();
|
||||
|
||||
// Apply merges iteratively
|
||||
while ids.len() >= 2 {
|
||||
// Find the best pair to merge
|
||||
let mut best_pair: Option<(usize, Pair, u32)> = None;
|
||||
|
||||
for i in 0..ids.len() - 1 {
|
||||
let pair: Pair = (ids[i], ids[i + 1]);
|
||||
if let Some(&new_id) = self.merges.get(&pair) {
|
||||
if best_pair.is_none() || new_id < best_pair.unwrap().2 {
|
||||
best_pair = Some((i, pair, new_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we found a pair to merge, apply it
|
||||
if let Some((idx, _pair, new_id)) = best_pair {
|
||||
ids[idx] = new_id;
|
||||
ids.remove(idx + 1);
|
||||
} else {
|
||||
// No more merges possible
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
all_ids.extend(ids);
|
||||
}
|
||||
|
||||
all_ids
|
||||
}
|
||||
|
||||
/// Encode multiple texts in parallel using rayon.
|
||||
/// Returns a list of token ID vectors, one per input text.
|
||||
#[pyo3(signature = (texts))]
|
||||
#[pyo3(text_signature = "(self, texts)")]
|
||||
pub fn batch_encode(&self, py: Python<'_>, texts: Vec<String>) -> PyResult<Vec<Vec<u32>>> {
|
||||
// Release Python GIL and encode in parallel using rayon
|
||||
let results = py.allow_threads(|| {
|
||||
texts
|
||||
.par_iter()
|
||||
.map(|text| self.encode(text))
|
||||
.collect::<Vec<Vec<u32>>>()
|
||||
});
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn rustbpe(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
pyo3_log::init(); // forwards Rust `log` to Python's `logging`
|
||||
m.add_class::<Tokenizer>()?;
|
||||
Ok(())
|
||||
}
|
||||
115
scaling_laws.sh
Normal file
115
scaling_laws.sh
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
#!/bin/bash
|
||||
|
||||
FLOPS_BUDGETS=(
|
||||
1e18
|
||||
3e18
|
||||
6e18
|
||||
)
|
||||
DEPTHS=(8 10 12 14 16 18 20)
|
||||
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
||||
WANDB_RUN="${WANDB_RUN:-scaling}"
|
||||
EVAL_TOKENS=$((100 * 524288)) # ~100M tokens for final eval (default is ~10M)
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat}"
|
||||
source .venv/bin/activate
|
||||
|
||||
RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results"
|
||||
mkdir -p "$RESULTS_DIR"
|
||||
RESULTS_FILE="$RESULTS_DIR/results.csv"
|
||||
|
||||
# Write CSV header only if file doesn't exist
|
||||
if [ ! -f "$RESULTS_FILE" ]; then
|
||||
echo "flops_budget,depth,model_dim,num_params,num_scaling_params,num_iterations,tokens_trained,param_data_ratio,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
|
||||
fi
|
||||
|
||||
log() {
|
||||
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
|
||||
}
|
||||
|
||||
# Check if a run already exists in results
|
||||
run_exists() {
|
||||
local flops=$1
|
||||
local depth=$2
|
||||
grep -q "^${flops},${depth}," "$RESULTS_FILE" 2>/dev/null
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# Main Loop
|
||||
# =============================================================================
|
||||
|
||||
for flops in "${FLOPS_BUDGETS[@]}"; do
|
||||
log "=============================================="
|
||||
log "Compute budget: $flops FLOPs"
|
||||
log "=============================================="
|
||||
|
||||
for d in "${DEPTHS[@]}"; do
|
||||
|
||||
# Skip if already completed
|
||||
if run_exists "$flops" "$d"; then
|
||||
log "Skipping d=$d at $flops FLOPs (already in results)"
|
||||
continue
|
||||
fi
|
||||
|
||||
log "Training d=$d at $flops FLOPs..."
|
||||
|
||||
# Unique tag for this run
|
||||
TAG="scaling_${flops}_d${d}"
|
||||
|
||||
# Record start time
|
||||
START_TIME=$(date +%s)
|
||||
|
||||
# Train the model with fixed flops budget
|
||||
# The script will auto-calculate num_iterations to hit target_flops
|
||||
# CORE eval happens once at the end (999999 ensures only final step)
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \
|
||||
--depth=$d \
|
||||
--target-flops=$flops \
|
||||
--target-param-data-ratio=-1 \
|
||||
--run="${WANDB_RUN}_${TAG}" \
|
||||
--model-tag="${TAG}" \
|
||||
--eval-tokens=$EVAL_TOKENS \
|
||||
--core-metric-every=999999 \
|
||||
--core-metric-max-per-task=-1 \
|
||||
--sample-every=-1 \
|
||||
--save-every=-1 \
|
||||
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
|
||||
|
||||
END_TIME=$(date +%s)
|
||||
TRAIN_TIME=$((END_TIME - START_TIME))
|
||||
|
||||
# Extract training stats from the log
|
||||
LOG_FILE="$RESULTS_DIR/${TAG}_train.log"
|
||||
NUM_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | head -1 | tr -d ',')
|
||||
NUM_SCALING_PARAMS=$(grep "Number of parameters:" "$LOG_FILE" | tail -1 | grep -oP 'scaling: [\d,]+' | grep -oP '[\d,]+' | tr -d ',')
|
||||
NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',')
|
||||
# Calculate tokens trained (iterations * batch_size, default 524288)
|
||||
TOKENS_TRAINED=$((NUM_ITERS * 524288))
|
||||
# Param:data ratio (using scaling params per Kaplan et al.)
|
||||
PARAM_DATA_RATIO=$(python -c "print(f'{$TOKENS_TRAINED / $NUM_SCALING_PARAMS:.2f}')")
|
||||
# Model dim
|
||||
MODEL_DIM=$((d * 64))
|
||||
# Val BPB from final eval
|
||||
VAL_BPB=$(grep "Validation bpb:" "$LOG_FILE" | tail -1 | grep -oP '[\d.]+$')
|
||||
|
||||
# Extract CORE score from training log (evaluated on final step)
|
||||
CORE_SCORE=$(grep "CORE metric:" "$LOG_FILE" | tail -1 | awk '{print $NF}')
|
||||
if [ -z "$CORE_SCORE" ]; then
|
||||
log "WARNING: Could not extract CORE score for d=$d"
|
||||
CORE_SCORE="0.0"
|
||||
fi
|
||||
|
||||
log " Params: $NUM_PARAMS, Iters: $NUM_ITERS, Ratio: $PARAM_DATA_RATIO, Val BPB: $VAL_BPB, CORE: $CORE_SCORE"
|
||||
|
||||
# Append to CSV
|
||||
echo "$flops,$d,$MODEL_DIM,$NUM_PARAMS,$NUM_SCALING_PARAMS,$NUM_ITERS,$TOKENS_TRAINED,$PARAM_DATA_RATIO,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
|
||||
done
|
||||
done
|
||||
|
||||
log "=============================================="
|
||||
log "Scaling Laws Sweep Complete"
|
||||
log "=============================================="
|
||||
log "Results saved to: $RESULTS_FILE"
|
||||
echo ""
|
||||
echo "Results:"
|
||||
column -t -s',' "$RESULTS_FILE"
|
||||
|
|
@ -5,48 +5,108 @@ Loads a checkpoint, and:
|
|||
|
||||
Example run as:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||
|
||||
To evaluate a HuggingFace model:
|
||||
python -m scripts.base_loss --hf-path openai-community/gpt2
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
import torch
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
|
||||
from nanochat.tokenizer import get_token_bytes, HuggingFaceTokenizer
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
|
||||
# Configuration
|
||||
device_batch_size = 32
|
||||
split_tokens = 20*524288 # number of tokens to evaluate per split
|
||||
model_tag = None # optional model tag for the output directory name
|
||||
model_step = None # optional model step for the output directory name
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
# -----------------------------------------------------------------------------
|
||||
# HuggingFace loading utilities, making the APIs match up to those of nanochat
|
||||
|
||||
class ModelWrapper:
|
||||
"""Lightweight wrapper for a HuggingFace model"""
|
||||
def __init__(self, model, max_seq_len=None):
|
||||
self.model = model
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
def __call__(self, input_ids, targets=None, loss_reduction='mean'):
|
||||
logits = self.model(input_ids).logits
|
||||
if targets is None:
|
||||
return logits
|
||||
else:
|
||||
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
||||
return loss
|
||||
|
||||
def get_device(self):
|
||||
return next(self.model.parameters()).device
|
||||
|
||||
def load_hf_model(hf_path: str, device):
|
||||
print0(f"Loading model from: {hf_path}")
|
||||
from transformers import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained(hf_path)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None
|
||||
model = ModelWrapper(model, max_seq_len=max_seq_len)
|
||||
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
|
||||
return model, tokenizer
|
||||
|
||||
def get_hf_token_bytes(tokenizer, device="cpu"):
|
||||
"""Compute token_bytes tensor for a HuggingFace tokenizer."""
|
||||
vocab_size = tokenizer.tokenizer.get_vocab_size()
|
||||
token_bytes = torch.zeros(vocab_size, dtype=torch.int64, device=device)
|
||||
for token_id in range(vocab_size):
|
||||
token_str = tokenizer.tokenizer.decode([token_id])
|
||||
token_bytes[token_id] = len(token_str.encode('utf-8')) # Count UTF-8 bytes
|
||||
return token_bytes
|
||||
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model")
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--split-tokens", type=int, default=40*524288, help="number of tokens to evaluate per split")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag for checkpoint directory")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--hf-path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the base model and the tokenizer
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
|
||||
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
||||
print0(f"Device: {device} | DDP rank: {ddp_rank} | DDP local rank: {ddp_local_rank} | DDP world size: {ddp_world_size}")
|
||||
|
||||
if args.hf_path is not None:
|
||||
# Load HuggingFace model
|
||||
model, tokenizer = load_hf_model(args.hf_path, device)
|
||||
sequence_len = model.max_seq_len if model.max_seq_len else 1024
|
||||
token_bytes = get_hf_token_bytes(tokenizer, device=device)
|
||||
model_name = args.hf_path
|
||||
else:
|
||||
# Load local nanochat model
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
||||
sequence_len = meta["model_config"]["sequence_len"]
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
model_name = f"base_model (step {meta['step']})"
|
||||
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
print0(f"Evaluating model: {model_name}")
|
||||
|
||||
# Evaluate the loss on each split
|
||||
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
||||
assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
|
||||
steps = split_tokens // tokens_per_step
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size
|
||||
assert args.split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
|
||||
steps = args.split_tokens // tokens_per_step
|
||||
bpb_results = {}
|
||||
for split_name in ["train", "val"]:
|
||||
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
|
||||
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
|
||||
with autocast_ctx:
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
print0(f"{split_name} bpb: {bpb:.4f}")
|
||||
bpb_results[split_name] = bpb
|
||||
print0(f"Model: {model_name}, {split_name} bpb: {bpb:.6f}")
|
||||
|
||||
# Master process also samples from the model
|
||||
# Master process also samples from the model (only for nanochat models)
|
||||
samples = []
|
||||
if ddp_rank == 0:
|
||||
if ddp_rank == 0 and args.hf_path is None:
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The chemical symbol of gold is",
|
||||
|
|
@ -69,6 +129,7 @@ if ddp_rank == 0:
|
|||
from nanochat.report import get_report
|
||||
get_report().log(section="Base model loss", data=[
|
||||
{
|
||||
"model": model_name,
|
||||
"train bpb": bpb_results["train"],
|
||||
"val bpb": bpb_results["val"],
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,18 +1,19 @@
|
|||
"""
|
||||
Train model. Run as:
|
||||
Train model. From root directory of the project, run as:
|
||||
|
||||
python base_train.py
|
||||
python -m scripts.base_train.py
|
||||
|
||||
or distributed as:
|
||||
|
||||
torchrun --nproc_per_node=8 base_train.py
|
||||
torchrun --nproc_per_node=8 -m scripts.base_train.py
|
||||
|
||||
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
||||
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
|
||||
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import argparse
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
|
|
@ -20,7 +21,7 @@ import wandb
|
|||
import torch
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
|
|
@ -30,46 +31,51 @@ from scripts.base_eval import evaluate_model
|
|||
print_banner()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# User settings
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Pretrain base model")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
# Model architecture
|
||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
max_seq_len = 2048 # max context length
|
||||
# Training horizon. Only one of these 3 will be used, in this order of precedence.
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
|
||||
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
|
||||
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
|
||||
parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
|
||||
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
parser.add_argument("--target-param-data-ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
# Optimization
|
||||
device_batch_size = 32 # per-device batch size (set to not OOM)
|
||||
total_batch_size = 524288 # total desired batch size, in #tokens
|
||||
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
|
||||
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
||||
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
|
||||
matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
|
||||
grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
||||
warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
||||
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
||||
parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
|
||||
parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
|
||||
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
|
||||
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
|
||||
# Evaluation
|
||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
||||
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
|
||||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||
sample_every = 2000 # every how many steps to sample from the model
|
||||
save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
|
||||
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
|
||||
parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
|
||||
parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
|
||||
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
|
||||
# Output
|
||||
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy() # for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
|
@ -77,8 +83,8 @@ synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
|||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config)
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
|
||||
|
||||
# Tokenizer will be useful for evaluation, also we need the vocab size
|
||||
tokenizer = get_tokenizer()
|
||||
|
|
@ -87,9 +93,17 @@ vocab_size = tokenizer.get_vocab_size()
|
|||
print0(f"Vocab size: {vocab_size:,}")
|
||||
|
||||
# Model kwargs are derived from the desired depth of the model
|
||||
num_layers = depth
|
||||
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
|
||||
num_layers = args.depth
|
||||
model_dim = args.depth * args.aspect_ratio
|
||||
def find_num_heads(model_dim, target_head_dim):
|
||||
# Find num_heads that divides model_dim evenly, with head_dim closest to target.
|
||||
ideal = max(1, round(model_dim / target_head_dim))
|
||||
for offset in range(model_dim):
|
||||
for candidate in [ideal + offset, ideal - offset]:
|
||||
if candidate > 0 and model_dim % candidate == 0:
|
||||
return candidate
|
||||
return 1
|
||||
num_heads = find_num_heads(model_dim, args.head_dim)
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim}")
|
||||
|
|
@ -98,66 +112,93 @@ print0(f"num_kv_heads: {num_kv_heads}")
|
|||
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
# figure out the needed gradient accumulation to reach the desired total batch size
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
|
||||
# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19)
|
||||
batch_lr_scale = 1.0
|
||||
reference_batch_size = 2**19
|
||||
batch_ratio = args.total_batch_size / reference_batch_size
|
||||
if batch_ratio != 1.0:
|
||||
# SGD: linear scaling with batch size is standard (not used in nanochat)
|
||||
# AdamW: sqrt scaling is standard
|
||||
# Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer
|
||||
batch_lr_scale = batch_ratio ** 0.5
|
||||
print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})")
|
||||
|
||||
# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio)
|
||||
weight_decay_scaled = args.weight_decay * (12 / args.depth)**2
|
||||
if args.depth != 12:
|
||||
print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Model
|
||||
|
||||
# Create a new model with random weights
|
||||
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
||||
model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, window_pattern=args.window_pattern)
|
||||
with torch.device("meta"):
|
||||
# All tensors are created as meta tensors (they have shape/dtype but no data)
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
model = GPT(model_config)
|
||||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
model.to_empty(device=device) # All tensors get storage on target device but with uninitialized (garbage) data
|
||||
model.init_weights() # All tensors get initialized
|
||||
|
||||
# If we are resuming, overwrite the model parameters with those of the checkpoint
|
||||
base_dir = get_base_dir()
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
output_dirname = args.model_tag if args.model_tag else f"d{args.depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
||||
resuming = resume_from_step != -1
|
||||
resuming = args.resume_from_step != -1
|
||||
if resuming:
|
||||
print0(f"Resuming optimization from step {resume_from_step}")
|
||||
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank)
|
||||
print0(f"Resuming optimization from step {args.resume_from_step}")
|
||||
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, args.resume_from_step, device, load_optimizer=True, rank=ddp_rank)
|
||||
model.load_state_dict(model_data, strict=True, assign=True)
|
||||
del model_data # free up this memory after the copy
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
print0(f"Number of parameters: {num_params:,}")
|
||||
num_scaling_params = orig_model.num_scaling_params()
|
||||
print0(f"Number of parameters: {num_params:,} (scaling: {num_scaling_params:,})")
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
||||
|
||||
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
|
||||
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
|
||||
if num_iterations > 0:
|
||||
assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0
|
||||
if args.num_iterations > 0:
|
||||
num_iterations = args.num_iterations
|
||||
print0(f"Using user-provided number of iterations: {num_iterations:,}")
|
||||
elif target_flops > 0:
|
||||
elif args.target_flops > 0:
|
||||
# calculate the number of iterations from the target flops
|
||||
num_iterations = round(target_flops / (num_flops_per_token * total_batch_size))
|
||||
num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size))
|
||||
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
|
||||
elif target_param_data_ratio > 0:
|
||||
# calculate the number of iterations from the target param data ratio
|
||||
target_tokens = target_param_data_ratio * num_params
|
||||
num_iterations = target_tokens // total_batch_size
|
||||
elif args.target_param_data_ratio > 0:
|
||||
# calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.)
|
||||
target_tokens = args.target_param_data_ratio * num_scaling_params
|
||||
num_iterations = target_tokens // args.total_batch_size
|
||||
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
||||
else:
|
||||
raise ValueError("No training horizon specified")
|
||||
total_tokens = total_batch_size * num_iterations
|
||||
total_tokens = args.total_batch_size * num_iterations
|
||||
print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Tokens : Params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
adam_betas = (args.adam_beta1, args.adam_beta2)
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=args.unembedding_lr * batch_lr_scale,
|
||||
embedding_lr=args.embedding_lr * batch_lr_scale,
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
weight_decay=weight_decay_scaled,
|
||||
adam_betas=adam_betas,
|
||||
scalar_lr=args.scalar_lr * batch_lr_scale,
|
||||
)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
|
||||
if resuming:
|
||||
|
|
@ -169,8 +210,8 @@ if resuming:
|
|||
# Initialize the DataLoaders for train/val
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||
train_loader = tokenizing_distributed_data_loader_with_state(device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
|
||||
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
|
||||
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -178,15 +219,15 @@ x, y, dataloader_state_dict = next(train_loader) # kick off load of the very fir
|
|||
|
||||
# Learning rate scheduler
|
||||
def get_lr_multiplier(it):
|
||||
warmup_iters = round(warmup_ratio * num_iterations)
|
||||
warmdown_iters = round(warmdown_ratio * num_iterations)
|
||||
warmup_iters = round(args.warmup_ratio * num_iterations)
|
||||
warmdown_iters = round(args.warmdown_ratio * num_iterations)
|
||||
if it < warmup_iters:
|
||||
return (it + 1) / warmup_iters
|
||||
elif it <= num_iterations - warmdown_iters:
|
||||
return 1.0
|
||||
else:
|
||||
progress = (num_iterations - it) / warmdown_iters
|
||||
return progress * 1.0 + (1 - progress) * final_lr_frac
|
||||
return progress * 1.0 + (1 - progress) * args.final_lr_frac
|
||||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
def get_muon_momentum(it):
|
||||
|
|
@ -194,11 +235,16 @@ def get_muon_momentum(it):
|
|||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||||
return momentum
|
||||
|
||||
# Weight decay scheduler for Muon optimizer (linear to zero over the course of training)
|
||||
def get_weight_decay(it):
|
||||
return weight_decay_scaled * (1 - it / num_iterations)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Loop state (variables updated by the training loop)
|
||||
|
||||
if not resuming:
|
||||
step = 0
|
||||
val_bpb = None # will be set if eval_every > 0
|
||||
min_val_bpb = float("inf")
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
|
|
@ -214,16 +260,16 @@ else:
|
|||
# Training loop
|
||||
while True:
|
||||
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
flops_so_far = num_flops_per_token * args.total_batch_size * step
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if last_step or step % eval_every == 0:
|
||||
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
with autocast_ctx:
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
min_val_bpb = val_bpb
|
||||
wandb_run.log({
|
||||
|
|
@ -237,10 +283,10 @@ while True:
|
|||
# once in a while: estimate the CORE metric (all ranks participate)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
results = {}
|
||||
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
|
||||
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
|
||||
model.eval()
|
||||
with autocast_ctx:
|
||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
|
||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
|
|
@ -252,7 +298,7 @@ while True:
|
|||
|
||||
# once in a while: sample from the model (only on master process)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
if master_process and (last_step or (step > 0 and step % sample_every == 0)):
|
||||
if args.sample_every > 0 and master_process and (last_step or (step > 0 and step % args.sample_every == 0)):
|
||||
model.eval()
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
|
|
@ -272,7 +318,7 @@ while True:
|
|||
model.train()
|
||||
|
||||
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
|
||||
if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0):
|
||||
if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0):
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
|
|
@ -283,8 +329,8 @@ while True:
|
|||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": model_config_kwargs,
|
||||
"user_config": user_config, # inputs to the training script
|
||||
"device_batch_size": device_batch_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
"device_batch_size": args.device_batch_size,
|
||||
"max_seq_len": args.max_seq_len,
|
||||
"dataloader_state_dict": dataloader_state_dict,
|
||||
"loop_state": { # all loop state (other than step) so that we can resume training
|
||||
"min_val_bpb": min_val_bpb,
|
||||
|
|
@ -311,19 +357,16 @@ while True:
|
|||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# gradient clipping
|
||||
grad_clip_enabled = grad_clip > 0.0
|
||||
if grad_clip_enabled:
|
||||
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
|
||||
grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point)
|
||||
# step the optimizers
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
muon_momentum = get_muon_momentum(step)
|
||||
muon_weight_decay = get_weight_decay(step)
|
||||
for group in muon_optimizer.param_groups:
|
||||
group["momentum"] = muon_momentum
|
||||
group["weight_decay"] = muon_weight_decay
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
|
@ -337,14 +380,23 @@ while True:
|
|||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
tok_per_sec = int(args.total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled else ""
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
# Calculate ETA based on average time per step (excluding first 10 steps)
|
||||
steps_done = step - 10
|
||||
if steps_done > 0:
|
||||
avg_time_per_step = total_training_time / steps_done
|
||||
remaining_steps = num_iterations - step
|
||||
eta_seconds = remaining_steps * avg_time_per_step
|
||||
eta_str = f" | eta: {eta_seconds/60:.1f}m"
|
||||
else:
|
||||
eta_str = ""
|
||||
epoch = dataloader_state_dict["epoch"]
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
|
||||
if step % 100 == 0:
|
||||
log_data = {
|
||||
"step": step,
|
||||
|
|
@ -355,9 +407,8 @@ while True:
|
|||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
"train/epoch": epoch,
|
||||
}
|
||||
if grad_clip_enabled:
|
||||
log_data["train/grad_norm"] = grad_norm
|
||||
wandb_run.log(log_data)
|
||||
|
||||
# state update
|
||||
|
|
@ -366,7 +417,8 @@ while True:
|
|||
# print a few more stats
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
if val_bpb is not None:
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.6f}")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
|
|
@ -377,14 +429,14 @@ get_report().log(section="Base model training", data=[
|
|||
"Number of FLOPs per token": f"{num_flops_per_token:e}",
|
||||
"Calculated number of iterations": num_iterations,
|
||||
"Number of training tokens": total_tokens,
|
||||
"Tokens : Params ratio": total_batch_size * num_iterations / num_params,
|
||||
"Tokens : Params ratio": args.total_batch_size * num_iterations / num_params,
|
||||
"DDP world size": ddp_world_size,
|
||||
"warmup_ratio": warmup_ratio,
|
||||
"warmdown_ratio": warmdown_ratio,
|
||||
"final_lr_frac": final_lr_frac,
|
||||
"warmup_ratio": args.warmup_ratio,
|
||||
"warmdown_ratio": args.warmdown_ratio,
|
||||
"final_lr_frac": args.final_lr_frac,
|
||||
},
|
||||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
"Minimum validation bpb": min_val_bpb if val_bpb is not None else None,
|
||||
"Final validation bpb": val_bpb,
|
||||
"CORE metric estimate": results.get("core_metric", None),
|
||||
"MFU %": f"{mfu:.2f}%",
|
||||
|
|
|
|||
|
|
@ -16,57 +16,68 @@ python -m scripts.chat_rl
|
|||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import itertools
|
||||
import re
|
||||
import wandb
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model
|
||||
from nanochat.engine import Engine
|
||||
from tasks.gsm8k import GSM8K
|
||||
|
||||
# RL hyperparameters
|
||||
run = "dummy" # wandb run name
|
||||
source = "sft" # mid|sft
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 8 # no forward pass will go above this to not OOM
|
||||
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
|
||||
num_samples = 16 # number of samples per example (/question)
|
||||
max_new_tokens = 256
|
||||
temperature = 1.0
|
||||
top_k = 50 # TODO: try None?
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
weight_decay = 0.0
|
||||
init_lr_frac = 0.05
|
||||
num_epochs = 1 # how many epochs of gsm8k to train on
|
||||
save_every = 60 # every how many steps to save the model
|
||||
eval_every = 60 # every how many steps to evaluate the model for val pass@k
|
||||
eval_examples = 400 # number of examples used for evaluating pass@k
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs over GSM8K")
|
||||
# Batch sizes / sampling
|
||||
parser.add_argument("--device-batch-size", type=int, default=8, help="max batch size per forward pass")
|
||||
parser.add_argument("--examples-per-step", type=int, default=16, help="total examples per optimization step across all ranks")
|
||||
parser.add_argument("--num-samples", type=int, default=16, help="number of samples per example/question")
|
||||
# Generation
|
||||
parser.add_argument("--max-new-tokens", type=int, default=256, help="max tokens to generate per sample")
|
||||
parser.add_argument("--temperature", type=float, default=1.0, help="sampling temperature")
|
||||
parser.add_argument("--top-k", type=int, default=50, help="top-k sampling (0 = disabled)")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init-lr-frac", type=float, default=0.05, help="initial LR as fraction of base LR")
|
||||
# Evaluation / checkpointing
|
||||
parser.add_argument("--eval-every", type=int, default=60, help="evaluate pass@k every N steps")
|
||||
parser.add_argument("--eval-examples", type=int, default=400, help="number of examples for pass@k evaluation")
|
||||
parser.add_argument("--save-every", type=int, default=60, help="save checkpoint every N steps")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Init compute/precision
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=args.run, config=user_config)
|
||||
|
||||
# Init model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
||||
engine = Engine(model, tokenizer) # for sampling rollouts
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -74,7 +85,7 @@ engine = Engine(model, tokenizer) # for sampling rollouts
|
|||
|
||||
train_task = GSM8K(subset="main", split="train")
|
||||
val_task = GSM8K(subset="main", split="test")
|
||||
num_steps = (len(train_task) // examples_per_step) * num_epochs
|
||||
num_steps = (len(train_task) // args.examples_per_step) * args.num_epochs
|
||||
print0(f"Calculated number of steps: {num_steps}")
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
@ -95,16 +106,16 @@ def get_batch():
|
|||
model.eval() # ensure the model is in eval mode
|
||||
generated_token_sequences = []
|
||||
masks = []
|
||||
num_sampling_steps = num_samples // device_batch_size # go sequentially to prevent OOMs
|
||||
num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs
|
||||
for sampling_step in range(num_sampling_steps):
|
||||
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
|
||||
with autocast_ctx:
|
||||
generated_token_sequences_batch, masks_batch = engine.generate_batch(
|
||||
tokens,
|
||||
num_samples=device_batch_size,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
num_samples=args.device_batch_size,
|
||||
max_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
seed=seed, # must make sure to change the seed for each sampling step
|
||||
)
|
||||
generated_token_sequences.extend(generated_token_sequences_batch)
|
||||
|
|
@ -162,7 +173,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
|
|||
tokens = tokenizer.render_for_completion(conversation)
|
||||
prefix_length = len(tokens)
|
||||
# Generate k samples using batched generation inside the Engine
|
||||
assert num_samples <= device_batch_size # usually this is true. we can add a loop if not...
|
||||
assert num_samples <= args.device_batch_size # usually this is true. we can add a loop if not...
|
||||
generated_token_sequences, masks = engine.generate_batch(
|
||||
tokens,
|
||||
num_samples=num_samples,
|
||||
|
|
@ -191,16 +202,16 @@ def run_gsm8k_eval(task, tokenizer, engine,
|
|||
|
||||
# Init the optimizer
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=unembedding_lr,
|
||||
embedding_lr=embedding_lr,
|
||||
matrix_lr=matrix_lr,
|
||||
weight_decay=weight_decay,
|
||||
unembedding_lr=args.unembedding_lr,
|
||||
embedding_lr=args.embedding_lr,
|
||||
matrix_lr=args.matrix_lr,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
|
||||
# Set the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# Learning rate scheduler: simple rampdown to zero over num_steps
|
||||
|
|
@ -209,9 +220,9 @@ def get_lr_multiplier(it):
|
|||
return lrm
|
||||
|
||||
# Calculate the number of examples each rank handles to achieve the desired examples_per_step
|
||||
print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step
|
||||
assert examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
|
||||
examples_per_rank = examples_per_step // ddp_world_size # per GPU
|
||||
print0(f"Total sequences per step: {args.examples_per_step * args.num_samples}") # total batch size in sequences/step
|
||||
assert args.examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
|
||||
examples_per_rank = args.examples_per_step // ddp_world_size # per GPU
|
||||
print0(f"Calculated examples per rank: {examples_per_rank}")
|
||||
|
||||
# Kick off the training loop
|
||||
|
|
@ -219,22 +230,22 @@ batch_iterator = get_batch()
|
|||
for step in range(num_steps):
|
||||
|
||||
# Evaluate the model once in a while and log to wandb
|
||||
if step % eval_every == 0:
|
||||
if step % args.eval_every == 0:
|
||||
model.eval()
|
||||
passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size
|
||||
passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size
|
||||
with autocast_ctx:
|
||||
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0)
|
||||
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
|
||||
records = list(records_iter) # collect all records
|
||||
for k in range(1, device_batch_size + 1):
|
||||
for k in range(1, args.device_batch_size + 1):
|
||||
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
|
||||
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
|
||||
if ddp:
|
||||
dist.all_reduce(num_records, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(passk, op=dist.ReduceOp.SUM)
|
||||
passk = passk / num_records.item() # normalize by the total number of records
|
||||
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, device_batch_size + 1)]
|
||||
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, args.device_batch_size + 1)]
|
||||
print0(f"Step {step} | {', '.join(print_passk)}")
|
||||
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)}
|
||||
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, args.device_batch_size + 1)}
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
**log_passk,
|
||||
|
|
@ -249,11 +260,11 @@ for step in range(num_steps):
|
|||
# Evaluate the loss and gradients
|
||||
model.train() # ensure the model is in train mode
|
||||
# We need one more loop because we can never exceed the device_batch_size
|
||||
assert inputs_all.size(0) % device_batch_size == 0
|
||||
num_passes = inputs_all.size(0) // device_batch_size
|
||||
assert inputs_all.size(0) % args.device_batch_size == 0
|
||||
num_passes = inputs_all.size(0) // args.device_batch_size
|
||||
for pass_idx in range(num_passes):
|
||||
# Pluck out the batch for this pass
|
||||
b0, b1 = pass_idx * device_batch_size, (pass_idx + 1) * device_batch_size
|
||||
b0, b1 = pass_idx * args.device_batch_size, (pass_idx + 1) * args.device_batch_size
|
||||
inputs = inputs_all[b0:b1]
|
||||
targets = targets_all[b0:b1]
|
||||
rewards = rewards_all[b0:b1]
|
||||
|
|
@ -306,10 +317,10 @@ for step in range(num_steps):
|
|||
})
|
||||
|
||||
# Master process saves the model once in a while. Skip first step. Save last step.
|
||||
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
|
||||
if master_process and ((step > 0 and step % args.save_every == 0) or step == num_steps - 1):
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # base the model tag on the depth of the base model
|
||||
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ Or torchrun for training:
|
|||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
|
||||
|
|
@ -31,49 +32,51 @@ from tasks.customjson import CustomJSON
|
|||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# SFT Hyperparameters
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
# input model options
|
||||
source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
# compute/precision
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 4 # max to avoid OOM
|
||||
# optimization
|
||||
num_epochs = 1
|
||||
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
|
||||
target_examples_per_step = 32
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
weight_decay = 0.0
|
||||
init_lr_frac = 0.02
|
||||
# evaluation and logging there of
|
||||
eval_every = 100
|
||||
eval_steps = 100
|
||||
eval_metrics_every = 200
|
||||
eval_metrics_max_problems = 1024
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Supervised finetuning for chat")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--source", type=str, default="mid", help="base|mid - which checkpoint to load from")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs")
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="override number of iterations (-1 = use num_epochs)")
|
||||
# Batch sizes
|
||||
parser.add_argument("--device-batch-size", type=int, default=4, help="per-device batch size")
|
||||
parser.add_argument("--target-examples-per-step", type=int, default=32, help="target examples per optimization step")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init-lr-frac", type=float, default=0.02, help="initial LR as fraction of base LR")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval-every", type=int, default=100, help="evaluate val loss every N steps")
|
||||
parser.add_argument("--eval-steps", type=int, default=100, help="number of batches for val loss evaluation")
|
||||
parser.add_argument("--eval-metrics-every", type=int, default=200, help="evaluate accuracy metrics every N steps")
|
||||
parser.add_argument("--eval-metrics-max-problems", type=int, default=1024, help="max problems per metric evaluation")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config, save_code=True)
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="train", model_tag=args.model_tag, step=args.model_step)
|
||||
orig_model = model # original, uncompiled model
|
||||
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
||||
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
||||
|
|
@ -127,34 +130,36 @@ def sft_data_generator(dataset, batch_size):
|
|||
yield collate_and_yield(batch)
|
||||
batch = []
|
||||
|
||||
examples_per_step = device_batch_size * ddp_world_size
|
||||
print0(f"Target examples per step: {target_examples_per_step}")
|
||||
print0(f"Device batch size: {device_batch_size}")
|
||||
examples_per_step = args.device_batch_size * ddp_world_size
|
||||
print0(f"Target examples per step: {args.target_examples_per_step}")
|
||||
print0(f"Device batch size: {args.device_batch_size}")
|
||||
print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}")
|
||||
assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
|
||||
grad_accum_steps = target_examples_per_step // examples_per_step
|
||||
assert args.target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
|
||||
grad_accum_steps = args.target_examples_per_step // examples_per_step
|
||||
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
|
||||
|
||||
if num_iterations == -1:
|
||||
if args.num_iterations == -1:
|
||||
# derive num_iterations from num_epochs and the size of the dataset
|
||||
assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
|
||||
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
|
||||
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
|
||||
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
|
||||
assert args.num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
|
||||
num_iterations = (len(train_ds) // args.target_examples_per_step) * args.num_epochs
|
||||
else:
|
||||
num_iterations = args.num_iterations
|
||||
train_loader = sft_data_generator(train_ds, batch_size=args.device_batch_size)
|
||||
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=args.device_batch_size)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer
|
||||
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=unembedding_lr,
|
||||
embedding_lr=embedding_lr,
|
||||
matrix_lr=matrix_lr,
|
||||
weight_decay=weight_decay,
|
||||
unembedding_lr=args.unembedding_lr,
|
||||
embedding_lr=args.embedding_lr,
|
||||
matrix_lr=args.matrix_lr,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
# Set the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -171,11 +176,11 @@ for step in range(num_iterations):
|
|||
last_step = step == num_iterations - 1
|
||||
|
||||
# evaluate the validation loss
|
||||
if last_step or step % eval_every == 0:
|
||||
if last_step or step % args.eval_every == 0:
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
losses = []
|
||||
for _ in range(eval_steps):
|
||||
for _ in range(args.eval_steps):
|
||||
val_inputs, val_targets = next(val_loader)
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
loss = model(val_inputs, val_targets)
|
||||
|
|
@ -192,13 +197,13 @@ for step in range(num_iterations):
|
|||
model.train()
|
||||
|
||||
# evaluate accuracy of the multiple choice tasks (which are quick to run)
|
||||
if last_step or (step > 0 and step % eval_metrics_every == 0):
|
||||
if last_step or (step > 0 and step % args.eval_metrics_every == 0):
|
||||
model.eval()
|
||||
metrics = {}
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=args.device_batch_size*2, max_problems=args.eval_metrics_max_problems)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=args.device_batch_size*2, max_problems=args.eval_metrics_max_problems)
|
||||
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
||||
print0(f"Step {step:05d} | {metrics_str}")
|
||||
wandb_run.log({
|
||||
|
|
@ -250,7 +255,7 @@ for step in range(num_iterations):
|
|||
if master_process:
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ python -m scripts.mid_train
|
|||
|
||||
Or torchrun for training:
|
||||
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
import argparse
|
||||
import os
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
|
|
@ -31,65 +31,75 @@ from tasks.customjson import CustomJSON
|
|||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
max_seq_len = 2048
|
||||
device_batch_size = 32
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
||||
weight_decay = 0.0
|
||||
eval_every = 150 # -1 = disable
|
||||
eval_tokens = 20*524288
|
||||
total_batch_size = 524288
|
||||
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Midtrain the model")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
|
||||
# Batch sizes
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
# Output
|
||||
parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config)
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=args.run, config=user_config)
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
|
||||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
|
||||
pretrain_batch_size = meta.get("device_batch_size", None)
|
||||
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
|
||||
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
|
||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
optimizers = model.setup_optimizers(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
# Override the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# Midtraining data mixture and DataLoader
|
||||
|
|
@ -114,49 +124,100 @@ val_dataset = TaskMixture([
|
|||
# these two global variables and update them from within the data generator.
|
||||
last_step = False # we will toggle this to True when we reach the end of the training dataset
|
||||
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
|
||||
def mid_data_generator(split):
|
||||
global last_step, approx_progress
|
||||
current_epoch = 1 # track epoch for logging
|
||||
def mid_data_generator_bos_bestfit(split, buffer_size=100):
|
||||
"""
|
||||
BOS-aligned dataloader for midtraining with bestfit-crop packing.
|
||||
|
||||
Each row in the batch starts with BOS (beginning of a conversation).
|
||||
Conversations are packed using best-fit algorithm to minimize cropping.
|
||||
This matches the BOS-aligned approach used in pretraining.
|
||||
"""
|
||||
global last_step, approx_progress, current_epoch
|
||||
assert split in {"train", "val"}, "split must be 'train' or 'val'"
|
||||
dataset = train_dataset if split == "train" else val_dataset
|
||||
dataset_size = len(dataset)
|
||||
assert dataset_size > 0
|
||||
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
|
||||
token_buffer = deque()
|
||||
# CUDA supports memory pinning for faster transfers between CPU and GPU:
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
|
||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||
it = 0 # iteration counter
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding
|
||||
while len(token_buffer) < needed_tokens:
|
||||
row_capacity = args.max_seq_len + 1 # +1 for target at last position
|
||||
|
||||
# Conversation buffer: list of token lists
|
||||
conv_buffer = []
|
||||
cursor = ddp_rank # Each rank processes different conversations (for fetching)
|
||||
consumed = ddp_rank # Track actual consumption separately from buffering
|
||||
epoch = 1
|
||||
it = 0 # iteration counter
|
||||
|
||||
def refill_buffer():
|
||||
nonlocal cursor, epoch
|
||||
while len(conv_buffer) < buffer_size:
|
||||
conversation = dataset[cursor]
|
||||
ids, _ = tokenizer.render_conversation(conversation)
|
||||
token_buffer.extend(ids)
|
||||
conv_buffer.append(ids)
|
||||
cursor += ddp_world_size
|
||||
if cursor >= dataset_size:
|
||||
cursor -= dataset_size # wrap around for another epoch
|
||||
if split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
cursor = cursor % dataset_size
|
||||
epoch += 1
|
||||
# Note: last_step is now triggered based on consumption, not fetching
|
||||
|
||||
while True:
|
||||
rows = []
|
||||
for _ in range(args.device_batch_size):
|
||||
row = []
|
||||
while len(row) < row_capacity:
|
||||
# Ensure buffer has conversations
|
||||
while len(conv_buffer) < buffer_size:
|
||||
refill_buffer()
|
||||
|
||||
remaining = row_capacity - len(row)
|
||||
|
||||
# Find largest conversation that fits entirely
|
||||
best_idx = -1
|
||||
best_len = 0
|
||||
for i, conv in enumerate(conv_buffer):
|
||||
conv_len = len(conv)
|
||||
if conv_len <= remaining and conv_len > best_len:
|
||||
best_idx = i
|
||||
best_len = conv_len
|
||||
|
||||
if best_idx >= 0:
|
||||
# Found a conversation that fits - use it entirely
|
||||
conv = conv_buffer.pop(best_idx)
|
||||
row.extend(conv)
|
||||
consumed += ddp_world_size # Track actual consumption
|
||||
else:
|
||||
# No conversation fits - crop first conversation to fill remaining
|
||||
conv = conv_buffer.pop(0)
|
||||
row.extend(conv[:remaining])
|
||||
consumed += ddp_world_size # Track actual consumption
|
||||
|
||||
rows.append(row[:row_capacity])
|
||||
|
||||
# Stopping condition to respect num_iterations, if given
|
||||
it += 1
|
||||
if 0 < num_iterations <= it and split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Build up inputs/targets and yield
|
||||
for i in range(needed_tokens):
|
||||
scratch[i] = token_buffer.popleft()
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
if 0 < args.num_iterations <= it and split == "train":
|
||||
last_step = True
|
||||
|
||||
# Update progress tracking (based on consumed, not cursor, to account for buffering)
|
||||
if split == "train":
|
||||
if num_iterations > 0:
|
||||
approx_progress = it / num_iterations # calculate progress from the max number of iterations
|
||||
current_epoch = epoch
|
||||
if args.num_iterations > 0:
|
||||
approx_progress = it / args.num_iterations
|
||||
else:
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
approx_progress = consumed / dataset_size
|
||||
# Trigger last_step when we've consumed enough (instead of when cursor wraps)
|
||||
if consumed >= dataset_size:
|
||||
last_step = True
|
||||
|
||||
# Build tensors
|
||||
use_cuda = device_type == "cuda"
|
||||
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
|
||||
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
|
||||
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
|
||||
|
||||
yield inputs, targets
|
||||
|
||||
train_loader = mid_data_generator("train")
|
||||
build_val_loader = lambda: mid_data_generator("val")
|
||||
train_loader = mid_data_generator_bos_bestfit("train")
|
||||
build_val_loader = lambda: mid_data_generator_bos_bestfit("val")
|
||||
progress = 0 # will go from 0 to 1 over the course of the epoch
|
||||
|
||||
# Learning rate scheduler
|
||||
|
|
@ -179,7 +240,7 @@ ema_beta = 0.9 # EMA decay factor
|
|||
total_training_time = 0 # total wall-clock time of training
|
||||
step = 0
|
||||
while True:
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
flops_so_far = num_flops_per_token * args.total_batch_size * step
|
||||
|
||||
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
|
||||
if ddp:
|
||||
|
|
@ -188,10 +249,10 @@ while True:
|
|||
last_step = bool(last_step_tensor.item())
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if eval_every > 0 and (last_step or step % eval_every == 0):
|
||||
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
with autocast_ctx:
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
|
|
@ -206,8 +267,8 @@ while True:
|
|||
model.train()
|
||||
|
||||
# save checkpoint at the end of the run (only on master process)
|
||||
if master_process and last_step and not dry_run:
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
if master_process and last_step and not args.dry_run:
|
||||
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
|
|
@ -218,7 +279,7 @@ while True:
|
|||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": {
|
||||
"sequence_len": max_seq_len,
|
||||
"sequence_len": args.max_seq_len,
|
||||
"vocab_size": tokenizer.get_vocab_size(),
|
||||
"n_layer": depth,
|
||||
"n_head": model.config.n_head,
|
||||
|
|
@ -268,13 +329,13 @@ while True:
|
|||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * progress
|
||||
tok_per_sec = int(total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
tok_per_sec = int(args.total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
|
||||
if step % 10 == 0:
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
|
|
@ -285,6 +346,7 @@ while True:
|
|||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
"train/epoch": current_epoch,
|
||||
})
|
||||
|
||||
# print a few more stats
|
||||
|
|
@ -293,7 +355,7 @@ print0(f"Total training time: {total_training_time/60:.2f}m")
|
|||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
# Log to report
|
||||
if not dry_run:
|
||||
if not args.dry_run:
|
||||
from nanochat.report import get_report
|
||||
get_report().log(section="Midtraining", data=[
|
||||
user_config, # CLI args
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Train a tokenizer using the HuggingFace Tokenizers library.
|
||||
Train a tokenizer using our own BPE Tokenizer library.
|
||||
In the style of GPT-4 tokenizer.
|
||||
"""
|
||||
import os
|
||||
|
|
@ -14,9 +14,9 @@ from nanochat.dataset import parquets_iter_batched
|
|||
# Parse command line arguments
|
||||
|
||||
parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
|
||||
parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
|
||||
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
|
||||
parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)')
|
||||
parser.add_argument('--max-chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
|
||||
parser.add_argument('--doc-cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
|
||||
parser.add_argument('--vocab-size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)')
|
||||
args = parser.parse_args()
|
||||
print(f"max_chars: {args.max_chars:,}")
|
||||
print(f"doc_cap: {args.doc_cap:,}")
|
||||
|
|
|
|||
19
speedrun.sh
19
speedrun.sh
|
|
@ -48,13 +48,6 @@ python -m nanochat.report reset
|
|||
# -----------------------------------------------------------------------------
|
||||
# Tokenizer
|
||||
|
||||
# Install Rust / Cargo
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
|
||||
# Build the rustbpe Tokenizer
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
|
||||
# Download the first ~2B characters of pretraining dataset
|
||||
# look at dev/repackage_data_reference.py for details on how this data was prepared
|
||||
# each data shard is ~250M chars
|
||||
|
|
@ -62,11 +55,11 @@ uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
|||
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
|
||||
python -m nanochat.dataset -n 8
|
||||
# Immediately also kick off downloading more shards in the background while tokenizer trains
|
||||
# See comment below for why 240 is the right number here
|
||||
python -m nanochat.dataset -n 240 &
|
||||
# See comment below for why 370 is the right number here
|
||||
python -m nanochat.dataset -n 370 &
|
||||
DATASET_DOWNLOAD_PID=$!
|
||||
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
|
||||
python -m scripts.tok_train --max_chars=2000000000
|
||||
python -m scripts.tok_train --max-chars=2000000000 --vocab-size=65536
|
||||
# evaluate the tokenizer (report compression ratio etc.)
|
||||
python -m scripts.tok_eval
|
||||
|
||||
|
|
@ -77,7 +70,9 @@ python -m scripts.tok_eval
|
|||
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
|
||||
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
|
||||
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
|
||||
# Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk.
|
||||
# Round up to 240 for safety. Also, the new DataLoader wastes about 35% of tokens to cropping
|
||||
# so 240 / (1 - 0.35) = 370 shards are needed.
|
||||
# At ~100MB/shard, this downloads ~37GB of data to disk.
|
||||
# (The total number of shards available in the entire dataset is 1822.)
|
||||
echo "Waiting for dataset download to complete..."
|
||||
wait $DATASET_DOWNLOAD_PID
|
||||
|
|
@ -86,7 +81,7 @@ wait $DATASET_DOWNLOAD_PID
|
|||
NPROC_PER_NODE=8
|
||||
|
||||
# pretrain the d20 model
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target-param-data-ratio=20 --run=$WANDB_RUN
|
||||
# evaluate the model on a larger chunk of train/val data and draw some samples
|
||||
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
|
||||
# evaluate the model on CORE tasks
|
||||
|
|
|
|||
|
|
@ -39,13 +39,9 @@ class MockModel:
|
|||
def forward(self, ids, kv_cache=None):
|
||||
"""Return uniform logits so sampling is spread across vocab."""
|
||||
B, T = ids.shape
|
||||
# Simulate what a real transformer does: insert k,v into the cache for each layer
|
||||
# With FA3, flash_attn_with_kvcache updates cache in-place and we advance position
|
||||
if kv_cache is not None:
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
for layer_idx in range(self.config.n_layer):
|
||||
k = torch.zeros(B, self.config.n_kv_head, T, head_dim)
|
||||
v = torch.zeros(B, self.config.n_kv_head, T, head_dim)
|
||||
kv_cache.insert_kv(layer_idx, k, v)
|
||||
kv_cache.advance(T)
|
||||
# Uniform logits -> equal probability for all tokens
|
||||
logits = torch.zeros(B, T, self.vocab_size)
|
||||
return logits
|
||||
|
|
@ -85,16 +81,11 @@ class ByteTokenizer:
|
|||
byte_tokens = [t for t in tokens if t < 256]
|
||||
return bytes(byte_tokens).decode("utf-8", errors="replace")
|
||||
|
||||
def test_kv_cache_resize():
|
||||
"""
|
||||
The KV cache was not resized correctly, more information here:
|
||||
https://github.com/karpathy/nanochat/pull/186
|
||||
This test reproduces the issue and will be merged alongside the fix.
|
||||
"""
|
||||
|
||||
def test_kv_cache_basic():
|
||||
"""Test basic KVCache functionality for FA3."""
|
||||
batch_size = 2
|
||||
num_heads = 3
|
||||
seq_len = 4
|
||||
seq_len = 64
|
||||
head_dim = 5
|
||||
num_layers = 6
|
||||
|
||||
|
|
@ -103,45 +94,64 @@ def test_kv_cache_resize():
|
|||
num_heads=num_heads,
|
||||
seq_len=seq_len,
|
||||
head_dim=head_dim,
|
||||
num_layers=num_layers
|
||||
num_layers=num_layers,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
# Insert a single token with a distinct fill value to all layers
|
||||
def insert_token(token_idx):
|
||||
for layer_idx in range(num_layers):
|
||||
k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32)
|
||||
v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32)
|
||||
kv_cache.insert_kv(layer_idx, k, v)
|
||||
# Check initial state
|
||||
assert kv_cache.get_pos() == 0
|
||||
assert kv_cache.k_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim)
|
||||
assert kv_cache.v_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim)
|
||||
|
||||
# Insert 4 tokens (fills the initial seq_len=4)
|
||||
for i in range(4):
|
||||
insert_token(i)
|
||||
# Test advance
|
||||
kv_cache.advance(10)
|
||||
assert kv_cache.get_pos() == 10
|
||||
|
||||
# Record the original state of the cache
|
||||
original_cache = kv_cache.kv_cache.clone()
|
||||
original_seq_len = original_cache.shape[4]
|
||||
kv_cache.advance(5)
|
||||
assert kv_cache.get_pos() == 15
|
||||
|
||||
# Insert the 5th token, which will trigger a resize
|
||||
insert_token(4)
|
||||
# Verify that the cache actually resized
|
||||
new_seq_len = kv_cache.kv_cache.shape[4]
|
||||
assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
|
||||
# Test reset
|
||||
kv_cache.reset()
|
||||
assert kv_cache.get_pos() == 0
|
||||
|
||||
# Verify that the original 4 tokens are still intact after resize
|
||||
for layer_idx in range(num_layers):
|
||||
for token_idx in range(4):
|
||||
# Check that resized cache matches expected values
|
||||
expected_k = float(token_idx)
|
||||
expected_v = float(token_idx * 100)
|
||||
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :]
|
||||
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :]
|
||||
assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
|
||||
assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
|
||||
# And that the original cache matches resized cache
|
||||
original_k = original_cache[layer_idx, 0, :, :, token_idx, :]
|
||||
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
|
||||
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
|
||||
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
|
||||
# Test get_layer_cache returns correct views
|
||||
k_layer0, v_layer0 = kv_cache.get_layer_cache(0)
|
||||
assert k_layer0.shape == (batch_size, seq_len, num_heads, head_dim)
|
||||
assert v_layer0.shape == (batch_size, seq_len, num_heads, head_dim)
|
||||
|
||||
|
||||
def test_kv_cache_prefill():
|
||||
"""Test KVCache.prefill() copies data correctly."""
|
||||
batch_size = 1
|
||||
num_heads = 4
|
||||
head_dim = 8
|
||||
num_layers = 2
|
||||
|
||||
# Create source cache and advance it
|
||||
src_cache = KVCache(
|
||||
batch_size=batch_size, num_heads=num_heads, seq_len=32,
|
||||
head_dim=head_dim, num_layers=num_layers, device="cpu",
|
||||
)
|
||||
# Write some data to source cache
|
||||
src_cache.k_cache[0, 0, :16, :, :] = 1.0
|
||||
src_cache.v_cache[0, 0, :16, :, :] = 2.0
|
||||
src_cache.advance(16)
|
||||
|
||||
# Create destination cache with larger seq_len
|
||||
dst_cache = KVCache(
|
||||
batch_size=batch_size, num_heads=num_heads, seq_len=64,
|
||||
head_dim=head_dim, num_layers=num_layers, device="cpu",
|
||||
)
|
||||
|
||||
# Prefill
|
||||
dst_cache.prefill(src_cache)
|
||||
|
||||
# Check position was copied
|
||||
assert dst_cache.get_pos() == 16
|
||||
|
||||
# Check data was copied
|
||||
assert (dst_cache.k_cache[0, 0, :16, :, :] == 1.0).all()
|
||||
assert (dst_cache.v_cache[0, 0, :16, :, :] == 2.0).all()
|
||||
|
||||
|
||||
def test_multi_sample_first_token_diversity():
|
||||
|
|
|
|||
|
|
@ -1,718 +0,0 @@
|
|||
"""
|
||||
Comparing the training of:
|
||||
|
||||
1. (very slow) Python reference implementation
|
||||
2. Optimized Python implementation
|
||||
3. HuggingFace tokenizers training implementation
|
||||
4. Our own custom RustBPE training implementation
|
||||
|
||||
All of these should calculate the same merges and produce
|
||||
the same vocabulary and tokenizations.
|
||||
|
||||
Finally, for inference we will use tiktoken for efficiency.
|
||||
So we want to make sure we can export our rustbpe tokenizer
|
||||
into tiktoken and use it for inference with identical results.
|
||||
|
||||
Run with:
|
||||
python -m pytest tests/test_rustbpe.py -v -s
|
||||
-v is verbose, -s is show prints
|
||||
"""
|
||||
|
||||
import regex as re
|
||||
from collections import Counter, defaultdict
|
||||
import time
|
||||
import warnings
|
||||
import rustbpe
|
||||
import tiktoken
|
||||
import pytest
|
||||
|
||||
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Reference tokenizer, pretty much copy pasted and pruned a bit from minbpe
|
||||
|
||||
def get_stats(ids, counts=None):
|
||||
"""
|
||||
Given a list of integers, return a dictionary of counts of consecutive pairs
|
||||
Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
|
||||
Optionally allows to update an existing dictionary of counts
|
||||
"""
|
||||
counts = {} if counts is None else counts
|
||||
for pair in zip(ids, ids[1:]): # iterate consecutive elements
|
||||
counts[pair] = counts.get(pair, 0) + 1
|
||||
return counts
|
||||
|
||||
def merge(ids, pair, idx):
|
||||
"""
|
||||
In the list of integers (ids), replace all consecutive occurrences
|
||||
of pair with the new integer token idx
|
||||
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
|
||||
"""
|
||||
newids = []
|
||||
i = 0
|
||||
while i < len(ids):
|
||||
# if not at the very last position AND the pair matches, replace it
|
||||
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
|
||||
newids.append(idx)
|
||||
i += 2
|
||||
else:
|
||||
newids.append(ids[i])
|
||||
i += 1
|
||||
return newids
|
||||
|
||||
class RegexTokenizer:
|
||||
|
||||
def __init__(self, pattern=None):
|
||||
"""
|
||||
- pattern: optional string to override the default (GPT-4 split pattern)
|
||||
- special_tokens: str -> int dictionary of special tokens
|
||||
example: {'<|endoftext|>': 100257}
|
||||
"""
|
||||
self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
|
||||
self.merges = {} # (int, int) -> int
|
||||
self.compiled_pattern = re.compile(self.pattern)
|
||||
self.special_tokens = {}
|
||||
self.inverse_special_tokens = {}
|
||||
self.vocab = self._build_vocab()
|
||||
|
||||
def _build_vocab(self):
|
||||
# vocab is simply and deterministically derived from merges
|
||||
vocab = {idx: bytes([idx]) for idx in range(256)}
|
||||
for (p0, p1), idx in self.merges.items():
|
||||
vocab[idx] = vocab[p0] + vocab[p1]
|
||||
for special, idx in self.special_tokens.items():
|
||||
vocab[idx] = special.encode("utf-8")
|
||||
return vocab
|
||||
|
||||
def train(self, text, vocab_size, verbose=False):
|
||||
assert vocab_size >= 256
|
||||
num_merges = vocab_size - 256
|
||||
|
||||
# keep track of whether at any point during training the merge is ambiguous (counts of pairs are not unique)
|
||||
ambiguous = False
|
||||
|
||||
# split the text up into text chunks
|
||||
text_chunks = re.findall(self.compiled_pattern, text)
|
||||
|
||||
# input text preprocessing
|
||||
ids = [list(ch.encode("utf-8")) for ch in text_chunks]
|
||||
|
||||
# iteratively merge the most common pairs to create new tokens
|
||||
merges = {} # (int, int) -> int
|
||||
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
|
||||
for i in range(num_merges):
|
||||
# count the number of times every consecutive pair appears
|
||||
stats = {}
|
||||
for chunk_ids in ids:
|
||||
# passing in stats will update it in place, adding up counts
|
||||
get_stats(chunk_ids, stats)
|
||||
# find the pair with the highest count
|
||||
pair = max(stats, key=stats.get)
|
||||
# check if the merge is ambiguous - i.e. the max value is not unique
|
||||
pair_count = stats[pair]
|
||||
pairs_with_max_count = [pair for pair, count in stats.items() if count == pair_count]
|
||||
if len(pairs_with_max_count) > 1:
|
||||
# print the top 10 pairs with their counts
|
||||
# print(f"{i} Merge is ambiguous! {pair} has {pair_count} occurrences")
|
||||
# for print_pair, print_count in sorted(stats.items(), key=lambda x: x[1], reverse=True)[:10]:
|
||||
# print(f"{print_pair}: {print_count}")
|
||||
ambiguous = True
|
||||
# mint a new token: assign it the next available id
|
||||
idx = 256 + i
|
||||
# replace all occurrences of pair in ids with idx
|
||||
ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
|
||||
# save the merge
|
||||
merges[pair] = idx
|
||||
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
|
||||
# prints
|
||||
if verbose:
|
||||
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
|
||||
|
||||
# save class variables
|
||||
self.merges = merges # used in encode()
|
||||
self.vocab = vocab # used in decode()
|
||||
return ambiguous
|
||||
|
||||
def _encode_chunk(self, text_bytes):
|
||||
# return the token ids
|
||||
# let's begin. first, convert all bytes to integers in range 0..255
|
||||
ids = list(text_bytes)
|
||||
while len(ids) >= 2:
|
||||
# find the pair with the lowest merge index
|
||||
stats = get_stats(ids)
|
||||
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
|
||||
# subtle: if there are no more merges available, the key will
|
||||
# result in an inf for every single pair, and the min will be
|
||||
# just the first pair in the list, arbitrarily
|
||||
# we can detect this terminating case by a membership check
|
||||
if pair not in self.merges:
|
||||
break # nothing else can be merged anymore
|
||||
# otherwise let's merge the best pair (lowest merge index)
|
||||
idx = self.merges[pair]
|
||||
ids = merge(ids, pair, idx)
|
||||
return ids
|
||||
|
||||
def encode_ordinary(self, text):
|
||||
"""Encoding that ignores any special tokens."""
|
||||
# split text into chunks of text by categories defined in regex pattern
|
||||
text_chunks = re.findall(self.compiled_pattern, text)
|
||||
# all chunks of text are encoded separately, then results are joined
|
||||
ids = []
|
||||
for chunk in text_chunks:
|
||||
chunk_bytes = chunk.encode("utf-8") # raw bytes
|
||||
chunk_ids = self._encode_chunk(chunk_bytes)
|
||||
ids.extend(chunk_ids)
|
||||
return ids
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Faster Python tokenizer, optimized version of the reference tokenizer
|
||||
|
||||
def fast_merge_inplace(ids, pair, idx):
|
||||
"""
|
||||
In the list of integers (ids), replace all consecutive occurrences
|
||||
of pair with the new integer token idx in place
|
||||
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
|
||||
"""
|
||||
# Find all positions where the pair occurs
|
||||
i = 0
|
||||
while i < len(ids) - 1:
|
||||
if ids[i] == pair[0] and ids[i+1] == pair[1]:
|
||||
ids[i] = idx
|
||||
ids.pop(i+1)
|
||||
else:
|
||||
i += 1
|
||||
return ids
|
||||
|
||||
|
||||
class FastRegexTokenizer:
|
||||
|
||||
def __init__(self, pattern=None):
|
||||
"""
|
||||
- pattern: optional string to override the default (GPT-4 split pattern)
|
||||
- special_tokens: str -> int dictionary of special tokens
|
||||
example: {'<|endoftext|>': 100257}
|
||||
"""
|
||||
self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
|
||||
self.compiled_pattern = re.compile(self.pattern)
|
||||
self.special_tokens = {}
|
||||
self.inverse_special_tokens = {}
|
||||
self.merges = {}
|
||||
self.vocab = self._build_vocab()
|
||||
|
||||
def _build_vocab(self):
|
||||
# vocab is simply and deterministically derived from merges
|
||||
vocab = {idx: bytes([idx]) for idx in range(256)}
|
||||
for (p0, p1), idx in self.merges.items():
|
||||
vocab[idx] = vocab[p0] + vocab[p1]
|
||||
for special, idx in self.special_tokens.items():
|
||||
vocab[idx] = special.encode("utf-8")
|
||||
return vocab
|
||||
|
||||
def train(self, text, vocab_size, verbose=False):
|
||||
"""
|
||||
A number of optimizations are introduced:
|
||||
- delete function call overhead by inlining functions
|
||||
- modifying list of ids in place with .pop() instead of creating a new list
|
||||
- collapse identical chunks to just the unique ones
|
||||
- update counts more cleverly - only around the affected chunks
|
||||
"""
|
||||
assert vocab_size >= 256
|
||||
num_merges = vocab_size - 256
|
||||
|
||||
# split the text up into text chunks
|
||||
text_chunks = re.findall(self.compiled_pattern, text)
|
||||
|
||||
# many, many chunks are identical, so we can "collapse" them to just the unique ones
|
||||
counts = Counter(text_chunks)
|
||||
unique_chunks = [ch for ch, count in counts.items()]
|
||||
chunk_counts = [count for ch, count in counts.items()]
|
||||
|
||||
# input text preprocessing
|
||||
ids = [list(ch.encode("utf-8")) for ch in unique_chunks]
|
||||
# iteratively merge the most common pairs to create new tokens
|
||||
merges = {} # (int, int) -> int
|
||||
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
|
||||
|
||||
# Initial count: build stats and position tracking
|
||||
stats = defaultdict(int)
|
||||
positions = defaultdict(set) # pair -> set of chunk indices that contain this pair
|
||||
|
||||
for chunk_idx, (chunk_ids, count) in enumerate(zip(ids, chunk_counts)):
|
||||
for pair in zip(chunk_ids, chunk_ids[1:]):
|
||||
stats[pair] += count
|
||||
positions[pair].add(chunk_idx)
|
||||
|
||||
for i in range(num_merges):
|
||||
if not stats:
|
||||
break
|
||||
|
||||
# find the pair with the highest count
|
||||
pair = max(stats, key=stats.get)
|
||||
# mint a new token: assign it the next available id
|
||||
idx = 256 + i
|
||||
|
||||
# Get chunks that contain this pair
|
||||
affected_chunks = positions[pair]
|
||||
|
||||
# Track count changes for incremental update
|
||||
count_changes = defaultdict(int)
|
||||
|
||||
# Replace all occurrences of pair in affected chunks only
|
||||
for chunk_idx in affected_chunks:
|
||||
chunk_ids = ids[chunk_idx]
|
||||
chunk_count = chunk_counts[chunk_idx]
|
||||
ix = 0
|
||||
while ix < len(chunk_ids) - 1:
|
||||
if chunk_ids[ix] == pair[0] and chunk_ids[ix+1] == pair[1]:
|
||||
# Track what pairs are being removed/added
|
||||
# Remove: (prev, A), (A, B), (B, next)
|
||||
if ix > 0:
|
||||
old_left = (chunk_ids[ix-1], chunk_ids[ix])
|
||||
count_changes[old_left] -= chunk_count
|
||||
|
||||
# The merged pair disappears
|
||||
count_changes[pair] -= chunk_count
|
||||
|
||||
if ix + 2 < len(chunk_ids):
|
||||
old_right = (chunk_ids[ix+1], chunk_ids[ix+2])
|
||||
count_changes[old_right] -= chunk_count
|
||||
|
||||
# Apply the merge
|
||||
chunk_ids[ix] = idx
|
||||
chunk_ids.pop(ix+1)
|
||||
|
||||
# Add: (prev, C), (C, next)
|
||||
if ix > 0:
|
||||
new_left = (chunk_ids[ix-1], chunk_ids[ix])
|
||||
count_changes[new_left] += chunk_count
|
||||
|
||||
if ix + 1 < len(chunk_ids):
|
||||
new_right = (chunk_ids[ix], chunk_ids[ix+1])
|
||||
count_changes[new_right] += chunk_count
|
||||
else:
|
||||
ix += 1
|
||||
|
||||
# Apply incremental changes to stats and positions
|
||||
for changed_pair, delta in count_changes.items():
|
||||
if changed_pair == pair:
|
||||
# The merged pair should disappear completely
|
||||
continue
|
||||
|
||||
stats[changed_pair] += delta
|
||||
|
||||
# Update positions for changed pairs - only check affected chunks
|
||||
for chunk_idx in affected_chunks:
|
||||
chunk_ids = ids[chunk_idx]
|
||||
contains_pair = any((chunk_ids[j], chunk_ids[j+1]) == changed_pair
|
||||
for j in range(len(chunk_ids) - 1))
|
||||
if contains_pair:
|
||||
positions[changed_pair].add(chunk_idx)
|
||||
else:
|
||||
positions[changed_pair].discard(chunk_idx)
|
||||
|
||||
# Remove the merged pair completely
|
||||
del stats[pair]
|
||||
del positions[pair]
|
||||
|
||||
# save the merge
|
||||
merges[pair] = idx
|
||||
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
|
||||
|
||||
# save class variables
|
||||
self.merges = merges # used in encode()
|
||||
self.vocab = vocab # used in decode()
|
||||
|
||||
def register_special_tokens(self, special_tokens):
|
||||
# special_tokens is a dictionary of str -> int
|
||||
# example: {"<|endoftext|>": 100257}
|
||||
self.special_tokens = special_tokens
|
||||
self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
|
||||
|
||||
def decode(self, ids):
|
||||
# given ids (list of integers), return Python string
|
||||
part_bytes = []
|
||||
for idx in ids:
|
||||
if idx in self.vocab:
|
||||
part_bytes.append(self.vocab[idx])
|
||||
elif idx in self.inverse_special_tokens:
|
||||
part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
|
||||
else:
|
||||
raise ValueError(f"invalid token id: {idx}")
|
||||
text_bytes = b"".join(part_bytes)
|
||||
text = text_bytes.decode("utf-8", errors="replace")
|
||||
return text
|
||||
|
||||
def _encode_chunk(self, text_bytes):
|
||||
# return the token ids
|
||||
# let's begin. first, convert all bytes to integers in range 0..255
|
||||
ids = list(text_bytes)
|
||||
while len(ids) >= 2:
|
||||
# find the pair with the lowest merge index
|
||||
stats = get_stats(ids)
|
||||
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
|
||||
# subtle: if there are no more merges available, the key will
|
||||
# result in an inf for every single pair, and the min will be
|
||||
# just the first pair in the list, arbitrarily
|
||||
# we can detect this terminating case by a membership check
|
||||
if pair not in self.merges:
|
||||
break # nothing else can be merged anymore
|
||||
# otherwise let's merge the best pair (lowest merge index)
|
||||
idx = self.merges[pair]
|
||||
ids = fast_merge_inplace(ids, pair, idx)
|
||||
return ids
|
||||
|
||||
def encode_ordinary(self, text):
|
||||
"""Encoding that ignores any special tokens."""
|
||||
# split text into chunks of text by categories defined in regex pattern
|
||||
text_chunks = re.findall(self.compiled_pattern, text)
|
||||
# all chunks of text are encoded separately, then results are joined
|
||||
ids = []
|
||||
for chunk in text_chunks:
|
||||
chunk_bytes = chunk.encode("utf-8") # raw bytes
|
||||
chunk_ids = self._encode_chunk(chunk_bytes)
|
||||
ids.extend(chunk_ids)
|
||||
return ids
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# HuggingFace tokenizer
|
||||
from tokenizers import Tokenizer as HFTokenizer
|
||||
from tokenizers import pre_tokenizers, decoders, Regex
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
|
||||
class HuggingFaceTokenizer:
|
||||
"""Light wrapper around HuggingFace Tokenizer for some utilities"""
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@classmethod
|
||||
def train_from_iterator(cls, text_iterator, vocab_size):
|
||||
# train from an iterator of text
|
||||
# Configure the HuggingFace Tokenizer
|
||||
tokenizer = HFTokenizer(BPE(
|
||||
byte_fallback=True, # needed!
|
||||
unk_token=None,
|
||||
fuse_unk=False,
|
||||
))
|
||||
# Normalizer: None
|
||||
tokenizer.normalizer = None
|
||||
# Pre-tokenizer: GPT-4 style
|
||||
gpt4_split_regex = Regex(GPT4_SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
|
||||
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
|
||||
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
|
||||
])
|
||||
# Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
|
||||
tokenizer.decoder = decoders.ByteLevel()
|
||||
# Post-processor: None
|
||||
tokenizer.post_processor = None
|
||||
# Trainer: BPE
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=vocab_size,
|
||||
show_progress=True,
|
||||
min_frequency=0, # no minimum frequency
|
||||
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
||||
special_tokens=[], # no special tokens
|
||||
)
|
||||
# Kick off the training
|
||||
tokenizer.train_from_iterator(text_iterator, trainer)
|
||||
return cls(tokenizer)
|
||||
|
||||
def encode_ordinary(self, text):
|
||||
ids = self.tokenizer.encode(text, add_special_tokens=False).ids
|
||||
return ids
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Test all of the above
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def enwik8_path():
|
||||
"""Fixture to download and cache enwik8 dataset."""
|
||||
import os
|
||||
import zipfile
|
||||
from nanochat.common import get_base_dir
|
||||
base_dir = get_base_dir()
|
||||
# download and unzip enwik8 to .cache directory
|
||||
enwik8_url = "https://mattmahoney.net/dc/enwik8.zip"
|
||||
enwik8_local_path = os.path.join(base_dir, "enwik8")
|
||||
enwik8_local_path_zip = os.path.join(base_dir, "enwik8.zip")
|
||||
if not os.path.exists(enwik8_local_path):
|
||||
print(f"Downloading enwik8 to {enwik8_local_path_zip}")
|
||||
import requests
|
||||
response = requests.get(enwik8_url)
|
||||
with open(enwik8_local_path_zip, "wb") as f:
|
||||
f.write(response.content)
|
||||
with zipfile.ZipFile(enwik8_local_path_zip, "r") as zip_ref:
|
||||
zip_ref.extractall(base_dir)
|
||||
print(f"Unzipped enwik8 to {enwik8_local_path}")
|
||||
os.remove(enwik8_local_path_zip)
|
||||
print(f"Removed {enwik8_local_path_zip}")
|
||||
else:
|
||||
print(f"Using existing enwik8 at {enwik8_local_path}")
|
||||
return enwik8_local_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def enwik8_small(enwik8_path):
|
||||
"""Fixture providing 100KB of enwik8 for quick tests."""
|
||||
with open(enwik8_path, "r", encoding="utf-8") as f:
|
||||
return f.read(100_000)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def enwik8_large(enwik8_path):
|
||||
"""Fixture providing 10MB of enwik8 for performance tests."""
|
||||
with open(enwik8_path, "r", encoding="utf-8") as f:
|
||||
return f.read(10**7)
|
||||
|
||||
def time_function(func, *args, **kwargs):
|
||||
"""Time a function call and return the result and elapsed time"""
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
elapsed = end_time - start_time
|
||||
return result, elapsed
|
||||
|
||||
def test_correctness(enwik8_small):
|
||||
"""Test that all tokenizer implementations produce the same results."""
|
||||
text = enwik8_small
|
||||
encode_text = text
|
||||
vocab_size = 256 + 20 # 20 merges
|
||||
|
||||
# Train slow reference
|
||||
print("\nTraining slow reference...")
|
||||
slow_reference_tokenizer = RegexTokenizer()
|
||||
ambiguous_flag, slow_reference_train_time = time_function(slow_reference_tokenizer.train, text, vocab_size)
|
||||
slow_reference_ids, slow_reference_encode_time = time_function(slow_reference_tokenizer.encode_ordinary, encode_text)
|
||||
print(f"Slow reference train time: {slow_reference_train_time:.4f}s")
|
||||
print(f"Slow reference encode time: {slow_reference_encode_time:.4f}s")
|
||||
print(slow_reference_ids[:20])
|
||||
|
||||
if ambiguous_flag:
|
||||
print("‼️ WARNING: merge order was detected to be ambiguous given current text and vocab size")
|
||||
print("The implementation could be correct but we might see different results below")
|
||||
else:
|
||||
print("✅ Merge order is NOT ambiguous")
|
||||
|
||||
# Train fast reference
|
||||
print("\nTraining fast reference...")
|
||||
fast_reference_tokenizer = FastRegexTokenizer()
|
||||
_, fast_reference_train_time = time_function(fast_reference_tokenizer.train, text, vocab_size)
|
||||
fast_reference_ids, fast_reference_encode_time = time_function(fast_reference_tokenizer.encode_ordinary, encode_text)
|
||||
print(f"Fast reference train time: {fast_reference_train_time:.4f}s")
|
||||
print(f"Fast reference encode time: {fast_reference_encode_time:.4f}s")
|
||||
print(fast_reference_ids[:20])
|
||||
|
||||
# Assert fast equals slow
|
||||
assert fast_reference_ids == slow_reference_ids, "Fast reference should match slow reference"
|
||||
print("✅ Fast == Slow")
|
||||
|
||||
# Train HuggingFace
|
||||
print("\nTraining HuggingFace...")
|
||||
hf_tokenizer, hf_train_time = time_function(HuggingFaceTokenizer.train_from_iterator, [text], vocab_size)
|
||||
hf_ids, hf_encode_time = time_function(hf_tokenizer.encode_ordinary, encode_text)
|
||||
print(f"HuggingFace train time: {hf_train_time:.4f}s")
|
||||
print(f"HuggingFace encode time: {hf_encode_time:.4f}s")
|
||||
print(hf_ids[:20])
|
||||
|
||||
# HuggingFace has a different byte order, so we need custom matching
|
||||
def custom_match(ids1, ids2):
|
||||
perm = {}
|
||||
for x, y in zip(ids1, ids2):
|
||||
if x < 256:
|
||||
if x in perm:
|
||||
if perm[x] != y:
|
||||
return False
|
||||
perm[x] = y
|
||||
if x >= 256 and x != y:
|
||||
return False
|
||||
return True
|
||||
|
||||
assert custom_match(hf_ids, fast_reference_ids), "HuggingFace should match fast reference"
|
||||
print("✅ HuggingFace == Fast")
|
||||
|
||||
# Finally use our own Rust implementation
|
||||
print("\nTraining rustbpe...")
|
||||
rustbpe_tokenizer = rustbpe.Tokenizer()
|
||||
_, rustbpe_train_time = time_function(rustbpe_tokenizer.train_from_iterator, [text], vocab_size)
|
||||
rustbpe_ids, rustbpe_encode_time = time_function(rustbpe_tokenizer.encode, encode_text)
|
||||
print(f"RustBPE train time: {rustbpe_train_time:.4f}s")
|
||||
print(f"RustBPE encode time: {rustbpe_encode_time:.4f}s")
|
||||
print(rustbpe_ids[:20])
|
||||
|
||||
assert rustbpe_ids == fast_reference_ids, "RustBPE should match fast reference"
|
||||
print("✅ RustBPE == Fast")
|
||||
|
||||
# Now export rustbpe to tiktoken for more efficient inference
|
||||
print("\nTesting tiktoken export...")
|
||||
pattern = rustbpe_tokenizer.get_pattern()
|
||||
mergeable_ranks_list = rustbpe_tokenizer.get_mergeable_ranks()
|
||||
mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
|
||||
enc = tiktoken.Encoding(
|
||||
name="rustbpe",
|
||||
pat_str=pattern,
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens={},
|
||||
)
|
||||
tiktoken_ids, tiktoken_encode_time = time_function(enc.encode, encode_text)
|
||||
print(f"Tiktoken encode time: {tiktoken_encode_time:.4f}s")
|
||||
print(tiktoken_ids[:20])
|
||||
|
||||
assert tiktoken_ids == rustbpe_ids, "Tiktoken should match RustBPE"
|
||||
print("✅ Tiktoken == RustBPE")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_training_performance(enwik8_large):
|
||||
"""Use a bigger dataset and compare the training speed of the optimized tokenizers (Python, Rust, HuggingFace)."""
|
||||
text = enwik8_large
|
||||
vocab_size = 2048
|
||||
print(f"\nText length: {len(text)}")
|
||||
|
||||
# Commenting out because it's just way too slow to matter
|
||||
# Train optimized python version
|
||||
# print("Training optimized python version...")
|
||||
# optimized_python_tokenizer = FastRegexTokenizer()
|
||||
# _, optimized_python_train_time = time_function(optimized_python_tokenizer.train, text, vocab_size)
|
||||
# print(f"Optimized python train time: {optimized_python_train_time:.4f}s")
|
||||
|
||||
# Train rustbpe
|
||||
print("\nTraining rustbpe...")
|
||||
rustbpe_tokenizer = rustbpe.Tokenizer()
|
||||
_, rustbpe_train_time = time_function(rustbpe_tokenizer.train_from_iterator, [text], vocab_size)
|
||||
print(f"RustBPE train time: {rustbpe_train_time:.4f}s")
|
||||
assert rustbpe_train_time > 0, "Training should take some time"
|
||||
|
||||
# Train HuggingFace
|
||||
print("\nTraining HuggingFace...")
|
||||
hf_tokenizer, hf_train_time = time_function(HuggingFaceTokenizer.train_from_iterator, [text], vocab_size)
|
||||
print(f"HuggingFace train time: {hf_train_time:.4f}s")
|
||||
assert hf_train_time > 0, "Training should take some time"
|
||||
|
||||
# Print comparison
|
||||
print(f"\n📊 Performance comparison:")
|
||||
print(f" RustBPE: {rustbpe_train_time:.4f}s")
|
||||
print(f" HuggingFace: {hf_train_time:.4f}s")
|
||||
print(f" Speedup: {hf_train_time/rustbpe_train_time:.2f}x")
|
||||
|
||||
def test_interface(enwik8_small):
|
||||
"""Test the RustBPETokenizer interface for training, encoding, decoding, and serialization."""
|
||||
import tempfile
|
||||
from nanochat.tokenizer import RustBPETokenizer
|
||||
|
||||
# Simple train test
|
||||
vocab_size = 300
|
||||
tok = RustBPETokenizer.train_from_iterator([enwik8_small], vocab_size)
|
||||
assert tok.get_vocab_size() == vocab_size, f"Expected vocab size {vocab_size}, got {tok.get_vocab_size()}"
|
||||
print(f"✅ Trained tokenizer with vocab size {vocab_size}")
|
||||
|
||||
# Encode/decode text
|
||||
encode_text = "Hello world! How are you? 🙃"
|
||||
ids = tok.encode(encode_text)
|
||||
print(f"\nInput text: {encode_text}")
|
||||
print(f"IDs: {ids}")
|
||||
decoded = tok.decode(ids)
|
||||
print(f"Decoded: {decoded}")
|
||||
assert decoded == encode_text, f"Decoded text doesn't match: {decoded} != {encode_text}"
|
||||
print("✅ Encode/decode test passed")
|
||||
|
||||
# Encode batch test
|
||||
ids_new = tok.encode([encode_text, encode_text])
|
||||
assert all(x == ids for x in ids_new), "Batch encoding should produce identical results"
|
||||
print("✅ Encode batch OK")
|
||||
|
||||
# append/prepend functionality
|
||||
ids_special = tok.encode(encode_text, prepend="<|bos|>", append="<|bos|>")
|
||||
bos_token_id = tok.encode_special("<|bos|>")
|
||||
assert ids_special == [bos_token_id] + ids + [bos_token_id], "Special tokens not correctly added"
|
||||
print("✅ append/prepend OK")
|
||||
|
||||
# Save/load test through a temporary directory
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tok.save(tmp_dir)
|
||||
tok_reloaded = RustBPETokenizer.from_directory(tmp_dir)
|
||||
ids_reloaded = tok_reloaded.encode(encode_text)
|
||||
assert ids_reloaded == ids, "Reloaded tokenizer should produce same results"
|
||||
print("✅ Save/load through temporary directory OK")
|
||||
|
||||
|
||||
def test_batch_encode_correctness(enwik8_small):
|
||||
"""Quick correctness test for batch_encode()"""
|
||||
text = enwik8_small
|
||||
vocab_size = 512
|
||||
|
||||
tokenizer = rustbpe.Tokenizer()
|
||||
tokenizer.train_from_iterator([text], vocab_size)
|
||||
|
||||
# Test with various batch sizes and edge cases
|
||||
test_texts = [
|
||||
"Hello world",
|
||||
"The quick brown fox",
|
||||
"jumps over the lazy dog",
|
||||
"", # empty string
|
||||
"a", # single char
|
||||
]
|
||||
|
||||
# Compare batch vs individual encoding
|
||||
individual = [tokenizer.encode(t) for t in test_texts]
|
||||
batched = tokenizer.batch_encode(test_texts)
|
||||
|
||||
assert individual == batched, "Batch encoding should match individual encoding"
|
||||
print("✅ batch_encode() correctness verified")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_batch_encode_performance(enwik8_large):
|
||||
"""
|
||||
Benchmark batch_encode() vs sequential encode() loop.
|
||||
Demonstrates parallelization speedup.
|
||||
"""
|
||||
# Setup
|
||||
text = enwik8_large # 10MB dataset
|
||||
vocab_size = 2048
|
||||
|
||||
# Train tokenizer
|
||||
print("\nTraining tokenizer...")
|
||||
tokenizer = rustbpe.Tokenizer()
|
||||
tokenizer.train_from_iterator([text], vocab_size)
|
||||
|
||||
# Create test batch: split text into chunks
|
||||
chunk_size = 50_000 # ~50KB per chunk
|
||||
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
|
||||
chunks = chunks[:20] # Use first 20 chunks (~1MB total)
|
||||
|
||||
print(f"\nBatch encoding benchmark:")
|
||||
print(f" Number of texts: {len(chunks)}")
|
||||
print(f" Avg text length: {sum(len(c) for c in chunks) / len(chunks):.0f} chars")
|
||||
|
||||
# Benchmark 1: Sequential encoding (baseline)
|
||||
print("\n [1/3] Sequential encode() loop...")
|
||||
sequential_results, sequential_time = time_function(
|
||||
lambda: [tokenizer.encode(chunk) for chunk in chunks]
|
||||
)
|
||||
print(f" Time: {sequential_time:.4f}s")
|
||||
|
||||
# Benchmark 2: Parallel batch_encode()
|
||||
print(" [2/3] Parallel batch_encode()...")
|
||||
batch_results, batch_time = time_function(
|
||||
tokenizer.batch_encode, chunks
|
||||
)
|
||||
print(f" Time: {batch_time:.4f}s")
|
||||
|
||||
# Verify correctness
|
||||
print(" [3/3] Verifying correctness...")
|
||||
assert len(batch_results) == len(sequential_results), "Result count mismatch"
|
||||
for i, (seq, batch) in enumerate(zip(sequential_results, batch_results)):
|
||||
assert seq == batch, f"Mismatch at index {i}"
|
||||
print(" ✓ All results match")
|
||||
|
||||
# Report speedup
|
||||
speedup = sequential_time / batch_time
|
||||
print(f"\n Performance Results:")
|
||||
print(f" Sequential: {sequential_time:.4f}s")
|
||||
print(f" Batch: {batch_time:.4f}s")
|
||||
print(f" Speedup: {speedup:.2f}x")
|
||||
|
||||
# Warn if speedup is low (can vary by machine/load)
|
||||
if speedup < 1.5:
|
||||
warnings.warn(f"batch_encode() speedup was only {speedup:.2f}x (expected >1.5x)")
|
||||
Loading…
Reference in New Issue
Block a user