From d9678ff0f9c5d9967512adce23cb60ea0a5cd3f3 Mon Sep 17 00:00:00 2001 From: Alan Date: Sun, 15 Feb 2026 14:31:54 +0000 Subject: [PATCH 1/2] Save FP8 tensors in autograd ctx instead of full-precision inputs Store quantized input/weight and their inverse scales in _Float8Matmul ctx to avoid re-quantization in backward and reduce saved-activation memory without changing numerics. --- nanochat/fp8.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/nanochat/fp8.py b/nanochat/fp8.py index 9d9e9c3..8649760 100644 --- a/nanochat/fp8.py +++ b/nanochat/fp8.py @@ -123,19 +123,16 @@ def _to_col_major(x): class _Float8Matmul(torch.autograd.Function): """Custom autograd for the three FP8 GEMMs of a Linear layer. - The forward saves input and weight in their original precision for the - backward pass. Each GEMM independently re-quantizes its operands to FP8. - (We don't reuse the forward's FP8 tensors in backward — the backward might - want different precision, and saving FP8 would lose information.) + The forward quantizes input and weight to FP8 and saves + the quantized tensors + scales for backward. """ @staticmethod def forward(ctx, input_2d, weight): - ctx.save_for_backward(input_2d, weight) - # Quantize both operands to e4m3 (higher precision format) input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn) weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn) + ctx.save_for_backward(input_fp8, input_inv, weight_fp8, weight_inv) # output = input @ weight.T # input_fp8 is [B, K] contiguous = row-major (good for first arg) @@ -156,13 +153,12 @@ class _Float8Matmul(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): - input_2d, weight = ctx.saved_tensors + in_fp8, in_inv, w_fp8, w_inv = ctx.saved_tensors # === GEMM 1: grad_input = grad_output @ weight === # Shapes: [B, N] @ [N, K] -> [B, K] # Gradients use e5m2 (wider range), weights use e4m3 (higher precision) go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2) - w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn) # go_fp8 is [B, N] contiguous = row-major, good for first arg # w_fp8 is [N, K] contiguous = row-major, need column-major for second arg w_col = _to_col_major(w_fp8) @@ -178,7 +174,6 @@ class _Float8Matmul(torch.autograd.Function): # === GEMM 2: grad_weight = grad_output.T @ input === # Shapes: [N, B] @ [B, K] -> [N, K] go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2) - in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn) # go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg. # Transposing gives column-major, but first arg needs row-major, # so we must call .contiguous() to physically rearrange the memory. From 124f49be98e53bf734e2918dc58a580dbf31a80c Mon Sep 17 00:00:00 2001 From: Alan Date: Sun, 15 Feb 2026 15:41:33 +0000 Subject: [PATCH 2/2] Removed redundant qunatization of gradients --- nanochat/fp8.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/nanochat/fp8.py b/nanochat/fp8.py index 8649760..3e88285 100644 --- a/nanochat/fp8.py +++ b/nanochat/fp8.py @@ -173,16 +173,15 @@ class _Float8Matmul(torch.autograd.Function): # === GEMM 2: grad_weight = grad_output.T @ input === # Shapes: [N, B] @ [B, K] -> [N, K] - go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2) - # go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg. + # go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg. # Transposing gives column-major, but first arg needs row-major, # so we must call .contiguous() to physically rearrange the memory. - go_T = go_fp8_2.t().contiguous() # [N, B] row-major + go_T = go_fp8.t().contiguous() # [N, B] row-major in_col = _to_col_major(in_fp8) # [B, K] column-major grad_weight = torch._scaled_mm( go_T, in_col, - scale_a=go_inv_2, + scale_a=go_inv, scale_b=in_inv, out_dtype=grad_output.dtype, use_fast_accum=False,