mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 05:35:19 +00:00
Merge 7640a1781f into 1b1cc3c599
This commit is contained in:
commit
71d90265f5
|
|
@ -218,12 +218,13 @@ def disable_fp8(model):
|
|||
return
|
||||
|
||||
# Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype)
|
||||
# Use device="meta" to avoid VRAM spike - the weight tensor will be swapped in afterwards
|
||||
for parent, attr_name, fp8_module in fp8_locations:
|
||||
linear = Linear(
|
||||
fp8_module.in_features,
|
||||
fp8_module.out_features,
|
||||
bias=fp8_module.bias is not None,
|
||||
device=fp8_module.weight.device,
|
||||
device="meta", # Use meta device to avoid unnecessary VRAM allocation
|
||||
dtype=fp8_module.weight.dtype,
|
||||
)
|
||||
linear.weight = fp8_module.weight # share, don't copy
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user