mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-20 03:43:20 +00:00
Replace torchao with custom fp8 module in gpt.py
- Update reparam_linear to use nanochat.fp8.Float8Linear instead of torchao - Replace matmul_with_hp_or_float8_args with direct _Float8Matmul.apply call - Remove torchao dependency mention from base_train.py help text - Functionally equivalent: both use torch._scaled_mm, custom version ~3% faster Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
ee04406ebb
commit
31e5bec402
|
|
@ -29,9 +29,10 @@ from nanochat.flash_attention import flash_attn
|
|||
|
||||
# FP8 imports (optional) — needed by reparam_linear for FP8 path
|
||||
try:
|
||||
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
|
||||
from nanochat.fp8 import Float8Linear, _Float8Matmul
|
||||
except ImportError:
|
||||
pass
|
||||
Float8Linear = None
|
||||
_Float8Matmul = None
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
|
|
@ -66,8 +67,18 @@ def reparam_linear(module, x, gamma=None, scalar=None):
|
|||
if scalar is not None:
|
||||
w = scalar[:, None] * w
|
||||
# FP8 path: use Float8Linear's internal matmul to preserve FP8 tensor cores
|
||||
if hasattr(module, 'linear_mm_config'):
|
||||
return matmul_with_hp_or_float8_args.apply(x, w.t(), module.linear_mm_config, module.config)
|
||||
if Float8Linear is not None and isinstance(module, Float8Linear):
|
||||
# Handle autocast similar to Float8Linear.forward
|
||||
if torch.is_autocast_enabled():
|
||||
x = x.to(torch.get_autocast_gpu_dtype())
|
||||
# Flatten batch dimensions for _Float8Matmul
|
||||
orig_shape = x.shape
|
||||
input_2d = x.reshape(-1, orig_shape[-1])
|
||||
output = _Float8Matmul.apply(input_2d, w)
|
||||
output = output.reshape(*orig_shape[:-1], output.shape[-1])
|
||||
if module.bias is not None:
|
||||
output = output + module.bias.to(output.dtype)
|
||||
return output
|
||||
# BF16 path
|
||||
return F.linear(x, w)
|
||||
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d
|
|||
# Runtime
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
# FP8 training
|
||||
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)")
|
||||
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU)")
|
||||
parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)")
|
||||
# Model architecture
|
||||
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user