mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-18 10:53:13 +00:00
Merge branch 'master' into a2p2
This commit is contained in:
commit
65071f688c
26
README.md
26
README.md
|
|
@ -3,7 +3,7 @@
|
|||

|
||||

|
||||
|
||||
nanochat is the simplest experimental harness for training LLMs. It is designed to run on a single GPU node, the code is minimal/hackable, and it covers all major LLM stages including tokenization, pretraining, finetuning, evaluation, inference, and a chat UI. For example, you can train your own GPT-2 capability LLM (which cost ~$43,000 to train in 2019) for only $72 (~3 hours of 8XH100 GPU node) and then talk to it in a familiar ChatGPT-like web UI. On a spot instance, the total cost can be closer to ~$20. More generally, nanochat is configured out of the box to train an entire miniseries of compute-optimal models by setting one single complexity dial: `--depth`, the number of layers in the GPT transformer model (GPT-2 capability happens to be approximately depth 26). All other hyperparameters (the width of the transformer, number of heads, learning rate adjustments, training horizons, weight decays, ...) are calculated automatically in an optimal way.
|
||||
nanochat is the simplest experimental harness for training LLMs. It is designed to run on a single GPU node, the code is minimal/hackable, and it covers all major LLM stages including tokenization, pretraining, finetuning, evaluation, inference, and a chat UI. For example, you can train your own GPT-2 capability LLM (which cost ~$43,000 to train in 2019) for only $48 (~2 hours of 8XH100 GPU node) and then talk to it in a familiar ChatGPT-like web UI. On a spot instance, the total cost can be closer to ~$15. More generally, nanochat is configured out of the box to train an entire miniseries of compute-optimal models by setting one single complexity dial: `--depth`, the number of layers in the GPT transformer model (GPT-2 capability happens to be approximately depth 26). All other hyperparameters (the width of the transformer, number of heads, learning rate adjustments, training horizons, weight decays, ...) are calculated automatically in an optimal way.
|
||||
|
||||
For questions about the repo, I recommend either using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions about the repo, or use the [Discussions tab](https://github.com/karpathy/nanochat/discussions), or come by the [#nanochat](https://discord.com/channels/1020383067459821711/1427295580895314031) channel on Discord.
|
||||
|
||||
|
|
@ -17,8 +17,9 @@ Presently, the main focus of development is on tuning the pretraining stage, whi
|
|||
| 1 | 3.04 | 0.74833 | 0.2585 | d24 baseline, slightly overtrained | Jan 29 2026 | 348fbb3 | @karpathy |
|
||||
| 2 | 2.91 | 0.74504 | 0.2578 | d26 slightly undertrained **+fp8** | Feb 2 2026 | a67eba3 | @karpathy |
|
||||
| 3 | 2.76 | 0.74645 | 0.2602 | bump total batch size to 1M tokens | Feb 5 2026 | 2c062aa | @karpathy |
|
||||
| 4 | 2.02 | 0.71854 | 0.2571 | change dataset to NVIDIA ClimbMix | Mar 4 2026 | 324e69c | @ddudek @karpathy |
|
||||
|
||||
The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $43,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 3 hours is ~$72).
|
||||
The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $43,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 2 hours is ~$48).
|
||||
|
||||
See [dev/LEADERBOARD.md](dev/LEADERBOARD.md) for more docs on how to interpret and contribute to the leaderboard.
|
||||
|
||||
|
|
@ -81,6 +82,27 @@ The important thing to note is that nanochat is written and configured around on
|
|||
|
||||
The script [runs/runcpu.sh](runs/runcpu.sh) shows a very simple example of running on CPU or Apple Silicon. It dramatically shrinks the LLM that is being trained to make things fit into a reasonable time interval of a few ten minutes of training. You will not get strong results in this way.
|
||||
|
||||
## Precision / dtype
|
||||
|
||||
nanochat does not use `torch.amp.autocast`. Instead, precision is managed explicitly through a single global `COMPUTE_DTYPE` (defined in `nanochat/common.py`). By default this is auto-detected based on your hardware:
|
||||
|
||||
| Hardware | Default dtype | Why |
|
||||
|----------|--------------|-----|
|
||||
| CUDA SM 80+ (A100, H100, ...) | `bfloat16` | Native bf16 tensor cores |
|
||||
| CUDA SM < 80 (V100, T4, ...) | `float32` | No bf16; fp16 available via `NANOCHAT_DTYPE=float16` (uses GradScaler) |
|
||||
| CPU / MPS | `float32` | No reduced-precision tensor cores |
|
||||
|
||||
You can override the default with the `NANOCHAT_DTYPE` environment variable:
|
||||
|
||||
```bash
|
||||
NANOCHAT_DTYPE=float32 python -m scripts.chat_cli -p "hello" # force fp32
|
||||
NANOCHAT_DTYPE=bfloat16 torchrun --nproc_per_node=8 -m scripts.base_train # force bf16
|
||||
```
|
||||
|
||||
How it works: model weights are stored in fp32 (for optimizer precision), but our custom `Linear` layer casts them to `COMPUTE_DTYPE` during the forward pass. Embeddings are stored directly in `COMPUTE_DTYPE` to save memory. This gives us the same mixed-precision benefit as autocast but with full explicit control over what runs in which precision.
|
||||
|
||||
Note: `float16` training automatically enables a `GradScaler` in `base_train.py` to prevent gradient underflow. SFT suppors this too but RL currently does not. Inference in fp16 works fine everywhere.
|
||||
|
||||
## Guides
|
||||
|
||||
I've published a number of guides that might contain helpful information, most recent to least recent:
|
||||
|
|
|
|||
|
|
@ -147,3 +147,47 @@ Minimum validation bpb: 0.74645
|
|||
```
|
||||
|
||||
The big change here is that the batch size was doubled from 0.5M to 1M, which works better for a d26 model and allowed me to decrease the number of optimization steps a bit via `--target-param-data-ratio` from 8.5 to 8.25. The TLDR is that the original batch size of 0.5M was tuned for d12, but bigger models (e.g. d26) prefer larger total batch size. I determined in experiments that d26 prefers 1M. Then I implemented and merged a principled way to calculate the optimal batch size given depth so that all nanochat models of all depths benefit. See [dev/LOG.md](dev/LOG.md) entry "2026-02-05: Auto Batch Size Scaling" for more detail.
|
||||
|
||||
## Run 4
|
||||
|
||||
Achived Mar 3 2026 on commit `324e69c`. The big change is the switch from HuggingFace FineWeb-EDU to NVIDIA ClimbMix dataset. `@karpathy` has tried to swap the dataset many times, each time with a negative result (FineWeb, DCLM, Olmo), but ClimbMix produced clear and immediate gains. Credit to `@ddudek` for originally discovering ClimbMix for nanochat and reporting the improvements, which kicked off the followup investigation.
|
||||
|
||||
To reproduce, use the commit above, download at least 150 data shards, train the tokenizer:
|
||||
|
||||
```
|
||||
python -m nanochat.dataset -n 150
|
||||
python -m scripts.tok_train
|
||||
```
|
||||
|
||||
Then kick off the run in the typical way, using a slightly lower than compute optimal ratio of 9.5 (vs compute optimal 10.5), meaning the d24 is slightly undertrained.
|
||||
|
||||
```
|
||||
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
|
||||
--depth=24 \
|
||||
--run="d24-climbmix" \
|
||||
--model-tag="d24-climbmix" \
|
||||
--sample-every=-1 \
|
||||
--save-every=-1 \
|
||||
--core-metric-max-per-task=-1 \
|
||||
--core-metric-every=999999 \
|
||||
--target-param-data-ratio=9.5 \
|
||||
--device-batch-size=16 \
|
||||
--fp8
|
||||
```
|
||||
|
||||
I ran this command 7 individual times. Because our training is mildly non-deterministic, we get a spread of CORE scores, e.g.:
|
||||
|
||||
```
|
||||
0.25373
|
||||
0.2584
|
||||
0.25489
|
||||
0.2568
|
||||
0.25732
|
||||
0.26765
|
||||
0.25119
|
||||
```
|
||||
|
||||
Mean is 0.25714 (higher than the GPT-2 threshold needed), max-min is 0.01646. Something to investigate in the future is that even slightly better results can be obtained by randomly shuffling the the data shards (i.e. just going in a different order). This is unexpected because the documents were completely fully shuffled during data construction, so one would expect a relatively uniform data distribution. Indeed, the current default order is unfortunately among the worse ("unlucky") ones you can obtain with different shuffle seeds, but it suffices to beat GPT-2 for now so I am merging. TODO investing a bit more later.
|
||||
|
||||
NOTE: The `val_bpb` is as of this run *NOT* comparable due to the data distribution change to the previous 3 runs. This run happens to be at `0.71854` validation bpb. If the dataset is not changed, the `val_bpb` number is a great, smooth metric to track relative performance w.r.t. and has less noise than CORE.
|
||||
|
||||
|
|
|
|||
62
dev/LOG.md
62
dev/LOG.md
|
|
@ -4,6 +4,68 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
|||
|
||||
---
|
||||
|
||||
## 2026-03-04: Remove autocast, explicit dtype management, fp16 GradScaler
|
||||
|
||||
Replaced `torch.amp.autocast` throughout the codebase with explicit dtype management via a single `COMPUTE_DTYPE` global. Also added fp16 training support with GradScaler.
|
||||
|
||||
### Motivation
|
||||
|
||||
autocast is "magic we don't control" — it silently decides which ops run in which precision via internal allowlists. For this codebase, autocast was doing very little: the only thing it actually cast was `nn.Linear` weights from fp32 to bf16 for matmuls. `F.rms_norm`, `F.cross_entropy`, and Flash Attention all handle their own dtypes already. By making precision explicit, we gain fine-grained control (e.g. can experiment with fp32 norms) and eliminate an unnecessary layer of abstraction.
|
||||
|
||||
### What changed
|
||||
|
||||
**Core mechanism** (`nanochat/common.py`, `nanochat/gpt.py`):
|
||||
- `COMPUTE_DTYPE` auto-detected from hardware: SM 80+ → bf16, pre-Ampere → fp32, CPU/MPS → fp32. Override via `NANOCHAT_DTYPE` env var.
|
||||
- Custom `Linear(nn.Linear)` class that casts weights to match input dtype in forward: `F.linear(x, self.weight.to(dtype=x.dtype))`. This is the single mechanism that replaces autocast.
|
||||
- Embeddings cast to `COMPUTE_DTYPE` at init (saves memory). Exception: fp16 keeps embeddings fp32 because GradScaler cannot unscale fp16 gradients.
|
||||
- Embedding output explicitly cast to `COMPUTE_DTYPE` in `GPT.forward()` (no-op for bf16, active for fp16 path).
|
||||
- RoPE cos/sin cache uses `COMPUTE_DTYPE` instead of hardcoded bf16.
|
||||
|
||||
**Autocast removal** (11 files):
|
||||
- Deleted `--dtype` CLI flag, `ptdtype` variables, `autocast_ctx` definitions, and all `with autocast_ctx:` blocks from: `base_train.py`, `chat_sft.py`, `chat_rl.py`, `chat_cli.py`, `chat_eval.py`, `chat_web.py`, `base_eval.py`, `engine.py`, `bench_train_toks.py`, `test_e2e_pipeline.py`.
|
||||
|
||||
**fp16 + GradScaler** (`base_train.py`, `chat_sft.py`):
|
||||
- `scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None`
|
||||
- Backward: `scaler.scale(loss).backward()` vs plain `loss.backward()`
|
||||
- After accumulation: `scaler.unscale_(optimizer)` → distributed inf-sync via `scaler._found_inf_per_device(optimizer)` all-reduced with `ReduceOp.MAX` → `scaler.step(optimizer)` → `scaler.update()`
|
||||
- Zero overhead for bf16/fp32 paths (scaler is None, no branching inside kernels).
|
||||
|
||||
**FP8 fix** (`nanochat/fp8.py`, `base_train.py`):
|
||||
- `Float8Linear.forward` explicitly casts input to `COMPUTE_DTYPE` (previously relied on autocast).
|
||||
- `disable_fp8` context manager now creates our custom `Linear` (not vanilla `nn.Linear`) when swapping out Float8Linear during eval.
|
||||
|
||||
**Flash Attention** (`flash_attention.py`):
|
||||
- FA3 Hopper kernels don't support fp16 or fp32, so `USE_FA3` (module-level constant, resolved once at import) returns False, falling back to SDPA.
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-04: Dataset upgrade: FineWeb-EDU 100B → ClimbMix 400B
|
||||
|
||||
Switched the pretraining dataset from FineWeb-EDU 100B to ClimbMix 400B. This is by far the single biggest improvement to nanochat's GPT-2 speedrun time, bringing it down from **2 hours 46 minutes to 2 hours 1 minute** — a 27% reduction.
|
||||
|
||||
### What is ClimbMix?
|
||||
|
||||
ClimbMix 400B is a curated 400B-token pretraining mixture hosted at `karpathy/climbmix-400b-shuffle` on HuggingFace. It comes form [NVIDIA](https://huggingface.co/datasets/nvidia/Nemotron-ClimbMix). It is a blend of high-quality web text, code, math, and other sources, designed to be a better general-purpose pretraining dataset than FineWeb-EDU alone.
|
||||
|
||||
### What changed
|
||||
|
||||
- **Dataset**: `karpathy/fineweb-edu-100b-shuffle` → `karpathy/climbmix-400b-shuffle` (up to 6543 shards available vs the previous 1823 data shards, allowing for longer training in the future)
|
||||
- **Data directory**: `base_data/` → `base_data_climbmix/` (clean separation from legacy data)
|
||||
- **Model depth**: d26 → d24. ClimbMix trains more efficiently, so a smaller model reaches GPT-2 capability
|
||||
- **Shard count**: Only approx 150 data shards (~7B tokens) are now needed for GPT-2 capability
|
||||
- **Eval tokens**: doubled from 40 to 80 batches for more stable validation loss estimates
|
||||
- **Legacy fallback**: added a migration warning in `list_parquet_files()` that detects the old `base_data/` directory and falls back gracefully, so existing users see clear upgrade instructions on `git pull`
|
||||
|
||||
### Context
|
||||
|
||||
This is the sixth attempt at beating FineWeb-EDU on CORE score — the previous five all failed (see entries on 2026-02-17, 2026-02-10, 2026-01-12 below). ClimbMix is the first dataset to convincingly surpass it, and the margin is large enough to also shrink the model from d26 to d24.
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-02: SoftCap tuning
|
||||
|
||||
Quick experiment to tune logit softcap on d24 scale. Tried 5..30. 5 was terrible, the rest of them were all about equal with the exception of 20, which was the best. Minor but solid improvement: val loss improved by ~1e-3 (0.716 -> 0.715). Setting as default.
|
||||
|
||||
## 2026-02-19: Mixture of Experts (negative)
|
||||
|
||||
Implemented a DeepSeekV3-style Mixture of Experts layer as a drop-in replacement for the dense MLP. The MoE branch works and improves per-step validation loss, but is not a net improvement on wall clock time due to MoE overhead (at least for our scale of interest of approx GPT-2 capability).
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Repackage the FinewebEdu-100B dataset into shards:
|
||||
Repackage a given dataset into simple parquet shards:
|
||||
|
||||
- each shard is ~100MB in size (after zstd compression)
|
||||
- parquets are written with row group size of 1000
|
||||
|
|
@ -10,6 +10,16 @@ The big deal is that our DataLoader will be able to stream
|
|||
the data and cache it along the way on disk, decreasing the
|
||||
training latency.
|
||||
|
||||
Historical context:
|
||||
Originally, nanochat used the FinewebEdu-100B dataset.
|
||||
Then we switched to the ClimbMix-400B dataset due to superior performance.
|
||||
This script documents how both were prepared.
|
||||
|
||||
The outputs are here:
|
||||
|
||||
https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle
|
||||
https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle
|
||||
|
||||
NOTE: This file is meant only as reference/documentation of the
|
||||
dataset preparation and it is not used during the project runtime.
|
||||
"""
|
||||
|
|
@ -20,12 +30,37 @@ from datasets import load_dataset
|
|||
import pyarrow.parquet as pq
|
||||
import pyarrow as pa
|
||||
|
||||
# You can change these:
|
||||
dataset_tag = "climbmix"
|
||||
upload_to_hf = True
|
||||
|
||||
# Dataset configurations:
|
||||
if dataset_tag == "fineweb_edu":
|
||||
dataset_kwargs = {
|
||||
"path": "HuggingFaceFW/fineweb-edu",
|
||||
"split": "train",
|
||||
"name": "sample-100BT", # ~100B GPT-2 tokens at ~3 chars/token => ~300B chars total
|
||||
}
|
||||
output_dirname = "fineweb_edu"
|
||||
data_column_name = "text"
|
||||
tokenizer = None
|
||||
upload_tag = "fineweb-edu-100b-shuffle"
|
||||
|
||||
elif dataset_tag == "climbmix":
|
||||
import tiktoken # the ClimbMix data is stored tokenized with GPT-2 tokenizer
|
||||
dataset_kwargs = {
|
||||
"path": "nvidia/Nemotron-ClimbMix",
|
||||
"split": "train",
|
||||
}
|
||||
output_dirname = "climbmix"
|
||||
data_column_name = "tokens"
|
||||
tokenizer = tiktoken.encoding_for_model("gpt-2")
|
||||
upload_tag = "climbmix-400b-shuffle"
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset tag: {dataset_tag}")
|
||||
|
||||
# Source dataset
|
||||
dataset_kwargs = {
|
||||
"path": "HuggingFaceFW/fineweb-edu",
|
||||
"split": "train",
|
||||
"name": "sample-100BT", # ~100B GPT-2 tokens at ~3 chars/token => ~300B chars total
|
||||
}
|
||||
ds = load_dataset(**dataset_kwargs)
|
||||
|
||||
# Shuffle to scramble the order
|
||||
|
|
@ -34,7 +69,7 @@ ndocs = len(ds) # total number of documents to process
|
|||
print(f"Total number of documents: {ndocs}")
|
||||
|
||||
# Repackage into parquet files
|
||||
output_dir = "/home/ubuntu/.cache/nanochat/base_data"
|
||||
output_dir = f"/home/ubuntu/.cache/nanochat/base_data_{output_dirname}"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Write to parquet files
|
||||
|
|
@ -47,7 +82,8 @@ total_docs_processed = 0
|
|||
total_time_spent = 0
|
||||
t0 = time.time()
|
||||
for doc in ds:
|
||||
text = doc['text']
|
||||
data = doc[data_column_name]
|
||||
text = tokenizer.decode(data) if tokenizer is not None else data
|
||||
shard_docs.append(text)
|
||||
shard_characters += len(text)
|
||||
collected_enough_chars = shard_characters >= chars_per_shard
|
||||
|
|
@ -79,14 +115,12 @@ for doc in ds:
|
|||
shard_index += 1
|
||||
|
||||
# Demonstration of how the data was later uploaded to HuggingFace
|
||||
def upload():
|
||||
import os
|
||||
if upload_to_hf:
|
||||
from huggingface_hub import HfApi
|
||||
token = os.getenv("HF_TOKEN")
|
||||
api = HfApi(token=token)
|
||||
api.upload_large_folder(
|
||||
folder_path=output_dir,
|
||||
repo_id="karpathy/fineweb-edu-100b-shuffle",
|
||||
repo_id=f"karpathy/{upload_tag}",
|
||||
repo_type="dataset",
|
||||
)
|
||||
# upload()
|
||||
|
|
|
|||
|
|
@ -10,6 +10,26 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from filelock import FileLock
|
||||
|
||||
# The dtype used for compute (matmuls, activations). Master weights stay fp32 for optimizer precision.
|
||||
# Linear layers cast their weights to this dtype in forward, replacing torch.amp.autocast.
|
||||
# Override with NANOCHAT_DTYPE env var: "bfloat16", "float16", "float32"
|
||||
_DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
|
||||
def _detect_compute_dtype():
|
||||
env = os.environ.get("NANOCHAT_DTYPE")
|
||||
if env is not None:
|
||||
return _DTYPE_MAP[env], f"set via NANOCHAT_DTYPE={env}"
|
||||
if torch.cuda.is_available():
|
||||
# bf16 requires SM 80+ (Ampere: A100, A10, etc.)
|
||||
# Older GPUs like V100 (SM 70) and T4 (SM 75) only have fp16 tensor cores
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability >= (8, 0):
|
||||
return torch.bfloat16, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (bf16 supported)"
|
||||
# fp16 training requires GradScaler (not yet implemented), so fall back to fp32.
|
||||
# Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing.
|
||||
return torch.float32, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (pre-Ampere, bf16 not supported, using fp32)"
|
||||
return torch.float32, "auto-detected: no CUDA (CPU/MPS)"
|
||||
COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype()
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""Custom formatter that adds colors to log messages."""
|
||||
# ANSI color codes
|
||||
|
|
|
|||
|
|
@ -32,7 +32,8 @@ def _document_batches(split, resume_state_dict, tokenizer_batch_size):
|
|||
"""
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
|
||||
parquet_paths = list_parquet_files()
|
||||
warn_on_legacy = ddp_rank == 0 and split == "train" # rank 0 on train split will warn on legacy
|
||||
parquet_paths = list_parquet_files(warn_on_legacy=warn_on_legacy)
|
||||
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
|
||||
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||
|
||||
|
|
|
|||
|
|
@ -20,19 +20,43 @@ from nanochat.common import get_base_dir
|
|||
# The specifics of the current pretraining dataset
|
||||
|
||||
# The URL on the internet where the data is hosted and downloaded from on demand
|
||||
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
|
||||
MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
|
||||
BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main"
|
||||
MAX_SHARD = 6542 # the last datashard is shard_06542.parquet
|
||||
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
|
||||
base_dir = get_base_dir()
|
||||
DATA_DIR = os.path.join(base_dir, "base_data")
|
||||
os.makedirs(DATA_DIR, exist_ok=True)
|
||||
DATA_DIR = os.path.join(base_dir, "base_data_climbmix")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# These functions are useful utilities to other modules, can/should be imported
|
||||
|
||||
def list_parquet_files(data_dir=None):
|
||||
def list_parquet_files(data_dir=None, warn_on_legacy=False):
|
||||
""" Looks into a data dir and returns full paths to all parquet files. """
|
||||
data_dir = DATA_DIR if data_dir is None else data_dir
|
||||
|
||||
# Legacy-supporting code due to the upgrade from FinewebEdu-100B to ClimbMix-400B
|
||||
# This code will eventually be deleted.
|
||||
if not os.path.exists(data_dir):
|
||||
if warn_on_legacy:
|
||||
print()
|
||||
print("=" * 80)
|
||||
print(" WARNING: DATASET UPGRADE REQUIRED")
|
||||
print("=" * 80)
|
||||
print()
|
||||
print(f" Could not find: {data_dir}")
|
||||
print()
|
||||
print(" nanochat recently switched from FinewebEdu-100B to ClimbMix-400B.")
|
||||
print(" Everyone who does `git pull` as of March 4, 2026 is expected to see this message.")
|
||||
print(" To upgrade to the new ClimbMix-400B dataset, run these two commands:")
|
||||
print()
|
||||
print(" python -m nanochat.dataset -n 170 # download ~170 shards, enough for GPT-2, adjust as desired")
|
||||
print(" python -m scripts.tok_train # re-train tokenizer on new ClimbMix data")
|
||||
print()
|
||||
print(" For now, falling back to your old FinewebEdu-100B dataset...")
|
||||
print("=" * 80)
|
||||
print()
|
||||
# attempt a fallback to the legacy data directory
|
||||
data_dir = os.path.join(base_dir, "base_data")
|
||||
|
||||
parquet_files = sorted([
|
||||
f for f in os.listdir(data_dir)
|
||||
if f.endswith('.parquet') and not f.endswith('.tmp')
|
||||
|
|
@ -110,13 +134,21 @@ def download_single_file(index):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
|
||||
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
|
||||
parser = argparse.ArgumentParser(description="Download pretraining dataset shards")
|
||||
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of train shards to download (default: -1), -1 = disable")
|
||||
parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
|
||||
args = parser.parse_args()
|
||||
|
||||
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
|
||||
ids_to_download = list(range(num))
|
||||
# Prepare the output directory
|
||||
os.makedirs(DATA_DIR, exist_ok=True)
|
||||
|
||||
# The way this works is that the user specifies the number of train shards to download via the -n flag.
|
||||
# In addition to that, the validation shard is *always* downloaded and is pinned to be the last shard.
|
||||
num_train_shards = MAX_SHARD if args.num_files == -1 else min(args.num_files, MAX_SHARD)
|
||||
ids_to_download = list(range(num_train_shards))
|
||||
ids_to_download.append(MAX_SHARD) # always download the validation shard
|
||||
|
||||
# Download the shards
|
||||
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
|
||||
print(f"Target directory: {DATA_DIR}")
|
||||
print()
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ from contextlib import contextmanager
|
|||
from collections import deque
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from contextlib import nullcontext
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Calculator tool helpers
|
||||
|
|
@ -308,8 +307,6 @@ if __name__ == "__main__":
|
|||
# init compute
|
||||
device_type = autodetect_device_type()
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# load the model and tokenizer
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval")
|
||||
bos_token_id = tokenizer.get_bos_token_id()
|
||||
|
|
@ -322,11 +319,10 @@ if __name__ == "__main__":
|
|||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
stream = model.generate(prompt_tokens, **kwargs)
|
||||
with autocast_ctx:
|
||||
for token in stream:
|
||||
generated_tokens.append(token)
|
||||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
for token in stream:
|
||||
generated_tokens.append(token)
|
||||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
print()
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
|
|
@ -338,12 +334,11 @@ if __name__ == "__main__":
|
|||
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
with autocast_ctx:
|
||||
for token_column, token_masks in stream:
|
||||
token = token_column[0] # only print out the first row
|
||||
generated_tokens.append(token)
|
||||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
for token_column, token_masks in stream:
|
||||
token = token_column[0] # only print out the first row
|
||||
generated_tokens.append(token)
|
||||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
print()
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
|
|
|
|||
|
|
@ -45,14 +45,22 @@ HAS_FA3 = _fa3 is not None
|
|||
_override_impl = None
|
||||
|
||||
|
||||
def _use_fa3():
|
||||
"""Determine whether to use FA3 based on availability and override."""
|
||||
def _resolve_use_fa3():
|
||||
"""Decide once whether to use FA3, based on availability, override, and dtype."""
|
||||
if _override_impl == 'fa3':
|
||||
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
|
||||
return True
|
||||
if _override_impl == 'sdpa':
|
||||
return False
|
||||
return HAS_FA3 # auto
|
||||
if HAS_FA3:
|
||||
# FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback
|
||||
from nanochat.common import COMPUTE_DTYPE
|
||||
if COMPUTE_DTYPE == torch.bfloat16:
|
||||
return True
|
||||
return False
|
||||
return False
|
||||
|
||||
USE_FA3 = _resolve_use_fa3()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -90,7 +98,7 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
|||
# sliding window (left)
|
||||
if window >= 0 and window < Tk:
|
||||
mask = mask & ((row_idx - col_idx) <= window)
|
||||
|
||||
|
||||
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -108,7 +116,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
|||
Returns:
|
||||
Output tensor of shape (B, T, H, D)
|
||||
"""
|
||||
if _use_fa3():
|
||||
if USE_FA3:
|
||||
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
||||
|
||||
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
|
||||
|
|
@ -138,7 +146,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
|
|||
Returns:
|
||||
Output tensor of shape (B, T_new, H, D)
|
||||
"""
|
||||
if _use_fa3():
|
||||
if USE_FA3:
|
||||
return _fa3.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
|
||||
causal=causal, window_size=window_size
|
||||
|
|
|
|||
|
|
@ -72,6 +72,8 @@ generates a different graph. Numerics are bitwise identical in eager mode.
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from nanochat.common import COMPUTE_DTYPE
|
||||
|
||||
# Avoid division by zero when computing scale from an all-zeros tensor
|
||||
EPS = 1e-12
|
||||
|
||||
|
|
@ -123,7 +125,7 @@ def _to_col_major(x):
|
|||
class _Float8Matmul(torch.autograd.Function):
|
||||
"""Custom autograd for the three FP8 GEMMs of a Linear layer.
|
||||
|
||||
The forward quantizes input and weight to FP8 and saves
|
||||
The forward quantizes input and weight to FP8 and saves
|
||||
the quantized tensors + scales for backward.
|
||||
"""
|
||||
|
||||
|
|
@ -198,11 +200,9 @@ class Float8Linear(nn.Linear):
|
|||
"""
|
||||
|
||||
def forward(self, input):
|
||||
# Replicate the autocast behavior of F.linear — when autocast is active,
|
||||
# we need to manually cast input to the autocast dtype (e.g. bf16),
|
||||
# since we bypass F.linear's built-in autocast handling.
|
||||
if torch.is_autocast_enabled():
|
||||
input = input.to(torch.get_autocast_gpu_dtype())
|
||||
# Cast input to COMPUTE_DTYPE (typically bf16) since _scaled_mm expects
|
||||
# reduced precision input, and we no longer rely on autocast to do this.
|
||||
input = input.to(COMPUTE_DTYPE)
|
||||
# _scaled_mm only works on 2D tensors, so flatten batch dimensions
|
||||
orig_shape = input.shape
|
||||
input_2d = input.reshape(-1, orig_shape[-1])
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nanochat.common import get_dist_info, print0
|
||||
from nanochat.common import get_dist_info, print0, COMPUTE_DTYPE
|
||||
from nanochat.optim import MuonAdamW, DistMuonAdamW
|
||||
|
||||
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
||||
|
|
@ -46,8 +46,14 @@ class GPTConfig:
|
|||
|
||||
|
||||
def norm(x):
|
||||
# Purely functional rmsnorm with no learnable params
|
||||
return F.rms_norm(x, (x.size(-1),))
|
||||
return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok
|
||||
|
||||
class Linear(nn.Linear):
|
||||
"""nn.Linear that casts weights to match input dtype in forward.
|
||||
Replaces autocast: master weights stay fp32 for optimizer precision,
|
||||
but matmuls run in the activation dtype (typically bf16 from embeddings)."""
|
||||
def forward(self, x):
|
||||
return F.linear(x, self.weight.to(dtype=x.dtype))
|
||||
|
||||
|
||||
def has_ve(layer_idx, n_layer):
|
||||
|
|
@ -72,12 +78,12 @@ class CausalSelfAttention(nn.Module):
|
|||
self.head_dim = self.n_embd // self.n_head
|
||||
assert self.n_embd % self.n_head == 0
|
||||
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
||||
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
||||
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||
self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
||||
self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_proj = Linear(self.n_embd, self.n_embd, bias=False)
|
||||
self.ve_gate_channels = 32
|
||||
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
||||
self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
||||
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
B, T, C = x.size()
|
||||
|
|
@ -139,7 +145,6 @@ class MLP(nn.Module):
|
|||
else: # relu2
|
||||
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
||||
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
if self.mlp_type == "swiglu":
|
||||
return self.c_proj(F.silu(self.c_gate(x)) * self.c_up(x))
|
||||
|
|
@ -180,7 +185,7 @@ class GPT(nn.Module):
|
|||
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||
})
|
||||
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
|
||||
self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False)
|
||||
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
||||
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
||||
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
||||
|
|
@ -267,11 +272,13 @@ class GPT(nn.Module):
|
|||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.cos, self.sin = cos, sin
|
||||
|
||||
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
# Cast embeddings to COMPUTE_DTYPE: optimizer can tolerate reduced-precision
|
||||
# embeddings and it saves memory. Exception: fp16 requires fp32 embeddings
|
||||
# because GradScaler cannot unscale fp16 gradients.
|
||||
if COMPUTE_DTYPE != torch.float16:
|
||||
self.transformer.wte.to(dtype=COMPUTE_DTYPE)
|
||||
for ve in self.value_embeds.values():
|
||||
ve.to(dtype=torch.bfloat16)
|
||||
ve.to(dtype=COMPUTE_DTYPE)
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=None, device=None):
|
||||
if base is None:
|
||||
|
|
@ -287,7 +294,7 @@ class GPT(nn.Module):
|
|||
# calculate the rotation frequencies at each (time, channel) pair
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
cos, sin = freqs.cos(), freqs.sin()
|
||||
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
||||
cos, sin = cos.to(COMPUTE_DTYPE), sin.to(COMPUTE_DTYPE)
|
||||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||
return cos, sin
|
||||
|
||||
|
|
@ -428,24 +435,25 @@ class GPT(nn.Module):
|
|||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
||||
assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}"
|
||||
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
||||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
x = self.transformer.wte(idx) # embed current token
|
||||
x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
|
||||
x = norm(x)
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
||||
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
x_hidden = x # pre-norm hidden states, used by MTP heads
|
||||
x = norm(x)
|
||||
|
||||
# Forward the lm_head (compute logits)
|
||||
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
|
||||
softcap = 20 # smoothly cap the logits to the range [-softcap, softcap]
|
||||
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
|
||||
logits = logits[..., :self.config.vocab_size] # slice to remove padding
|
||||
logits = logits.float() # switch to fp32 for logit softcap and loss computation
|
||||
|
|
|
|||
|
|
@ -55,9 +55,9 @@ python -m nanochat.report reset
|
|||
# look at dev/repackage_data_reference.py for details on how this data was prepared
|
||||
python -m nanochat.dataset -n 8
|
||||
# Immediately also kick off downloading more shards in the background while tokenizer trains
|
||||
# Approximately 350 shards are needed for 10B tokens of data for pretraining.
|
||||
# The maximum total number of shards available in the entire dataset is 1822.
|
||||
python -m nanochat.dataset -n 370 &
|
||||
# Approximately 150 shards are needed for GPT-2 capability pretraining, add 20 for padding.
|
||||
# The maximum total number of shards available in the entire dataset is 6542.
|
||||
python -m nanochat.dataset -n 170 &
|
||||
DATASET_DOWNLOAD_PID=$!
|
||||
# train the tokenizer with vocab size 2**15 = 32768 on ~2B characters of data
|
||||
python -m scripts.tok_train
|
||||
|
|
@ -69,8 +69,8 @@ python -m scripts.tok_eval
|
|||
echo "Waiting for dataset download to complete..."
|
||||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
# d26 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 8.25)
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --target-param-data-ratio=8.25 --device-batch-size=16 --fp8 --run=$WANDB_RUN
|
||||
# d24 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 9.5)
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=24 --target-param-data-ratio=9.5 --device-batch-size=16 --fp8 --run=$WANDB_RUN
|
||||
# evaluate the model: CORE metric, BPB on train/val, and draw samples
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16
|
||||
|
||||
|
|
|
|||
|
|
@ -29,8 +29,6 @@ import random
|
|||
import zipfile
|
||||
import tempfile
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
|
||||
|
|
@ -199,8 +197,6 @@ def main():
|
|||
# Distributed / precision setup
|
||||
device_type = autodetect_device_type() if args.device_type == '' else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# Load model and tokenizer
|
||||
is_hf_model = args.hf_path is not None
|
||||
if is_hf_model:
|
||||
|
|
@ -244,8 +240,7 @@ def main():
|
|||
print0("\nConditioned samples:")
|
||||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||
with autocast_ctx:
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
sample_str = tokenizer.decode(sample[0])
|
||||
print0("-" * 80)
|
||||
print0(sample_str)
|
||||
|
|
@ -253,8 +248,7 @@ def main():
|
|||
|
||||
print0("\nUnconditioned samples:")
|
||||
tokens = tokenizer("", prepend="<|bos|>")
|
||||
with autocast_ctx:
|
||||
uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
|
||||
uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
|
||||
for sample in uncond:
|
||||
sample_str = tokenizer.decode(sample)
|
||||
print0("-" * 80)
|
||||
|
|
@ -277,8 +271,7 @@ def main():
|
|||
|
||||
for split_name in ["train", "val"]:
|
||||
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
|
||||
with autocast_ctx:
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
bpb_results[split_name] = bpb
|
||||
print0(f"{split_name} bpb: {bpb:.6f}")
|
||||
|
||||
|
|
@ -287,8 +280,7 @@ def main():
|
|||
print0("\n" + "="*80)
|
||||
print0("CORE Evaluation")
|
||||
print0("="*80)
|
||||
with autocast_ctx:
|
||||
core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task)
|
||||
core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task)
|
||||
|
||||
# Write CSV output
|
||||
if ddp_rank == 0:
|
||||
|
|
|
|||
|
|
@ -19,14 +19,15 @@ import time
|
|||
import math
|
||||
import argparse
|
||||
from dataclasses import asdict
|
||||
from contextlib import nullcontext, contextmanager
|
||||
from contextlib import contextmanager
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.gpt import GPT, GPTConfig, Linear
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -75,7 +76,7 @@ parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR a
|
|||
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--eval-tokens", type=int, default=80*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
|
||||
parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
|
||||
parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
|
||||
|
|
@ -90,7 +91,6 @@ user_config = vars(args).copy() # for logging
|
|||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
if device_type == "cuda":
|
||||
|
|
@ -99,17 +99,23 @@ if device_type == "cuda":
|
|||
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
|
||||
else:
|
||||
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
|
||||
print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
|
||||
|
||||
# Flash Attention status
|
||||
if HAS_FA3:
|
||||
from nanochat.flash_attention import USE_FA3
|
||||
using_fa3 = USE_FA3
|
||||
if using_fa3:
|
||||
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
|
||||
else:
|
||||
print0("!" * 80)
|
||||
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
|
||||
if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16:
|
||||
print0(f"WARNING: Flash Attention 3 only supports bf16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback")
|
||||
else:
|
||||
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
|
||||
print0("WARNING: Training will be less efficient without FA3")
|
||||
if args.window_pattern != "L":
|
||||
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
|
||||
|
|
@ -221,9 +227,9 @@ def disable_fp8(model):
|
|||
yield # No FP8 modules, nothing to do
|
||||
return
|
||||
|
||||
# Swap Float8Linear -> nn.Linear (shares the same weight tensor, no copy)
|
||||
# Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype)
|
||||
for parent, attr_name, fp8_module in fp8_locations:
|
||||
linear = nn.Linear(
|
||||
linear = Linear(
|
||||
fp8_module.in_features,
|
||||
fp8_module.out_features,
|
||||
bias=fp8_module.bias is not None,
|
||||
|
|
@ -323,6 +329,12 @@ if resuming:
|
|||
optimizer.load_state_dict(optimizer_data)
|
||||
del optimizer_data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# GradScaler for fp16 training (bf16/fp32 don't need it — bf16 has the same exponent range as fp32)
|
||||
scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None
|
||||
if scaler is not None:
|
||||
print0("GradScaler enabled for fp16 training")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the DataLoaders for train/val
|
||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||
|
|
@ -413,7 +425,7 @@ while True:
|
|||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
with disable_fp8(model), autocast_ctx:
|
||||
with disable_fp8(model):
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
|
|
@ -432,7 +444,7 @@ while True:
|
|||
results = {}
|
||||
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
|
||||
model.eval()
|
||||
with disable_fp8(orig_model), autocast_ctx:
|
||||
with disable_fp8(orig_model):
|
||||
results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
|
||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
||||
wandb_run.log({
|
||||
|
|
@ -459,7 +471,7 @@ while True:
|
|||
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
|
||||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||
with disable_fp8(orig_model), autocast_ctx:
|
||||
with disable_fp8(orig_model):
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
print0(tokenizer.decode(sample[0]))
|
||||
model.train()
|
||||
|
|
@ -499,11 +511,13 @@ while True:
|
|||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
loss = model(x, y)
|
||||
loss = model(x, y)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
if scaler is not None:
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# step the optimizer
|
||||
lrm = get_lr_multiplier(step)
|
||||
|
|
@ -514,7 +528,18 @@ while True:
|
|||
if group['kind'] == 'muon':
|
||||
group["momentum"] = muon_momentum
|
||||
group["weight_decay"] = muon_weight_decay
|
||||
optimizer.step()
|
||||
if scaler is not None:
|
||||
scaler.unscale_(optimizer)
|
||||
# In distributed training, all ranks must agree on whether to skip the step.
|
||||
# Each rank may independently encounter inf/nan gradients, so we all-reduce
|
||||
# the found_inf flag (MAX = if any rank found inf, all ranks skip).
|
||||
if is_ddp_initialized():
|
||||
for v in scaler._found_inf_per_device(optimizer).values():
|
||||
dist.all_reduce(v, op=dist.ReduceOp.MAX)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
||||
synchronize()
|
||||
|
|
@ -541,7 +566,7 @@ while True:
|
|||
eta_str = f" | eta: {eta_seconds/60:.1f}m"
|
||||
else:
|
||||
eta_str = ""
|
||||
epoch = dataloader_state_dict["epoch"]
|
||||
epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}"
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
|
||||
if step % 100 == 0:
|
||||
log_data = {
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ python -m scripts.chat_cli
|
|||
import argparse
|
||||
import torch
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from contextlib import nullcontext
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
|
||||
|
|
@ -19,15 +18,12 @@ parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the mod
|
|||
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
args = parser.parse_args()
|
||||
|
||||
# Init the model and tokenizer
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
|
||||
# Special tokens for the chat state machine
|
||||
|
|
@ -87,12 +83,11 @@ while True:
|
|||
}
|
||||
response_tokens = []
|
||||
print("\nAssistant: ", end="", flush=True)
|
||||
with autocast_ctx:
|
||||
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
|
||||
token = token_column[0] # pop the batch dimension (num_samples=1)
|
||||
response_tokens.append(token)
|
||||
token_text = tokenizer.decode([token])
|
||||
print(token_text, end="", flush=True)
|
||||
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
|
||||
token = token_column[0] # pop the batch dimension (num_samples=1)
|
||||
response_tokens.append(token)
|
||||
token_text = tokenizer.decode([token])
|
||||
print(token_text, end="", flush=True)
|
||||
print()
|
||||
# we have to ensure that the assistant end token is the last token
|
||||
# so even if generation ends due to max tokens, we have to append it to the end
|
||||
|
|
|
|||
|
|
@ -10,8 +10,6 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
|
|||
|
||||
import argparse
|
||||
from functools import partial
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
|
@ -185,7 +183,6 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|rl")
|
||||
parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.0)
|
||||
parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
|
||||
parser.add_argument('-n', '--num-samples', type=int, default=1)
|
||||
|
|
@ -199,8 +196,6 @@ if __name__ == "__main__":
|
|||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
engine = Engine(model, tokenizer)
|
||||
|
|
@ -220,19 +215,18 @@ if __name__ == "__main__":
|
|||
# Run all the task evaluations sequentially
|
||||
results = {}
|
||||
for task_name in task_names:
|
||||
with autocast_ctx:
|
||||
acc = run_chat_eval(
|
||||
task_name,
|
||||
model, tokenizer, engine,
|
||||
batch_size=args.batch_size,
|
||||
num_samples=args.num_samples,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
max_problems=args.max_problems,
|
||||
)
|
||||
results[task_name] = acc
|
||||
print0(f"{task_name} accuracy: {100 * acc:.2f}%")
|
||||
acc = run_chat_eval(
|
||||
task_name,
|
||||
model, tokenizer, engine,
|
||||
batch_size=args.batch_size,
|
||||
num_samples=args.num_samples,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
max_problems=args.max_problems,
|
||||
)
|
||||
results[task_name] = acc
|
||||
print0(f"{task_name} accuracy: {100 * acc:.2f}%")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
|
|
|
|||
|
|
@ -22,8 +22,6 @@ import itertools
|
|||
import wandb
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model
|
||||
from nanochat.engine import Engine
|
||||
|
|
@ -36,7 +34,6 @@ parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
|
|||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
|
|
@ -68,8 +65,6 @@ user_config = vars(args).copy()
|
|||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
|
|
@ -108,15 +103,14 @@ def get_batch():
|
|||
num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs
|
||||
for sampling_step in range(num_sampling_steps):
|
||||
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
|
||||
with autocast_ctx:
|
||||
generated_token_sequences_batch, masks_batch = engine.generate_batch(
|
||||
tokens,
|
||||
num_samples=args.device_batch_size,
|
||||
max_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
seed=seed, # must make sure to change the seed for each sampling step
|
||||
)
|
||||
generated_token_sequences_batch, masks_batch = engine.generate_batch(
|
||||
tokens,
|
||||
num_samples=args.device_batch_size,
|
||||
max_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
seed=seed, # must make sure to change the seed for each sampling step
|
||||
)
|
||||
generated_token_sequences.extend(generated_token_sequences_batch)
|
||||
masks.extend(masks_batch)
|
||||
|
||||
|
|
@ -231,9 +225,8 @@ for step in range(num_steps):
|
|||
if step % args.eval_every == 0:
|
||||
model.eval()
|
||||
passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size
|
||||
with autocast_ctx:
|
||||
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
|
||||
records = list(records_iter) # collect all records
|
||||
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
|
||||
records = list(records_iter) # collect all records
|
||||
for k in range(1, args.device_batch_size + 1):
|
||||
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
|
||||
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
|
||||
|
|
@ -268,8 +261,7 @@ for step in range(num_steps):
|
|||
rewards = rewards_all[b0:b1]
|
||||
advantages = advantages_all[b0:b1]
|
||||
# Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
|
||||
with autocast_ctx:
|
||||
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
|
||||
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
|
||||
# Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
|
||||
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
|
||||
# normalize by the number of valid tokens, number of passes, and examples_per_rank
|
||||
|
|
|
|||
|
|
@ -16,8 +16,7 @@ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
|||
import time
|
||||
import wandb
|
||||
import torch
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -75,7 +74,7 @@ user_config = vars(args).copy()
|
|||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
if device_type == "cuda":
|
||||
|
|
@ -151,6 +150,11 @@ if args.load_optimizer:
|
|||
else:
|
||||
print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)")
|
||||
|
||||
# GradScaler for fp16 training (bf16/fp32 don't need it)
|
||||
scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None
|
||||
if scaler is not None:
|
||||
print0("GradScaler enabled for fp16 training")
|
||||
|
||||
# Override the initial learning rate as a fraction of the base learning rate
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
|
|
@ -197,7 +201,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
row_capacity = args.max_seq_len + 1 # +1 for target at last position
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
|
||||
# Conversation buffer: list of token lists
|
||||
# Conversation buffer: list of (token_ids, loss_mask) tuples
|
||||
conv_buffer = []
|
||||
cursor = ddp_rank # Each rank processes different conversations (for fetching)
|
||||
consumed = ddp_rank # Track actual consumption separately from buffering
|
||||
|
|
@ -208,8 +212,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
nonlocal cursor, epoch
|
||||
while len(conv_buffer) < buffer_size:
|
||||
conversation = dataset[cursor]
|
||||
ids, _ = tokenizer.render_conversation(conversation)
|
||||
conv_buffer.append(ids)
|
||||
ids, mask = tokenizer.render_conversation(conversation)
|
||||
conv_buffer.append((ids, mask))
|
||||
cursor += ddp_world_size
|
||||
if cursor >= dataset_size:
|
||||
cursor = cursor % dataset_size
|
||||
|
|
@ -218,9 +222,11 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
|
||||
while True:
|
||||
rows = []
|
||||
mask_rows = []
|
||||
row_lengths = [] # Track actual content length (excluding padding) for each row
|
||||
for _ in range(args.device_batch_size):
|
||||
row = []
|
||||
mask_row = []
|
||||
padded = False
|
||||
while len(row) < row_capacity:
|
||||
# Ensure buffer has conversations
|
||||
|
|
@ -232,7 +238,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
# Find largest conversation that fits entirely
|
||||
best_idx = -1
|
||||
best_len = 0
|
||||
for i, conv in enumerate(conv_buffer):
|
||||
for i, (conv, _) in enumerate(conv_buffer):
|
||||
conv_len = len(conv)
|
||||
if conv_len <= remaining and conv_len > best_len:
|
||||
best_idx = i
|
||||
|
|
@ -240,14 +246,16 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
|
||||
if best_idx >= 0:
|
||||
# Found a conversation that fits - use it entirely
|
||||
conv = conv_buffer.pop(best_idx)
|
||||
conv, conv_mask = conv_buffer.pop(best_idx)
|
||||
row.extend(conv)
|
||||
mask_row.extend(conv_mask)
|
||||
consumed += ddp_world_size # Track actual consumption
|
||||
else:
|
||||
# No conversation fits - pad the remainder instead of cropping
|
||||
# This ensures we never discard any tokens
|
||||
content_len = len(row)
|
||||
row.extend([bos_token] * remaining) # Pad with BOS tokens
|
||||
mask_row.extend([0] * remaining)
|
||||
padded = True
|
||||
break # Row is now full (with padding)
|
||||
|
||||
|
|
@ -257,6 +265,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
else:
|
||||
row_lengths.append(row_capacity)
|
||||
rows.append(row[:row_capacity])
|
||||
mask_rows.append(mask_row[:row_capacity])
|
||||
|
||||
# Stopping condition to respect num_iterations, if given
|
||||
it += 1
|
||||
|
|
@ -277,8 +286,15 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||
# Build tensors
|
||||
use_cuda = device_type == "cuda"
|
||||
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
|
||||
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
|
||||
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
|
||||
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda).contiguous()
|
||||
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda).contiguous()
|
||||
|
||||
# Apply the loss mask from render_conversation (mask=1 for assistant completions,
|
||||
# mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns
|
||||
# with targets (shifted by 1). Unmasked positions get -1 (ignore_index).
|
||||
mask_tensor = torch.tensor(mask_rows, dtype=torch.int8)
|
||||
mask_targets = mask_tensor[:, 1:].to(device=device)
|
||||
targets[mask_targets == 0] = -1
|
||||
|
||||
# Mask out padding positions in targets (set to -1 = ignore_index)
|
||||
# For each row, positions >= (content_length - 1) in targets should be masked
|
||||
|
|
@ -332,8 +348,7 @@ while True:
|
|||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
with autocast_ctx:
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
min_val_bpb = val_bpb
|
||||
|
|
@ -361,9 +376,8 @@ while True:
|
|||
for task_name in all_tasks:
|
||||
limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample
|
||||
max_problems = None if limit < 0 else limit # -1 means no limit
|
||||
with autocast_ctx:
|
||||
acc = run_chat_eval(task_name, orig_model, tokenizer, engine,
|
||||
batch_size=args.device_batch_size, max_problems=max_problems)
|
||||
acc = run_chat_eval(task_name, orig_model, tokenizer, engine,
|
||||
batch_size=args.device_batch_size, max_problems=max_problems)
|
||||
task_results[task_name] = acc
|
||||
print0(f" {task_name}: {100*acc:.2f}%")
|
||||
# Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect)
|
||||
|
|
@ -416,11 +430,13 @@ while True:
|
|||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
loss = model(x, y)
|
||||
loss = model(x, y)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
if scaler is not None:
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
progress = max(progress, approx_progress) # only increase progress monotonically
|
||||
# step the optimizer
|
||||
|
|
@ -430,7 +446,15 @@ while True:
|
|||
group["lr"] = group["initial_lr"] * lrm
|
||||
if group['kind'] == 'muon':
|
||||
group["momentum"] = muon_momentum
|
||||
optimizer.step()
|
||||
if scaler is not None:
|
||||
scaler.unscale_(optimizer)
|
||||
if is_ddp_initialized():
|
||||
for v in scaler._found_inf_per_device(optimizer).values():
|
||||
dist.all_reduce(v, op=dist.ReduceOp.MAX)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
|
|
|
|||
|
|
@ -44,7 +44,6 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Optional, AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
|
|
@ -69,7 +68,6 @@ parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default m
|
|||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||
args = parser.parse_args()
|
||||
|
|
@ -84,7 +82,6 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
|
||||
@dataclass
|
||||
class Worker:
|
||||
|
|
@ -93,7 +90,6 @@ class Worker:
|
|||
device: torch.device
|
||||
engine: Engine
|
||||
tokenizer: object
|
||||
autocast_ctx: torch.amp.autocast
|
||||
|
||||
class WorkerPool:
|
||||
"""Pool of workers, each with a model replica on a different GPU."""
|
||||
|
|
@ -125,14 +121,11 @@ class WorkerPool:
|
|||
|
||||
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||
engine = Engine(model, tokenizer)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
worker = Worker(
|
||||
gpu_id=gpu_id,
|
||||
device=device,
|
||||
engine=engine,
|
||||
tokenizer=tokenizer,
|
||||
autocast_ctx=autocast_ctx
|
||||
)
|
||||
self.workers.append(worker)
|
||||
await self.available_workers.put(worker)
|
||||
|
|
@ -279,34 +272,33 @@ async def generate_stream(
|
|||
# Track the last complete UTF-8 string (without replacement characters)
|
||||
last_clean_text = ""
|
||||
|
||||
with worker.autocast_ctx:
|
||||
for token_column, token_masks in worker.engine.generate(
|
||||
tokens,
|
||||
num_samples=1,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
seed=random.randint(0, 2**31 - 1)
|
||||
):
|
||||
token = token_column[0]
|
||||
for token_column, token_masks in worker.engine.generate(
|
||||
tokens,
|
||||
num_samples=1,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
seed=random.randint(0, 2**31 - 1)
|
||||
):
|
||||
token = token_column[0]
|
||||
|
||||
# Stopping criteria
|
||||
if token == assistant_end or token == bos:
|
||||
break
|
||||
# Stopping criteria
|
||||
if token == assistant_end or token == bos:
|
||||
break
|
||||
|
||||
# Append the token to sequence
|
||||
accumulated_tokens.append(token)
|
||||
# Decode all accumulated tokens to get proper UTF-8 handling
|
||||
# Note that decode is a quite efficient operation, basically table lookup and string concat
|
||||
current_text = worker.tokenizer.decode(accumulated_tokens)
|
||||
# Only emit text if it doesn't end with a replacement character
|
||||
# This ensures we don't emit incomplete UTF-8 sequences
|
||||
if not current_text.endswith('<EFBFBD>'):
|
||||
# Extract only the new text since last clean decode
|
||||
new_text = current_text[len(last_clean_text):]
|
||||
if new_text: # Only yield if there's new content
|
||||
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
|
||||
last_clean_text = current_text
|
||||
# Append the token to sequence
|
||||
accumulated_tokens.append(token)
|
||||
# Decode all accumulated tokens to get proper UTF-8 handling
|
||||
# Note that decode is a quite efficient operation, basically table lookup and string concat
|
||||
current_text = worker.tokenizer.decode(accumulated_tokens)
|
||||
# Only emit text if it doesn't end with a replacement character
|
||||
# This ensures we don't emit incomplete UTF-8 sequences
|
||||
if not current_text.endswith('<EFBFBD>'):
|
||||
# Extract only the new text since last clean decode
|
||||
new_text = current_text[len(last_clean_text):]
|
||||
if new_text: # Only yield if there's new content
|
||||
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
|
||||
last_clean_text = current_text
|
||||
|
||||
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||
|
||||
|
|
|
|||
|
|
@ -21,8 +21,9 @@ from nanochat.engine import KVCache
|
|||
|
||||
|
||||
def set_impl(impl):
|
||||
"""Set the implementation override ('fa3', 'sdpa', or None for auto)."""
|
||||
"""Set the implementation override ('fa3', 'sdpa', or None for auto) and re-resolve USE_FA3."""
|
||||
fa_module._override_impl = impl
|
||||
fa_module.USE_FA3 = fa_module._resolve_use_fa3()
|
||||
|
||||
|
||||
def run_both_impls(fn):
|
||||
|
|
@ -343,19 +344,19 @@ class TestOverrideMechanism:
|
|||
def test_override_fa3(self):
|
||||
"""Test that override='fa3' uses FA3."""
|
||||
set_impl('fa3')
|
||||
assert fa_module._use_fa3() == True
|
||||
assert fa_module.USE_FA3 == True
|
||||
set_impl(None)
|
||||
|
||||
def test_override_sdpa(self):
|
||||
"""Test that override='sdpa' uses SDPA."""
|
||||
set_impl('sdpa')
|
||||
assert fa_module._use_fa3() == False
|
||||
assert fa_module.USE_FA3 == False
|
||||
set_impl(None)
|
||||
|
||||
def test_override_auto(self):
|
||||
"""Test that override=None uses auto-detection."""
|
||||
set_impl(None)
|
||||
assert fa_module._use_fa3() == HAS_FA3
|
||||
assert fa_module.USE_FA3 == HAS_FA3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user