diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index f400d47..26fdb0d 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -1,6 +1,7 @@ """ Utilities for saving and loading model/optim/state checkpoints. """ + import os import re import glob @@ -16,12 +17,15 @@ from nanochat.common import setup_default_logging # Set up logging setup_default_logging() logger = logging.getLogger(__name__) + + def log0(message): - if int(os.environ.get('RANK', 0)) == 0: + if int(os.environ.get("RANK", 0)) == 0: logger.info(message) + def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data): - assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now + assert int(os.environ.get("RANK", 0)) == 0 # prevent footguns for now os.makedirs(checkpoint_dir, exist_ok=True) # Save the model state (parameters) model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") @@ -64,7 +68,15 @@ def build_model(checkpoint_dir, step, device, phase): - meta data saved during base model training """ assert phase in ["train", "eval"], f"Invalid phase: {phase}" - model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) + model_data, optimizer_data, meta_data = load_checkpoint( + checkpoint_dir, step, device, load_optimizer=False + ) + if device.type == "cpu": + # Convert bfloat16 tensors to float for CPU inference + model_data = { + k: v.float() if v.dtype == torch.bfloat16 else v + for k, v in model_data.items() + } # Hack: fix torch compile issue, which prepends all keys with _orig_mod. model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()} model_config_kwargs = meta_data["model_config"] @@ -74,7 +86,7 @@ def build_model(checkpoint_dir, step, device, phase): model = GPT(model_config) # Load the model state model.to_empty(device=device) - model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init + model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init model.load_state_dict(model_data, strict=True, assign=True) # Put the model in the right training phase / mode if phase == "eval": @@ -90,7 +102,11 @@ def build_model(checkpoint_dir, step, device, phase): def find_largest_model(checkpoint_dir): # attempt to guess the model tag: take the biggest model available - model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))] + model_tags = [ + f + for f in os.listdir(checkpoint_dir) + if os.path.isdir(os.path.join(checkpoint_dir, f)) + ] if not model_tags: raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") # 1) normally all model tags are of the form d, try that first: @@ -104,7 +120,9 @@ def find_largest_model(checkpoint_dir): candidates.sort(key=lambda x: x[0], reverse=True) return candidates[0][1] # 2) if that failed, take the most recently updated model: - model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True) + model_tags.sort( + key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True + ) return model_tags[0] @@ -113,12 +131,16 @@ def find_last_step(checkpoint_dir): checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt")) if not checkpoint_files: raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") - last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)) + last_step = int( + max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files) + ) return last_step + # ----------------------------------------------------------------------------- # convenience functions that take into account nanochat's directory structure + def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None): if model_tag is None: # guess the model tag by defaulting to the largest model @@ -134,6 +156,7 @@ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=Non model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase) return model, tokenizer, meta_data + def load_model(source, *args, **kwargs): model_dir = { "base": "base_checkpoints",