From 47e983eea7513d545fb6becc8b32756b6c43d06b Mon Sep 17 00:00:00 2001 From: RoomWithOutRoof <166608075+Jah-yee@users.noreply.github.com> Date: Thu, 26 Mar 2026 05:24:57 +0800 Subject: [PATCH] fix: use meta device in disable_fp8 to avoid VRAM spike (#616) When swapping Float8Linear to Linear in disable_fp8 context manager, using device=fp8_module.weight.device directly allocates new tensors on GPU, causing unnecessary VRAM spike (~1GB for large models). This fix uses device='meta' to avoid physical memory allocation, then swaps in the weight tensor reference. This eliminates the unnecessary VRAM spike during evaluation phase. Fixes issue #592 Co-authored-by: RoomWithOutRoof --- scripts/base_train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index c7683c9..a161c47 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -218,12 +218,13 @@ def disable_fp8(model): return # Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype) + # Use device="meta" to avoid VRAM spike - the weight tensor will be swapped in afterwards for parent, attr_name, fp8_module in fp8_locations: linear = Linear( fp8_module.in_features, fp8_module.out_features, bias=fp8_module.bias is not None, - device=fp8_module.weight.device, + device="meta", # Use meta device to avoid unnecessary VRAM allocation dtype=fp8_module.weight.dtype, ) linear.weight = fp8_module.weight # share, don't copy