diff --git a/nanochat/fp8.py b/nanochat/fp8.py index 9d9e9c3..3f056d1 100644 --- a/nanochat/fp8.py +++ b/nanochat/fp8.py @@ -196,6 +196,81 @@ class _Float8Matmul(torch.autograd.Function): return grad_input, grad_weight +@torch._dynamo.allow_in_graph +class _Float8MatmulND(torch.autograd.Function): + """FP8 matmul that handles N-D input tensors. + + Same as _Float8Matmul but accepts inputs of any shape (not just 2D). + Reshaping is done internally so torch.compile sees this as one opaque node, + preventing the reshaping overhead that occurs when reshapes are external. + + This is specifically for reparam_linear where N-D tensors are common. + """ + + @staticmethod + def forward(ctx, input, weight): + # Save original shape and flatten batch dimensions + orig_shape = input.shape + ctx.orig_shape = orig_shape + input_2d = input.reshape(-1, orig_shape[-1]) + ctx.save_for_backward(input_2d, weight) + + # Quantize and matmul (same as _Float8Matmul.forward) + input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn) + weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn) + output = torch._scaled_mm( + input_fp8, + weight_fp8.t(), + scale_a=input_inv, + scale_b=weight_inv, + out_dtype=input.dtype, + use_fast_accum=True, + ) + + # Reshape back to original batch dims + output = output.reshape(*orig_shape[:-1], output.shape[-1]) + return output + + @staticmethod + def backward(ctx, grad_output): + input_2d, weight = ctx.saved_tensors + orig_shape = ctx.orig_shape + + # Flatten grad_output to match input_2d + grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1]) + + # === GEMM 1: grad_input = grad_output @ weight === + go_fp8, go_inv = _to_fp8(grad_output_flat, torch.float8_e5m2) + w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn) + w_col = _to_col_major(w_fp8) + grad_input_flat = torch._scaled_mm( + go_fp8, + w_col, + scale_a=go_inv, + scale_b=w_inv, + out_dtype=grad_output.dtype, + use_fast_accum=False, + ) + # Reshape back to original input shape + grad_input = grad_input_flat.reshape(orig_shape) + + # === GEMM 2: grad_weight = grad_output.T @ input === + go_fp8_2, go_inv_2 = _to_fp8(grad_output_flat, torch.float8_e5m2) + in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn) + go_T = go_fp8_2.t().contiguous() + in_col = _to_col_major(in_fp8) + grad_weight = torch._scaled_mm( + go_T, + in_col, + scale_a=go_inv_2, + scale_b=in_inv, + out_dtype=grad_output.dtype, + use_fast_accum=False, + ) + + return grad_input, grad_weight + + class Float8Linear(nn.Linear): """Drop-in nn.Linear replacement that does FP8 compute. diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 14e7c0e..2e29943 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -27,18 +27,12 @@ from nanochat.optim import MuonAdamW, DistMuonAdamW # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere from nanochat.flash_attention import flash_attn -# FP8 imports (optional) -# torchao: used in reparam_linear for efficient N-D tensor handling -# custom fp8: used for regular Float8Linear layers (simpler, same performance) +# FP8 imports (optional) - minimal custom implementation try: - from torchao.float8.float8_linear import matmul_with_hp_or_float8_args -except ImportError: - matmul_with_hp_or_float8_args = None - -try: - from nanochat.fp8 import Float8Linear + from nanochat.fp8 import Float8Linear, _Float8MatmulND except ImportError: Float8Linear = None + _Float8MatmulND = None @dataclass class GPTConfig: @@ -65,17 +59,23 @@ def reparam_linear(module, x, gamma=None, scalar=None): gamma: RMSNorm learnable weight, folded into input dim of W (w = w * gamma[None, :]) scalar: projection scalar, folded into output dim of W (w = scalar[:, None] * w) - For FP8, uses torchao's matmul which handles N-D tensors efficiently without reshaping. + For FP8, uses minimal custom _Float8MatmulND which handles N-D tensors internally. """ w = module.weight if gamma is not None: w = w * gamma[None, :] if scalar is not None: w = scalar[:, None] * w - # FP8 path: use torchao's matmul for efficient N-D tensor handling - # (torchao handles arbitrary shapes without external reshaping overhead) - if hasattr(module, 'linear_mm_config') and matmul_with_hp_or_float8_args is not None: - return matmul_with_hp_or_float8_args.apply(x, w.t(), module.linear_mm_config, module.config) + # FP8 path: use custom _Float8MatmulND for efficient N-D tensor handling + # (reshaping is done internally, so torch.compile sees it as one opaque operation) + if Float8Linear is not None and isinstance(module, Float8Linear): + # Handle autocast (Float8Linear expects this) + if torch.is_autocast_enabled(): + x = x.to(torch.get_autocast_gpu_dtype()) + output = _Float8MatmulND.apply(x, w) + if module.bias is not None: + output = output + module.bias.to(output.dtype) + return output # BF16 path return F.linear(x, w) diff --git a/pyproject.toml b/pyproject.toml index 13e293a..8b6fd95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,6 @@ dependencies = [ "tabulate>=0.9.0", "tiktoken>=0.11.0", "tokenizers>=0.22.0", - "torchao>=0.13.0", "torch==2.9.1", "transformers>=4.57.3", "uvicorn>=0.36.0", diff --git a/scripts/base_train.py b/scripts/base_train.py index 432c998..25426f5 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -43,7 +43,7 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") # FP8 training -parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao); uses custom fp8 module for layers, torchao for reparam_linear") +parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU); uses minimal custom fp8 module") parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") # Model architecture parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")