fix: use meta device in disable_fp8 to avoid VRAM spike (#616)

When swapping Float8Linear to Linear in disable_fp8 context manager,
using device=fp8_module.weight.device directly allocates new tensors
on GPU, causing unnecessary VRAM spike (~1GB for large models).

This fix uses device='meta' to avoid physical memory allocation,
then swaps in the weight tensor reference. This eliminates the
unnecessary VRAM spike during evaluation phase.

Fixes issue #592

Co-authored-by: RoomWithOutRoof <roomwithoutroof@sparklab.ai>
This commit is contained in:
RoomWithOutRoof 2026-03-26 05:24:57 +08:00 committed by GitHub
parent c0dbf1f3ff
commit 47e983eea7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -218,12 +218,13 @@ def disable_fp8(model):
return
# Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype)
# Use device="meta" to avoid VRAM spike - the weight tensor will be swapped in afterwards
for parent, attr_name, fp8_module in fp8_locations:
linear = Linear(
fp8_module.in_features,
fp8_module.out_features,
bias=fp8_module.bias is not None,
device=fp8_module.weight.device,
device="meta", # Use meta device to avoid unnecessary VRAM allocation
dtype=fp8_module.weight.dtype,
)
linear.weight = fp8_module.weight # share, don't copy