From 6e369829781a6a37244e466dfefd06959f77a402 Mon Sep 17 00:00:00 2001 From: Fabian Date: Thu, 30 Apr 2026 17:23:42 +0100 Subject: [PATCH] use BF16 --- scripts/base_train.py | 48 +++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index a161c477..d39af838 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -169,27 +169,39 @@ if args.fp8: if device_type != "cuda": print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag") else: + # FP8 path requires Hopper+ kernels. A100 is SM80 and will fail + # during Triton/Inductor compilation with unsupported fp8 dtypes. + major, minor = torch.cuda.get_device_capability(device) + if major < 9: + print0( + f"Warning: --fp8 requested on SM{major}{minor}, but FP8 training " + "requires Hopper (SM90+) in this codepath. Disabling FP8 and " + "continuing with BF16." + ) + args.fp8 = False + user_config["fp8"] = False + else: # our custom fp8 is simpler than torchao, written for exact API compatibility - from nanochat.fp8 import Float8LinearConfig, convert_to_float8_training - # from torchao.float8 import Float8LinearConfig, convert_to_float8_training - import torch.nn as nn + from nanochat.fp8 import Float8LinearConfig, convert_to_float8_training + # from torchao.float8 import Float8LinearConfig, convert_to_float8_training + import torch.nn as nn - # Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough - def fp8_module_filter(mod: nn.Module, fqn: str) -> bool: - if not isinstance(mod, nn.Linear): - return False - if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: - return False - if min(mod.in_features, mod.out_features) < 128: - return False - return True + # Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough + def fp8_module_filter(mod: nn.Module, fqn: str) -> bool: + if not isinstance(mod, nn.Linear): + return False + if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: + return False + if min(mod.in_features, mod.out_features) < 128: + return False + return True - fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe) - num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) - convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter) - num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__) - num_skipped = num_linear - num_fp8 - print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)") + fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe) + num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) + convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter) + num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__) + num_skipped = num_linear - num_fp8 + print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)") # Context manager to temporarily disable FP8 so that model evaluation remains in BF16 @contextmanager