Removed redundant qunatization of gradients

This commit is contained in:
Alan 2026-02-15 15:41:33 +00:00 committed by GitHub
parent d9678ff0f9
commit 124f49be98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -173,16 +173,15 @@ class _Float8Matmul(torch.autograd.Function):
# === GEMM 2: grad_weight = grad_output.T @ input ===
# Shapes: [N, B] @ [B, K] -> [N, K]
go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2)
# go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg.
# go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg.
# Transposing gives column-major, but first arg needs row-major,
# so we must call .contiguous() to physically rearrange the memory.
go_T = go_fp8_2.t().contiguous() # [N, B] row-major
go_T = go_fp8.t().contiguous() # [N, B] row-major
in_col = _to_col_major(in_fp8) # [B, K] column-major
grad_weight = torch._scaled_mm(
go_T,
in_col,
scale_a=go_inv_2,
scale_a=go_inv,
scale_b=in_inv,
out_dtype=grad_output.dtype,
use_fast_accum=False,