mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-10 19:25:31 +00:00
Revert to torchao for FP8 training to fix MFU regression
The custom fp8 module had a performance issue in reparam_linear: it was doing reshape→matmul→reshape on every linear layer, and torch.compile couldn't fuse these operations because _Float8Matmul was marked @allow_in_graph (opaque to compiler). torchao's matmul_with_hp_or_float8_args handles N-D tensors directly without external reshaping, allowing better fusion opportunities and higher MFU. Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
31e5bec402
commit
29487517ed
|
|
@ -29,10 +29,9 @@ from nanochat.flash_attention import flash_attn
|
|||
|
||||
# FP8 imports (optional) — needed by reparam_linear for FP8 path
|
||||
try:
|
||||
from nanochat.fp8 import Float8Linear, _Float8Matmul
|
||||
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
|
||||
except ImportError:
|
||||
Float8Linear = None
|
||||
_Float8Matmul = None
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
|
|
@ -67,18 +66,8 @@ def reparam_linear(module, x, gamma=None, scalar=None):
|
|||
if scalar is not None:
|
||||
w = scalar[:, None] * w
|
||||
# FP8 path: use Float8Linear's internal matmul to preserve FP8 tensor cores
|
||||
if Float8Linear is not None and isinstance(module, Float8Linear):
|
||||
# Handle autocast similar to Float8Linear.forward
|
||||
if torch.is_autocast_enabled():
|
||||
x = x.to(torch.get_autocast_gpu_dtype())
|
||||
# Flatten batch dimensions for _Float8Matmul
|
||||
orig_shape = x.shape
|
||||
input_2d = x.reshape(-1, orig_shape[-1])
|
||||
output = _Float8Matmul.apply(input_2d, w)
|
||||
output = output.reshape(*orig_shape[:-1], output.shape[-1])
|
||||
if module.bias is not None:
|
||||
output = output + module.bias.to(output.dtype)
|
||||
return output
|
||||
if hasattr(module, 'linear_mm_config'):
|
||||
return matmul_with_hp_or_float8_args.apply(x, w.t(), module.linear_mm_config, module.config)
|
||||
# BF16 path
|
||||
return F.linear(x, w)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ dependencies = [
|
|||
"tabulate>=0.9.0",
|
||||
"tiktoken>=0.11.0",
|
||||
"tokenizers>=0.22.0",
|
||||
"torchao>=0.13.0",
|
||||
"torch==2.9.1",
|
||||
"transformers>=4.57.3",
|
||||
"uvicorn>=0.36.0",
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d
|
|||
# Runtime
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
# FP8 training
|
||||
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU)")
|
||||
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)")
|
||||
parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)")
|
||||
# Model architecture
|
||||
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user