CPU Support, as bfloat16 params breaks inference

This commit is contained in:
Manuel Saelices 2025-11-01 23:38:50 +01:00
parent dfc88334b6
commit d54c9cbf8c

View File

@ -1,6 +1,7 @@
""" """
Utilities for saving and loading model/optim/state checkpoints. Utilities for saving and loading model/optim/state checkpoints.
""" """
import os import os
import re import re
import glob import glob
@ -16,12 +17,15 @@ from nanochat.common import setup_default_logging
# Set up logging # Set up logging
setup_default_logging() setup_default_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def log0(message): def log0(message):
if int(os.environ.get('RANK', 0)) == 0: if int(os.environ.get("RANK", 0)) == 0:
logger.info(message) logger.info(message)
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data): def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now assert int(os.environ.get("RANK", 0)) == 0 # prevent footguns for now
os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True)
# Save the model state (parameters) # Save the model state (parameters)
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
@ -64,7 +68,15 @@ def build_model(checkpoint_dir, step, device, phase):
- meta data saved during base model training - meta data saved during base model training
""" """
assert phase in ["train", "eval"], f"Invalid phase: {phase}" assert phase in ["train", "eval"], f"Invalid phase: {phase}"
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) model_data, optimizer_data, meta_data = load_checkpoint(
checkpoint_dir, step, device, load_optimizer=False
)
if device.type == "cpu":
# Convert bfloat16 tensors to float for CPU inference
model_data = {
k: v.float() if v.dtype == torch.bfloat16 else v
for k, v in model_data.items()
}
# Hack: fix torch compile issue, which prepends all keys with _orig_mod. # Hack: fix torch compile issue, which prepends all keys with _orig_mod.
model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()} model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()}
model_config_kwargs = meta_data["model_config"] model_config_kwargs = meta_data["model_config"]
@ -90,7 +102,11 @@ def build_model(checkpoint_dir, step, device, phase):
def find_largest_model(checkpoint_dir): def find_largest_model(checkpoint_dir):
# attempt to guess the model tag: take the biggest model available # attempt to guess the model tag: take the biggest model available
model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))] model_tags = [
f
for f in os.listdir(checkpoint_dir)
if os.path.isdir(os.path.join(checkpoint_dir, f))
]
if not model_tags: if not model_tags:
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
# 1) normally all model tags are of the form d<number>, try that first: # 1) normally all model tags are of the form d<number>, try that first:
@ -104,7 +120,9 @@ def find_largest_model(checkpoint_dir):
candidates.sort(key=lambda x: x[0], reverse=True) candidates.sort(key=lambda x: x[0], reverse=True)
return candidates[0][1] return candidates[0][1]
# 2) if that failed, take the most recently updated model: # 2) if that failed, take the most recently updated model:
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True) model_tags.sort(
key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True
)
return model_tags[0] return model_tags[0]
@ -113,12 +131,16 @@ def find_last_step(checkpoint_dir):
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt")) checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
if not checkpoint_files: if not checkpoint_files:
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)) last_step = int(
max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)
)
return last_step return last_step
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# convenience functions that take into account nanochat's directory structure # convenience functions that take into account nanochat's directory structure
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None): def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
if model_tag is None: if model_tag is None:
# guess the model tag by defaulting to the largest model # guess the model tag by defaulting to the largest model
@ -134,6 +156,7 @@ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=Non
model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase) model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
return model, tokenizer, meta_data return model, tokenizer, meta_data
def load_model(source, *args, **kwargs): def load_model(source, *args, **kwargs):
model_dir = { model_dir = {
"base": "base_checkpoints", "base": "base_checkpoints",