mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-07 01:40:30 +00:00
delete torchao dependency, create our own exact API-matched version of Float8Linear, document it very well. for some poorly understood reason, the performance is not only ~identical but actually runs 3% faster. despite of it being significantly simpler and much less code. i don't fully understand why/how atm
This commit is contained in:
parent
1ec0a34779
commit
e569b59f92
272
nanochat/fp8.py
Normal file
272
nanochat/fp8.py
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
"""Minimal FP8 training for nanochat — tensorwise dynamic scaling only.
|
||||
|
||||
Drop-in replacement for torchao's Float8Linear (~2000 lines) with ~150 lines.
|
||||
We only need the "tensorwise" recipe (one scalar scale per tensor), not the full
|
||||
generality of torchao (rowwise scaling, FSDP float8 all-gather, DTensor, tensor
|
||||
subclass dispatch tables, etc.)
|
||||
|
||||
How FP8 training works
|
||||
======================
|
||||
A standard Linear layer does one matmul in forward and two in backward:
|
||||
forward: output = input @ weight.T
|
||||
backward: grad_input = grad_output @ weight
|
||||
grad_weight= grad_output.T @ input
|
||||
|
||||
FP8 training wraps each of these three matmuls with:
|
||||
1. Compute scale = FP8_MAX / max(|tensor|) for each operand
|
||||
2. Quantize: fp8_tensor = clamp(tensor * scale, -FP8_MAX, FP8_MAX).to(fp8)
|
||||
3. Matmul via torch._scaled_mm (cuBLAS FP8 kernel, ~2x faster than bf16)
|
||||
4. Dequantize: _scaled_mm handles this internally using the inverse scales
|
||||
|
||||
The key insight: torch._scaled_mm and the float8 dtypes are PyTorch built-ins.
|
||||
torchao is just orchestration around these primitives. We can call them directly.
|
||||
|
||||
FP8 dtype choice
|
||||
================
|
||||
There are two FP8 formats. We use both, following the standard convention:
|
||||
- float8_e4m3fn: 4-bit exponent, 3-bit mantissa, range [-448, 448]
|
||||
Higher precision (more mantissa bits), used for input and weight.
|
||||
- float8_e5m2: 5-bit exponent, 2-bit mantissa, range [-57344, 57344]
|
||||
Wider range (more exponent bits), used for gradients which can be large.
|
||||
|
||||
torch._scaled_mm layout requirements
|
||||
=====================================
|
||||
The cuBLAS FP8 kernel requires specific memory layouts:
|
||||
- First argument (A): must be row-major (contiguous)
|
||||
- Second argument (B): must be column-major (B.t().contiguous().t())
|
||||
If B is obtained by transposing a contiguous tensor (e.g. weight.t()), it is
|
||||
already column-major — no copy needed. Otherwise we use _to_col_major().
|
||||
|
||||
How this differs from torchao's approach
|
||||
========================================
|
||||
torchao uses a "tensor subclass" architecture: Float8TrainingTensor is a subclass
|
||||
of torch.Tensor that bundles FP8 data + scale + metadata. It implements
|
||||
__torch_dispatch__ with a dispatch table that intercepts every aten op (mm, t,
|
||||
reshape, clone, ...) and handles it in FP8-aware fashion. When you call
|
||||
output = input @ weight.T
|
||||
the @ operator dispatches to aten.mm, which gets intercepted and routed to
|
||||
torch._scaled_mm behind the scenes. This is ~2000 lines of code because you need
|
||||
a handler for every tensor operation that might touch an FP8 tensor.
|
||||
|
||||
We take a simpler approach: a single autograd.Function (_Float8Matmul) that takes
|
||||
full-precision inputs, quantizes to FP8 internally, calls _scaled_mm, and returns
|
||||
full-precision outputs. Marked @allow_in_graph so torch.compile treats it as one
|
||||
opaque node rather than trying to trace inside.
|
||||
|
||||
The trade-off is in how torch.compile sees the two approaches:
|
||||
- torchao: compile decomposes the tensor subclass (via __tensor_flatten__) and
|
||||
sees every individual op (amax, scale, cast, _scaled_mm) as separate graph
|
||||
nodes. Inductor can fuse these with surrounding operations (e.g. fuse the
|
||||
amax computation with the preceding layer's activation function).
|
||||
- ours: compile sees a single opaque call. It can optimize everything around
|
||||
the FP8 linear (attention, norms, etc.) but cannot fuse across the boundary.
|
||||
|
||||
Both call the exact same cuBLAS _scaled_mm kernel — the GPU matmul is identical.
|
||||
The difference is only in the "glue" ops (amax, scale, cast) which are tiny
|
||||
compared to the matmul. In practice this means our version is slightly faster
|
||||
(less compilation overhead, no tensor subclass dispatch cost) but can produce
|
||||
subtly different floating-point rounding paths under torch.compile, since Inductor
|
||||
generates a different graph. Numerics are bitwise identical in eager mode.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Avoid division by zero when computing scale from an all-zeros tensor
|
||||
EPS = 1e-12
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _to_fp8(x, fp8_dtype):
|
||||
"""Dynamically quantize a tensor to FP8 using tensorwise scaling.
|
||||
|
||||
"Tensorwise" means one scalar scale for the entire tensor (as opposed to
|
||||
"rowwise" which computes a separate scale per row). Tensorwise is faster
|
||||
because cuBLAS handles the scaling; rowwise needs the CUTLASS kernel.
|
||||
|
||||
Returns (fp8_data, inverse_scale) for use with torch._scaled_mm.
|
||||
"""
|
||||
fp8_max = torch.finfo(fp8_dtype).max
|
||||
# Compute the max absolute value across the entire tensor
|
||||
amax = x.float().abs().max()
|
||||
# Scale maps [0, amax] -> [0, fp8_max]. Use float64 for the division to
|
||||
# ensure consistent numerics between torch.compile and eager mode.
|
||||
# (torchao does the same upcast — without it, compile/eager can diverge)
|
||||
scale = fp8_max / amax.double().clamp(min=EPS)
|
||||
scale = scale.float()
|
||||
# Quantize: scale into FP8 range, saturate (clamp prevents overflow when
|
||||
# casting — PyTorch's default is to wrap, not saturate), then cast to FP8
|
||||
x_scaled = x.float() * scale
|
||||
x_clamped = x_scaled.clamp(-fp8_max, fp8_max)
|
||||
x_fp8 = x_clamped.to(fp8_dtype)
|
||||
# _scaled_mm expects the *inverse* of our scale (it multiplies by this to
|
||||
# convert FP8 values back to the original range during the matmul)
|
||||
inv_scale = scale.reciprocal()
|
||||
return x_fp8, inv_scale
|
||||
|
||||
|
||||
def _to_col_major(x):
|
||||
"""Rearrange a 2D tensor's memory to column-major layout.
|
||||
|
||||
torch._scaled_mm requires its second operand in column-major layout.
|
||||
The trick: transpose -> contiguous (forces a copy in transposed order)
|
||||
-> transpose back. The result has the same logical shape but column-major
|
||||
strides, e.g. a [M, N] tensor gets strides (1, M) instead of (N, 1).
|
||||
"""
|
||||
return x.t().contiguous().t()
|
||||
|
||||
|
||||
# allow_in_graph tells torch.compile to treat this as an opaque operation —
|
||||
# dynamo won't try to decompose it into smaller ops. See the module docstring
|
||||
# for how this differs from torchao's tensor subclass approach.
|
||||
@torch._dynamo.allow_in_graph
|
||||
class _Float8Matmul(torch.autograd.Function):
|
||||
"""Custom autograd for the three FP8 GEMMs of a Linear layer.
|
||||
|
||||
The forward saves input and weight in their original precision for the
|
||||
backward pass. Each GEMM independently re-quantizes its operands to FP8.
|
||||
(We don't reuse the forward's FP8 tensors in backward — the backward might
|
||||
want different precision, and saving FP8 would lose information.)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_2d, weight):
|
||||
ctx.save_for_backward(input_2d, weight)
|
||||
|
||||
# Quantize both operands to e4m3 (higher precision format)
|
||||
input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
|
||||
weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
|
||||
|
||||
# output = input @ weight.T
|
||||
# input_fp8 is [B, K] contiguous = row-major (good for first arg)
|
||||
# weight_fp8 is [N, K] contiguous, so weight_fp8.t() is [K, N] with
|
||||
# strides (1, K) = column-major (good for second arg, no copy needed!)
|
||||
output = torch._scaled_mm(
|
||||
input_fp8,
|
||||
weight_fp8.t(),
|
||||
scale_a=input_inv,
|
||||
scale_b=weight_inv,
|
||||
out_dtype=input_2d.dtype,
|
||||
# use_fast_accum=True accumulates the dot products in lower precision.
|
||||
# Slightly less accurate but measurably faster. Standard practice for
|
||||
# the forward pass; we use False in backward for more precise gradients.
|
||||
use_fast_accum=True,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input_2d, weight = ctx.saved_tensors
|
||||
|
||||
# === GEMM 1: grad_input = grad_output @ weight ===
|
||||
# Shapes: [B, N] @ [N, K] -> [B, K]
|
||||
# Gradients use e5m2 (wider range), weights use e4m3 (higher precision)
|
||||
go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2)
|
||||
w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn)
|
||||
# go_fp8 is [B, N] contiguous = row-major, good for first arg
|
||||
# w_fp8 is [N, K] contiguous = row-major, need column-major for second arg
|
||||
w_col = _to_col_major(w_fp8)
|
||||
grad_input = torch._scaled_mm(
|
||||
go_fp8,
|
||||
w_col,
|
||||
scale_a=go_inv,
|
||||
scale_b=w_inv,
|
||||
out_dtype=grad_output.dtype,
|
||||
use_fast_accum=False,
|
||||
)
|
||||
|
||||
# === GEMM 2: grad_weight = grad_output.T @ input ===
|
||||
# Shapes: [N, B] @ [B, K] -> [N, K]
|
||||
go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2)
|
||||
in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
|
||||
# go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg.
|
||||
# Transposing gives column-major, but first arg needs row-major,
|
||||
# so we must call .contiguous() to physically rearrange the memory.
|
||||
go_T = go_fp8_2.t().contiguous() # [N, B] row-major
|
||||
in_col = _to_col_major(in_fp8) # [B, K] column-major
|
||||
grad_weight = torch._scaled_mm(
|
||||
go_T,
|
||||
in_col,
|
||||
scale_a=go_inv_2,
|
||||
scale_b=in_inv,
|
||||
out_dtype=grad_output.dtype,
|
||||
use_fast_accum=False,
|
||||
)
|
||||
|
||||
return grad_input, grad_weight
|
||||
|
||||
|
||||
class Float8Linear(nn.Linear):
|
||||
"""Drop-in nn.Linear replacement that does FP8 compute.
|
||||
|
||||
Weights and biases remain in their original precision (e.g. fp32/bf16).
|
||||
Only the matmul is performed in FP8 via the _Float8Matmul autograd function.
|
||||
"""
|
||||
|
||||
def forward(self, input):
|
||||
# Replicate the autocast behavior of F.linear — when autocast is active,
|
||||
# we need to manually cast input to the autocast dtype (e.g. bf16),
|
||||
# since we bypass F.linear's built-in autocast handling.
|
||||
if torch.is_autocast_enabled():
|
||||
input = input.to(torch.get_autocast_gpu_dtype())
|
||||
# _scaled_mm only works on 2D tensors, so flatten batch dimensions
|
||||
orig_shape = input.shape
|
||||
input_2d = input.reshape(-1, orig_shape[-1])
|
||||
output = _Float8Matmul.apply(input_2d, self.weight)
|
||||
output = output.reshape(*orig_shape[:-1], output.shape[-1])
|
||||
if self.bias is not None:
|
||||
output = output + self.bias.to(output.dtype)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod):
|
||||
"""Create Float8Linear from nn.Linear, sharing the same weight and bias.
|
||||
|
||||
Uses meta device to avoid allocating a temporary weight tensor — we
|
||||
create the module shell on meta (shapes/dtypes only, no memory), then
|
||||
point .weight and .bias to the original module's parameters.
|
||||
"""
|
||||
with torch.device("meta"):
|
||||
new_mod = cls(mod.in_features, mod.out_features, bias=False)
|
||||
new_mod.weight = mod.weight
|
||||
new_mod.bias = mod.bias
|
||||
return new_mod
|
||||
|
||||
|
||||
class Float8LinearConfig:
|
||||
"""Minimal config matching torchao's API. Only tensorwise recipe is supported."""
|
||||
|
||||
@staticmethod
|
||||
def from_recipe_name(recipe_name):
|
||||
if recipe_name != "tensorwise":
|
||||
raise ValueError(
|
||||
f"Only 'tensorwise' recipe is supported, got '{recipe_name}'. "
|
||||
f"Rowwise/axiswise recipes require the full torchao library."
|
||||
)
|
||||
return Float8LinearConfig()
|
||||
|
||||
|
||||
def convert_to_float8_training(module, *, config=None, module_filter_fn=None):
|
||||
"""Replace nn.Linear layers with Float8Linear throughout a module.
|
||||
|
||||
Walks the module tree in post-order (children before parents) and swaps
|
||||
each nn.Linear that passes the optional filter. The new Float8Linear shares
|
||||
the original weight and bias tensors — no copies, no extra memory.
|
||||
|
||||
Args:
|
||||
module: Root module to convert.
|
||||
config: Float8LinearConfig (accepted for API compat, only tensorwise supported).
|
||||
module_filter_fn: Optional filter(module, fqn) -> bool. Only matching Linears
|
||||
are converted. Common use: skip layers with dims not divisible by 16
|
||||
(hardware requirement for FP8 matmuls on H100).
|
||||
"""
|
||||
def _convert(mod, prefix=""):
|
||||
for name, child in mod.named_children():
|
||||
fqn = f"{prefix}.{name}" if prefix else name
|
||||
_convert(child, fqn)
|
||||
if isinstance(child, nn.Linear) and not isinstance(child, Float8Linear):
|
||||
if module_filter_fn is None or module_filter_fn(child, fqn):
|
||||
setattr(mod, name, Float8Linear.from_float(child))
|
||||
|
||||
_convert(module)
|
||||
return module
|
||||
|
|
@ -20,7 +20,6 @@ dependencies = [
|
|||
"tiktoken>=0.11.0",
|
||||
"tokenizers>=0.22.0",
|
||||
"torch==2.9.1",
|
||||
"torchao==0.15.0",
|
||||
"transformers>=4.57.3",
|
||||
"uvicorn>=0.36.0",
|
||||
"wandb>=0.21.3",
|
||||
|
|
|
|||
|
|
@ -165,7 +165,9 @@ if args.fp8:
|
|||
if device_type != "cuda":
|
||||
print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag")
|
||||
else:
|
||||
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
|
||||
# our custom fp8 is simpler than torchao, written for exact API compatibility
|
||||
from nanochat.fp8 import Float8LinearConfig, convert_to_float8_training
|
||||
# from torchao.float8 import Float8LinearConfig, convert_to_float8_training
|
||||
import torch.nn as nn
|
||||
|
||||
# Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement)
|
||||
|
|
|
|||
11
uv.lock
11
uv.lock
|
|
@ -1509,7 +1509,6 @@ dependencies = [
|
|||
{ name = "torch", version = "2.9.1", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" },
|
||||
{ name = "torch", version = "2.9.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "torch", version = "2.9.1+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
{ name = "torchao" },
|
||||
{ name = "transformers" },
|
||||
{ name = "uvicorn" },
|
||||
{ name = "wandb" },
|
||||
|
|
@ -1549,7 +1548,6 @@ requires-dist = [
|
|||
{ name = "torch", specifier = "==2.9.1" },
|
||||
{ name = "torch", marker = "extra == 'cpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "nanochat", extra = "cpu" } },
|
||||
{ name = "torch", marker = "extra == 'gpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nanochat", extra = "gpu" } },
|
||||
{ name = "torchao", specifier = "==0.15.0" },
|
||||
{ name = "transformers", specifier = ">=4.57.3" },
|
||||
{ name = "uvicorn", specifier = ">=0.36.0" },
|
||||
{ name = "wandb", specifier = ">=0.21.3" },
|
||||
|
|
@ -3184,15 +3182,6 @@ wheels = [
|
|||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-win_amd64.whl", hash = "sha256:0c784b600959ec70ee01cb23e8bc870a0e0475af30378ff5e39f4abed8b7c1cc" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torchao"
|
||||
version = "0.15.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/57/2d/472b9362dceae05a4599e2b94f86e69a29c0e20964a6af84f34f6ead5938/torchao-0.15.0-cp310-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1cbe813201314ba6329a650a76944502f3e8ec4b1b44523f3f48676810d8d1f6", size = 7163930, upload-time = "2025-12-18T23:14:41.876Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/3b/6b9d5618720f63dbc2e2509cd6b57aae9c0d61b738d1d2172f4d5d9efaab/torchao-0.15.0-py3-none-any.whl", hash = "sha256:3f3812676048ef8a2a0e9d492d12d8971ba7a7ebb16f54aa56f690414e130d2c", size = 1080679, upload-time = "2025-12-18T23:14:43.807Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tornado"
|
||||
version = "6.5.4"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user