revert formatting changes to facilitate review

This commit is contained in:
svlandeg 2025-11-02 14:16:43 +01:00
parent d54c9cbf8c
commit 036a3c5881

View File

@ -1,7 +1,6 @@
"""
Utilities for saving and loading model/optim/state checkpoints.
"""
import os
import re
import glob
@ -17,15 +16,13 @@ from nanochat.common import setup_default_logging
# Set up logging
setup_default_logging()
logger = logging.getLogger(__name__)
def log0(message):
if int(os.environ.get("RANK", 0)) == 0:
if int(os.environ.get('RANK', 0)) == 0:
logger.info(message)
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
assert int(os.environ.get("RANK", 0)) == 0 # prevent footguns for now
assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now
os.makedirs(checkpoint_dir, exist_ok=True)
# Save the model state (parameters)
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
@ -68,9 +65,7 @@ def build_model(checkpoint_dir, step, device, phase):
- meta data saved during base model training
"""
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
model_data, optimizer_data, meta_data = load_checkpoint(
checkpoint_dir, step, device, load_optimizer=False
)
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
if device.type == "cpu":
# Convert bfloat16 tensors to float for CPU inference
model_data = {
@ -86,7 +81,7 @@ def build_model(checkpoint_dir, step, device, phase):
model = GPT(model_config)
# Load the model state
model.to_empty(device=device)
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
model.load_state_dict(model_data, strict=True, assign=True)
# Put the model in the right training phase / mode
if phase == "eval":
@ -102,11 +97,7 @@ def build_model(checkpoint_dir, step, device, phase):
def find_largest_model(checkpoint_dir):
# attempt to guess the model tag: take the biggest model available
model_tags = [
f
for f in os.listdir(checkpoint_dir)
if os.path.isdir(os.path.join(checkpoint_dir, f))
]
model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
if not model_tags:
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
# 1) normally all model tags are of the form d<number>, try that first:
@ -120,9 +111,7 @@ def find_largest_model(checkpoint_dir):
candidates.sort(key=lambda x: x[0], reverse=True)
return candidates[0][1]
# 2) if that failed, take the most recently updated model:
model_tags.sort(
key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True
)
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
return model_tags[0]
@ -131,16 +120,12 @@ def find_last_step(checkpoint_dir):
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
if not checkpoint_files:
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
last_step = int(
max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)
)
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
return last_step
# -----------------------------------------------------------------------------
# convenience functions that take into account nanochat's directory structure
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
if model_tag is None:
# guess the model tag by defaulting to the largest model
@ -156,7 +141,6 @@ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=Non
model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
return model, tokenizer, meta_data
def load_model(source, *args, **kwargs):
model_dir = {
"base": "base_checkpoints",