From b0778933ee152bcfd9024fa6d19799669a14f4dc Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Thu, 5 Mar 2026 11:50:17 +0200 Subject: [PATCH] Add upload to hf hub --- nanochat/checkpoint_manager.py | 64 +++++++++++++++++++++++++++++----- pyproject.toml | 2 ++ scripts/base_train.py | 8 ++++- scripts/chat_sft.py | 8 ++++- 4 files changed, 72 insertions(+), 10 deletions(-) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index f71524e..876cffe 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -7,6 +7,7 @@ import glob import json import logging import torch +import safetensors.torch from nanochat.common import get_base_dir from nanochat.gpt import GPT, GPTConfig @@ -42,9 +43,9 @@ def _patch_missing_keys(model_data, model_config): def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0): if rank == 0: os.makedirs(checkpoint_dir, exist_ok=True) - # Save the model state parameters - model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") - torch.save(model_data, model_path) + # Save the model state parameters in safetensors format + model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.safetensors") + safetensors.torch.save_file(model_data, model_path) logger.info(f"Saved model parameters to: {model_path}") # Save the metadata dict as json meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") @@ -59,9 +60,9 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, logger.info(f"Saved optimizer state to: {optimizer_path}") def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0): - # Load the model state - model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") - model_data = torch.load(model_path, map_location=device) + # Load the model state in safetensors format + model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.safetensors") + model_data = safetensors.torch.load_file(model_path, device=str(device)) # Load the optimizer state if requested optimizer_data = None if load_optimizer: @@ -136,8 +137,8 @@ def find_largest_model(checkpoints_dir): def find_last_step(checkpoint_dir): - # Look into checkpoint_dir and find model_.pt with the highest step - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt")) + # Look into checkpoint_dir and find model_.safetensors or model_.pt with the highest step + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.safetensors")) if not checkpoint_files: raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)) @@ -192,3 +193,50 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None): log0(f"Loading optimizer state from {optimizer_path}") optimizer_data = torch.load(optimizer_path, map_location=device) return optimizer_data + +def upload_to_hf(checkpoint_dir, step, repo_id, token=None, with_optimizer=False): + """Upload the model, metadata, and tokenizer to Hugging Face Hub.""" + try: + from huggingface_hub import HfApi + except ImportError: + log0("Error: huggingface_hub not installed. Run 'pip install huggingface_hub'") + return + + log0(f"Uploading model to Hugging Face Hub: {repo_id}") + api = HfApi(token=token) + api.create_repo(repo_id=repo_id, exist_ok=True) + + # 1. Upload model and meta using standard names + model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.safetensors") + meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") + + if os.path.exists(model_path): + api.upload_file(path_or_fileobj=model_path, path_in_repo="model.safetensors", repo_id=repo_id) + else: + log0(f"Warning: model file not found at {model_path}") + + if os.path.exists(meta_path): + api.upload_file(path_or_fileobj=meta_path, path_in_repo="config.json", repo_id=repo_id) + else: + log0(f"Warning: metadata file not found at {meta_path}") + + # 2. Upload optimizer state if requested + if with_optimizer: + log0("Uploading optimizer state...") + # Find all optimizer shards for this step + optim_shards = glob.glob(os.path.join(checkpoint_dir, f"optim_{step:06d}_rank*.pt")) + for shard in optim_shards: + filename = os.path.basename(shard) + api.upload_file(path_or_fileobj=shard, path_in_repo=f"optimizer/{filename}", repo_id=repo_id) + + # 3. Upload tokenizer files + base_dir = get_base_dir() + tokenizer_dir = os.path.join(base_dir, "tokenizer") + if os.path.exists(tokenizer_dir): + for f in os.listdir(tokenizer_dir): + full_path = os.path.join(tokenizer_dir, f) + if os.path.isfile(full_path): + # Upload tokenizer files to the root for better compatibility + api.upload_file(path_or_fileobj=full_path, path_in_repo=f, repo_id=repo_id) + + log0(f"Successfully uploaded model to https://huggingface.co/{repo_id}") diff --git a/pyproject.toml b/pyproject.toml index 8b6fd95..eca44e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ requires-python = ">=3.10" dependencies = [ "datasets>=4.0.0", "fastapi>=0.117.1", + "huggingface-hub>=1.5.0", "ipykernel>=7.1.0", "kernels>=0.11.7", "matplotlib>=3.10.8", @@ -14,6 +15,7 @@ dependencies = [ "python-dotenv>=1.2.1", "regex>=2025.9.1", "rustbpe>=0.1.0", + "safetensors>=0.7.0", "scipy>=1.15.3", "setuptools>=80.9.0", "tabulate>=0.9.0", diff --git a/scripts/base_train.py b/scripts/base_train.py index 4bf7959..5c24c0b 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -29,7 +29,7 @@ from nanochat.gpt import GPT, GPTConfig, Linear from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized from nanochat.tokenizer import get_tokenizer, get_token_bytes -from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint +from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint, upload_to_hf from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine from nanochat.flash_attention import HAS_FA3 @@ -79,6 +79,8 @@ parser.add_argument("--sample-every", type=int, default=2000, help="sample from parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") # Output parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") +parser.add_argument("--hf-repo", type=str, default=None, help="Hugging Face repo ID to upload final model") +parser.add_argument("--hf-upload-optim", action="store_true", help="upload optimizer state to Hugging Face") args = parser.parse_args() user_config = vars(args).copy() # for logging # ----------------------------------------------------------------------------- @@ -620,6 +622,10 @@ get_report().log(section="Base model training", data=[ } ]) +# upload to huggingface +if args.hf_repo and master_process: + upload_to_hf(checkpoint_dir, step, args.hf_repo, with_optimizer=args.hf_upload_optim) + # cleanup wandb_run.finish() # wandb run finish compute_cleanup() diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index c1adbb6..813d454 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -18,7 +18,7 @@ import wandb import torch from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized from nanochat.tokenizer import get_token_bytes -from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state +from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state, upload_to_hf from nanochat.loss_eval import evaluate_bpb import torch.distributed as dist from nanochat.flash_attention import HAS_FA3 @@ -66,6 +66,8 @@ parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max pro # Data mixture parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)") parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)") +parser.add_argument("--hf-repo", type=str, default=None, help="Hugging Face repo ID to upload final model") +parser.add_argument("--hf-upload-optim", action="store_true", help="upload optimizer state to Hugging Face") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- @@ -514,6 +516,10 @@ get_report().log(section="SFT", data=[ } ]) +# upload to huggingface +if args.hf_repo and master_process: + upload_to_hf(checkpoint_dir, step, args.hf_repo, with_optimizer=args.hf_upload_optim) + # cleanup wandb_run.finish() # wandb run finish compute_cleanup()