mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-13 00:13:19 +00:00
Merge branch 'Chetter2-patch-1'
This commit is contained in:
commit
458555117b
|
|
@ -123,19 +123,16 @@ def _to_col_major(x):
|
|||
class _Float8Matmul(torch.autograd.Function):
|
||||
"""Custom autograd for the three FP8 GEMMs of a Linear layer.
|
||||
|
||||
The forward saves input and weight in their original precision for the
|
||||
backward pass. Each GEMM independently re-quantizes its operands to FP8.
|
||||
(We don't reuse the forward's FP8 tensors in backward — the backward might
|
||||
want different precision, and saving FP8 would lose information.)
|
||||
The forward quantizes input and weight to FP8 and saves
|
||||
the quantized tensors + scales for backward.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_2d, weight):
|
||||
ctx.save_for_backward(input_2d, weight)
|
||||
|
||||
# Quantize both operands to e4m3 (higher precision format)
|
||||
input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
|
||||
weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
|
||||
ctx.save_for_backward(input_fp8, input_inv, weight_fp8, weight_inv)
|
||||
|
||||
# output = input @ weight.T
|
||||
# input_fp8 is [B, K] contiguous = row-major (good for first arg)
|
||||
|
|
@ -156,13 +153,12 @@ class _Float8Matmul(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input_2d, weight = ctx.saved_tensors
|
||||
in_fp8, in_inv, w_fp8, w_inv = ctx.saved_tensors
|
||||
|
||||
# === GEMM 1: grad_input = grad_output @ weight ===
|
||||
# Shapes: [B, N] @ [N, K] -> [B, K]
|
||||
# Gradients use e5m2 (wider range), weights use e4m3 (higher precision)
|
||||
go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2)
|
||||
w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn)
|
||||
# go_fp8 is [B, N] contiguous = row-major, good for first arg
|
||||
# w_fp8 is [N, K] contiguous = row-major, need column-major for second arg
|
||||
w_col = _to_col_major(w_fp8)
|
||||
|
|
@ -177,17 +173,15 @@ class _Float8Matmul(torch.autograd.Function):
|
|||
|
||||
# === GEMM 2: grad_weight = grad_output.T @ input ===
|
||||
# Shapes: [N, B] @ [B, K] -> [N, K]
|
||||
go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2)
|
||||
in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
|
||||
# go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg.
|
||||
# go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg.
|
||||
# Transposing gives column-major, but first arg needs row-major,
|
||||
# so we must call .contiguous() to physically rearrange the memory.
|
||||
go_T = go_fp8_2.t().contiguous() # [N, B] row-major
|
||||
go_T = go_fp8.t().contiguous() # [N, B] row-major
|
||||
in_col = _to_col_major(in_fp8) # [B, K] column-major
|
||||
grad_weight = torch._scaled_mm(
|
||||
go_T,
|
||||
in_col,
|
||||
scale_a=go_inv_2,
|
||||
scale_a=go_inv,
|
||||
scale_b=in_inv,
|
||||
out_dtype=grad_output.dtype,
|
||||
use_fast_accum=False,
|
||||
|
|
|
|||
|
|
@ -170,20 +170,22 @@ if args.fp8:
|
|||
# from torchao.float8 import Float8LinearConfig, convert_to_float8_training
|
||||
import torch.nn as nn
|
||||
|
||||
# Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement)
|
||||
# Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough
|
||||
def fp8_module_filter(mod: nn.Module, fqn: str) -> bool:
|
||||
if not isinstance(mod, nn.Linear):
|
||||
return False
|
||||
# FP8 requires both in_features and out_features divisible by 16
|
||||
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
|
||||
return False
|
||||
if min(mod.in_features, mod.out_features) < 128:
|
||||
return False
|
||||
return True
|
||||
|
||||
fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe)
|
||||
num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
|
||||
convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter)
|
||||
num_fp8_layers = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
|
||||
num_skipped = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) - num_fp8_layers
|
||||
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8_layers} layers, skipped {num_skipped} (dims not divisible by 16)")
|
||||
num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
|
||||
num_skipped = num_linear - num_fp8
|
||||
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)")
|
||||
|
||||
# Context manager to temporarily disable FP8 so that model evaluation remains in BF16
|
||||
@contextmanager
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user