mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
when loading models on CPU, convert tensors from bfloat16 to float
This commit is contained in:
commit
d1ac0b2d07
|
|
@ -65,6 +65,12 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
"""
|
||||
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
|
||||
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
|
||||
if device.type == "cpu":
|
||||
# Convert bfloat16 tensors to float for CPU inference
|
||||
model_data = {
|
||||
k: v.float() if v.dtype == torch.bfloat16 else v
|
||||
for k, v in model_data.items()
|
||||
}
|
||||
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
||||
model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()}
|
||||
model_config_kwargs = meta_data["model_config"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user