Use hybrid FP8 approach: torchao for reparam_linear, custom fp8 for layers

- reparam_linear: uses torchao for efficient N-D tensor handling without reshaping
- Float8Linear layers: uses custom fp8 module (simpler, same performance)
- This gives us the best of both: high MFU and minimal dependencies

Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kaiyue Wen 2026-02-12 16:59:52 -08:00
parent 29487517ed
commit 931d59c515
2 changed files with 14 additions and 6 deletions

View File

@ -27,11 +27,18 @@ from nanochat.optim import MuonAdamW, DistMuonAdamW
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
from nanochat.flash_attention import flash_attn
# FP8 imports (optional) — needed by reparam_linear for FP8 path
# FP8 imports (optional)
# torchao: used in reparam_linear for efficient N-D tensor handling
# custom fp8: used for regular Float8Linear layers (simpler, same performance)
try:
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
except ImportError:
pass
matmul_with_hp_or_float8_args = None
try:
from nanochat.fp8 import Float8Linear
except ImportError:
Float8Linear = None
@dataclass
class GPTConfig:
@ -58,15 +65,16 @@ def reparam_linear(module, x, gamma=None, scalar=None):
gamma: RMSNorm learnable weight, folded into input dim of W (w = w * gamma[None, :])
scalar: projection scalar, folded into output dim of W (w = scalar[:, None] * w)
For FP8, dispatches through Float8Linear's internal matmul to preserve FP8 tensor cores.
For FP8, uses torchao's matmul which handles N-D tensors efficiently without reshaping.
"""
w = module.weight
if gamma is not None:
w = w * gamma[None, :]
if scalar is not None:
w = scalar[:, None] * w
# FP8 path: use Float8Linear's internal matmul to preserve FP8 tensor cores
if hasattr(module, 'linear_mm_config'):
# FP8 path: use torchao's matmul for efficient N-D tensor handling
# (torchao handles arbitrary shapes without external reshaping overhead)
if hasattr(module, 'linear_mm_config') and matmul_with_hp_or_float8_args is not None:
return matmul_with_hp_or_float8_args.apply(x, w.t(), module.linear_mm_config, module.config)
# BF16 path
return F.linear(x, w)

View File

@ -43,7 +43,7 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d
# Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
# FP8 training
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)")
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao); uses custom fp8 module for layers, torchao for reparam_linear")
parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)")
# Model architecture
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")