mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-09 18:55:30 +00:00
Replace torchao with minimal custom FP8 implementation
Added _Float8MatmulND to fp8.py: - Handles N-D input tensors efficiently - Does reshaping internally (opaque to torch.compile) - Prevents external reshape overhead that was causing MFU regression - ~75 lines of clean, documented code Benefits: - No torchao dependency (removed from pyproject.toml) - Same performance as torchao for reparam_linear - Consistent with fp8.py's minimal philosophy (~350 total lines) - All FP8 logic in one self-contained module Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
931d59c515
commit
fe2a80badd
|
|
@ -196,6 +196,81 @@ class _Float8Matmul(torch.autograd.Function):
|
|||
return grad_input, grad_weight
|
||||
|
||||
|
||||
@torch._dynamo.allow_in_graph
|
||||
class _Float8MatmulND(torch.autograd.Function):
|
||||
"""FP8 matmul that handles N-D input tensors.
|
||||
|
||||
Same as _Float8Matmul but accepts inputs of any shape (not just 2D).
|
||||
Reshaping is done internally so torch.compile sees this as one opaque node,
|
||||
preventing the reshaping overhead that occurs when reshapes are external.
|
||||
|
||||
This is specifically for reparam_linear where N-D tensors are common.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
# Save original shape and flatten batch dimensions
|
||||
orig_shape = input.shape
|
||||
ctx.orig_shape = orig_shape
|
||||
input_2d = input.reshape(-1, orig_shape[-1])
|
||||
ctx.save_for_backward(input_2d, weight)
|
||||
|
||||
# Quantize and matmul (same as _Float8Matmul.forward)
|
||||
input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
|
||||
weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
|
||||
output = torch._scaled_mm(
|
||||
input_fp8,
|
||||
weight_fp8.t(),
|
||||
scale_a=input_inv,
|
||||
scale_b=weight_inv,
|
||||
out_dtype=input.dtype,
|
||||
use_fast_accum=True,
|
||||
)
|
||||
|
||||
# Reshape back to original batch dims
|
||||
output = output.reshape(*orig_shape[:-1], output.shape[-1])
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input_2d, weight = ctx.saved_tensors
|
||||
orig_shape = ctx.orig_shape
|
||||
|
||||
# Flatten grad_output to match input_2d
|
||||
grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1])
|
||||
|
||||
# === GEMM 1: grad_input = grad_output @ weight ===
|
||||
go_fp8, go_inv = _to_fp8(grad_output_flat, torch.float8_e5m2)
|
||||
w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn)
|
||||
w_col = _to_col_major(w_fp8)
|
||||
grad_input_flat = torch._scaled_mm(
|
||||
go_fp8,
|
||||
w_col,
|
||||
scale_a=go_inv,
|
||||
scale_b=w_inv,
|
||||
out_dtype=grad_output.dtype,
|
||||
use_fast_accum=False,
|
||||
)
|
||||
# Reshape back to original input shape
|
||||
grad_input = grad_input_flat.reshape(orig_shape)
|
||||
|
||||
# === GEMM 2: grad_weight = grad_output.T @ input ===
|
||||
go_fp8_2, go_inv_2 = _to_fp8(grad_output_flat, torch.float8_e5m2)
|
||||
in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
|
||||
go_T = go_fp8_2.t().contiguous()
|
||||
in_col = _to_col_major(in_fp8)
|
||||
grad_weight = torch._scaled_mm(
|
||||
go_T,
|
||||
in_col,
|
||||
scale_a=go_inv_2,
|
||||
scale_b=in_inv,
|
||||
out_dtype=grad_output.dtype,
|
||||
use_fast_accum=False,
|
||||
)
|
||||
|
||||
return grad_input, grad_weight
|
||||
|
||||
|
||||
class Float8Linear(nn.Linear):
|
||||
"""Drop-in nn.Linear replacement that does FP8 compute.
|
||||
|
||||
|
|
|
|||
|
|
@ -27,18 +27,12 @@ from nanochat.optim import MuonAdamW, DistMuonAdamW
|
|||
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
||||
from nanochat.flash_attention import flash_attn
|
||||
|
||||
# FP8 imports (optional)
|
||||
# torchao: used in reparam_linear for efficient N-D tensor handling
|
||||
# custom fp8: used for regular Float8Linear layers (simpler, same performance)
|
||||
# FP8 imports (optional) - minimal custom implementation
|
||||
try:
|
||||
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
|
||||
except ImportError:
|
||||
matmul_with_hp_or_float8_args = None
|
||||
|
||||
try:
|
||||
from nanochat.fp8 import Float8Linear
|
||||
from nanochat.fp8 import Float8Linear, _Float8MatmulND
|
||||
except ImportError:
|
||||
Float8Linear = None
|
||||
_Float8MatmulND = None
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
|
|
@ -65,17 +59,23 @@ def reparam_linear(module, x, gamma=None, scalar=None):
|
|||
gamma: RMSNorm learnable weight, folded into input dim of W (w = w * gamma[None, :])
|
||||
scalar: projection scalar, folded into output dim of W (w = scalar[:, None] * w)
|
||||
|
||||
For FP8, uses torchao's matmul which handles N-D tensors efficiently without reshaping.
|
||||
For FP8, uses minimal custom _Float8MatmulND which handles N-D tensors internally.
|
||||
"""
|
||||
w = module.weight
|
||||
if gamma is not None:
|
||||
w = w * gamma[None, :]
|
||||
if scalar is not None:
|
||||
w = scalar[:, None] * w
|
||||
# FP8 path: use torchao's matmul for efficient N-D tensor handling
|
||||
# (torchao handles arbitrary shapes without external reshaping overhead)
|
||||
if hasattr(module, 'linear_mm_config') and matmul_with_hp_or_float8_args is not None:
|
||||
return matmul_with_hp_or_float8_args.apply(x, w.t(), module.linear_mm_config, module.config)
|
||||
# FP8 path: use custom _Float8MatmulND for efficient N-D tensor handling
|
||||
# (reshaping is done internally, so torch.compile sees it as one opaque operation)
|
||||
if Float8Linear is not None and isinstance(module, Float8Linear):
|
||||
# Handle autocast (Float8Linear expects this)
|
||||
if torch.is_autocast_enabled():
|
||||
x = x.to(torch.get_autocast_gpu_dtype())
|
||||
output = _Float8MatmulND.apply(x, w)
|
||||
if module.bias is not None:
|
||||
output = output + module.bias.to(output.dtype)
|
||||
return output
|
||||
# BF16 path
|
||||
return F.linear(x, w)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ dependencies = [
|
|||
"tabulate>=0.9.0",
|
||||
"tiktoken>=0.11.0",
|
||||
"tokenizers>=0.22.0",
|
||||
"torchao>=0.13.0",
|
||||
"torch==2.9.1",
|
||||
"transformers>=4.57.3",
|
||||
"uvicorn>=0.36.0",
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d
|
|||
# Runtime
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
# FP8 training
|
||||
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao); uses custom fp8 module for layers, torchao for reparam_linear")
|
||||
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU); uses minimal custom fp8 module")
|
||||
parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)")
|
||||
# Model architecture
|
||||
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user