diff --git a/scripts/base_train.py b/scripts/base_train.py index c7683c9..a161c47 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -218,12 +218,13 @@ def disable_fp8(model): return # Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype) + # Use device="meta" to avoid VRAM spike - the weight tensor will be swapped in afterwards for parent, attr_name, fp8_module in fp8_locations: linear = Linear( fp8_module.in_features, fp8_module.out_features, bias=fp8_module.bias is not None, - device=fp8_module.weight.device, + device="meta", # Use meta device to avoid unnecessary VRAM allocation dtype=fp8_module.weight.dtype, ) linear.weight = fp8_module.weight # share, don't copy