mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 05:35:19 +00:00
Merge b0778933ee into 5019accc5b
This commit is contained in:
commit
7bc3284b46
|
|
@ -7,6 +7,7 @@ import glob
|
|||
import json
|
||||
import logging
|
||||
import torch
|
||||
import safetensors.torch
|
||||
|
||||
from nanochat.common import get_base_dir
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
|
|
@ -42,9 +43,9 @@ def _patch_missing_keys(model_data, model_config):
|
|||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
||||
if rank == 0:
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
# Save the model state parameters
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||
torch.save(model_data, model_path)
|
||||
# Save the model state parameters in safetensors format
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.safetensors")
|
||||
safetensors.torch.save_file(model_data, model_path)
|
||||
logger.info(f"Saved model parameters to: {model_path}")
|
||||
# Save the metadata dict as json
|
||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||
|
|
@ -59,9 +60,9 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data,
|
|||
logger.info(f"Saved optimizer state to: {optimizer_path}")
|
||||
|
||||
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
|
||||
# Load the model state
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||
model_data = torch.load(model_path, map_location=device)
|
||||
# Load the model state in safetensors format
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.safetensors")
|
||||
model_data = safetensors.torch.load_file(model_path, device=str(device))
|
||||
# Load the optimizer state if requested
|
||||
optimizer_data = None
|
||||
if load_optimizer:
|
||||
|
|
@ -136,8 +137,8 @@ def find_largest_model(checkpoints_dir):
|
|||
|
||||
|
||||
def find_last_step(checkpoint_dir):
|
||||
# Look into checkpoint_dir and find model_<step>.pt with the highest step
|
||||
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
|
||||
# Look into checkpoint_dir and find model_<step>.safetensors or model_<step>.pt with the highest step
|
||||
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.safetensors"))
|
||||
if not checkpoint_files:
|
||||
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
||||
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
|
||||
|
|
@ -192,3 +193,50 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None):
|
|||
log0(f"Loading optimizer state from {optimizer_path}")
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||
return optimizer_data
|
||||
|
||||
def upload_to_hf(checkpoint_dir, step, repo_id, token=None, with_optimizer=False):
|
||||
"""Upload the model, metadata, and tokenizer to Hugging Face Hub."""
|
||||
try:
|
||||
from huggingface_hub import HfApi
|
||||
except ImportError:
|
||||
log0("Error: huggingface_hub not installed. Run 'pip install huggingface_hub'")
|
||||
return
|
||||
|
||||
log0(f"Uploading model to Hugging Face Hub: {repo_id}")
|
||||
api = HfApi(token=token)
|
||||
api.create_repo(repo_id=repo_id, exist_ok=True)
|
||||
|
||||
# 1. Upload model and meta using standard names
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.safetensors")
|
||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||
|
||||
if os.path.exists(model_path):
|
||||
api.upload_file(path_or_fileobj=model_path, path_in_repo="model.safetensors", repo_id=repo_id)
|
||||
else:
|
||||
log0(f"Warning: model file not found at {model_path}")
|
||||
|
||||
if os.path.exists(meta_path):
|
||||
api.upload_file(path_or_fileobj=meta_path, path_in_repo="config.json", repo_id=repo_id)
|
||||
else:
|
||||
log0(f"Warning: metadata file not found at {meta_path}")
|
||||
|
||||
# 2. Upload optimizer state if requested
|
||||
if with_optimizer:
|
||||
log0("Uploading optimizer state...")
|
||||
# Find all optimizer shards for this step
|
||||
optim_shards = glob.glob(os.path.join(checkpoint_dir, f"optim_{step:06d}_rank*.pt"))
|
||||
for shard in optim_shards:
|
||||
filename = os.path.basename(shard)
|
||||
api.upload_file(path_or_fileobj=shard, path_in_repo=f"optimizer/{filename}", repo_id=repo_id)
|
||||
|
||||
# 3. Upload tokenizer files
|
||||
base_dir = get_base_dir()
|
||||
tokenizer_dir = os.path.join(base_dir, "tokenizer")
|
||||
if os.path.exists(tokenizer_dir):
|
||||
for f in os.listdir(tokenizer_dir):
|
||||
full_path = os.path.join(tokenizer_dir, f)
|
||||
if os.path.isfile(full_path):
|
||||
# Upload tokenizer files to the root for better compatibility
|
||||
api.upload_file(path_or_fileobj=full_path, path_in_repo=f, repo_id=repo_id)
|
||||
|
||||
log0(f"Successfully uploaded model to https://huggingface.co/{repo_id}")
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ requires-python = ">=3.10"
|
|||
dependencies = [
|
||||
"datasets>=4.0.0",
|
||||
"fastapi>=0.117.1",
|
||||
"huggingface-hub>=1.5.0",
|
||||
"ipykernel>=7.1.0",
|
||||
"kernels>=0.11.7",
|
||||
"matplotlib>=3.10.8",
|
||||
|
|
@ -14,6 +15,7 @@ dependencies = [
|
|||
"python-dotenv>=1.2.1",
|
||||
"regex>=2025.9.1",
|
||||
"rustbpe>=0.1.0",
|
||||
"safetensors>=0.7.0",
|
||||
"scipy>=1.15.3",
|
||||
"setuptools>=80.9.0",
|
||||
"tabulate>=0.9.0",
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from nanochat.gpt import GPT, GPTConfig, Linear
|
|||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint, upload_to_hf
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.flash_attention import HAS_FA3
|
||||
|
|
@ -77,6 +77,8 @@ parser.add_argument("--sample-every", type=int, default=2000, help="sample from
|
|||
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
|
||||
# Output
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
|
||||
parser.add_argument("--hf-repo", type=str, default=None, help="Hugging Face repo ID to upload final model")
|
||||
parser.add_argument("--hf-upload-optim", action="store_true", help="upload optimizer state to Hugging Face")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy() # for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -624,6 +626,10 @@ get_report().log(section="Base model training", data=[
|
|||
}
|
||||
])
|
||||
|
||||
# upload to huggingface
|
||||
if args.hf_repo and master_process:
|
||||
upload_to_hf(checkpoint_dir, step, args.hf_repo, with_optimizer=args.hf_upload_optim)
|
||||
|
||||
# cleanup
|
||||
wandb_run.finish() # wandb run finish
|
||||
compute_cleanup()
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import wandb
|
|||
import torch
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state, upload_to_hf
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
import torch.distributed as dist
|
||||
from nanochat.flash_attention import HAS_FA3
|
||||
|
|
@ -66,6 +66,8 @@ parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max pro
|
|||
# Data mixture
|
||||
parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)")
|
||||
parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)")
|
||||
parser.add_argument("--hf-repo", type=str, default=None, help="Hugging Face repo ID to upload final model")
|
||||
parser.add_argument("--hf-upload-optim", action="store_true", help="upload optimizer state to Hugging Face")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -514,6 +516,10 @@ get_report().log(section="SFT", data=[
|
|||
}
|
||||
])
|
||||
|
||||
# upload to huggingface
|
||||
if args.hf_repo and master_process:
|
||||
upload_to_hf(checkpoint_dir, step, args.hf_repo, with_optimizer=args.hf_upload_optim)
|
||||
|
||||
# cleanup
|
||||
wandb_run.finish() # wandb run finish
|
||||
compute_cleanup()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user