From 31e5bec402a7f074a24a602cffc62a1954ed370a Mon Sep 17 00:00:00 2001 From: Kaiyue Wen Date: Thu, 12 Feb 2026 16:25:52 -0800 Subject: [PATCH] Replace torchao with custom fp8 module in gpt.py - Update reparam_linear to use nanochat.fp8.Float8Linear instead of torchao - Replace matmul_with_hp_or_float8_args with direct _Float8Matmul.apply call - Remove torchao dependency mention from base_train.py help text - Functionally equivalent: both use torch._scaled_mm, custom version ~3% faster Co-Authored-By: Claude Sonnet 4.5 (1M context) --- nanochat/gpt.py | 19 +++++++++++++++---- scripts/base_train.py | 2 +- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 8779a85..1e56c10 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -29,9 +29,10 @@ from nanochat.flash_attention import flash_attn # FP8 imports (optional) — needed by reparam_linear for FP8 path try: - from torchao.float8.float8_linear import matmul_with_hp_or_float8_args + from nanochat.fp8 import Float8Linear, _Float8Matmul except ImportError: - pass + Float8Linear = None + _Float8Matmul = None @dataclass class GPTConfig: @@ -66,8 +67,18 @@ def reparam_linear(module, x, gamma=None, scalar=None): if scalar is not None: w = scalar[:, None] * w # FP8 path: use Float8Linear's internal matmul to preserve FP8 tensor cores - if hasattr(module, 'linear_mm_config'): - return matmul_with_hp_or_float8_args.apply(x, w.t(), module.linear_mm_config, module.config) + if Float8Linear is not None and isinstance(module, Float8Linear): + # Handle autocast similar to Float8Linear.forward + if torch.is_autocast_enabled(): + x = x.to(torch.get_autocast_gpu_dtype()) + # Flatten batch dimensions for _Float8Matmul + orig_shape = x.shape + input_2d = x.reshape(-1, orig_shape[-1]) + output = _Float8Matmul.apply(input_2d, w) + output = output.reshape(*orig_shape[:-1], output.shape[-1]) + if module.bias is not None: + output = output + module.bias.to(output.dtype) + return output # BF16 path return F.linear(x, w) diff --git a/scripts/base_train.py b/scripts/base_train.py index f41c3aa..b05e889 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -43,7 +43,7 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") # FP8 training -parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)") +parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU)") parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") # Model architecture parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")