This commit is contained in:
Evgenii Zheltonozhskii 2026-03-18 11:17:04 +01:00 committed by GitHub
commit 7bc3284b46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 72 additions and 10 deletions

View File

@ -7,6 +7,7 @@ import glob
import json
import logging
import torch
import safetensors.torch
from nanochat.common import get_base_dir
from nanochat.gpt import GPT, GPTConfig
@ -42,9 +43,9 @@ def _patch_missing_keys(model_data, model_config):
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
if rank == 0:
os.makedirs(checkpoint_dir, exist_ok=True)
# Save the model state parameters
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
torch.save(model_data, model_path)
# Save the model state parameters in safetensors format
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.safetensors")
safetensors.torch.save_file(model_data, model_path)
logger.info(f"Saved model parameters to: {model_path}")
# Save the metadata dict as json
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
@ -59,9 +60,9 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data,
logger.info(f"Saved optimizer state to: {optimizer_path}")
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
# Load the model state
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
model_data = torch.load(model_path, map_location=device)
# Load the model state in safetensors format
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.safetensors")
model_data = safetensors.torch.load_file(model_path, device=str(device))
# Load the optimizer state if requested
optimizer_data = None
if load_optimizer:
@ -136,8 +137,8 @@ def find_largest_model(checkpoints_dir):
def find_last_step(checkpoint_dir):
# Look into checkpoint_dir and find model_<step>.pt with the highest step
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
# Look into checkpoint_dir and find model_<step>.safetensors or model_<step>.pt with the highest step
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.safetensors"))
if not checkpoint_files:
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
@ -192,3 +193,50 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None):
log0(f"Loading optimizer state from {optimizer_path}")
optimizer_data = torch.load(optimizer_path, map_location=device)
return optimizer_data
def upload_to_hf(checkpoint_dir, step, repo_id, token=None, with_optimizer=False):
"""Upload the model, metadata, and tokenizer to Hugging Face Hub."""
try:
from huggingface_hub import HfApi
except ImportError:
log0("Error: huggingface_hub not installed. Run 'pip install huggingface_hub'")
return
log0(f"Uploading model to Hugging Face Hub: {repo_id}")
api = HfApi(token=token)
api.create_repo(repo_id=repo_id, exist_ok=True)
# 1. Upload model and meta using standard names
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.safetensors")
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
if os.path.exists(model_path):
api.upload_file(path_or_fileobj=model_path, path_in_repo="model.safetensors", repo_id=repo_id)
else:
log0(f"Warning: model file not found at {model_path}")
if os.path.exists(meta_path):
api.upload_file(path_or_fileobj=meta_path, path_in_repo="config.json", repo_id=repo_id)
else:
log0(f"Warning: metadata file not found at {meta_path}")
# 2. Upload optimizer state if requested
if with_optimizer:
log0("Uploading optimizer state...")
# Find all optimizer shards for this step
optim_shards = glob.glob(os.path.join(checkpoint_dir, f"optim_{step:06d}_rank*.pt"))
for shard in optim_shards:
filename = os.path.basename(shard)
api.upload_file(path_or_fileobj=shard, path_in_repo=f"optimizer/{filename}", repo_id=repo_id)
# 3. Upload tokenizer files
base_dir = get_base_dir()
tokenizer_dir = os.path.join(base_dir, "tokenizer")
if os.path.exists(tokenizer_dir):
for f in os.listdir(tokenizer_dir):
full_path = os.path.join(tokenizer_dir, f)
if os.path.isfile(full_path):
# Upload tokenizer files to the root for better compatibility
api.upload_file(path_or_fileobj=full_path, path_in_repo=f, repo_id=repo_id)
log0(f"Successfully uploaded model to https://huggingface.co/{repo_id}")

View File

@ -7,6 +7,7 @@ requires-python = ">=3.10"
dependencies = [
"datasets>=4.0.0",
"fastapi>=0.117.1",
"huggingface-hub>=1.5.0",
"ipykernel>=7.1.0",
"kernels>=0.11.7",
"matplotlib>=3.10.8",
@ -14,6 +15,7 @@ dependencies = [
"python-dotenv>=1.2.1",
"regex>=2025.9.1",
"rustbpe>=0.1.0",
"safetensors>=0.7.0",
"scipy>=1.15.3",
"setuptools>=80.9.0",
"tabulate>=0.9.0",

View File

@ -29,7 +29,7 @@ from nanochat.gpt import GPT, GPTConfig, Linear
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint, upload_to_hf
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.flash_attention import HAS_FA3
@ -77,6 +77,8 @@ parser.add_argument("--sample-every", type=int, default=2000, help="sample from
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
# Output
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
parser.add_argument("--hf-repo", type=str, default=None, help="Hugging Face repo ID to upload final model")
parser.add_argument("--hf-upload-optim", action="store_true", help="upload optimizer state to Hugging Face")
args = parser.parse_args()
user_config = vars(args).copy() # for logging
# -----------------------------------------------------------------------------
@ -624,6 +626,10 @@ get_report().log(section="Base model training", data=[
}
])
# upload to huggingface
if args.hf_repo and master_process:
upload_to_hf(checkpoint_dir, step, args.hf_repo, with_optimizer=args.hf_upload_optim)
# cleanup
wandb_run.finish() # wandb run finish
compute_cleanup()

View File

@ -18,7 +18,7 @@ import wandb
import torch
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state, upload_to_hf
from nanochat.loss_eval import evaluate_bpb
import torch.distributed as dist
from nanochat.flash_attention import HAS_FA3
@ -66,6 +66,8 @@ parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max pro
# Data mixture
parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)")
parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)")
parser.add_argument("--hf-repo", type=str, default=None, help="Hugging Face repo ID to upload final model")
parser.add_argument("--hf-upload-optim", action="store_true", help="upload optimizer state to Hugging Face")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
@ -514,6 +516,10 @@ get_report().log(section="SFT", data=[
}
])
# upload to huggingface
if args.hf_repo and master_process:
upload_to_hf(checkpoint_dir, step, args.hf_repo, with_optimizer=args.hf_upload_optim)
# cleanup
wandb_run.finish() # wandb run finish
compute_cleanup()