From d54c9cbf8c3ec7e4436bef404d605700f661f12c Mon Sep 17 00:00:00 2001 From: Manuel Saelices Date: Sat, 1 Nov 2025 23:38:50 +0100 Subject: [PATCH 1/3] CPU Support, as bfloat16 params breaks inference --- nanochat/checkpoint_manager.py | 37 +++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index f400d47..26fdb0d 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -1,6 +1,7 @@ """ Utilities for saving and loading model/optim/state checkpoints. """ + import os import re import glob @@ -16,12 +17,15 @@ from nanochat.common import setup_default_logging # Set up logging setup_default_logging() logger = logging.getLogger(__name__) + + def log0(message): - if int(os.environ.get('RANK', 0)) == 0: + if int(os.environ.get("RANK", 0)) == 0: logger.info(message) + def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data): - assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now + assert int(os.environ.get("RANK", 0)) == 0 # prevent footguns for now os.makedirs(checkpoint_dir, exist_ok=True) # Save the model state (parameters) model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") @@ -64,7 +68,15 @@ def build_model(checkpoint_dir, step, device, phase): - meta data saved during base model training """ assert phase in ["train", "eval"], f"Invalid phase: {phase}" - model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) + model_data, optimizer_data, meta_data = load_checkpoint( + checkpoint_dir, step, device, load_optimizer=False + ) + if device.type == "cpu": + # Convert bfloat16 tensors to float for CPU inference + model_data = { + k: v.float() if v.dtype == torch.bfloat16 else v + for k, v in model_data.items() + } # Hack: fix torch compile issue, which prepends all keys with _orig_mod. model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()} model_config_kwargs = meta_data["model_config"] @@ -74,7 +86,7 @@ def build_model(checkpoint_dir, step, device, phase): model = GPT(model_config) # Load the model state model.to_empty(device=device) - model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init + model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init model.load_state_dict(model_data, strict=True, assign=True) # Put the model in the right training phase / mode if phase == "eval": @@ -90,7 +102,11 @@ def build_model(checkpoint_dir, step, device, phase): def find_largest_model(checkpoint_dir): # attempt to guess the model tag: take the biggest model available - model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))] + model_tags = [ + f + for f in os.listdir(checkpoint_dir) + if os.path.isdir(os.path.join(checkpoint_dir, f)) + ] if not model_tags: raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") # 1) normally all model tags are of the form d, try that first: @@ -104,7 +120,9 @@ def find_largest_model(checkpoint_dir): candidates.sort(key=lambda x: x[0], reverse=True) return candidates[0][1] # 2) if that failed, take the most recently updated model: - model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True) + model_tags.sort( + key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True + ) return model_tags[0] @@ -113,12 +131,16 @@ def find_last_step(checkpoint_dir): checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt")) if not checkpoint_files: raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") - last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)) + last_step = int( + max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files) + ) return last_step + # ----------------------------------------------------------------------------- # convenience functions that take into account nanochat's directory structure + def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None): if model_tag is None: # guess the model tag by defaulting to the largest model @@ -134,6 +156,7 @@ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=Non model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase) return model, tokenizer, meta_data + def load_model(source, *args, **kwargs): model_dir = { "base": "base_checkpoints", From 036a3c5881c7e6430d5565bf8f1224fef54cdc82 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Sun, 2 Nov 2025 14:16:43 +0100 Subject: [PATCH 2/3] revert formatting changes to facilitate review --- nanochat/checkpoint_manager.py | 30 +++++++----------------------- 1 file changed, 7 insertions(+), 23 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index 26fdb0d..a1120cb 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -1,7 +1,6 @@ """ Utilities for saving and loading model/optim/state checkpoints. """ - import os import re import glob @@ -17,15 +16,13 @@ from nanochat.common import setup_default_logging # Set up logging setup_default_logging() logger = logging.getLogger(__name__) - - def log0(message): - if int(os.environ.get("RANK", 0)) == 0: + if int(os.environ.get('RANK', 0)) == 0: logger.info(message) def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data): - assert int(os.environ.get("RANK", 0)) == 0 # prevent footguns for now + assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now os.makedirs(checkpoint_dir, exist_ok=True) # Save the model state (parameters) model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") @@ -68,9 +65,7 @@ def build_model(checkpoint_dir, step, device, phase): - meta data saved during base model training """ assert phase in ["train", "eval"], f"Invalid phase: {phase}" - model_data, optimizer_data, meta_data = load_checkpoint( - checkpoint_dir, step, device, load_optimizer=False - ) + model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) if device.type == "cpu": # Convert bfloat16 tensors to float for CPU inference model_data = { @@ -86,7 +81,7 @@ def build_model(checkpoint_dir, step, device, phase): model = GPT(model_config) # Load the model state model.to_empty(device=device) - model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init + model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init model.load_state_dict(model_data, strict=True, assign=True) # Put the model in the right training phase / mode if phase == "eval": @@ -102,11 +97,7 @@ def build_model(checkpoint_dir, step, device, phase): def find_largest_model(checkpoint_dir): # attempt to guess the model tag: take the biggest model available - model_tags = [ - f - for f in os.listdir(checkpoint_dir) - if os.path.isdir(os.path.join(checkpoint_dir, f)) - ] + model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))] if not model_tags: raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") # 1) normally all model tags are of the form d, try that first: @@ -120,9 +111,7 @@ def find_largest_model(checkpoint_dir): candidates.sort(key=lambda x: x[0], reverse=True) return candidates[0][1] # 2) if that failed, take the most recently updated model: - model_tags.sort( - key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True - ) + model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True) return model_tags[0] @@ -131,16 +120,12 @@ def find_last_step(checkpoint_dir): checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt")) if not checkpoint_files: raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") - last_step = int( - max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files) - ) + last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)) return last_step - # ----------------------------------------------------------------------------- # convenience functions that take into account nanochat's directory structure - def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None): if model_tag is None: # guess the model tag by defaulting to the largest model @@ -156,7 +141,6 @@ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=Non model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase) return model, tokenizer, meta_data - def load_model(source, *args, **kwargs): model_dir = { "base": "base_checkpoints", From 5bfcd31b7311036a647b0677d2638046ef05f252 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Sun, 2 Nov 2025 14:17:10 +0100 Subject: [PATCH 3/3] revert more formatting changes --- nanochat/checkpoint_manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index a1120cb..262ff97 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -20,9 +20,8 @@ def log0(message): if int(os.environ.get('RANK', 0)) == 0: logger.info(message) - def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data): - assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now + assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now os.makedirs(checkpoint_dir, exist_ok=True) # Save the model state (parameters) model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")