fix: cast bf16 to fp32 on MPS (like CPU) to avoid dtype issues

This commit is contained in:
Dipesh Babu 2025-11-03 16:00:56 -05:00
parent a83646e098
commit 7a40ee77b4

View File

@ -65,7 +65,7 @@ def build_model(checkpoint_dir, step, device, phase):
""" """
assert phase in ["train", "eval"], f"Invalid phase: {phase}" assert phase in ["train", "eval"], f"Invalid phase: {phase}"
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
if device.type == "cpu": if device.type in {"cpu", "mps"}:
# Convert bfloat16 tensors to float for CPU inference # Convert bfloat16 tensors to float for CPU inference
model_data = { model_data = {
k: v.float() if v.dtype == torch.bfloat16 else v k: v.float() if v.dtype == torch.bfloat16 else v