mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
239 lines
9.9 KiB
Python
239 lines
9.9 KiB
Python
"""
|
|
Utilities for saving and loading model/optim/state checkpoints.
|
|
"""
|
|
import os
|
|
import re
|
|
import glob
|
|
import json
|
|
import logging
|
|
import torch
|
|
import io
|
|
from google.cloud import storage
|
|
|
|
from nanochat.common import get_base_dir
|
|
from nanochat.gpt import GPT, GPTConfig
|
|
from nanochat.tokenizer import get_tokenizer
|
|
from nanochat.common import setup_default_logging
|
|
|
|
# Set up logging
|
|
setup_default_logging()
|
|
logger = logging.getLogger(__name__)
|
|
def log0(message):
|
|
if int(os.environ.get('RANK', 0)) == 0:
|
|
logger.info(message)
|
|
|
|
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
|
|
assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now
|
|
if checkpoint_dir.startswith("gs://"):
|
|
storage_client = storage.Client()
|
|
bucket_name, prefix = checkpoint_dir[5:].split("/", 1)
|
|
bucket = storage_client.bucket(bucket_name)
|
|
|
|
# Save model data
|
|
model_blob = bucket.blob(f"{prefix}/model_{step:06d}.pt")
|
|
with io.BytesIO() as buffer:
|
|
torch.save(model_data, buffer)
|
|
buffer.seek(0)
|
|
model_blob.upload_from_file(buffer)
|
|
log0(f"Saved model file to: gs://{bucket_name}/{prefix}/model_{step:06d}.pt")
|
|
|
|
# Save optimizer data
|
|
if optimizer_data is not None:
|
|
optimizer_blob = bucket.blob(f"{prefix}/optim_{step:06d}.pt")
|
|
with io.BytesIO() as buffer:
|
|
torch.save(optimizer_data, buffer)
|
|
buffer.seek(0)
|
|
optimizer_blob.upload_from_file(buffer)
|
|
log0(f"Saved optimizer file to: gs://{bucket_name}/{prefix}/optim_{step:06d}.pt")
|
|
|
|
# Save metadata
|
|
meta_blob = bucket.blob(f"{prefix}/meta_{step:06d}.json")
|
|
meta_blob.upload_from_string(json.dumps(meta_data, indent=2))
|
|
log0(f"Saved metadata file to: gs://{bucket_name}/{prefix}/meta_{step:06d}.json")
|
|
else:
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
# Save the model state (parameters)
|
|
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
|
torch.save(model_data, model_path)
|
|
log0(f"Saved model file to: {model_path}")
|
|
# Save the optimizer state (useful for SFT or any other fine-tuning)
|
|
if optimizer_data is not None:
|
|
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
|
torch.save(optimizer_data, optimizer_path)
|
|
log0(f"Saved optimizer file to: {optimizer_path}")
|
|
# Save the metadata dict as json
|
|
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
|
with open(meta_path, "w") as f:
|
|
json.dump(meta_data, f, indent=2)
|
|
log0(f"Saved metadata file to: {meta_path}")
|
|
|
|
|
|
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
|
|
if checkpoint_dir.startswith("gs://"):
|
|
storage_client = storage.Client()
|
|
bucket_name, prefix = checkpoint_dir[5:].split("/", 1)
|
|
bucket = storage_client.bucket(bucket_name)
|
|
|
|
# Load model data
|
|
model_blob = bucket.blob(f"{prefix}/model_{step:06d}.pt")
|
|
with io.BytesIO() as buffer:
|
|
model_blob.download_to_file(buffer)
|
|
buffer.seek(0)
|
|
model_data = torch.load(buffer, map_location=device)
|
|
|
|
# Load optimizer data
|
|
optimizer_data = None
|
|
if load_optimizer:
|
|
optimizer_blob = bucket.blob(f"{prefix}/optim_{step:06d}.pt")
|
|
with io.BytesIO() as buffer:
|
|
optimizer_blob.download_to_file(buffer)
|
|
buffer.seek(0)
|
|
optimizer_data = torch.load(buffer, map_location=device)
|
|
|
|
# Load metadata
|
|
meta_blob = bucket.blob(f"{prefix}/meta_{step:06d}.json")
|
|
meta_data = json.loads(meta_blob.download_as_string())
|
|
else:
|
|
# Load the model state
|
|
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
|
model_data = torch.load(model_path, map_location=device)
|
|
# Load the optimizer state if requested
|
|
optimizer_data = None
|
|
if load_optimizer:
|
|
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
|
optimizer_data = torch.load(optimizer_path, map_location=device)
|
|
# Load the metadata
|
|
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
|
with open(meta_path, "r") as f:
|
|
meta_data = json.load(f)
|
|
return model_data, optimizer_data, meta_data
|
|
|
|
|
|
def build_model(checkpoint_dir, step, device, phase):
|
|
"""
|
|
A bunch of repetitive code to build a model from a given checkpoint.
|
|
Returns:
|
|
- base model - uncompiled, not wrapped in DDP
|
|
- tokenizer
|
|
- meta data saved during base model training
|
|
"""
|
|
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
|
|
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
|
|
if device.type == "cpu":
|
|
# Convert bfloat16 tensors to float for CPU inference
|
|
model_data = {
|
|
k: v.float() if v.dtype == torch.bfloat16 else v
|
|
for k, v in model_data.items()
|
|
}
|
|
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
|
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
|
model_config_kwargs = meta_data["model_config"]
|
|
log0(f"Building model with config: {model_config_kwargs}")
|
|
model_config = GPTConfig(**model_config_kwargs)
|
|
with torch.device("meta"):
|
|
model = GPT(model_config)
|
|
# Load the model state
|
|
model.to_empty(device=device)
|
|
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
|
|
model.load_state_dict(model_data, strict=True, assign=True)
|
|
# Put the model in the right training phase / mode
|
|
if phase == "eval":
|
|
model.eval()
|
|
else:
|
|
model.train()
|
|
# Load the Tokenizer
|
|
tokenizer = get_tokenizer()
|
|
# Sanity check: compatibility between model and tokenizer
|
|
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"]
|
|
return model, tokenizer, meta_data
|
|
|
|
|
|
def find_largest_model(checkpoint_dir):
|
|
if checkpoint_dir.startswith("gs://"):
|
|
storage_client = storage.Client()
|
|
bucket_name, prefix = checkpoint_dir[5:].split("/", 1)
|
|
bucket = storage_client.bucket(bucket_name)
|
|
if not prefix.endswith("/"):
|
|
prefix += "/"
|
|
blobs = bucket.list_blobs(prefix=prefix, delimiter='/')
|
|
list(blobs) # Iterate to populate prefixes
|
|
log0(f"DEBUG: prefix={prefix}")
|
|
log0(f"DEBUG: blobs.prefixes={list(blobs.prefixes)}")
|
|
model_tags = [p.split('/')[-2] for p in blobs.prefixes]
|
|
log0(f"DEBUG: model_tags={model_tags}")
|
|
else:
|
|
# attempt to guess the model tag: take the biggest model available
|
|
model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
|
|
|
|
if not model_tags:
|
|
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
|
# 1) normally all model tags are of the form d<number>, try that first:
|
|
candidates = []
|
|
for model_tag in model_tags:
|
|
match = re.match(r"d(\d+)", model_tag)
|
|
if match:
|
|
model_depth = int(match.group(1))
|
|
candidates.append((model_depth, model_tag))
|
|
if candidates:
|
|
candidates.sort(key=lambda x: x[0], reverse=True)
|
|
return candidates[0][1]
|
|
# 2) if that failed, take the most recently updated model:
|
|
if not checkpoint_dir.startswith("gs://"):
|
|
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
|
|
return model_tags[0]
|
|
|
|
|
|
def find_last_step(checkpoint_dir):
|
|
if checkpoint_dir.startswith("gs://"):
|
|
storage_client = storage.Client()
|
|
bucket_name, prefix = checkpoint_dir[5:].split("/", 1)
|
|
bucket = storage_client.bucket(bucket_name)
|
|
blobs = bucket.list_blobs(prefix=f"{prefix}/model_")
|
|
checkpoint_files = [blob.name for blob in blobs]
|
|
else:
|
|
# Look into checkpoint_dir and find model_<step>.pt with the highest step
|
|
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
|
|
|
|
if not checkpoint_files:
|
|
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
|
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
|
|
return last_step
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# convenience functions that take into account nanochat's directory structure
|
|
|
|
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
|
|
if model_tag is None:
|
|
# guess the model tag by defaulting to the largest model
|
|
model_tag = find_largest_model(checkpoints_dir)
|
|
log0(f"No model tag provided, guessing model tag: {model_tag}")
|
|
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
|
|
if step is None:
|
|
# guess the step by defaulting to the last step
|
|
step = find_last_step(checkpoint_dir)
|
|
assert step is not None, f"No checkpoints found in {checkpoint_dir}"
|
|
# build the model
|
|
log0(f"Loading model from {checkpoint_dir} with step {step}")
|
|
model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
|
|
return model, tokenizer, meta_data
|
|
|
|
def load_model(source, *args, **kwargs):
|
|
model_dir = {
|
|
"base": "base_checkpoints",
|
|
"mid": "mid_checkpoints",
|
|
"sft": "chatsft_checkpoints",
|
|
"rl": "chatrl_checkpoints",
|
|
}[source]
|
|
|
|
# Check if running in Vertex AI with GCS data directory
|
|
data_dir = os.environ.get("NANOCHAT_DATA_DIR", "")
|
|
if data_dir.startswith("gs://"):
|
|
# Use GCS checkpoint directory
|
|
checkpoints_dir = data_dir.replace("/base_data", f"/{model_dir}")
|
|
else:
|
|
# Use local checkpoint directory
|
|
base_dir = get_base_dir()
|
|
checkpoints_dir = os.path.join(base_dir, model_dir)
|
|
|
|
return load_model_from_dir(checkpoints_dir, *args, **kwargs)
|