From 29487517edb40ea670a89e3a194e37ff27fdb0d6 Mon Sep 17 00:00:00 2001 From: Kaiyue Wen Date: Thu, 12 Feb 2026 16:58:05 -0800 Subject: [PATCH] Revert to torchao for FP8 training to fix MFU regression MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The custom fp8 module had a performance issue in reparam_linear: it was doing reshape→matmul→reshape on every linear layer, and torch.compile couldn't fuse these operations because _Float8Matmul was marked @allow_in_graph (opaque to compiler). torchao's matmul_with_hp_or_float8_args handles N-D tensors directly without external reshaping, allowing better fusion opportunities and higher MFU. Co-Authored-By: Claude Sonnet 4.5 (1M context) --- nanochat/gpt.py | 19 ++++--------------- pyproject.toml | 1 + scripts/base_train.py | 2 +- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 1e56c10..8779a85 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -29,10 +29,9 @@ from nanochat.flash_attention import flash_attn # FP8 imports (optional) — needed by reparam_linear for FP8 path try: - from nanochat.fp8 import Float8Linear, _Float8Matmul + from torchao.float8.float8_linear import matmul_with_hp_or_float8_args except ImportError: - Float8Linear = None - _Float8Matmul = None + pass @dataclass class GPTConfig: @@ -67,18 +66,8 @@ def reparam_linear(module, x, gamma=None, scalar=None): if scalar is not None: w = scalar[:, None] * w # FP8 path: use Float8Linear's internal matmul to preserve FP8 tensor cores - if Float8Linear is not None and isinstance(module, Float8Linear): - # Handle autocast similar to Float8Linear.forward - if torch.is_autocast_enabled(): - x = x.to(torch.get_autocast_gpu_dtype()) - # Flatten batch dimensions for _Float8Matmul - orig_shape = x.shape - input_2d = x.reshape(-1, orig_shape[-1]) - output = _Float8Matmul.apply(input_2d, w) - output = output.reshape(*orig_shape[:-1], output.shape[-1]) - if module.bias is not None: - output = output + module.bias.to(output.dtype) - return output + if hasattr(module, 'linear_mm_config'): + return matmul_with_hp_or_float8_args.apply(x, w.t(), module.linear_mm_config, module.config) # BF16 path return F.linear(x, w) diff --git a/pyproject.toml b/pyproject.toml index 8b6fd95..13e293a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "tabulate>=0.9.0", "tiktoken>=0.11.0", "tokenizers>=0.22.0", + "torchao>=0.13.0", "torch==2.9.1", "transformers>=4.57.3", "uvicorn>=0.36.0", diff --git a/scripts/base_train.py b/scripts/base_train.py index b05e889..f41c3aa 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -43,7 +43,7 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") # FP8 training -parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU)") +parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)") parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") # Model architecture parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")