Merge branch 'karpathy:master' into master

This commit is contained in:
Dipesh Babu 2026-02-20 02:42:54 -05:00 committed by GitHub
commit c902a8d6bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 550 additions and 73 deletions

View File

@ -4,6 +4,95 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
---
## 2026-02-19: Mixture of Experts (negative)
Implemented a DeepSeekV3-style Mixture of Experts layer as a drop-in replacement for the dense MLP. The MoE branch works and improves per-step validation loss, but is not a net improvement on wall clock time due to MoE overhead (at least for our scale of interest of approx GPT-2 capability).
### Implementation
Follows DeepSeekV3 and using torchtitan as reference:
- **8 routed experts, top-2 routing** with sigmoid gating (not softmax)
- **1 shared expert** (dense MLP processing all tokens, following DeepSeekV3)
- **Auxiliary-loss-free load balancing** (DeepSeekV3's expert bias nudging)
- **Iso-FLOP sizing**: `expert_hidden_dim = round(4 * dim / (top_k + num_shared) / 128) * 128`, so active FLOPs per token match the dense MLP
- **`torch._grouped_mm`** for dispatching tokens to experts in a single kernel (instead of a Python for-loop)
- **3D expert weight tensors** `(num_experts, hidden, dim)` — Muon's Polar Express operates on the last two dims, so each expert is independently orthogonalized
- **Active parameter counting** for scaling laws (only `top_k + shared` experts, not all 8)
### What was easy
- The core MoE forward pass: router, sort tokens by expert, grouped matmul, scatter back. Conceptually clean.
- Shared expert: just an `nn.Linear` MLP that runs on all tokens alongside the routed path.
- 3D expert params + Muon: only required fixing `second_momentum_buffer` shape to preserve leading dims.
- Load balancing: DeepSeekV3's bias nudging is simple and effective (~10 lines).
### What was hard / ugly
- **`torch._grouped_mm` quirks**: requires bf16 (not fp32), column-major right operand, int32 cumulative offsets. The API is undocumented and only discoverable by trial and error.
- **Token count padding**: torchtitan pads each expert's token count to alignment multiples (8 for bf16) for better grouped_mm throughput. We implemented this with both a pure PyTorch approach and a copy of torchtitan's Triton kernel. Both compiled cleanly (0 graph breaks), but with ~65K tokens across 8 experts, each expert already gets ~8K tokens which is well-aligned. The padding overhead (gather/scatter) actually regressed MFU from 35% to 33%. Reverted.
- **FP8 + MoE**: `torch._grouped_mm` does NOT support FP8. There's a separate `torch._scaled_grouped_mm` API that requires per-row scaling (not per-tensor like our `Float8Linear`). The backward pass for weight gradients needs per-group column-wise scaling, which torchao implements with custom Triton kernels. We investigated thoroughly (see `dev/moe_fp8.md`) but did not implement — would require either depending on `torchao.prototype` (unstable) or writing ~200 lines of custom autograd + quantization code. Partial FP8 support exists: the shared expert's `nn.Linear` layers do get converted, but the routed experts (3D `nn.Parameter`) stay in bf16.
### Results
- d18: MFU dropped from ~46% to ~35% (the grouped_mm dispatch + token sorting overhead is significant)
- Per-step improvement in validation loss does not compensate for the throughput hit
- Net negative on wall clock time
### What remains (if revisited)
- **FP8 for routed experts**: Use `torch._scaled_grouped_mm` with a custom `_Float8GroupedMatmul` autograd function, with bf16 fallback for weight gradient (avoiding the per-group column-wise Triton kernels).
What's really needed is a fused "FlashMoE" kernel that handles routing + expert dispatch + matmul in one shot (like FlashAttention did for attention), with all the needed features. This doesn't exist yet. Rawdogging MoE with current PyTorch primitives is painful — lots of sorting, gathering, scattering, and layout wrangling around the actual compute.
### Verdict
MoE is not worth the trouble for nanochat right now. The code bloat is substantial (moe.py, router, shared expert, load balancing, optimizer fixes, FP8 gaps, active param counting) and the performance is worse wall-clock at our scale of interest. The fundamental issue is that the grouped_mm dispatch overhead eats the FLOP savings from sparsity, at least at our model scales and sequence lengths.
---
## 2026-02-17: Pretraining Data: FineWeb (negative)
Tried vanilla fineweb instead of fineweb-edu dataset. Significantly, shockingly worse results:
- d26 (GPT-2): CORE 0.2602 → 0.2241
This is the fifth failed attempt to beat pure FineWeb-EDU on CORE score.
---
## 2026-02-17: Pretraining Data Mixture Experiment (negative)
Tried [hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT](https://huggingface.co/datasets/hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT), a mixture of FinePDFs, DCLM, and FineWeb-EDU. Slightly worse on both model sizes tested:
- d26 (GPT-2): CORE 0.2602 → 0.2549
- d18: CORE 0.199 → 0.192
This is the fourth failed attempt to beat pure FineWeb-EDU on CORE score.
---
## 2026-02-16: SFT Script Upgrades
Brought `chat_sft.py` up to parity with `base_train.py` and tuned settings based on SFT sweeps.
Tuning:
- **Optimizer warm-start** (`--load-optimizer=1`, default on): loads pretrained momentum buffers via new `load_optimizer_state()` in `checkpoint_manager.py`. LRs are reset to fresh SFT values after load. Loading the optimizer works slightly better but not by too much.
- **LR schedule**: replaced "constant 80%, linear to 0" with warmup/constant/warmdown matching `base_train.py` (`--warmup-ratio`, `--warmdown-ratio`, `--init-lr-frac`, `--final-lr-frac`). Similar to pretraining, warmdown ratio of 0.5 worked the best. `--init-lr-frac` changed from 1.0 slightly lower to 0.8.
- **LR tuning**: attempted to tune all the individual LRs (e.g. does SFT prefer lower LR for embeddings? etc.) but all of this produced negative results.
- **Data mixture**: MMLU epochs 1→3, GSM8K epochs 2→4 (confirmed best from sweeps). Epoch counts now configurable via `--mmlu-epochs` / `--gsm8k-epochs`. Might remove these in the future though.
Quality of life, footguns, minor fixes:
- **Hyperparameter inheritance**: SFT now inherits batch sizes and LRs from the pretrained checkpoint metadata by default (CLI overrides still work). Also saved `total_batch_size` to `base_train.py` checkpoint metadata.
- **GC management**: disabled Python GC after step 1 to avoid ~500ms pauses (manual collect every 5000 steps), same as base pretraining.
- **ChatCORE eval**: periodic eval during SFT (`--chatcore-every=200`) across all 6 tasks, logged to wandb.
- **MFU**: uses `get_peak_flops()` for actual GPU instead of hardcoded H100 value.
- Removed `--dry-run` and `--dtype` flags. All ranks now participate in checkpoint save.
---
## 2026-02-05: Auto Batch Size Scaling
### Background

View File

@ -170,3 +170,25 @@ def load_model(source, *args, **kwargs):
base_dir = get_base_dir()
checkpoints_dir = os.path.join(base_dir, model_dir)
return load_model_from_dir(checkpoints_dir, *args, **kwargs)
def load_optimizer_state(source, device, rank, model_tag=None, step=None):
"""Load just the optimizer shard for a given rank, without re-loading the model."""
model_dir = {
"base": "base_checkpoints",
"sft": "chatsft_checkpoints",
"rl": "chatrl_checkpoints",
}[source]
base_dir = get_base_dir()
checkpoints_dir = os.path.join(base_dir, model_dir)
if model_tag is None:
model_tag = find_largest_model(checkpoints_dir)
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
if step is None:
step = find_last_step(checkpoint_dir)
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
if not os.path.exists(optimizer_path):
log0(f"Optimizer checkpoint not found: {optimizer_path}")
return None
log0(f"Loading optimizer state from {optimizer_path}")
optimizer_data = torch.load(optimizer_path, map_location=device)
return optimizer_data

View File

@ -170,7 +170,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
# Precision
if device_type == "cuda":
torch.backends.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls, see https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()

266
nanochat/fp8.py Normal file
View File

@ -0,0 +1,266 @@
"""Minimal FP8 training for nanochat — tensorwise dynamic scaling only.
Drop-in replacement for torchao's Float8Linear (~2000 lines) with ~150 lines.
We only need the "tensorwise" recipe (one scalar scale per tensor), not the full
generality of torchao (rowwise scaling, FSDP float8 all-gather, DTensor, tensor
subclass dispatch tables, etc.)
How FP8 training works
======================
A standard Linear layer does one matmul in forward and two in backward:
forward: output = input @ weight.T
backward: grad_input = grad_output @ weight
grad_weight= grad_output.T @ input
FP8 training wraps each of these three matmuls with:
1. Compute scale = FP8_MAX / max(|tensor|) for each operand
2. Quantize: fp8_tensor = clamp(tensor * scale, -FP8_MAX, FP8_MAX).to(fp8)
3. Matmul via torch._scaled_mm (cuBLAS FP8 kernel, ~2x faster than bf16)
4. Dequantize: _scaled_mm handles this internally using the inverse scales
The key insight: torch._scaled_mm and the float8 dtypes are PyTorch built-ins.
torchao is just orchestration around these primitives. We can call them directly.
FP8 dtype choice
================
There are two FP8 formats. We use both, following the standard convention:
- float8_e4m3fn: 4-bit exponent, 3-bit mantissa, range [-448, 448]
Higher precision (more mantissa bits), used for input and weight.
- float8_e5m2: 5-bit exponent, 2-bit mantissa, range [-57344, 57344]
Wider range (more exponent bits), used for gradients which can be large.
torch._scaled_mm layout requirements
=====================================
The cuBLAS FP8 kernel requires specific memory layouts:
- First argument (A): must be row-major (contiguous)
- Second argument (B): must be column-major (B.t().contiguous().t())
If B is obtained by transposing a contiguous tensor (e.g. weight.t()), it is
already column-major no copy needed. Otherwise we use _to_col_major().
How this differs from torchao's approach
========================================
torchao uses a "tensor subclass" architecture: Float8TrainingTensor is a subclass
of torch.Tensor that bundles FP8 data + scale + metadata. It implements
__torch_dispatch__ with a dispatch table that intercepts every aten op (mm, t,
reshape, clone, ...) and handles it in FP8-aware fashion. When you call
output = input @ weight.T
the @ operator dispatches to aten.mm, which gets intercepted and routed to
torch._scaled_mm behind the scenes. This is ~2000 lines of code because you need
a handler for every tensor operation that might touch an FP8 tensor.
We take a simpler approach: a single autograd.Function (_Float8Matmul) that takes
full-precision inputs, quantizes to FP8 internally, calls _scaled_mm, and returns
full-precision outputs. Marked @allow_in_graph so torch.compile treats it as one
opaque node rather than trying to trace inside.
The trade-off is in how torch.compile sees the two approaches:
- torchao: compile decomposes the tensor subclass (via __tensor_flatten__) and
sees every individual op (amax, scale, cast, _scaled_mm) as separate graph
nodes. Inductor can fuse these with surrounding operations (e.g. fuse the
amax computation with the preceding layer's activation function).
- ours: compile sees a single opaque call. It can optimize everything around
the FP8 linear (attention, norms, etc.) but cannot fuse across the boundary.
Both call the exact same cuBLAS _scaled_mm kernel the GPU matmul is identical.
The difference is only in the "glue" ops (amax, scale, cast) which are tiny
compared to the matmul. In practice this means our version is slightly faster
(less compilation overhead, no tensor subclass dispatch cost) but can produce
subtly different floating-point rounding paths under torch.compile, since Inductor
generates a different graph. Numerics are bitwise identical in eager mode.
"""
import torch
import torch.nn as nn
# Avoid division by zero when computing scale from an all-zeros tensor
EPS = 1e-12
@torch.no_grad()
def _to_fp8(x, fp8_dtype):
"""Dynamically quantize a tensor to FP8 using tensorwise scaling.
"Tensorwise" means one scalar scale for the entire tensor (as opposed to
"rowwise" which computes a separate scale per row). Tensorwise is faster
because cuBLAS handles the scaling; rowwise needs the CUTLASS kernel.
Returns (fp8_data, inverse_scale) for use with torch._scaled_mm.
"""
fp8_max = torch.finfo(fp8_dtype).max
# Compute the max absolute value across the entire tensor
amax = x.float().abs().max()
# Scale maps [0, amax] -> [0, fp8_max]. Use float64 for the division to
# ensure consistent numerics between torch.compile and eager mode.
# (torchao does the same upcast — without it, compile/eager can diverge)
scale = fp8_max / amax.double().clamp(min=EPS)
scale = scale.float()
# Quantize: scale into FP8 range, saturate (clamp prevents overflow when
# casting — PyTorch's default is to wrap, not saturate), then cast to FP8
x_scaled = x.float() * scale
x_clamped = x_scaled.clamp(-fp8_max, fp8_max)
x_fp8 = x_clamped.to(fp8_dtype)
# _scaled_mm expects the *inverse* of our scale (it multiplies by this to
# convert FP8 values back to the original range during the matmul)
inv_scale = scale.reciprocal()
return x_fp8, inv_scale
def _to_col_major(x):
"""Rearrange a 2D tensor's memory to column-major layout.
torch._scaled_mm requires its second operand in column-major layout.
The trick: transpose -> contiguous (forces a copy in transposed order)
-> transpose back. The result has the same logical shape but column-major
strides, e.g. a [M, N] tensor gets strides (1, M) instead of (N, 1).
"""
return x.t().contiguous().t()
# allow_in_graph tells torch.compile to treat this as an opaque operation —
# dynamo won't try to decompose it into smaller ops. See the module docstring
# for how this differs from torchao's tensor subclass approach.
@torch._dynamo.allow_in_graph
class _Float8Matmul(torch.autograd.Function):
"""Custom autograd for the three FP8 GEMMs of a Linear layer.
The forward quantizes input and weight to FP8 and saves
the quantized tensors + scales for backward.
"""
@staticmethod
def forward(ctx, input_2d, weight):
# Quantize both operands to e4m3 (higher precision format)
input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
ctx.save_for_backward(input_fp8, input_inv, weight_fp8, weight_inv)
# output = input @ weight.T
# input_fp8 is [B, K] contiguous = row-major (good for first arg)
# weight_fp8 is [N, K] contiguous, so weight_fp8.t() is [K, N] with
# strides (1, K) = column-major (good for second arg, no copy needed!)
output = torch._scaled_mm(
input_fp8,
weight_fp8.t(),
scale_a=input_inv,
scale_b=weight_inv,
out_dtype=input_2d.dtype,
# use_fast_accum=True accumulates the dot products in lower precision.
# Slightly less accurate but measurably faster. Standard practice for
# the forward pass; we use False in backward for more precise gradients.
use_fast_accum=True,
)
return output
@staticmethod
def backward(ctx, grad_output):
in_fp8, in_inv, w_fp8, w_inv = ctx.saved_tensors
# === GEMM 1: grad_input = grad_output @ weight ===
# Shapes: [B, N] @ [N, K] -> [B, K]
# Gradients use e5m2 (wider range), weights use e4m3 (higher precision)
go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2)
# go_fp8 is [B, N] contiguous = row-major, good for first arg
# w_fp8 is [N, K] contiguous = row-major, need column-major for second arg
w_col = _to_col_major(w_fp8)
grad_input = torch._scaled_mm(
go_fp8,
w_col,
scale_a=go_inv,
scale_b=w_inv,
out_dtype=grad_output.dtype,
use_fast_accum=False,
)
# === GEMM 2: grad_weight = grad_output.T @ input ===
# Shapes: [N, B] @ [B, K] -> [N, K]
# go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg.
# Transposing gives column-major, but first arg needs row-major,
# so we must call .contiguous() to physically rearrange the memory.
go_T = go_fp8.t().contiguous() # [N, B] row-major
in_col = _to_col_major(in_fp8) # [B, K] column-major
grad_weight = torch._scaled_mm(
go_T,
in_col,
scale_a=go_inv,
scale_b=in_inv,
out_dtype=grad_output.dtype,
use_fast_accum=False,
)
return grad_input, grad_weight
class Float8Linear(nn.Linear):
"""Drop-in nn.Linear replacement that does FP8 compute.
Weights and biases remain in their original precision (e.g. fp32/bf16).
Only the matmul is performed in FP8 via the _Float8Matmul autograd function.
"""
def forward(self, input):
# Replicate the autocast behavior of F.linear — when autocast is active,
# we need to manually cast input to the autocast dtype (e.g. bf16),
# since we bypass F.linear's built-in autocast handling.
if torch.is_autocast_enabled():
input = input.to(torch.get_autocast_gpu_dtype())
# _scaled_mm only works on 2D tensors, so flatten batch dimensions
orig_shape = input.shape
input_2d = input.reshape(-1, orig_shape[-1])
output = _Float8Matmul.apply(input_2d, self.weight)
output = output.reshape(*orig_shape[:-1], output.shape[-1])
if self.bias is not None:
output = output + self.bias.to(output.dtype)
return output
@classmethod
def from_float(cls, mod):
"""Create Float8Linear from nn.Linear, sharing the same weight and bias.
Uses meta device to avoid allocating a temporary weight tensor we
create the module shell on meta (shapes/dtypes only, no memory), then
point .weight and .bias to the original module's parameters.
"""
with torch.device("meta"):
new_mod = cls(mod.in_features, mod.out_features, bias=False)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
return new_mod
class Float8LinearConfig:
"""Minimal config matching torchao's API. Only tensorwise recipe is supported."""
@staticmethod
def from_recipe_name(recipe_name):
if recipe_name != "tensorwise":
raise ValueError(
f"Only 'tensorwise' recipe is supported, got '{recipe_name}'. "
f"Rowwise/axiswise recipes require the full torchao library."
)
return Float8LinearConfig()
def convert_to_float8_training(module, *, config=None, module_filter_fn=None):
"""Replace nn.Linear layers with Float8Linear throughout a module.
Walks the module tree in post-order (children before parents) and swaps
each nn.Linear that passes the optional filter. The new Float8Linear shares
the original weight and bias tensors no copies, no extra memory.
Args:
module: Root module to convert.
config: Float8LinearConfig (accepted for API compat, only tensorwise supported).
module_filter_fn: Optional filter(module, fqn) -> bool. Only matching Linears
are converted. Common use: skip layers with dims not divisible by 16
(hardware requirement for FP8 matmuls on H100).
"""
def _convert(mod, prefix=""):
for name, child in mod.named_children():
fqn = f"{prefix}.{name}" if prefix else name
_convert(child, fqn)
if isinstance(child, nn.Linear) and not isinstance(child, Float8Linear):
if module_filter_fn is None or module_filter_fn(child, fqn):
setattr(mod, name, Float8Linear.from_float(child))
_convert(module)
return module

View File

@ -20,7 +20,6 @@ dependencies = [
"tiktoken>=0.11.0",
"tokenizers>=0.22.0",
"torch==2.9.1",
"torchao==0.15.0",
"transformers>=4.57.3",
"uvicorn>=0.36.0",
"wandb>=0.21.3",

View File

@ -69,7 +69,7 @@ python -m scripts.tok_eval
echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID
# d24 model (slightly overtrained is enough to beat GPT-2 => increase data:params ratio from compute optimal 10.5 (default) to 12)
# d26 model (slightly undertrained to beat GPT-2 => decrease data:params ratio from compute optimal 10.5 (default) to 8.25)
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --target-param-data-ratio=8.25 --device-batch-size=16 --fp8 --run=$WANDB_RUN
# evaluate the model: CORE metric, BPB on train/val, and draw samples
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval -- --device-batch-size=16

View File

@ -165,23 +165,27 @@ if args.fp8:
if device_type != "cuda":
print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag")
else:
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
# our custom fp8 is simpler than torchao, written for exact API compatibility
from nanochat.fp8 import Float8LinearConfig, convert_to_float8_training
# from torchao.float8 import Float8LinearConfig, convert_to_float8_training
import torch.nn as nn
# Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement)
# Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough
def fp8_module_filter(mod: nn.Module, fqn: str) -> bool:
if not isinstance(mod, nn.Linear):
return False
# FP8 requires both in_features and out_features divisible by 16
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
if min(mod.in_features, mod.out_features) < 128:
return False
return True
fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe)
num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter)
num_fp8_layers = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
num_skipped = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) - num_fp8_layers
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8_layers} layers, skipped {num_skipped} (dims not divisible by 16)")
num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
num_skipped = num_linear - num_fp8
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)")
# Context manager to temporarily disable FP8 so that model evaluation remains in BF16
@contextmanager
@ -479,6 +483,7 @@ while True:
"user_config": user_config, # inputs to the training script
"device_batch_size": args.device_batch_size,
"max_seq_len": args.max_seq_len,
"total_batch_size": total_batch_size,
"dataloader_state_dict": dataloader_state_dict,
"loop_state": { # all loop state (other than step) so that we can resume training
"min_val_bpb": min_val_bpb,
@ -542,7 +547,7 @@ while True:
else:
eta_str = ""
epoch = dataloader_state_dict["epoch"]
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
if step % 100 == 0:
log_data = {
"step": step,

View File

@ -9,6 +9,7 @@ Or torchrun for training:
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16
"""
import gc
import argparse
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
@ -16,12 +17,14 @@ import time
import wandb
import torch
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops
from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
from nanochat.loss_eval import evaluate_bpb
from nanochat.checkpoint_manager import load_model
import torch.distributed as dist
from nanochat.flash_attention import HAS_FA3
from nanochat.engine import Engine
from scripts.chat_eval import run_chat_eval
from tasks.common import TaskMixture
from tasks.gsm8k import GSM8K
@ -37,27 +40,33 @@ parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the m
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
parser.add_argument("--load-optimizer", type=int, default=1, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)")
# Training horizon
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
# Batch sizes
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
# Optimization
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR")
# Batch sizes (default: inherit from pretrained checkpoint)
parser.add_argument("--max-seq-len", type=int, default=None, help="max context length (default: inherit from pretrain)")
parser.add_argument("--device-batch-size", type=int, default=None, help="per-device batch size (default: inherit from pretrain)")
parser.add_argument("--total-batch-size", type=int, default=None, help="total batch size in tokens (default: inherit from pretrain)")
# Optimization (default: inherit from pretrained checkpoint)
parser.add_argument("--embedding-lr", type=float, default=None, help="learning rate for embedding parameters (Adam) (default: inherit from pretrain)")
parser.add_argument("--unembedding-lr", type=float, default=None, help="learning rate for unembedding parameters (Adam) (default: inherit from pretrain)")
parser.add_argument("--matrix-lr", type=float, default=None, help="learning rate for matrix parameters (Muon) (default: inherit from pretrain)")
parser.add_argument("--init-lr-frac", type=float, default=0.8, help="initial LR as fraction of base LR")
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown")
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
# Evaluation
parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
# Output
parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report")
parser.add_argument("--eval-every", type=int, default=200, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on")
parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)")
parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE")
parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE")
# Data mixture
parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)")
parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
@ -66,20 +75,48 @@ user_config = vars(args).copy()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
if device_type == "cuda":
gpu_device_name = torch.cuda.get_device_name(0)
gpu_peak_flops = get_peak_flops(gpu_device_name)
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
else:
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
# wandb logging init
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
# Flash Attention status
if not HAS_FA3:
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.")
# Load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
# Inherit training hyperparameters from pretrained checkpoint (None = inherit, explicit value = override)
pretrain_user_config = meta.get("user_config", {})
for name, fallback, source in [
("max_seq_len", 2048, meta),
("device_batch_size", 32, meta),
("total_batch_size", 524288, meta),
("embedding_lr", 0.3, pretrain_user_config),
("unembedding_lr", 0.004, pretrain_user_config),
("matrix_lr", 0.02, pretrain_user_config),
]:
arg_val = getattr(args, name)
pretrain_val = source.get(name)
if arg_val is None:
resolved = pretrain_val if pretrain_val is not None else fallback
setattr(args, name, resolved)
print0(f"Inherited {name}={resolved} from pretrained checkpoint")
elif pretrain_val is not None and arg_val != pretrain_val:
print0(f"NOTE: --{name.replace('_', '-')}={arg_val} overrides pretrained value of {pretrain_val}")
else:
print0(f"Using {name}={arg_val}")
orig_model = model
model = torch.compile(model, dynamic=False)
depth = model.config.n_layer
@ -94,25 +131,44 @@ print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation ste
token_bytes = get_token_bytes(device=device)
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
# Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0)
# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.)
# Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the
# pretrained values. Since pretraining warmdown brings LRs to ~0, we must save and
# restore our fresh SFT LRs after loading.
base_dir = get_base_dir()
if args.load_optimizer:
optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step)
if optimizer_data is not None:
base_lrs = [group["lr"] for group in optimizer.param_groups]
optimizer.load_state_dict(optimizer_data)
del optimizer_data
for group, base_lr in zip(optimizer.param_groups, base_lrs):
group["lr"] = base_lr
print0("Loaded optimizer state from pretrained checkpoint (momentum buffers only, LRs reset)")
else:
print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)")
# Override the initial learning rate as a fraction of the base learning rate
for group in optimizer.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"]
# SFT data mixture and DataLoader
base_dir = get_base_dir()
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
train_dataset = TaskMixture([
train_tasks = [
SmolTalk(split="train"), # 460K rows of general conversations
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
GSM8K(subset="main", split="train"), # 2 epochs of GSM8K
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these
*[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch
*[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # total: 460K + 100K + 16K + 200K + 80K = 856K rows
]
train_dataset = TaskMixture(train_tasks)
print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})")
val_dataset = TaskMixture([
SmolTalk(split="test"), # 24K rows in test set
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
@ -236,10 +292,17 @@ train_loader = sft_data_generator_bos_bestfit("train")
build_val_loader = lambda: sft_data_generator_bos_bestfit("val")
progress = 0 # will go from 0 to 1 over the course of the epoch
# Learning rate scheduler
# Learning rate schedule (linear warmup, constant, linear warmdown)
# Same shape as base_train but uses progress (0→1) instead of absolute step counts,
# because SFT doesn't always know num_iterations in advance (dataset-driven stopping).
def get_lr_multiplier(progress):
# first 80% of training: no decay, then linearly ramp down to 0.
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
if progress < args.warmup_ratio:
return (progress + 1e-8) / args.warmup_ratio
elif progress <= 1.0 - args.warmdown_ratio:
return 1.0
else:
decay = (progress - (1.0 - args.warmdown_ratio)) / args.warmdown_ratio
return (1 - decay) * 1.0 + decay * args.final_lr_frac
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
@ -282,8 +345,44 @@ while True:
})
model.train()
# save checkpoint at the end of the run (only on master process)
if master_process and last_step and not args.dry_run:
# once in a while: estimate the ChatCORE metric (all ranks participate)
# use the original uncompiled model because the inputs keep changing shape
chatcore_results = {}
if args.chatcore_every > 0 and (last_step or (step > 0 and step % args.chatcore_every == 0)):
model.eval()
engine = Engine(orig_model, tokenizer)
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
categorical_tasks = {'ARC-Easy', 'ARC-Challenge', 'MMLU'}
baseline_accuracies = {
'ARC-Easy': 0.25, 'ARC-Challenge': 0.25, 'MMLU': 0.25,
'GSM8K': 0.0, 'HumanEval': 0.0, 'SpellingBee': 0.0,
}
task_results = {}
for task_name in all_tasks:
limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample
max_problems = None if limit < 0 else limit # -1 means no limit
with autocast_ctx:
acc = run_chat_eval(task_name, orig_model, tokenizer, engine,
batch_size=args.device_batch_size, max_problems=max_problems)
task_results[task_name] = acc
print0(f" {task_name}: {100*acc:.2f}%")
# Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect)
def centered_mean(tasks):
return sum((task_results[t] - baseline_accuracies[t]) / (1.0 - baseline_accuracies[t]) for t in tasks) / len(tasks)
chatcore = centered_mean(all_tasks)
chatcore_cat = centered_mean(categorical_tasks)
print0(f"Step {step:05d} | ChatCORE: {chatcore:.4f} | ChatCORE_cat: {chatcore_cat:.4f}")
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"chatcore_metric": chatcore,
"chatcore_cat": chatcore_cat,
**{f"chatcore/{task_name}": acc for task_name, acc in task_results.items()},
})
model.train()
# save checkpoint at the end of the run (all ranks participate so each saves its optimizer shard)
if last_step:
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
save_checkpoint(
@ -304,7 +403,8 @@ while True:
"window_pattern": model.config.window_pattern,
},
"user_config": user_config, # inputs to the training script
}
},
rank=ddp_rank,
)
if last_step:
@ -346,8 +446,7 @@ while True:
pct_done = 100 * progress
tok_per_sec = int(args.total_batch_size / dt)
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
@ -364,24 +463,32 @@ while True:
"train/epoch": current_epoch,
})
# The garbage collector spends ~500ms scanning for cycles quite frequently.
# We manually manage it to avoid these pauses during training.
if step == 1:
gc.collect() # manually collect a lot of garbage from setup
gc.freeze() # freeze all currently surviving objects and exclude them from GC
gc.disable() # disable GC entirely except:
elif step % 5000 == 0: # every 5000 steps...
gc.collect() # manually collect, just to be safe for very long runs
# print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report
if not args.dry_run:
from nanochat.report import get_report
get_report().log(section="SFT", data=[
user_config, # CLI args
{ # stats about the training setup
"Number of iterations": step,
"DDP world size": ddp_world_size,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
}
])
from nanochat.report import get_report
get_report().log(section="SFT", data=[
user_config, # CLI args
{ # stats about the training setup
"Number of iterations": step,
"DDP world size": ddp_world_size,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
}
])
# cleanup
wandb_run.finish() # wandb run finish

View File

@ -31,7 +31,7 @@ class MockModel:
def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens
self.vocab_size = vocab_size
self.config = MockConfig()
self._device = "cpu"
self._device = torch.device("cpu")
def get_device(self):
return self._device

11
uv.lock
View File

@ -1509,7 +1509,6 @@ dependencies = [
{ name = "torch", version = "2.9.1", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
{ name = "torch", version = "2.9.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "torch", version = "2.9.1+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-8-nanochat-gpu'" },
{ name = "torchao" },
{ name = "transformers" },
{ name = "uvicorn" },
{ name = "wandb" },
@ -1549,7 +1548,6 @@ requires-dist = [
{ name = "torch", specifier = "==2.9.1" },
{ name = "torch", marker = "extra == 'cpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "nanochat", extra = "cpu" } },
{ name = "torch", marker = "extra == 'gpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nanochat", extra = "gpu" } },
{ name = "torchao", specifier = "==0.15.0" },
{ name = "transformers", specifier = ">=4.57.3" },
{ name = "uvicorn", specifier = ">=0.36.0" },
{ name = "wandb", specifier = ">=0.21.3" },
@ -3184,15 +3182,6 @@ wheels = [
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-win_amd64.whl", hash = "sha256:0c784b600959ec70ee01cb23e8bc870a0e0475af30378ff5e39f4abed8b7c1cc" },
]
[[package]]
name = "torchao"
version = "0.15.0"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/57/2d/472b9362dceae05a4599e2b94f86e69a29c0e20964a6af84f34f6ead5938/torchao-0.15.0-cp310-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1cbe813201314ba6329a650a76944502f3e8ec4b1b44523f3f48676810d8d1f6", size = 7163930, upload-time = "2025-12-18T23:14:41.876Z" },
{ url = "https://files.pythonhosted.org/packages/f6/3b/6b9d5618720f63dbc2e2509cd6b57aae9c0d61b738d1d2172f4d5d9efaab/torchao-0.15.0-py3-none-any.whl", hash = "sha256:3f3812676048ef8a2a0e9d492d12d8971ba7a7ebb16f54aa56f690414e130d2c", size = 1080679, upload-time = "2025-12-18T23:14:43.807Z" },
]
[[package]]
name = "tornado"
version = "6.5.4"