mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-07 09:50:28 +00:00
fix minor bug in fp8 application to skip tiny matmuls
This commit is contained in:
parent
ad55575326
commit
bac5a35dd7
|
|
@ -170,20 +170,22 @@ if args.fp8:
|
|||
# from torchao.float8 import Float8LinearConfig, convert_to_float8_training
|
||||
import torch.nn as nn
|
||||
|
||||
# Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement)
|
||||
# Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough
|
||||
def fp8_module_filter(mod: nn.Module, fqn: str) -> bool:
|
||||
if not isinstance(mod, nn.Linear):
|
||||
return False
|
||||
# FP8 requires both in_features and out_features divisible by 16
|
||||
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
|
||||
return False
|
||||
if min(mod.in_features, mod.out_features) < 128:
|
||||
return False
|
||||
return True
|
||||
|
||||
fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe)
|
||||
num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
|
||||
convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter)
|
||||
num_fp8_layers = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
|
||||
num_skipped = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) - num_fp8_layers
|
||||
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8_layers} layers, skipped {num_skipped} (dims not divisible by 16)")
|
||||
num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
|
||||
num_skipped = num_linear - num_fp8
|
||||
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)")
|
||||
|
||||
# Context manager to temporarily disable FP8 so that model evaluation remains in BF16
|
||||
@contextmanager
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user