delete torchao dependency, create our own exact API-matched version of Float8Linear, document it very well. for some poorly understood reason, the performance is not only ~identical but actually runs 3% faster. despite of it being significantly simpler and much less code. i don't fully understand why/how atm

This commit is contained in:
Andrej Karpathy 2026-02-10 18:46:39 +00:00
parent 1ec0a34779
commit e569b59f92
4 changed files with 275 additions and 13 deletions

272
nanochat/fp8.py Normal file
View File

@ -0,0 +1,272 @@
"""Minimal FP8 training for nanochat — tensorwise dynamic scaling only.
Drop-in replacement for torchao's Float8Linear (~2000 lines) with ~150 lines.
We only need the "tensorwise" recipe (one scalar scale per tensor), not the full
generality of torchao (rowwise scaling, FSDP float8 all-gather, DTensor, tensor
subclass dispatch tables, etc.)
How FP8 training works
======================
A standard Linear layer does one matmul in forward and two in backward:
forward: output = input @ weight.T
backward: grad_input = grad_output @ weight
grad_weight= grad_output.T @ input
FP8 training wraps each of these three matmuls with:
1. Compute scale = FP8_MAX / max(|tensor|) for each operand
2. Quantize: fp8_tensor = clamp(tensor * scale, -FP8_MAX, FP8_MAX).to(fp8)
3. Matmul via torch._scaled_mm (cuBLAS FP8 kernel, ~2x faster than bf16)
4. Dequantize: _scaled_mm handles this internally using the inverse scales
The key insight: torch._scaled_mm and the float8 dtypes are PyTorch built-ins.
torchao is just orchestration around these primitives. We can call them directly.
FP8 dtype choice
================
There are two FP8 formats. We use both, following the standard convention:
- float8_e4m3fn: 4-bit exponent, 3-bit mantissa, range [-448, 448]
Higher precision (more mantissa bits), used for input and weight.
- float8_e5m2: 5-bit exponent, 2-bit mantissa, range [-57344, 57344]
Wider range (more exponent bits), used for gradients which can be large.
torch._scaled_mm layout requirements
=====================================
The cuBLAS FP8 kernel requires specific memory layouts:
- First argument (A): must be row-major (contiguous)
- Second argument (B): must be column-major (B.t().contiguous().t())
If B is obtained by transposing a contiguous tensor (e.g. weight.t()), it is
already column-major no copy needed. Otherwise we use _to_col_major().
How this differs from torchao's approach
========================================
torchao uses a "tensor subclass" architecture: Float8TrainingTensor is a subclass
of torch.Tensor that bundles FP8 data + scale + metadata. It implements
__torch_dispatch__ with a dispatch table that intercepts every aten op (mm, t,
reshape, clone, ...) and handles it in FP8-aware fashion. When you call
output = input @ weight.T
the @ operator dispatches to aten.mm, which gets intercepted and routed to
torch._scaled_mm behind the scenes. This is ~2000 lines of code because you need
a handler for every tensor operation that might touch an FP8 tensor.
We take a simpler approach: a single autograd.Function (_Float8Matmul) that takes
full-precision inputs, quantizes to FP8 internally, calls _scaled_mm, and returns
full-precision outputs. Marked @allow_in_graph so torch.compile treats it as one
opaque node rather than trying to trace inside.
The trade-off is in how torch.compile sees the two approaches:
- torchao: compile decomposes the tensor subclass (via __tensor_flatten__) and
sees every individual op (amax, scale, cast, _scaled_mm) as separate graph
nodes. Inductor can fuse these with surrounding operations (e.g. fuse the
amax computation with the preceding layer's activation function).
- ours: compile sees a single opaque call. It can optimize everything around
the FP8 linear (attention, norms, etc.) but cannot fuse across the boundary.
Both call the exact same cuBLAS _scaled_mm kernel the GPU matmul is identical.
The difference is only in the "glue" ops (amax, scale, cast) which are tiny
compared to the matmul. In practice this means our version is slightly faster
(less compilation overhead, no tensor subclass dispatch cost) but can produce
subtly different floating-point rounding paths under torch.compile, since Inductor
generates a different graph. Numerics are bitwise identical in eager mode.
"""
import torch
import torch.nn as nn
# Avoid division by zero when computing scale from an all-zeros tensor
EPS = 1e-12
@torch.no_grad()
def _to_fp8(x, fp8_dtype):
"""Dynamically quantize a tensor to FP8 using tensorwise scaling.
"Tensorwise" means one scalar scale for the entire tensor (as opposed to
"rowwise" which computes a separate scale per row). Tensorwise is faster
because cuBLAS handles the scaling; rowwise needs the CUTLASS kernel.
Returns (fp8_data, inverse_scale) for use with torch._scaled_mm.
"""
fp8_max = torch.finfo(fp8_dtype).max
# Compute the max absolute value across the entire tensor
amax = x.float().abs().max()
# Scale maps [0, amax] -> [0, fp8_max]. Use float64 for the division to
# ensure consistent numerics between torch.compile and eager mode.
# (torchao does the same upcast — without it, compile/eager can diverge)
scale = fp8_max / amax.double().clamp(min=EPS)
scale = scale.float()
# Quantize: scale into FP8 range, saturate (clamp prevents overflow when
# casting — PyTorch's default is to wrap, not saturate), then cast to FP8
x_scaled = x.float() * scale
x_clamped = x_scaled.clamp(-fp8_max, fp8_max)
x_fp8 = x_clamped.to(fp8_dtype)
# _scaled_mm expects the *inverse* of our scale (it multiplies by this to
# convert FP8 values back to the original range during the matmul)
inv_scale = scale.reciprocal()
return x_fp8, inv_scale
def _to_col_major(x):
"""Rearrange a 2D tensor's memory to column-major layout.
torch._scaled_mm requires its second operand in column-major layout.
The trick: transpose -> contiguous (forces a copy in transposed order)
-> transpose back. The result has the same logical shape but column-major
strides, e.g. a [M, N] tensor gets strides (1, M) instead of (N, 1).
"""
return x.t().contiguous().t()
# allow_in_graph tells torch.compile to treat this as an opaque operation —
# dynamo won't try to decompose it into smaller ops. See the module docstring
# for how this differs from torchao's tensor subclass approach.
@torch._dynamo.allow_in_graph
class _Float8Matmul(torch.autograd.Function):
"""Custom autograd for the three FP8 GEMMs of a Linear layer.
The forward saves input and weight in their original precision for the
backward pass. Each GEMM independently re-quantizes its operands to FP8.
(We don't reuse the forward's FP8 tensors in backward the backward might
want different precision, and saving FP8 would lose information.)
"""
@staticmethod
def forward(ctx, input_2d, weight):
ctx.save_for_backward(input_2d, weight)
# Quantize both operands to e4m3 (higher precision format)
input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
# output = input @ weight.T
# input_fp8 is [B, K] contiguous = row-major (good for first arg)
# weight_fp8 is [N, K] contiguous, so weight_fp8.t() is [K, N] with
# strides (1, K) = column-major (good for second arg, no copy needed!)
output = torch._scaled_mm(
input_fp8,
weight_fp8.t(),
scale_a=input_inv,
scale_b=weight_inv,
out_dtype=input_2d.dtype,
# use_fast_accum=True accumulates the dot products in lower precision.
# Slightly less accurate but measurably faster. Standard practice for
# the forward pass; we use False in backward for more precise gradients.
use_fast_accum=True,
)
return output
@staticmethod
def backward(ctx, grad_output):
input_2d, weight = ctx.saved_tensors
# === GEMM 1: grad_input = grad_output @ weight ===
# Shapes: [B, N] @ [N, K] -> [B, K]
# Gradients use e5m2 (wider range), weights use e4m3 (higher precision)
go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2)
w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn)
# go_fp8 is [B, N] contiguous = row-major, good for first arg
# w_fp8 is [N, K] contiguous = row-major, need column-major for second arg
w_col = _to_col_major(w_fp8)
grad_input = torch._scaled_mm(
go_fp8,
w_col,
scale_a=go_inv,
scale_b=w_inv,
out_dtype=grad_output.dtype,
use_fast_accum=False,
)
# === GEMM 2: grad_weight = grad_output.T @ input ===
# Shapes: [N, B] @ [B, K] -> [N, K]
go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2)
in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
# go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg.
# Transposing gives column-major, but first arg needs row-major,
# so we must call .contiguous() to physically rearrange the memory.
go_T = go_fp8_2.t().contiguous() # [N, B] row-major
in_col = _to_col_major(in_fp8) # [B, K] column-major
grad_weight = torch._scaled_mm(
go_T,
in_col,
scale_a=go_inv_2,
scale_b=in_inv,
out_dtype=grad_output.dtype,
use_fast_accum=False,
)
return grad_input, grad_weight
class Float8Linear(nn.Linear):
"""Drop-in nn.Linear replacement that does FP8 compute.
Weights and biases remain in their original precision (e.g. fp32/bf16).
Only the matmul is performed in FP8 via the _Float8Matmul autograd function.
"""
def forward(self, input):
# Replicate the autocast behavior of F.linear — when autocast is active,
# we need to manually cast input to the autocast dtype (e.g. bf16),
# since we bypass F.linear's built-in autocast handling.
if torch.is_autocast_enabled():
input = input.to(torch.get_autocast_gpu_dtype())
# _scaled_mm only works on 2D tensors, so flatten batch dimensions
orig_shape = input.shape
input_2d = input.reshape(-1, orig_shape[-1])
output = _Float8Matmul.apply(input_2d, self.weight)
output = output.reshape(*orig_shape[:-1], output.shape[-1])
if self.bias is not None:
output = output + self.bias.to(output.dtype)
return output
@classmethod
def from_float(cls, mod):
"""Create Float8Linear from nn.Linear, sharing the same weight and bias.
Uses meta device to avoid allocating a temporary weight tensor we
create the module shell on meta (shapes/dtypes only, no memory), then
point .weight and .bias to the original module's parameters.
"""
with torch.device("meta"):
new_mod = cls(mod.in_features, mod.out_features, bias=False)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
return new_mod
class Float8LinearConfig:
"""Minimal config matching torchao's API. Only tensorwise recipe is supported."""
@staticmethod
def from_recipe_name(recipe_name):
if recipe_name != "tensorwise":
raise ValueError(
f"Only 'tensorwise' recipe is supported, got '{recipe_name}'. "
f"Rowwise/axiswise recipes require the full torchao library."
)
return Float8LinearConfig()
def convert_to_float8_training(module, *, config=None, module_filter_fn=None):
"""Replace nn.Linear layers with Float8Linear throughout a module.
Walks the module tree in post-order (children before parents) and swaps
each nn.Linear that passes the optional filter. The new Float8Linear shares
the original weight and bias tensors no copies, no extra memory.
Args:
module: Root module to convert.
config: Float8LinearConfig (accepted for API compat, only tensorwise supported).
module_filter_fn: Optional filter(module, fqn) -> bool. Only matching Linears
are converted. Common use: skip layers with dims not divisible by 16
(hardware requirement for FP8 matmuls on H100).
"""
def _convert(mod, prefix=""):
for name, child in mod.named_children():
fqn = f"{prefix}.{name}" if prefix else name
_convert(child, fqn)
if isinstance(child, nn.Linear) and not isinstance(child, Float8Linear):
if module_filter_fn is None or module_filter_fn(child, fqn):
setattr(mod, name, Float8Linear.from_float(child))
_convert(module)
return module

View File

@ -20,7 +20,6 @@ dependencies = [
"tiktoken>=0.11.0",
"tokenizers>=0.22.0",
"torch==2.9.1",
"torchao==0.15.0",
"transformers>=4.57.3",
"uvicorn>=0.36.0",
"wandb>=0.21.3",

View File

@ -165,7 +165,9 @@ if args.fp8:
if device_type != "cuda":
print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag")
else:
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
# our custom fp8 is simpler than torchao, written for exact API compatibility
from nanochat.fp8 import Float8LinearConfig, convert_to_float8_training
# from torchao.float8 import Float8LinearConfig, convert_to_float8_training
import torch.nn as nn
# Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement)

11
uv.lock
View File

@ -1509,7 +1509,6 @@ dependencies = [
{ name = "torch", version = "2.9.1", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
{ name = "torch", version = "2.9.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "torch", version = "2.9.1+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-8-nanochat-gpu'" },
{ name = "torchao" },
{ name = "transformers" },
{ name = "uvicorn" },
{ name = "wandb" },
@ -1549,7 +1548,6 @@ requires-dist = [
{ name = "torch", specifier = "==2.9.1" },
{ name = "torch", marker = "extra == 'cpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "nanochat", extra = "cpu" } },
{ name = "torch", marker = "extra == 'gpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nanochat", extra = "gpu" } },
{ name = "torchao", specifier = "==0.15.0" },
{ name = "transformers", specifier = ">=4.57.3" },
{ name = "uvicorn", specifier = ">=0.36.0" },
{ name = "wandb", specifier = ">=0.21.3" },
@ -3184,15 +3182,6 @@ wheels = [
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-win_amd64.whl", hash = "sha256:0c784b600959ec70ee01cb23e8bc870a0e0475af30378ff5e39f4abed8b7c1cc" },
]
[[package]]
name = "torchao"
version = "0.15.0"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/57/2d/472b9362dceae05a4599e2b94f86e69a29c0e20964a6af84f34f6ead5938/torchao-0.15.0-cp310-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1cbe813201314ba6329a650a76944502f3e8ec4b1b44523f3f48676810d8d1f6", size = 7163930, upload-time = "2025-12-18T23:14:41.876Z" },
{ url = "https://files.pythonhosted.org/packages/f6/3b/6b9d5618720f63dbc2e2509cd6b57aae9c0d61b738d1d2172f4d5d9efaab/torchao-0.15.0-py3-none-any.whl", hash = "sha256:3f3812676048ef8a2a0e9d492d12d8971ba7a7ebb16f54aa56f690414e130d2c", size = 1080679, upload-time = "2025-12-18T23:14:43.807Z" },
]
[[package]]
name = "tornado"
version = "6.5.4"