handle bf16 on MPS by casting to fp32 during load checkpoint

This commit is contained in:
Andrej 2025-11-04 09:42:50 -08:00 committed by GitHub
commit d1558c7873
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -65,7 +65,7 @@ def build_model(checkpoint_dir, step, device, phase):
""" """
assert phase in ["train", "eval"], f"Invalid phase: {phase}" assert phase in ["train", "eval"], f"Invalid phase: {phase}"
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
if device.type == "cpu": if device.type in {"cpu", "mps"}:
# Convert bfloat16 tensors to float for CPU inference # Convert bfloat16 tensors to float for CPU inference
model_data = { model_data = {
k: v.float() if v.dtype == torch.bfloat16 else v k: v.float() if v.dtype == torch.bfloat16 else v