Merge branch 'Chetter2-patch-1'

This commit is contained in:
Andrej Karpathy 2026-02-18 23:17:39 +00:00
commit 458555117b
2 changed files with 14 additions and 18 deletions

View File

@ -123,19 +123,16 @@ def _to_col_major(x):
class _Float8Matmul(torch.autograd.Function):
"""Custom autograd for the three FP8 GEMMs of a Linear layer.
The forward saves input and weight in their original precision for the
backward pass. Each GEMM independently re-quantizes its operands to FP8.
(We don't reuse the forward's FP8 tensors in backward the backward might
want different precision, and saving FP8 would lose information.)
The forward quantizes input and weight to FP8 and saves
the quantized tensors + scales for backward.
"""
@staticmethod
def forward(ctx, input_2d, weight):
ctx.save_for_backward(input_2d, weight)
# Quantize both operands to e4m3 (higher precision format)
input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
ctx.save_for_backward(input_fp8, input_inv, weight_fp8, weight_inv)
# output = input @ weight.T
# input_fp8 is [B, K] contiguous = row-major (good for first arg)
@ -156,13 +153,12 @@ class _Float8Matmul(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
input_2d, weight = ctx.saved_tensors
in_fp8, in_inv, w_fp8, w_inv = ctx.saved_tensors
# === GEMM 1: grad_input = grad_output @ weight ===
# Shapes: [B, N] @ [N, K] -> [B, K]
# Gradients use e5m2 (wider range), weights use e4m3 (higher precision)
go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2)
w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn)
# go_fp8 is [B, N] contiguous = row-major, good for first arg
# w_fp8 is [N, K] contiguous = row-major, need column-major for second arg
w_col = _to_col_major(w_fp8)
@ -177,17 +173,15 @@ class _Float8Matmul(torch.autograd.Function):
# === GEMM 2: grad_weight = grad_output.T @ input ===
# Shapes: [N, B] @ [B, K] -> [N, K]
go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2)
in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
# go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg.
# go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg.
# Transposing gives column-major, but first arg needs row-major,
# so we must call .contiguous() to physically rearrange the memory.
go_T = go_fp8_2.t().contiguous() # [N, B] row-major
go_T = go_fp8.t().contiguous() # [N, B] row-major
in_col = _to_col_major(in_fp8) # [B, K] column-major
grad_weight = torch._scaled_mm(
go_T,
in_col,
scale_a=go_inv_2,
scale_a=go_inv,
scale_b=in_inv,
out_dtype=grad_output.dtype,
use_fast_accum=False,

View File

@ -170,20 +170,22 @@ if args.fp8:
# from torchao.float8 import Float8LinearConfig, convert_to_float8_training
import torch.nn as nn
# Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement)
# Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough
def fp8_module_filter(mod: nn.Module, fqn: str) -> bool:
if not isinstance(mod, nn.Linear):
return False
# FP8 requires both in_features and out_features divisible by 16
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
if min(mod.in_features, mod.out_features) < 128:
return False
return True
fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe)
num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter)
num_fp8_layers = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
num_skipped = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) - num_fp8_layers
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8_layers} layers, skipped {num_skipped} (dims not divisible by 16)")
num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
num_skipped = num_linear - num_fp8
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)")
# Context manager to temporarily disable FP8 so that model evaluation remains in BF16
@contextmanager