allow base_loss to report the loss of any arbitrary huggingface model similar to base_eval. had to change dataloader to be a lot better and just take tokenizer, not load the nanochat one. much better this way anyway

This commit is contained in:
Andrej Karpathy 2026-01-12 03:10:13 +00:00
parent aa95fb2e03
commit 21608ec51e
4 changed files with 73 additions and 15 deletions

View File

@ -5,9 +5,8 @@ import pyarrow.parquet as pq
from nanochat.common import get_dist_info
from nanochat.dataset import list_parquet_files
from nanochat.tokenizer import get_tokenizer
def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
def tokenizing_distributed_data_loader_with_state(tokenizer, B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
"""
Stream pretraining text from parquet files, tokenize, yield training batches.
@ -62,8 +61,6 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads
# Now emit batches of tokens.
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
# get the tokenizer and the bos token
tokenizer = get_tokenizer()
bos_token = tokenizer.get_bos_token_id()
# scratch buffer holds the tokens for one iteration
token_buffer = deque() # we stream tokens on the right and pop from the left

View File

@ -103,9 +103,10 @@ class HuggingFaceTokenizer:
def id_to_token(self, id):
return self.tokenizer.id_to_token(id)
def _encode_one(self, text, prepend=None, append=None):
def _encode_one(self, text, prepend=None, append=None, num_threads=None):
# encode a single string
# prepend/append can be either a string of a special token or a token id directly.
# num_threads is ignored (only used by the nanochat Tokenizer for parallel encoding)
assert isinstance(text, str)
ids = []
if prepend is not None:

View File

@ -5,6 +5,9 @@ Loads a checkpoint, and:
Example run as:
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
To evaluate a HuggingFace model:
python -m scripts.base_loss --hf_path openai-community/gpt2
"""
import argparse
from contextlib import nullcontext
@ -12,42 +15,98 @@ import torch
from nanochat.checkpoint_manager import load_model
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.tokenizer import get_token_bytes
from nanochat.tokenizer import get_token_bytes, HuggingFaceTokenizer
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
# -----------------------------------------------------------------------------
# HuggingFace loading utilities, making the APIs match up to those of nanochat
class ModelWrapper:
"""Lightweight wrapper for a HuggingFace model"""
def __init__(self, model, max_seq_len=None):
self.model = model
self.max_seq_len = max_seq_len
def __call__(self, input_ids, targets=None, loss_reduction='mean'):
logits = self.model(input_ids).logits
if targets is None:
return logits
else:
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
return loss
def get_device(self):
return next(self.model.parameters()).device
def load_hf_model(hf_path: str, device):
print0(f"Loading model from: {hf_path}")
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(hf_path)
model.to(device)
model.eval()
max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None
model = ModelWrapper(model, max_seq_len=max_seq_len)
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
return model, tokenizer
def get_hf_token_bytes(tokenizer, device="cpu"):
"""Compute token_bytes tensor for a HuggingFace tokenizer."""
vocab_size = tokenizer.tokenizer.get_vocab_size()
token_bytes = torch.zeros(vocab_size, dtype=torch.int64, device=device)
for token_id in range(vocab_size):
token_str = tokenizer.tokenizer.decode([token_id])
token_bytes[token_id] = len(token_str.encode('utf-8')) # Count UTF-8 bytes
return token_bytes
# CLI arguments
parser = argparse.ArgumentParser(description="Evaluate loss on train/val splits and sample from model")
parser.add_argument("--device_batch_size", type=int, default=32, help="per-device batch size")
parser.add_argument("--split_tokens", type=int, default=20*524288, help="number of tokens to evaluate per split")
parser.add_argument("--split_tokens", type=int, default=40*524288, help="number of tokens to evaluate per split")
parser.add_argument("--model_tag", type=str, default=None, help="model tag for checkpoint directory")
parser.add_argument("--model_step", type=int, default=None, help="model step to load")
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--hf_path", type=str, default=None, help="HuggingFace model path (e.g. openai-community/gpt2)")
args = parser.parse_args()
# Load the base model and the tokenizer
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
print0(f"Device: {device} | DDP rank: {ddp_rank} | DDP local rank: {ddp_local_rank} | DDP world size: {ddp_world_size}")
if args.hf_path is not None:
# Load HuggingFace model
model, tokenizer = load_hf_model(args.hf_path, device)
sequence_len = model.max_seq_len if model.max_seq_len else 1024
token_bytes = get_hf_token_bytes(tokenizer, device=device)
model_name = args.hf_path
else:
# Load local nanochat model
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
sequence_len = meta["model_config"]["sequence_len"]
token_bytes = get_token_bytes(device=device)
model_name = f"base_model (step {meta['step']})"
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
print0(f"Evaluating model: {model_name}")
# Evaluate the loss on each split
tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size
assert args.split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
steps = args.split_tokens // tokens_per_step
token_bytes = get_token_bytes(device=device)
bpb_results = {}
for split_name in ["train", "val"]:
loader = tokenizing_distributed_data_loader(args.device_batch_size, sequence_len, split_name, device=device)
loader = tokenizing_distributed_data_loader(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
with autocast_ctx:
bpb = evaluate_bpb(model, loader, steps, token_bytes)
print0(f"{split_name} bpb: {bpb:.4f}")
bpb_results[split_name] = bpb
print0(f"Model: {model_name}, {split_name} bpb: {bpb:.6f}")
# Master process also samples from the model
# Master process also samples from the model (only for nanochat models)
samples = []
if ddp_rank == 0:
if ddp_rank == 0 and args.hf_path is None:
prompts = [
"The capital of France is",
"The chemical symbol of gold is",
@ -70,6 +129,7 @@ if ddp_rank == 0:
from nanochat.report import get_report
get_report().log(section="Base model loss", data=[
{
"model": model_name,
"train bpb": bpb_results["train"],
"val bpb": bpb_results["val"],
},

View File

@ -210,8 +210,8 @@ if resuming:
# Initialize the DataLoaders for train/val
tokens_dir = os.path.join(base_dir, "tokenized_data")
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
train_loader = tokenizing_distributed_data_loader_with_state(args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
build_val_loader = lambda: tokenizing_distributed_data_loader(args.device_batch_size, args.max_seq_len, split="val", device=device)
train_loader = tokenizing_distributed_data_loader_with_state(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
build_val_loader = lambda: tokenizing_distributed_data_loader(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
# -----------------------------------------------------------------------------