diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index 26fdb0d..a1120cb 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -1,7 +1,6 @@ """ Utilities for saving and loading model/optim/state checkpoints. """ - import os import re import glob @@ -17,15 +16,13 @@ from nanochat.common import setup_default_logging # Set up logging setup_default_logging() logger = logging.getLogger(__name__) - - def log0(message): - if int(os.environ.get("RANK", 0)) == 0: + if int(os.environ.get('RANK', 0)) == 0: logger.info(message) def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data): - assert int(os.environ.get("RANK", 0)) == 0 # prevent footguns for now + assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now os.makedirs(checkpoint_dir, exist_ok=True) # Save the model state (parameters) model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") @@ -68,9 +65,7 @@ def build_model(checkpoint_dir, step, device, phase): - meta data saved during base model training """ assert phase in ["train", "eval"], f"Invalid phase: {phase}" - model_data, optimizer_data, meta_data = load_checkpoint( - checkpoint_dir, step, device, load_optimizer=False - ) + model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) if device.type == "cpu": # Convert bfloat16 tensors to float for CPU inference model_data = { @@ -86,7 +81,7 @@ def build_model(checkpoint_dir, step, device, phase): model = GPT(model_config) # Load the model state model.to_empty(device=device) - model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init + model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init model.load_state_dict(model_data, strict=True, assign=True) # Put the model in the right training phase / mode if phase == "eval": @@ -102,11 +97,7 @@ def build_model(checkpoint_dir, step, device, phase): def find_largest_model(checkpoint_dir): # attempt to guess the model tag: take the biggest model available - model_tags = [ - f - for f in os.listdir(checkpoint_dir) - if os.path.isdir(os.path.join(checkpoint_dir, f)) - ] + model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))] if not model_tags: raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") # 1) normally all model tags are of the form d, try that first: @@ -120,9 +111,7 @@ def find_largest_model(checkpoint_dir): candidates.sort(key=lambda x: x[0], reverse=True) return candidates[0][1] # 2) if that failed, take the most recently updated model: - model_tags.sort( - key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True - ) + model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True) return model_tags[0] @@ -131,16 +120,12 @@ def find_last_step(checkpoint_dir): checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt")) if not checkpoint_files: raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") - last_step = int( - max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files) - ) + last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)) return last_step - # ----------------------------------------------------------------------------- # convenience functions that take into account nanochat's directory structure - def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None): if model_tag is None: # guess the model tag by defaulting to the largest model @@ -156,7 +141,6 @@ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=Non model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase) return model, tokenizer, meta_data - def load_model(source, *args, **kwargs): model_dir = { "base": "base_checkpoints",