nanochat/scripts/import_hf_checkpoint.py
2026-03-24 20:52:36 -04:00

163 lines
6.3 KiB
Python

"""
Import a Hugging Face model repo into nanochat's native checkpoint format.
This is intended for base-model continuation before multi-stage nanochat runs.
Examples:
python -m scripts.import_hf_checkpoint --repo-id ManmohanSharma/nanochat-d24 --model-tag d24_hf_import
python -m scripts.import_hf_checkpoint --repo-id ManmohanSharma/nanochat-d24 --local-dir /path/to/snapshot
"""
import argparse
import os
from dataclasses import asdict
import torch
from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.common import get_base_dir
from nanochat.gpt import GPT, GPTConfig
from nanochat.tokenizer import get_tokenizer
from nanochat.tools import DEFAULT_TOOL_SCHEMA
def normalize_hf_state_dict_keys(state_dict):
normalized = {}
prefixes = ("_orig_mod.", "module.", "model.")
for key, value in state_dict.items():
normalized_key = key
for prefix in prefixes:
if normalized_key.startswith(prefix):
normalized_key = normalized_key[len(prefix):]
normalized[normalized_key] = value
return normalized
def infer_gpt_config(hf_config):
kwargs = {
"sequence_len": getattr(
hf_config,
"sequence_len",
getattr(hf_config, "max_position_embeddings", getattr(hf_config, "n_positions", 2048)),
),
"vocab_size": getattr(hf_config, "vocab_size"),
"n_layer": getattr(
hf_config,
"n_layer",
getattr(hf_config, "num_hidden_layers", getattr(hf_config, "num_layers")),
),
"n_head": getattr(
hf_config,
"n_head",
getattr(hf_config, "num_attention_heads", None),
),
"n_kv_head": getattr(
hf_config,
"n_kv_head",
getattr(hf_config, "num_key_value_heads", getattr(hf_config, "num_attention_heads", None)),
),
"n_embd": getattr(
hf_config,
"n_embd",
getattr(hf_config, "hidden_size", getattr(hf_config, "d_model", None)),
),
"window_pattern": getattr(hf_config, "window_pattern", "L"),
}
missing = [key for key, value in kwargs.items() if value is None]
if missing:
raise ValueError(f"Could not infer nanochat GPTConfig fields from HF config: {missing}")
return GPTConfig(**kwargs)
def verify_tokenizer_compatibility(hf_tokenizer, nanochat_tokenizer):
hf_vocab = hf_tokenizer.vocab_size
local_vocab = nanochat_tokenizer.get_vocab_size()
if hf_vocab != local_vocab:
raise ValueError(
f"Tokenizer vocab mismatch: HF repo has vocab_size={hf_vocab}, "
f"local nanochat tokenizer has vocab_size={local_vocab}"
)
def load_hf_snapshot(repo_id, revision, token, local_dir):
if local_dir is not None:
return local_dir
return snapshot_download(
repo_id=repo_id,
revision=revision,
token=token,
repo_type="model",
)
def main():
parser = argparse.ArgumentParser(description="Import HF repo into native nanochat checkpoints")
parser.add_argument("--repo-id", required=True, help="HF model repo id, e.g. ManmohanSharma/nanochat-d24")
parser.add_argument("--revision", default=None, help="Optional HF revision")
parser.add_argument("--local-dir", default=None, help="Use an already-downloaded HF snapshot instead of downloading")
parser.add_argument("--token-env", default="HF_TOKEN", help="Environment variable containing the HF token")
parser.add_argument("--model-tag", default=None, help="Destination model tag. Defaults to repo name slug")
parser.add_argument("--step", type=int, default=0, help="Checkpoint step number to write")
parser.add_argument("--source", choices=["base", "sft", "rl"], default="base", help="Destination checkpoint phase")
parser.add_argument("--trust-remote-code", type=int, default=1, help="Pass trust_remote_code to Transformers loaders")
args = parser.parse_args()
token = os.environ.get(args.token_env)
snapshot_path = load_hf_snapshot(args.repo_id, args.revision, token, args.local_dir)
trust_remote_code = bool(args.trust_remote_code)
hf_config = AutoConfig.from_pretrained(snapshot_path, token=token, trust_remote_code=trust_remote_code)
hf_tokenizer = AutoTokenizer.from_pretrained(snapshot_path, token=token, trust_remote_code=trust_remote_code)
nanochat_tokenizer = get_tokenizer()
verify_tokenizer_compatibility(hf_tokenizer, nanochat_tokenizer)
local_config = infer_gpt_config(hf_config)
with torch.no_grad():
model = AutoModelForCausalLM.from_pretrained(
snapshot_path,
token=token,
trust_remote_code=trust_remote_code,
)
hf_state_dict = normalize_hf_state_dict_keys(model.state_dict())
with torch.device("meta"):
local_model = GPT(local_config)
expected_keys = set(local_model.state_dict().keys())
provided_keys = set(hf_state_dict.keys())
missing = sorted(expected_keys - provided_keys)
extra = sorted(provided_keys - expected_keys)
if missing or extra:
message = [
"HF checkpoint keys do not match native nanochat keys after normalization.",
f"Missing keys: {missing[:12]}",
f"Extra keys: {extra[:12]}",
]
raise ValueError("\n".join(message))
model_data = {key: value.detach().cpu() for key, value in hf_state_dict.items()}
meta_data = {
"model_config": asdict(local_config),
"imported_from_hf": True,
"source_hf_repo": args.repo_id,
"source_hf_revision": args.revision,
"tool_schema": DEFAULT_TOOL_SCHEMA,
"tokenizer_vocab_size": nanochat_tokenizer.get_vocab_size(),
}
model_tag = args.model_tag or args.repo_id.split("/")[-1].replace("-", "_")
base_dir = get_base_dir()
phase_dir = {
"base": "base_checkpoints",
"sft": "chatsft_checkpoints",
"rl": "chatrl_checkpoints",
}[args.source]
checkpoint_dir = os.path.join(base_dir, phase_dir, model_tag)
save_checkpoint(checkpoint_dir, args.step, model_data, optimizer_data=None, meta_data=meta_data, rank=0)
print(f"Imported {args.repo_id} into {checkpoint_dir} at step {args.step}")
if __name__ == "__main__":
main()