Replace torchao with custom fp8 module in gpt.py

- Update reparam_linear to use nanochat.fp8.Float8Linear instead of torchao
- Replace matmul_with_hp_or_float8_args with direct _Float8Matmul.apply call
- Remove torchao dependency mention from base_train.py help text
- Functionally equivalent: both use torch._scaled_mm, custom version ~3% faster

Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kaiyue Wen 2026-02-12 16:25:52 -08:00
parent ee04406ebb
commit 31e5bec402
2 changed files with 16 additions and 5 deletions

View File

@ -29,9 +29,10 @@ from nanochat.flash_attention import flash_attn
# FP8 imports (optional) — needed by reparam_linear for FP8 path
try:
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
from nanochat.fp8 import Float8Linear, _Float8Matmul
except ImportError:
pass
Float8Linear = None
_Float8Matmul = None
@dataclass
class GPTConfig:
@ -66,8 +67,18 @@ def reparam_linear(module, x, gamma=None, scalar=None):
if scalar is not None:
w = scalar[:, None] * w
# FP8 path: use Float8Linear's internal matmul to preserve FP8 tensor cores
if hasattr(module, 'linear_mm_config'):
return matmul_with_hp_or_float8_args.apply(x, w.t(), module.linear_mm_config, module.config)
if Float8Linear is not None and isinstance(module, Float8Linear):
# Handle autocast similar to Float8Linear.forward
if torch.is_autocast_enabled():
x = x.to(torch.get_autocast_gpu_dtype())
# Flatten batch dimensions for _Float8Matmul
orig_shape = x.shape
input_2d = x.reshape(-1, orig_shape[-1])
output = _Float8Matmul.apply(input_2d, w)
output = output.reshape(*orig_shape[:-1], output.shape[-1])
if module.bias is not None:
output = output + module.bias.to(output.dtype)
return output
# BF16 path
return F.linear(x, w)

View File

@ -43,7 +43,7 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d
# Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
# FP8 training
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)")
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU)")
parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)")
# Model architecture
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")