nanochat/scripts_moe/quick_infer.py
2026-01-01 12:03:33 +00:00

236 lines
8.8 KiB
Python

"""
Quick local inference for nanochat-MoE checkpoints (pretrain-style, plain text).
Example:
uv run python scripts_moe/quick_infer.py --model-tag d20 --prompt "what's 1+1 equal to?"
uv run python scripts_moe/quick_infer.py --hf-path hf-export/moe_gpt2 --prompt "what's 1+1 equal to?"
"""
import argparse
import json
import os
import sys
import torch
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if REPO_ROOT not in sys.path:
sys.path.insert(0, REPO_ROOT)
import nanochat_moe.manager as _moe_manager
sys.modules.setdefault("manager", _moe_manager)
from nanochat_moe.checkpoint_manager import find_last_step, find_largest_model
from nanochat_moe.common import get_base_dir
from nanochat_moe.standard import GPT, GPTConfig
from nanochat_moe.tokenizer import RustBPETokenizer, get_tokenizer
def pad_vocab_weights(state_dict, target_vocab):
for key in ("transformer.wte.weight", "lm_head.weight"):
if key not in state_dict:
continue
tensor = state_dict[key]
if tensor.size(0) >= target_vocab:
continue
pad_rows = target_vocab - tensor.size(0)
pad = tensor.new_zeros((pad_rows, tensor.size(1)))
state_dict[key] = torch.cat([tensor, pad], dim=0)
def load_tokenizer(mode):
if mode == "gpt2":
return RustBPETokenizer.from_pretrained("gpt2")
if mode == "cache":
return get_tokenizer()
raise ValueError(f"Unknown tokenizer mode: {mode}")
def load_hf_model(hf_path, device):
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(hf_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(hf_path, trust_remote_code=True)
model.to(device)
model.eval()
vocab_limit = None
if tokenizer.vocab_size < model.config.vocab_size:
vocab_limit = tokenizer.vocab_size
print(
f"Warning: tokenizer vocab {tokenizer.vocab_size} < model vocab {model.config.vocab_size}; "
"clamping logits to tokenizer vocab."
)
return model, tokenizer, vocab_limit
def load_standard_model(source, device, model_tag, step, tokenizer):
base_dir = get_base_dir()
checkpoint_root = os.path.join(base_dir, f"{source}_checkpoints")
if model_tag is None:
model_tag = find_largest_model(checkpoint_root)
checkpoint_dir = os.path.join(checkpoint_root, model_tag)
if step is None:
step = find_last_step(checkpoint_dir)
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
model_data = torch.load(model_path, map_location=device)
if device.type in {"cpu", "mps"}:
model_data = {
k: (v.float() if isinstance(v, torch.Tensor) and v.dtype == torch.bfloat16 else v)
for k, v in model_data.items()
}
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
with open(meta_path, "r", encoding="utf-8") as f:
meta = json.load(f)
cfg_kwargs = meta["model_config"]
tok_vocab = tokenizer.get_vocab_size()
cfg_vocab = cfg_kwargs.get("vocab_size")
vocab_limit = None
if tok_vocab > cfg_vocab:
print(
f"Warning: tokenizer vocab {tok_vocab} > model vocab {cfg_vocab}; "
"padding embeddings to match."
)
cfg_kwargs = dict(cfg_kwargs)
cfg_kwargs["vocab_size"] = tok_vocab
pad_vocab_weights(model_data, tok_vocab)
elif tok_vocab < cfg_vocab:
vocab_limit = tok_vocab
print(
f"Warning: tokenizer vocab {tok_vocab} < model vocab {cfg_vocab}; "
"clamping logits to tokenizer vocab."
)
model = GPT(GPTConfig(**cfg_kwargs))
model.load_state_dict(model_data, strict=True)
model.to(device)
model.eval()
return model, tokenizer, meta, model_tag, step, vocab_limit
def build_prompt_tokens(tokenizer, question: str):
prompt = f"Question: {question}\nAnswer:"
bos = tokenizer.get_bos_token_id()
return tokenizer.encode(prompt, prepend=bos), prompt
def build_prompt_tokens_hf(tokenizer, question: str):
prompt = f"Question: {question}\nAnswer:"
ids = tokenizer.encode(prompt, add_special_tokens=False)
bos_id = getattr(tokenizer, "bos_token_id", None)
if bos_id is not None:
ids = [bos_id] + ids
return ids, prompt
def sample_next_token(logits, temperature, top_k):
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits = torch.where(
logits < v[:, [-1]],
torch.full_like(logits, -float("inf")),
logits,
)
if temperature <= 0:
return torch.argmax(logits, dim=-1, keepdim=True)
logits = logits / temperature
probs = torch.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1)
@torch.inference_mode()
def generate_tokens(model, input_ids, max_new_tokens, temperature, top_k, vocab_limit):
for _ in range(max_new_tokens):
idx_cond = (
input_ids
if input_ids.size(1) <= model.config.block_size
else input_ids[:, -model.config.block_size :]
)
logits, _ = model(idx_cond)
logits = logits[:, -1, :]
if vocab_limit is not None and vocab_limit < logits.size(-1):
logits = logits[:, :vocab_limit]
next_token = sample_next_token(logits, temperature, top_k)
input_ids = torch.cat((input_ids, next_token), dim=1)
return input_ids
@torch.inference_mode()
def generate_tokens_hf(model, input_ids, max_new_tokens, temperature, top_k, vocab_limit):
for _ in range(max_new_tokens):
logits = model(input_ids).logits
logits = logits[:, -1, :]
if vocab_limit is not None and vocab_limit < logits.size(-1):
logits = logits[:, :vocab_limit]
next_token = sample_next_token(logits, temperature, top_k)
input_ids = torch.cat((input_ids, next_token), dim=1)
return input_ids
@torch.inference_mode()
def main():
parser = argparse.ArgumentParser(description="Run a quick nanochat-MoE inference")
parser.add_argument("--source", type=str, default="base", choices=["base", "mid", "sft", "rl"])
parser.add_argument("--model-tag", type=str, default="d20")
parser.add_argument("--step", type=int, default=None)
parser.add_argument("--hf-path", type=str, default=None)
parser.add_argument("--prompt", type=str, default="what's 1+1 equal to?")
parser.add_argument("--max-tokens", type=int, default=64)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--top-k", type=int, default=None)
parser.add_argument("--tokenizer", type=str, default="gpt2", choices=["gpt2", "cache"])
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.hf_path:
model, tokenizer, vocab_limit = load_hf_model(args.hf_path, device)
prompt_tokens, prompt_text = build_prompt_tokens_hf(tokenizer, args.prompt)
input_ids = torch.tensor([prompt_tokens], dtype=torch.long, device=device)
output_ids = generate_tokens_hf(
model,
input_ids,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
vocab_limit=vocab_limit,
)
generated_tokens = output_ids[0].tolist()[len(prompt_tokens) :]
output_text = tokenizer.decode(generated_tokens)
print(f"Loaded HF model: {args.hf_path}")
else:
tokenizer = load_tokenizer(args.tokenizer)
model, tokenizer, _, resolved_tag, resolved_step, vocab_limit = load_standard_model(
args.source, device, args.model_tag, args.step, tokenizer
)
prompt_tokens, prompt_text = build_prompt_tokens(tokenizer, args.prompt)
if len(prompt_tokens) > model.config.block_size:
print(
f"Warning: prompt length {len(prompt_tokens)} exceeds block_size "
f"{model.config.block_size}; context will be truncated."
)
input_ids = torch.tensor([prompt_tokens], dtype=torch.long, device=device)
output_ids = generate_tokens(
model,
input_ids,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
vocab_limit=vocab_limit,
)
generated_tokens = output_ids[0].tolist()[len(prompt_tokens) :]
output_text = tokenizer.decode(generated_tokens)
print(f"Loaded: {args.source}/{resolved_tag} step {resolved_step}")
# print("Prompt:", prompt_text)
# print("Output:", output_text)
print("===============Prompt===============")
print(prompt_text)
print("===============Output===============")
print(output_text)
if __name__ == "__main__":
main()