mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-28 10:38:37 +00:00
use BF16
This commit is contained in:
parent
6bf12998d8
commit
6e36982978
|
|
@ -169,27 +169,39 @@ if args.fp8:
|
|||
if device_type != "cuda":
|
||||
print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag")
|
||||
else:
|
||||
# FP8 path requires Hopper+ kernels. A100 is SM80 and will fail
|
||||
# during Triton/Inductor compilation with unsupported fp8 dtypes.
|
||||
major, minor = torch.cuda.get_device_capability(device)
|
||||
if major < 9:
|
||||
print0(
|
||||
f"Warning: --fp8 requested on SM{major}{minor}, but FP8 training "
|
||||
"requires Hopper (SM90+) in this codepath. Disabling FP8 and "
|
||||
"continuing with BF16."
|
||||
)
|
||||
args.fp8 = False
|
||||
user_config["fp8"] = False
|
||||
else:
|
||||
# our custom fp8 is simpler than torchao, written for exact API compatibility
|
||||
from nanochat.fp8 import Float8LinearConfig, convert_to_float8_training
|
||||
# from torchao.float8 import Float8LinearConfig, convert_to_float8_training
|
||||
import torch.nn as nn
|
||||
from nanochat.fp8 import Float8LinearConfig, convert_to_float8_training
|
||||
# from torchao.float8 import Float8LinearConfig, convert_to_float8_training
|
||||
import torch.nn as nn
|
||||
|
||||
# Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough
|
||||
def fp8_module_filter(mod: nn.Module, fqn: str) -> bool:
|
||||
if not isinstance(mod, nn.Linear):
|
||||
return False
|
||||
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
|
||||
return False
|
||||
if min(mod.in_features, mod.out_features) < 128:
|
||||
return False
|
||||
return True
|
||||
# Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough
|
||||
def fp8_module_filter(mod: nn.Module, fqn: str) -> bool:
|
||||
if not isinstance(mod, nn.Linear):
|
||||
return False
|
||||
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
|
||||
return False
|
||||
if min(mod.in_features, mod.out_features) < 128:
|
||||
return False
|
||||
return True
|
||||
|
||||
fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe)
|
||||
num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
|
||||
convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter)
|
||||
num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
|
||||
num_skipped = num_linear - num_fp8
|
||||
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)")
|
||||
fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe)
|
||||
num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
|
||||
convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter)
|
||||
num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
|
||||
num_skipped = num_linear - num_fp8
|
||||
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)")
|
||||
|
||||
# Context manager to temporarily disable FP8 so that model evaluation remains in BF16
|
||||
@contextmanager
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user