From e569b59f92aea06bf8fc1c48489b3cc2e57189f4 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 10 Feb 2026 18:46:39 +0000 Subject: [PATCH] delete torchao dependency, create our own exact API-matched version of Float8Linear, document it very well. for some poorly understood reason, the performance is not only ~identical but actually runs 3% faster. despite of it being significantly simpler and much less code. i don't fully understand why/how atm --- nanochat/fp8.py | 272 ++++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 1 - scripts/base_train.py | 4 +- uv.lock | 11 -- 4 files changed, 275 insertions(+), 13 deletions(-) create mode 100644 nanochat/fp8.py diff --git a/nanochat/fp8.py b/nanochat/fp8.py new file mode 100644 index 0000000..9d9e9c3 --- /dev/null +++ b/nanochat/fp8.py @@ -0,0 +1,272 @@ +"""Minimal FP8 training for nanochat — tensorwise dynamic scaling only. + +Drop-in replacement for torchao's Float8Linear (~2000 lines) with ~150 lines. +We only need the "tensorwise" recipe (one scalar scale per tensor), not the full +generality of torchao (rowwise scaling, FSDP float8 all-gather, DTensor, tensor +subclass dispatch tables, etc.) + +How FP8 training works +====================== +A standard Linear layer does one matmul in forward and two in backward: + forward: output = input @ weight.T + backward: grad_input = grad_output @ weight + grad_weight= grad_output.T @ input + +FP8 training wraps each of these three matmuls with: + 1. Compute scale = FP8_MAX / max(|tensor|) for each operand + 2. Quantize: fp8_tensor = clamp(tensor * scale, -FP8_MAX, FP8_MAX).to(fp8) + 3. Matmul via torch._scaled_mm (cuBLAS FP8 kernel, ~2x faster than bf16) + 4. Dequantize: _scaled_mm handles this internally using the inverse scales + +The key insight: torch._scaled_mm and the float8 dtypes are PyTorch built-ins. +torchao is just orchestration around these primitives. We can call them directly. + +FP8 dtype choice +================ +There are two FP8 formats. We use both, following the standard convention: + - float8_e4m3fn: 4-bit exponent, 3-bit mantissa, range [-448, 448] + Higher precision (more mantissa bits), used for input and weight. + - float8_e5m2: 5-bit exponent, 2-bit mantissa, range [-57344, 57344] + Wider range (more exponent bits), used for gradients which can be large. + +torch._scaled_mm layout requirements +===================================== +The cuBLAS FP8 kernel requires specific memory layouts: + - First argument (A): must be row-major (contiguous) + - Second argument (B): must be column-major (B.t().contiguous().t()) +If B is obtained by transposing a contiguous tensor (e.g. weight.t()), it is +already column-major — no copy needed. Otherwise we use _to_col_major(). + +How this differs from torchao's approach +======================================== +torchao uses a "tensor subclass" architecture: Float8TrainingTensor is a subclass +of torch.Tensor that bundles FP8 data + scale + metadata. It implements +__torch_dispatch__ with a dispatch table that intercepts every aten op (mm, t, +reshape, clone, ...) and handles it in FP8-aware fashion. When you call + output = input @ weight.T +the @ operator dispatches to aten.mm, which gets intercepted and routed to +torch._scaled_mm behind the scenes. This is ~2000 lines of code because you need +a handler for every tensor operation that might touch an FP8 tensor. + +We take a simpler approach: a single autograd.Function (_Float8Matmul) that takes +full-precision inputs, quantizes to FP8 internally, calls _scaled_mm, and returns +full-precision outputs. Marked @allow_in_graph so torch.compile treats it as one +opaque node rather than trying to trace inside. + +The trade-off is in how torch.compile sees the two approaches: + - torchao: compile decomposes the tensor subclass (via __tensor_flatten__) and + sees every individual op (amax, scale, cast, _scaled_mm) as separate graph + nodes. Inductor can fuse these with surrounding operations (e.g. fuse the + amax computation with the preceding layer's activation function). + - ours: compile sees a single opaque call. It can optimize everything around + the FP8 linear (attention, norms, etc.) but cannot fuse across the boundary. + +Both call the exact same cuBLAS _scaled_mm kernel — the GPU matmul is identical. +The difference is only in the "glue" ops (amax, scale, cast) which are tiny +compared to the matmul. In practice this means our version is slightly faster +(less compilation overhead, no tensor subclass dispatch cost) but can produce +subtly different floating-point rounding paths under torch.compile, since Inductor +generates a different graph. Numerics are bitwise identical in eager mode. +""" + +import torch +import torch.nn as nn + +# Avoid division by zero when computing scale from an all-zeros tensor +EPS = 1e-12 + + +@torch.no_grad() +def _to_fp8(x, fp8_dtype): + """Dynamically quantize a tensor to FP8 using tensorwise scaling. + + "Tensorwise" means one scalar scale for the entire tensor (as opposed to + "rowwise" which computes a separate scale per row). Tensorwise is faster + because cuBLAS handles the scaling; rowwise needs the CUTLASS kernel. + + Returns (fp8_data, inverse_scale) for use with torch._scaled_mm. + """ + fp8_max = torch.finfo(fp8_dtype).max + # Compute the max absolute value across the entire tensor + amax = x.float().abs().max() + # Scale maps [0, amax] -> [0, fp8_max]. Use float64 for the division to + # ensure consistent numerics between torch.compile and eager mode. + # (torchao does the same upcast — without it, compile/eager can diverge) + scale = fp8_max / amax.double().clamp(min=EPS) + scale = scale.float() + # Quantize: scale into FP8 range, saturate (clamp prevents overflow when + # casting — PyTorch's default is to wrap, not saturate), then cast to FP8 + x_scaled = x.float() * scale + x_clamped = x_scaled.clamp(-fp8_max, fp8_max) + x_fp8 = x_clamped.to(fp8_dtype) + # _scaled_mm expects the *inverse* of our scale (it multiplies by this to + # convert FP8 values back to the original range during the matmul) + inv_scale = scale.reciprocal() + return x_fp8, inv_scale + + +def _to_col_major(x): + """Rearrange a 2D tensor's memory to column-major layout. + + torch._scaled_mm requires its second operand in column-major layout. + The trick: transpose -> contiguous (forces a copy in transposed order) + -> transpose back. The result has the same logical shape but column-major + strides, e.g. a [M, N] tensor gets strides (1, M) instead of (N, 1). + """ + return x.t().contiguous().t() + + +# allow_in_graph tells torch.compile to treat this as an opaque operation — +# dynamo won't try to decompose it into smaller ops. See the module docstring +# for how this differs from torchao's tensor subclass approach. +@torch._dynamo.allow_in_graph +class _Float8Matmul(torch.autograd.Function): + """Custom autograd for the three FP8 GEMMs of a Linear layer. + + The forward saves input and weight in their original precision for the + backward pass. Each GEMM independently re-quantizes its operands to FP8. + (We don't reuse the forward's FP8 tensors in backward — the backward might + want different precision, and saving FP8 would lose information.) + """ + + @staticmethod + def forward(ctx, input_2d, weight): + ctx.save_for_backward(input_2d, weight) + + # Quantize both operands to e4m3 (higher precision format) + input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn) + weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn) + + # output = input @ weight.T + # input_fp8 is [B, K] contiguous = row-major (good for first arg) + # weight_fp8 is [N, K] contiguous, so weight_fp8.t() is [K, N] with + # strides (1, K) = column-major (good for second arg, no copy needed!) + output = torch._scaled_mm( + input_fp8, + weight_fp8.t(), + scale_a=input_inv, + scale_b=weight_inv, + out_dtype=input_2d.dtype, + # use_fast_accum=True accumulates the dot products in lower precision. + # Slightly less accurate but measurably faster. Standard practice for + # the forward pass; we use False in backward for more precise gradients. + use_fast_accum=True, + ) + return output + + @staticmethod + def backward(ctx, grad_output): + input_2d, weight = ctx.saved_tensors + + # === GEMM 1: grad_input = grad_output @ weight === + # Shapes: [B, N] @ [N, K] -> [B, K] + # Gradients use e5m2 (wider range), weights use e4m3 (higher precision) + go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2) + w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn) + # go_fp8 is [B, N] contiguous = row-major, good for first arg + # w_fp8 is [N, K] contiguous = row-major, need column-major for second arg + w_col = _to_col_major(w_fp8) + grad_input = torch._scaled_mm( + go_fp8, + w_col, + scale_a=go_inv, + scale_b=w_inv, + out_dtype=grad_output.dtype, + use_fast_accum=False, + ) + + # === GEMM 2: grad_weight = grad_output.T @ input === + # Shapes: [N, B] @ [B, K] -> [N, K] + go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2) + in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn) + # go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg. + # Transposing gives column-major, but first arg needs row-major, + # so we must call .contiguous() to physically rearrange the memory. + go_T = go_fp8_2.t().contiguous() # [N, B] row-major + in_col = _to_col_major(in_fp8) # [B, K] column-major + grad_weight = torch._scaled_mm( + go_T, + in_col, + scale_a=go_inv_2, + scale_b=in_inv, + out_dtype=grad_output.dtype, + use_fast_accum=False, + ) + + return grad_input, grad_weight + + +class Float8Linear(nn.Linear): + """Drop-in nn.Linear replacement that does FP8 compute. + + Weights and biases remain in their original precision (e.g. fp32/bf16). + Only the matmul is performed in FP8 via the _Float8Matmul autograd function. + """ + + def forward(self, input): + # Replicate the autocast behavior of F.linear — when autocast is active, + # we need to manually cast input to the autocast dtype (e.g. bf16), + # since we bypass F.linear's built-in autocast handling. + if torch.is_autocast_enabled(): + input = input.to(torch.get_autocast_gpu_dtype()) + # _scaled_mm only works on 2D tensors, so flatten batch dimensions + orig_shape = input.shape + input_2d = input.reshape(-1, orig_shape[-1]) + output = _Float8Matmul.apply(input_2d, self.weight) + output = output.reshape(*orig_shape[:-1], output.shape[-1]) + if self.bias is not None: + output = output + self.bias.to(output.dtype) + return output + + @classmethod + def from_float(cls, mod): + """Create Float8Linear from nn.Linear, sharing the same weight and bias. + + Uses meta device to avoid allocating a temporary weight tensor — we + create the module shell on meta (shapes/dtypes only, no memory), then + point .weight and .bias to the original module's parameters. + """ + with torch.device("meta"): + new_mod = cls(mod.in_features, mod.out_features, bias=False) + new_mod.weight = mod.weight + new_mod.bias = mod.bias + return new_mod + + +class Float8LinearConfig: + """Minimal config matching torchao's API. Only tensorwise recipe is supported.""" + + @staticmethod + def from_recipe_name(recipe_name): + if recipe_name != "tensorwise": + raise ValueError( + f"Only 'tensorwise' recipe is supported, got '{recipe_name}'. " + f"Rowwise/axiswise recipes require the full torchao library." + ) + return Float8LinearConfig() + + +def convert_to_float8_training(module, *, config=None, module_filter_fn=None): + """Replace nn.Linear layers with Float8Linear throughout a module. + + Walks the module tree in post-order (children before parents) and swaps + each nn.Linear that passes the optional filter. The new Float8Linear shares + the original weight and bias tensors — no copies, no extra memory. + + Args: + module: Root module to convert. + config: Float8LinearConfig (accepted for API compat, only tensorwise supported). + module_filter_fn: Optional filter(module, fqn) -> bool. Only matching Linears + are converted. Common use: skip layers with dims not divisible by 16 + (hardware requirement for FP8 matmuls on H100). + """ + def _convert(mod, prefix=""): + for name, child in mod.named_children(): + fqn = f"{prefix}.{name}" if prefix else name + _convert(child, fqn) + if isinstance(child, nn.Linear) and not isinstance(child, Float8Linear): + if module_filter_fn is None or module_filter_fn(child, fqn): + setattr(mod, name, Float8Linear.from_float(child)) + + _convert(module) + return module diff --git a/pyproject.toml b/pyproject.toml index bcb674d..8b6fd95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,6 @@ dependencies = [ "tiktoken>=0.11.0", "tokenizers>=0.22.0", "torch==2.9.1", - "torchao==0.15.0", "transformers>=4.57.3", "uvicorn>=0.36.0", "wandb>=0.21.3", diff --git a/scripts/base_train.py b/scripts/base_train.py index ccf35e6..ee53098 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -165,7 +165,9 @@ if args.fp8: if device_type != "cuda": print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag") else: - from torchao.float8 import Float8LinearConfig, convert_to_float8_training + # our custom fp8 is simpler than torchao, written for exact API compatibility + from nanochat.fp8 import Float8LinearConfig, convert_to_float8_training + # from torchao.float8 import Float8LinearConfig, convert_to_float8_training import torch.nn as nn # Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement) diff --git a/uv.lock b/uv.lock index e5fc97f..bbc9519 100644 --- a/uv.lock +++ b/uv.lock @@ -1509,7 +1509,6 @@ dependencies = [ { name = "torch", version = "2.9.1", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu') or (extra != 'extra-8-nanochat-cpu' and extra != 'extra-8-nanochat-gpu')" }, { name = "torch", version = "2.9.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-nanochat-cpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" }, { name = "torch", version = "2.9.1+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-8-nanochat-gpu'" }, - { name = "torchao" }, { name = "transformers" }, { name = "uvicorn" }, { name = "wandb" }, @@ -1549,7 +1548,6 @@ requires-dist = [ { name = "torch", specifier = "==2.9.1" }, { name = "torch", marker = "extra == 'cpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "nanochat", extra = "cpu" } }, { name = "torch", marker = "extra == 'gpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nanochat", extra = "gpu" } }, - { name = "torchao", specifier = "==0.15.0" }, { name = "transformers", specifier = ">=4.57.3" }, { name = "uvicorn", specifier = ">=0.36.0" }, { name = "wandb", specifier = ">=0.21.3" }, @@ -3184,15 +3182,6 @@ wheels = [ { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-win_amd64.whl", hash = "sha256:0c784b600959ec70ee01cb23e8bc870a0e0475af30378ff5e39f4abed8b7c1cc" }, ] -[[package]] -name = "torchao" -version = "0.15.0" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/57/2d/472b9362dceae05a4599e2b94f86e69a29c0e20964a6af84f34f6ead5938/torchao-0.15.0-cp310-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1cbe813201314ba6329a650a76944502f3e8ec4b1b44523f3f48676810d8d1f6", size = 7163930, upload-time = "2025-12-18T23:14:41.876Z" }, - { url = "https://files.pythonhosted.org/packages/f6/3b/6b9d5618720f63dbc2e2509cd6b57aae9c0d61b738d1d2172f4d5d9efaab/torchao-0.15.0-py3-none-any.whl", hash = "sha256:3f3812676048ef8a2a0e9d492d12d8971ba7a7ebb16f54aa56f690414e130d2c", size = 1080679, upload-time = "2025-12-18T23:14:43.807Z" }, -] - [[package]] name = "tornado" version = "6.5.4"