diff --git a/nanochat/fp8.py b/nanochat/fp8.py index 9d9e9c3..3e88285 100644 --- a/nanochat/fp8.py +++ b/nanochat/fp8.py @@ -123,19 +123,16 @@ def _to_col_major(x): class _Float8Matmul(torch.autograd.Function): """Custom autograd for the three FP8 GEMMs of a Linear layer. - The forward saves input and weight in their original precision for the - backward pass. Each GEMM independently re-quantizes its operands to FP8. - (We don't reuse the forward's FP8 tensors in backward — the backward might - want different precision, and saving FP8 would lose information.) + The forward quantizes input and weight to FP8 and saves + the quantized tensors + scales for backward. """ @staticmethod def forward(ctx, input_2d, weight): - ctx.save_for_backward(input_2d, weight) - # Quantize both operands to e4m3 (higher precision format) input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn) weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn) + ctx.save_for_backward(input_fp8, input_inv, weight_fp8, weight_inv) # output = input @ weight.T # input_fp8 is [B, K] contiguous = row-major (good for first arg) @@ -156,13 +153,12 @@ class _Float8Matmul(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): - input_2d, weight = ctx.saved_tensors + in_fp8, in_inv, w_fp8, w_inv = ctx.saved_tensors # === GEMM 1: grad_input = grad_output @ weight === # Shapes: [B, N] @ [N, K] -> [B, K] # Gradients use e5m2 (wider range), weights use e4m3 (higher precision) go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2) - w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn) # go_fp8 is [B, N] contiguous = row-major, good for first arg # w_fp8 is [N, K] contiguous = row-major, need column-major for second arg w_col = _to_col_major(w_fp8) @@ -177,17 +173,15 @@ class _Float8Matmul(torch.autograd.Function): # === GEMM 2: grad_weight = grad_output.T @ input === # Shapes: [N, B] @ [B, K] -> [N, K] - go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2) - in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn) - # go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg. + # go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg. # Transposing gives column-major, but first arg needs row-major, # so we must call .contiguous() to physically rearrange the memory. - go_T = go_fp8_2.t().contiguous() # [N, B] row-major + go_T = go_fp8.t().contiguous() # [N, B] row-major in_col = _to_col_major(in_fp8) # [B, K] column-major grad_weight = torch._scaled_mm( go_T, in_col, - scale_a=go_inv_2, + scale_a=go_inv, scale_b=in_inv, out_dtype=grad_output.dtype, use_fast_accum=False, diff --git a/scripts/base_train.py b/scripts/base_train.py index bb76e90..24091b6 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -170,20 +170,22 @@ if args.fp8: # from torchao.float8 import Float8LinearConfig, convert_to_float8_training import torch.nn as nn - # Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement) + # Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough def fp8_module_filter(mod: nn.Module, fqn: str) -> bool: if not isinstance(mod, nn.Linear): return False - # FP8 requires both in_features and out_features divisible by 16 if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: return False + if min(mod.in_features, mod.out_features) < 128: + return False return True fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe) + num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter) - num_fp8_layers = sum(1 for m in model.modules() if 'Float8' in type(m).__name__) - num_skipped = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) - num_fp8_layers - print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8_layers} layers, skipped {num_skipped} (dims not divisible by 16)") + num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__) + num_skipped = num_linear - num_fp8 + print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)") # Context manager to temporarily disable FP8 so that model evaluation remains in BF16 @contextmanager