mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-10 18:00:17 +00:00
Removed redundant qunatization of gradients
This commit is contained in:
parent
d9678ff0f9
commit
124f49be98
|
|
@ -173,16 +173,15 @@ class _Float8Matmul(torch.autograd.Function):
|
|||
|
||||
# === GEMM 2: grad_weight = grad_output.T @ input ===
|
||||
# Shapes: [N, B] @ [B, K] -> [N, K]
|
||||
go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2)
|
||||
# go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg.
|
||||
# go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg.
|
||||
# Transposing gives column-major, but first arg needs row-major,
|
||||
# so we must call .contiguous() to physically rearrange the memory.
|
||||
go_T = go_fp8_2.t().contiguous() # [N, B] row-major
|
||||
go_T = go_fp8.t().contiguous() # [N, B] row-major
|
||||
in_col = _to_col_major(in_fp8) # [B, K] column-major
|
||||
grad_weight = torch._scaled_mm(
|
||||
go_T,
|
||||
in_col,
|
||||
scale_a=go_inv_2,
|
||||
scale_a=go_inv,
|
||||
scale_b=in_inv,
|
||||
out_dtype=grad_output.dtype,
|
||||
use_fast_accum=False,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user