CPU Support, as bfloat16 params breaks inference

This commit is contained in:
Manuel Saelices 2025-11-01 23:38:50 +01:00
parent dfc88334b6
commit d54c9cbf8c

View File

@ -1,6 +1,7 @@
"""
Utilities for saving and loading model/optim/state checkpoints.
"""
import os
import re
import glob
@ -16,12 +17,15 @@ from nanochat.common import setup_default_logging
# Set up logging
setup_default_logging()
logger = logging.getLogger(__name__)
def log0(message):
if int(os.environ.get('RANK', 0)) == 0:
if int(os.environ.get("RANK", 0)) == 0:
logger.info(message)
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now
assert int(os.environ.get("RANK", 0)) == 0 # prevent footguns for now
os.makedirs(checkpoint_dir, exist_ok=True)
# Save the model state (parameters)
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
@ -64,7 +68,15 @@ def build_model(checkpoint_dir, step, device, phase):
- meta data saved during base model training
"""
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
model_data, optimizer_data, meta_data = load_checkpoint(
checkpoint_dir, step, device, load_optimizer=False
)
if device.type == "cpu":
# Convert bfloat16 tensors to float for CPU inference
model_data = {
k: v.float() if v.dtype == torch.bfloat16 else v
for k, v in model_data.items()
}
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()}
model_config_kwargs = meta_data["model_config"]
@ -90,7 +102,11 @@ def build_model(checkpoint_dir, step, device, phase):
def find_largest_model(checkpoint_dir):
# attempt to guess the model tag: take the biggest model available
model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
model_tags = [
f
for f in os.listdir(checkpoint_dir)
if os.path.isdir(os.path.join(checkpoint_dir, f))
]
if not model_tags:
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
# 1) normally all model tags are of the form d<number>, try that first:
@ -104,7 +120,9 @@ def find_largest_model(checkpoint_dir):
candidates.sort(key=lambda x: x[0], reverse=True)
return candidates[0][1]
# 2) if that failed, take the most recently updated model:
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
model_tags.sort(
key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True
)
return model_tags[0]
@ -113,12 +131,16 @@ def find_last_step(checkpoint_dir):
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
if not checkpoint_files:
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
last_step = int(
max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)
)
return last_step
# -----------------------------------------------------------------------------
# convenience functions that take into account nanochat's directory structure
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
if model_tag is None:
# guess the model tag by defaulting to the largest model
@ -134,6 +156,7 @@ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=Non
model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
return model, tokenizer, meta_data
def load_model(source, *args, **kwargs):
model_dir = {
"base": "base_checkpoints",