mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
handle bf16 on MPS by casting to fp32 during load checkpoint
This commit is contained in:
commit
d1558c7873
|
|
@ -65,7 +65,7 @@ def build_model(checkpoint_dir, step, device, phase):
|
||||||
"""
|
"""
|
||||||
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
|
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
|
||||||
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
|
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
|
||||||
if device.type == "cpu":
|
if device.type in {"cpu", "mps"}:
|
||||||
# Convert bfloat16 tensors to float for CPU inference
|
# Convert bfloat16 tensors to float for CPU inference
|
||||||
model_data = {
|
model_data = {
|
||||||
k: v.float() if v.dtype == torch.bfloat16 else v
|
k: v.float() if v.dtype == torch.bfloat16 else v
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user