mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-04 15:00:28 +00:00
Resolves all crashes and silent errors when running on Apple Silicon (MPS) or CPU after merging upstream/master. Tested on torch 2.2.2 and 2.9.1. nanochat/engine.py - Replace signal.SIGALRM timeout with concurrent.futures.ThreadPoolExecutor so use_calculator() works from FastAPI worker threads (SIGALRM is Unix main-thread only, silently broken in any threaded web server) - Guard torch.cuda.synchronize() behind device_type == 'cuda' nanochat/gpt.py - Extract init_rotary_embeddings() from init_weights() so checkpoint loading can restore non-persistent cos/sin buffers without re-randomising all weights - Cast rotary cos/sin to bfloat16 on CUDA only (MPS bfloat16 requires torch>=2.4; float32 used on MPS/CPU) - Update forward() dtype assertion to match device - Add F.rms_norm fallback for torch<2.4 (rms_norm added in 2.4) nanochat/optim.py - _cuda_compile(): skip torch.compile(fullgraph=True) on MPS/CPU; return function unchanged so eager execution is used - adamw_step_fused / muon_step_fused: move 0-D CPU scalar tensors to parameter device at start of function (cross-device ops crash in eager mode on MPS) - muon_step_fused: use bfloat16 in polar express on CUDA only; fall back to float32 on MPS/CPU - _step_muon: replace torch._foreach_copy_() with p.copy_(s) loop on non-CUDA (_foreach_copy_ not implemented on MPS in torch<2.4) nanochat/flash_attention.py - Probe SDPA for enable_gqa support at import time (added in torch 2.5; inspect.signature raises on C builtins in older Python/torch) - Fall back to manual KV head repetition via repeat_interleave when enable_gqa is unavailable nanochat/checkpoint_manager.py - Call model.init_rotary_embeddings() instead of model.init_weights() after load_state_dict() to restore non-persistent rotary buffers without clobbering loaded weights scripts/base_train.py - Guard torch.compile(model) behind device_type == 'cuda' - Set mfu = None on non-CUDA instead of computing 0/inf = 0.00% - Handle mfu is None in end-of-run report tests/test_mps_compat.py (new) - 16 tests covering every fix; all pass on MPS (torch 2.2.2 and 2.9.1)
197 lines
8.6 KiB
Python
197 lines
8.6 KiB
Python
"""
|
|
Utilities for saving and loading model/optim/state checkpoints.
|
|
"""
|
|
import os
|
|
import re
|
|
import glob
|
|
import json
|
|
import logging
|
|
import torch
|
|
|
|
from nanochat.common import get_base_dir
|
|
from nanochat.gpt import GPT, GPTConfig
|
|
from nanochat.tokenizer import get_tokenizer
|
|
from nanochat.common import setup_default_logging
|
|
|
|
# Set up logging
|
|
setup_default_logging()
|
|
logger = logging.getLogger(__name__)
|
|
def log0(message):
|
|
if int(os.environ.get('RANK', 0)) == 0:
|
|
logger.info(message)
|
|
|
|
def _patch_missing_config_keys(model_config_kwargs):
|
|
"""Add default values for new config keys missing in old checkpoints."""
|
|
# Old models were trained with full context (no sliding window)
|
|
if "window_pattern" not in model_config_kwargs:
|
|
model_config_kwargs["window_pattern"] = "L"
|
|
log0(f"Patching missing window_pattern in model config to 'L'")
|
|
|
|
def _patch_missing_keys(model_data, model_config):
|
|
"""Add default values for new parameters that may be missing in old checkpoints."""
|
|
n_layer = model_config.n_layer
|
|
# resid_lambdas defaults to 1.0 (identity scaling)
|
|
if "resid_lambdas" not in model_data:
|
|
model_data["resid_lambdas"] = torch.ones(n_layer)
|
|
log0(f"Patching missing resid_lambdas in model data to 1.0")
|
|
# x0_lambdas defaults to 0.0 (disabled)
|
|
if "x0_lambdas" not in model_data:
|
|
model_data["x0_lambdas"] = torch.zeros(n_layer)
|
|
log0(f"Patching missing x0_lambdas in model data to 0.0")
|
|
|
|
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
|
if rank == 0:
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
# Save the model state parameters
|
|
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
|
torch.save(model_data, model_path)
|
|
logger.info(f"Saved model parameters to: {model_path}")
|
|
# Save the metadata dict as json
|
|
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
|
with open(meta_path, "w", encoding="utf-8") as f:
|
|
json.dump(meta_data, f, indent=2)
|
|
logger.info(f"Saved metadata to: {meta_path}")
|
|
# Note that optimizer state is sharded across ranks, so each rank must save its own.
|
|
if optimizer_data is not None:
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
|
torch.save(optimizer_data, optimizer_path)
|
|
logger.info(f"Saved optimizer state to: {optimizer_path}")
|
|
|
|
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
|
|
# Load the model state
|
|
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
|
model_data = torch.load(model_path, map_location=device)
|
|
# Load the optimizer state if requested
|
|
optimizer_data = None
|
|
if load_optimizer:
|
|
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
|
optimizer_data = torch.load(optimizer_path, map_location=device)
|
|
# Load the metadata
|
|
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
|
with open(meta_path, "r", encoding="utf-8") as f:
|
|
meta_data = json.load(f)
|
|
return model_data, optimizer_data, meta_data
|
|
|
|
|
|
def build_model(checkpoint_dir, step, device, phase):
|
|
"""
|
|
A bunch of repetitive code to build a model from a given checkpoint.
|
|
Returns:
|
|
- base model - uncompiled, not wrapped in DDP
|
|
- tokenizer
|
|
- meta data saved during base model training
|
|
"""
|
|
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
|
|
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
|
|
if device.type in {"cpu", "mps"}:
|
|
# Convert bfloat16 tensors to float for CPU inference
|
|
model_data = {
|
|
k: v.float() if v.dtype == torch.bfloat16 else v
|
|
for k, v in model_data.items()
|
|
}
|
|
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
|
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
|
model_config_kwargs = meta_data["model_config"]
|
|
_patch_missing_config_keys(model_config_kwargs)
|
|
log0(f"Building model with config: {model_config_kwargs}")
|
|
model_config = GPTConfig(**model_config_kwargs)
|
|
_patch_missing_keys(model_data, model_config)
|
|
with torch.device("meta"):
|
|
model = GPT(model_config)
|
|
# Load the model state
|
|
# Only init the rotary embedding buffers (persistent=False, absent from checkpoint),
|
|
# not the full weights — everything else comes from load_state_dict.
|
|
model.to_empty(device=device)
|
|
model.init_rotary_embeddings()
|
|
model.load_state_dict(model_data, strict=True, assign=True)
|
|
# Put the model in the right training phase / mode
|
|
if phase == "eval":
|
|
model.eval()
|
|
else:
|
|
model.train()
|
|
# Load the Tokenizer
|
|
tokenizer = get_tokenizer()
|
|
# Sanity check: compatibility between model and tokenizer
|
|
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"], f"Tokenizer vocab size {tokenizer.get_vocab_size()} does not match model config vocab size {model_config_kwargs['vocab_size']}"
|
|
return model, tokenizer, meta_data
|
|
|
|
|
|
def find_largest_model(checkpoints_dir):
|
|
# attempt to guess the model tag: take the biggest model available
|
|
model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))]
|
|
if not model_tags:
|
|
raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}")
|
|
# 1) normally all model tags are of the form d<number>, try that first:
|
|
candidates = []
|
|
for model_tag in model_tags:
|
|
match = re.match(r"d(\d+)", model_tag)
|
|
if match:
|
|
model_depth = int(match.group(1))
|
|
candidates.append((model_depth, model_tag))
|
|
if candidates:
|
|
candidates.sort(key=lambda x: x[0], reverse=True)
|
|
return candidates[0][1]
|
|
# 2) if that failed, take the most recently updated model:
|
|
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True)
|
|
return model_tags[0]
|
|
|
|
|
|
def find_last_step(checkpoint_dir):
|
|
# Look into checkpoint_dir and find model_<step>.pt with the highest step
|
|
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
|
|
if not checkpoint_files:
|
|
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
|
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
|
|
return last_step
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# convenience functions that take into account nanochat's directory structure
|
|
|
|
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
|
|
if model_tag is None:
|
|
# guess the model tag by defaulting to the largest model
|
|
model_tag = find_largest_model(checkpoints_dir)
|
|
log0(f"No model tag provided, guessing model tag: {model_tag}")
|
|
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
|
|
if step is None:
|
|
# guess the step by defaulting to the last step
|
|
step = find_last_step(checkpoint_dir)
|
|
assert step is not None, f"No checkpoints found in {checkpoint_dir}"
|
|
# build the model
|
|
log0(f"Loading model from {checkpoint_dir} with step {step}")
|
|
model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
|
|
return model, tokenizer, meta_data
|
|
|
|
def load_model(source, *args, **kwargs):
|
|
model_dir = {
|
|
"base": "base_checkpoints",
|
|
"sft": "chatsft_checkpoints",
|
|
"rl": "chatrl_checkpoints",
|
|
}[source]
|
|
base_dir = get_base_dir()
|
|
checkpoints_dir = os.path.join(base_dir, model_dir)
|
|
return load_model_from_dir(checkpoints_dir, *args, **kwargs)
|
|
|
|
def load_optimizer_state(source, device, rank, model_tag=None, step=None):
|
|
"""Load just the optimizer shard for a given rank, without re-loading the model."""
|
|
model_dir = {
|
|
"base": "base_checkpoints",
|
|
"sft": "chatsft_checkpoints",
|
|
"rl": "chatrl_checkpoints",
|
|
}[source]
|
|
base_dir = get_base_dir()
|
|
checkpoints_dir = os.path.join(base_dir, model_dir)
|
|
if model_tag is None:
|
|
model_tag = find_largest_model(checkpoints_dir)
|
|
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
|
|
if step is None:
|
|
step = find_last_step(checkpoint_dir)
|
|
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
|
if not os.path.exists(optimizer_path):
|
|
log0(f"Optimizer checkpoint not found: {optimizer_path}")
|
|
return None
|
|
log0(f"Loading optimizer state from {optimizer_path}")
|
|
optimizer_data = torch.load(optimizer_path, map_location=device)
|
|
return optimizer_data
|