when loading models on CPU, convert tensors from bfloat16 to float

This commit is contained in:
Andrej 2025-11-02 07:58:56 -08:00 committed by GitHub
commit d1ac0b2d07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -65,6 +65,12 @@ def build_model(checkpoint_dir, step, device, phase):
""" """
assert phase in ["train", "eval"], f"Invalid phase: {phase}" assert phase in ["train", "eval"], f"Invalid phase: {phase}"
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
if device.type == "cpu":
# Convert bfloat16 tensors to float for CPU inference
model_data = {
k: v.float() if v.dtype == torch.bfloat16 else v
for k, v in model_data.items()
}
# Hack: fix torch compile issue, which prepends all keys with _orig_mod. # Hack: fix torch compile issue, which prepends all keys with _orig_mod.
model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()} model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()}
model_config_kwargs = meta_data["model_config"] model_config_kwargs = meta_data["model_config"]