diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index 63f257f..671bbac 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -34,6 +34,7 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, logger.info(f"Saved metadata to: {meta_path}") # Note that optimizer state is sharded across ranks, so each rank must save its own. if optimizer_data is not None: + os.makedirs(checkpoint_dir, exist_ok=True) optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt") torch.save(optimizer_data, optimizer_path) logger.info(f"Saved optimizer state to: {optimizer_path}")