diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index f400d47..262ff97 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -65,6 +65,12 @@ def build_model(checkpoint_dir, step, device, phase): """ assert phase in ["train", "eval"], f"Invalid phase: {phase}" model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) + if device.type == "cpu": + # Convert bfloat16 tensors to float for CPU inference + model_data = { + k: v.float() if v.dtype == torch.bfloat16 else v + for k, v in model_data.items() + } # Hack: fix torch compile issue, which prepends all keys with _orig_mod. model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()} model_config_kwargs = meta_data["model_config"]