""" Utilities for saving and loading model/optim/state checkpoints. """ import os import re import glob import json import logging import torch import io from google.cloud import storage from nanochat.common import get_base_dir from nanochat.gpt import GPT, GPTConfig from nanochat.tokenizer import get_tokenizer from nanochat.common import setup_default_logging # Set up logging setup_default_logging() logger = logging.getLogger(__name__) def log0(message): if int(os.environ.get('RANK', 0)) == 0: logger.info(message) def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data): assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now if checkpoint_dir.startswith("gs://"): storage_client = storage.Client() bucket_name, prefix = checkpoint_dir[5:].split("/", 1) bucket = storage_client.bucket(bucket_name) # Save model data model_blob = bucket.blob(f"{prefix}/model_{step:06d}.pt") with io.BytesIO() as buffer: torch.save(model_data, buffer) buffer.seek(0) model_blob.upload_from_file(buffer) log0(f"Saved model file to: gs://{bucket_name}/{prefix}/model_{step:06d}.pt") # Save optimizer data if optimizer_data is not None: optimizer_blob = bucket.blob(f"{prefix}/optim_{step:06d}.pt") with io.BytesIO() as buffer: torch.save(optimizer_data, buffer) buffer.seek(0) optimizer_blob.upload_from_file(buffer) log0(f"Saved optimizer file to: gs://{bucket_name}/{prefix}/optim_{step:06d}.pt") # Save metadata meta_blob = bucket.blob(f"{prefix}/meta_{step:06d}.json") meta_blob.upload_from_string(json.dumps(meta_data, indent=2)) log0(f"Saved metadata file to: gs://{bucket_name}/{prefix}/meta_{step:06d}.json") else: os.makedirs(checkpoint_dir, exist_ok=True) # Save the model state (parameters) model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") torch.save(model_data, model_path) log0(f"Saved model file to: {model_path}") # Save the optimizer state (useful for SFT or any other fine-tuning) if optimizer_data is not None: optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt") torch.save(optimizer_data, optimizer_path) log0(f"Saved optimizer file to: {optimizer_path}") # Save the metadata dict as json meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") with open(meta_path, "w") as f: json.dump(meta_data, f, indent=2) log0(f"Saved metadata file to: {meta_path}") def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False): if checkpoint_dir.startswith("gs://"): storage_client = storage.Client() bucket_name, prefix = checkpoint_dir[5:].split("/", 1) bucket = storage_client.bucket(bucket_name) # Load model data model_blob = bucket.blob(f"{prefix}/model_{step:06d}.pt") with io.BytesIO() as buffer: model_blob.download_to_file(buffer) buffer.seek(0) model_data = torch.load(buffer, map_location=device) # Load optimizer data optimizer_data = None if load_optimizer: optimizer_blob = bucket.blob(f"{prefix}/optim_{step:06d}.pt") with io.BytesIO() as buffer: optimizer_blob.download_to_file(buffer) buffer.seek(0) optimizer_data = torch.load(buffer, map_location=device) # Load metadata meta_blob = bucket.blob(f"{prefix}/meta_{step:06d}.json") meta_data = json.loads(meta_blob.download_as_string()) else: # Load the model state model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") model_data = torch.load(model_path, map_location=device) # Load the optimizer state if requested optimizer_data = None if load_optimizer: optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt") optimizer_data = torch.load(optimizer_path, map_location=device) # Load the metadata meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") with open(meta_path, "r") as f: meta_data = json.load(f) return model_data, optimizer_data, meta_data def build_model(checkpoint_dir, step, device, phase): """ A bunch of repetitive code to build a model from a given checkpoint. Returns: - base model - uncompiled, not wrapped in DDP - tokenizer - meta data saved during base model training """ assert phase in ["train", "eval"], f"Invalid phase: {phase}" model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) if device.type == "cpu": # Convert bfloat16 tensors to float for CPU inference model_data = { k: v.float() if v.dtype == torch.bfloat16 else v for k, v in model_data.items() } # Hack: fix torch compile issue, which prepends all keys with _orig_mod. model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()} model_config_kwargs = meta_data["model_config"] log0(f"Building model with config: {model_config_kwargs}") model_config = GPTConfig(**model_config_kwargs) with torch.device("meta"): model = GPT(model_config) # Load the model state model.to_empty(device=device) model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init model.load_state_dict(model_data, strict=True, assign=True) # Put the model in the right training phase / mode if phase == "eval": model.eval() else: model.train() # Load the Tokenizer tokenizer = get_tokenizer() # Sanity check: compatibility between model and tokenizer assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"] return model, tokenizer, meta_data def find_largest_model(checkpoint_dir): if checkpoint_dir.startswith("gs://"): storage_client = storage.Client() bucket_name, prefix = checkpoint_dir[5:].split("/", 1) bucket = storage_client.bucket(bucket_name) if not prefix.endswith("/"): prefix += "/" blobs = bucket.list_blobs(prefix=prefix, delimiter='/') list(blobs) # Iterate to populate prefixes log0(f"DEBUG: prefix={prefix}") log0(f"DEBUG: blobs.prefixes={list(blobs.prefixes)}") model_tags = [p.split('/')[-2] for p in blobs.prefixes] log0(f"DEBUG: model_tags={model_tags}") else: # attempt to guess the model tag: take the biggest model available model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))] if not model_tags: raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") # 1) normally all model tags are of the form d, try that first: candidates = [] for model_tag in model_tags: match = re.match(r"d(\d+)", model_tag) if match: model_depth = int(match.group(1)) candidates.append((model_depth, model_tag)) if candidates: candidates.sort(key=lambda x: x[0], reverse=True) return candidates[0][1] # 2) if that failed, take the most recently updated model: if not checkpoint_dir.startswith("gs://"): model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True) return model_tags[0] def find_last_step(checkpoint_dir): if checkpoint_dir.startswith("gs://"): storage_client = storage.Client() bucket_name, prefix = checkpoint_dir[5:].split("/", 1) bucket = storage_client.bucket(bucket_name) blobs = bucket.list_blobs(prefix=f"{prefix}/model_") checkpoint_files = [blob.name for blob in blobs] else: # Look into checkpoint_dir and find model_.pt with the highest step checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt")) if not checkpoint_files: raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)) return last_step # ----------------------------------------------------------------------------- # convenience functions that take into account nanochat's directory structure def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None): if model_tag is None: # guess the model tag by defaulting to the largest model model_tag = find_largest_model(checkpoints_dir) log0(f"No model tag provided, guessing model tag: {model_tag}") checkpoint_dir = os.path.join(checkpoints_dir, model_tag) if step is None: # guess the step by defaulting to the last step step = find_last_step(checkpoint_dir) assert step is not None, f"No checkpoints found in {checkpoint_dir}" # build the model log0(f"Loading model from {checkpoint_dir} with step {step}") model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase) return model, tokenizer, meta_data def load_model(source, *args, **kwargs): model_dir = { "base": "base_checkpoints", "mid": "mid_checkpoints", "sft": "chatsft_checkpoints", "rl": "chatrl_checkpoints", }[source] # Check if running in Vertex AI with GCS data directory data_dir = os.environ.get("NANOCHAT_DATA_DIR", "") if data_dir.startswith("gs://"): # Use GCS checkpoint directory checkpoints_dir = data_dir.replace("/base_data", f"/{model_dir}") else: # Use local checkpoint directory base_dir = get_base_dir() checkpoints_dir = os.path.join(base_dir, model_dir) return load_model_from_dir(checkpoints_dir, *args, **kwargs)