This commit is contained in:
Alan 2026-02-17 08:00:46 -08:00 committed by GitHub
commit 6369ba2422
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -123,19 +123,16 @@ def _to_col_major(x):
class _Float8Matmul(torch.autograd.Function):
"""Custom autograd for the three FP8 GEMMs of a Linear layer.
The forward saves input and weight in their original precision for the
backward pass. Each GEMM independently re-quantizes its operands to FP8.
(We don't reuse the forward's FP8 tensors in backward the backward might
want different precision, and saving FP8 would lose information.)
The forward quantizes input and weight to FP8 and saves
the quantized tensors + scales for backward.
"""
@staticmethod
def forward(ctx, input_2d, weight):
ctx.save_for_backward(input_2d, weight)
# Quantize both operands to e4m3 (higher precision format)
input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
ctx.save_for_backward(input_fp8, input_inv, weight_fp8, weight_inv)
# output = input @ weight.T
# input_fp8 is [B, K] contiguous = row-major (good for first arg)
@ -156,13 +153,12 @@ class _Float8Matmul(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
input_2d, weight = ctx.saved_tensors
in_fp8, in_inv, w_fp8, w_inv = ctx.saved_tensors
# === GEMM 1: grad_input = grad_output @ weight ===
# Shapes: [B, N] @ [N, K] -> [B, K]
# Gradients use e5m2 (wider range), weights use e4m3 (higher precision)
go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2)
w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn)
# go_fp8 is [B, N] contiguous = row-major, good for first arg
# w_fp8 is [N, K] contiguous = row-major, need column-major for second arg
w_col = _to_col_major(w_fp8)
@ -177,17 +173,15 @@ class _Float8Matmul(torch.autograd.Function):
# === GEMM 2: grad_weight = grad_output.T @ input ===
# Shapes: [N, B] @ [B, K] -> [N, K]
go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2)
in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
# go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg.
# go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg.
# Transposing gives column-major, but first arg needs row-major,
# so we must call .contiguous() to physically rearrange the memory.
go_T = go_fp8_2.t().contiguous() # [N, B] row-major
go_T = go_fp8.t().contiguous() # [N, B] row-major
in_col = _to_col_major(in_fp8) # [B, K] column-major
grad_weight = torch._scaled_mm(
go_T,
in_col,
scale_a=go_inv_2,
scale_a=go_inv,
scale_b=in_inv,
out_dtype=grad_output.dtype,
use_fast_accum=False,