diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 8779a85..14e7c0e 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -27,11 +27,18 @@ from nanochat.optim import MuonAdamW, DistMuonAdamW # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere from nanochat.flash_attention import flash_attn -# FP8 imports (optional) — needed by reparam_linear for FP8 path +# FP8 imports (optional) +# torchao: used in reparam_linear for efficient N-D tensor handling +# custom fp8: used for regular Float8Linear layers (simpler, same performance) try: from torchao.float8.float8_linear import matmul_with_hp_or_float8_args except ImportError: - pass + matmul_with_hp_or_float8_args = None + +try: + from nanochat.fp8 import Float8Linear +except ImportError: + Float8Linear = None @dataclass class GPTConfig: @@ -58,15 +65,16 @@ def reparam_linear(module, x, gamma=None, scalar=None): gamma: RMSNorm learnable weight, folded into input dim of W (w = w * gamma[None, :]) scalar: projection scalar, folded into output dim of W (w = scalar[:, None] * w) - For FP8, dispatches through Float8Linear's internal matmul to preserve FP8 tensor cores. + For FP8, uses torchao's matmul which handles N-D tensors efficiently without reshaping. """ w = module.weight if gamma is not None: w = w * gamma[None, :] if scalar is not None: w = scalar[:, None] * w - # FP8 path: use Float8Linear's internal matmul to preserve FP8 tensor cores - if hasattr(module, 'linear_mm_config'): + # FP8 path: use torchao's matmul for efficient N-D tensor handling + # (torchao handles arbitrary shapes without external reshaping overhead) + if hasattr(module, 'linear_mm_config') and matmul_with_hp_or_float8_args is not None: return matmul_with_hp_or_float8_args.apply(x, w.t(), module.linear_mm_config, module.config) # BF16 path return F.linear(x, w) diff --git a/scripts/base_train.py b/scripts/base_train.py index f41c3aa..432c998 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -43,7 +43,7 @@ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('d # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") # FP8 training -parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)") +parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao); uses custom fp8 module for layers, torchao for reparam_linear") parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") # Model architecture parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")