From 7a40ee77b4695ccb7350a679230eb6a7f8a6ae29 Mon Sep 17 00:00:00 2001 From: Dipesh Babu Date: Mon, 3 Nov 2025 16:00:56 -0500 Subject: [PATCH] fix: cast bf16 to fp32 on MPS (like CPU) to avoid dtype issues --- nanochat/checkpoint_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index a9327c4..2fcb01b 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -65,7 +65,7 @@ def build_model(checkpoint_dir, step, device, phase): """ assert phase in ["train", "eval"], f"Invalid phase: {phase}" model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) - if device.type == "cpu": + if device.type in {"cpu", "mps"}: # Convert bfloat16 tensors to float for CPU inference model_data = { k: v.float() if v.dtype == torch.bfloat16 else v