Replace torchao with minimal custom FP8 implementation

Added _Float8MatmulND to fp8.py:
- Handles N-D input tensors efficiently
- Does reshaping internally (opaque to torch.compile)
- Prevents external reshape overhead that was causing MFU regression
- ~75 lines of clean, documented code

Benefits:
- No torchao dependency (removed from pyproject.toml)
- Same performance as torchao for reparam_linear
- Consistent with fp8.py's minimal philosophy (~350 total lines)
- All FP8 logic in one self-contained module

Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kaiyue Wen 2026-02-12 17:05:06 -08:00
parent 931d59c515
commit fe2a80badd
4 changed files with 90 additions and 16 deletions

View File

@ -196,6 +196,81 @@ class _Float8Matmul(torch.autograd.Function):
return grad_input, grad_weight
@torch._dynamo.allow_in_graph
class _Float8MatmulND(torch.autograd.Function):
"""FP8 matmul that handles N-D input tensors.
Same as _Float8Matmul but accepts inputs of any shape (not just 2D).
Reshaping is done internally so torch.compile sees this as one opaque node,
preventing the reshaping overhead that occurs when reshapes are external.
This is specifically for reparam_linear where N-D tensors are common.
"""
@staticmethod
def forward(ctx, input, weight):
# Save original shape and flatten batch dimensions
orig_shape = input.shape
ctx.orig_shape = orig_shape
input_2d = input.reshape(-1, orig_shape[-1])
ctx.save_for_backward(input_2d, weight)
# Quantize and matmul (same as _Float8Matmul.forward)
input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
output = torch._scaled_mm(
input_fp8,
weight_fp8.t(),
scale_a=input_inv,
scale_b=weight_inv,
out_dtype=input.dtype,
use_fast_accum=True,
)
# Reshape back to original batch dims
output = output.reshape(*orig_shape[:-1], output.shape[-1])
return output
@staticmethod
def backward(ctx, grad_output):
input_2d, weight = ctx.saved_tensors
orig_shape = ctx.orig_shape
# Flatten grad_output to match input_2d
grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1])
# === GEMM 1: grad_input = grad_output @ weight ===
go_fp8, go_inv = _to_fp8(grad_output_flat, torch.float8_e5m2)
w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn)
w_col = _to_col_major(w_fp8)
grad_input_flat = torch._scaled_mm(
go_fp8,
w_col,
scale_a=go_inv,
scale_b=w_inv,
out_dtype=grad_output.dtype,
use_fast_accum=False,
)
# Reshape back to original input shape
grad_input = grad_input_flat.reshape(orig_shape)
# === GEMM 2: grad_weight = grad_output.T @ input ===
go_fp8_2, go_inv_2 = _to_fp8(grad_output_flat, torch.float8_e5m2)
in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
go_T = go_fp8_2.t().contiguous()
in_col = _to_col_major(in_fp8)
grad_weight = torch._scaled_mm(
go_T,
in_col,
scale_a=go_inv_2,
scale_b=in_inv,
out_dtype=grad_output.dtype,
use_fast_accum=False,
)
return grad_input, grad_weight
class Float8Linear(nn.Linear):
"""Drop-in nn.Linear replacement that does FP8 compute.

View File

@ -27,18 +27,12 @@ from nanochat.optim import MuonAdamW, DistMuonAdamW
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
from nanochat.flash_attention import flash_attn
# FP8 imports (optional)
# torchao: used in reparam_linear for efficient N-D tensor handling
# custom fp8: used for regular Float8Linear layers (simpler, same performance)
# FP8 imports (optional) - minimal custom implementation
try:
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
except ImportError:
matmul_with_hp_or_float8_args = None
try:
from nanochat.fp8 import Float8Linear
from nanochat.fp8 import Float8Linear, _Float8MatmulND
except ImportError:
Float8Linear = None
_Float8MatmulND = None
@dataclass
class GPTConfig:
@ -65,17 +59,23 @@ def reparam_linear(module, x, gamma=None, scalar=None):
gamma: RMSNorm learnable weight, folded into input dim of W (w = w * gamma[None, :])
scalar: projection scalar, folded into output dim of W (w = scalar[:, None] * w)
For FP8, uses torchao's matmul which handles N-D tensors efficiently without reshaping.
For FP8, uses minimal custom _Float8MatmulND which handles N-D tensors internally.
"""
w = module.weight
if gamma is not None:
w = w * gamma[None, :]
if scalar is not None:
w = scalar[:, None] * w
# FP8 path: use torchao's matmul for efficient N-D tensor handling
# (torchao handles arbitrary shapes without external reshaping overhead)
if hasattr(module, 'linear_mm_config') and matmul_with_hp_or_float8_args is not None:
return matmul_with_hp_or_float8_args.apply(x, w.t(), module.linear_mm_config, module.config)
# FP8 path: use custom _Float8MatmulND for efficient N-D tensor handling
# (reshaping is done internally, so torch.compile sees it as one opaque operation)
if Float8Linear is not None and isinstance(module, Float8Linear):
# Handle autocast (Float8Linear expects this)
if torch.is_autocast_enabled():
x = x.to(torch.get_autocast_gpu_dtype())
output = _Float8MatmulND.apply(x, w)
if module.bias is not None:
output = output + module.bias.to(output.dtype)
return output
# BF16 path
return F.linear(x, w)

View File

@ -19,7 +19,6 @@ dependencies = [
"tabulate>=0.9.0",
"tiktoken>=0.11.0",
"tokenizers>=0.22.0",
"torchao>=0.13.0",
"torch==2.9.1",
"transformers>=4.57.3",
"uvicorn>=0.36.0",

View File

@ -43,7 +43,7 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d
# Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
# FP8 training
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao); uses custom fp8 module for layers, torchao for reparam_linear")
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU); uses minimal custom fp8 module")
parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)")
# Model architecture
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")