mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
CPU Support, as bfloat16 params breaks inference
This commit is contained in:
parent
dfc88334b6
commit
d54c9cbf8c
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Utilities for saving and loading model/optim/state checkpoints.
|
Utilities for saving and loading model/optim/state checkpoints.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import glob
|
import glob
|
||||||
|
|
@ -16,12 +17,15 @@ from nanochat.common import setup_default_logging
|
||||||
# Set up logging
|
# Set up logging
|
||||||
setup_default_logging()
|
setup_default_logging()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def log0(message):
|
def log0(message):
|
||||||
if int(os.environ.get('RANK', 0)) == 0:
|
if int(os.environ.get("RANK", 0)) == 0:
|
||||||
logger.info(message)
|
logger.info(message)
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
|
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
|
||||||
assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now
|
assert int(os.environ.get("RANK", 0)) == 0 # prevent footguns for now
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
# Save the model state (parameters)
|
# Save the model state (parameters)
|
||||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||||
|
|
@ -64,7 +68,15 @@ def build_model(checkpoint_dir, step, device, phase):
|
||||||
- meta data saved during base model training
|
- meta data saved during base model training
|
||||||
"""
|
"""
|
||||||
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
|
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
|
||||||
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
|
model_data, optimizer_data, meta_data = load_checkpoint(
|
||||||
|
checkpoint_dir, step, device, load_optimizer=False
|
||||||
|
)
|
||||||
|
if device.type == "cpu":
|
||||||
|
# Convert bfloat16 tensors to float for CPU inference
|
||||||
|
model_data = {
|
||||||
|
k: v.float() if v.dtype == torch.bfloat16 else v
|
||||||
|
for k, v in model_data.items()
|
||||||
|
}
|
||||||
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
||||||
model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()}
|
model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()}
|
||||||
model_config_kwargs = meta_data["model_config"]
|
model_config_kwargs = meta_data["model_config"]
|
||||||
|
|
@ -74,7 +86,7 @@ def build_model(checkpoint_dir, step, device, phase):
|
||||||
model = GPT(model_config)
|
model = GPT(model_config)
|
||||||
# Load the model state
|
# Load the model state
|
||||||
model.to_empty(device=device)
|
model.to_empty(device=device)
|
||||||
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
|
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
|
||||||
model.load_state_dict(model_data, strict=True, assign=True)
|
model.load_state_dict(model_data, strict=True, assign=True)
|
||||||
# Put the model in the right training phase / mode
|
# Put the model in the right training phase / mode
|
||||||
if phase == "eval":
|
if phase == "eval":
|
||||||
|
|
@ -90,7 +102,11 @@ def build_model(checkpoint_dir, step, device, phase):
|
||||||
|
|
||||||
def find_largest_model(checkpoint_dir):
|
def find_largest_model(checkpoint_dir):
|
||||||
# attempt to guess the model tag: take the biggest model available
|
# attempt to guess the model tag: take the biggest model available
|
||||||
model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
|
model_tags = [
|
||||||
|
f
|
||||||
|
for f in os.listdir(checkpoint_dir)
|
||||||
|
if os.path.isdir(os.path.join(checkpoint_dir, f))
|
||||||
|
]
|
||||||
if not model_tags:
|
if not model_tags:
|
||||||
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
||||||
# 1) normally all model tags are of the form d<number>, try that first:
|
# 1) normally all model tags are of the form d<number>, try that first:
|
||||||
|
|
@ -104,7 +120,9 @@ def find_largest_model(checkpoint_dir):
|
||||||
candidates.sort(key=lambda x: x[0], reverse=True)
|
candidates.sort(key=lambda x: x[0], reverse=True)
|
||||||
return candidates[0][1]
|
return candidates[0][1]
|
||||||
# 2) if that failed, take the most recently updated model:
|
# 2) if that failed, take the most recently updated model:
|
||||||
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
|
model_tags.sort(
|
||||||
|
key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True
|
||||||
|
)
|
||||||
return model_tags[0]
|
return model_tags[0]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -113,12 +131,16 @@ def find_last_step(checkpoint_dir):
|
||||||
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
|
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
|
||||||
if not checkpoint_files:
|
if not checkpoint_files:
|
||||||
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
||||||
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
|
last_step = int(
|
||||||
|
max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)
|
||||||
|
)
|
||||||
return last_step
|
return last_step
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# convenience functions that take into account nanochat's directory structure
|
# convenience functions that take into account nanochat's directory structure
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
|
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
|
||||||
if model_tag is None:
|
if model_tag is None:
|
||||||
# guess the model tag by defaulting to the largest model
|
# guess the model tag by defaulting to the largest model
|
||||||
|
|
@ -134,6 +156,7 @@ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=Non
|
||||||
model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
|
model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
|
||||||
return model, tokenizer, meta_data
|
return model, tokenizer, meta_data
|
||||||
|
|
||||||
|
|
||||||
def load_model(source, *args, **kwargs):
|
def load_model(source, *args, **kwargs):
|
||||||
model_dir = {
|
model_dir = {
|
||||||
"base": "base_checkpoints",
|
"base": "base_checkpoints",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user