nanochat/scripts/base_eval.py
Unsal Gokdag 4f79e750e7 CORE eval: batched forwarding by default, per-example mode for verification
Switch cached eval path to batched=True (forwards full collated batches)
      for ~5-7x speedup over sequential per-example evaluation. Add per-example
      forwarding mode (batched=False) that trims collation padding to recover
      exact per-example tensor shapes, guaranteeing identical results to the
      old sequential path. Bench script uses batched=True for speed sweeps and
      per-example mode for correctness verification against old.
2026-02-13 08:42:45 +00:00

478 lines
21 KiB
Python

"""
Unified evaluation script for base models.
Supports three evaluation modes (comma-separated):
--eval core : CORE metric (accuracy on ICL tasks)
--eval bpb : Bits per byte on train/val splits
--eval sample : Generate samples from the model
Default is all three: --eval core,bpb,sample
Examples:
# Evaluate a HuggingFace model (e.g. GPT-2 124M) using 8 GPUs
torchrun --nproc_per_node=8 -m scripts.base_eval --hf-path openai-community/gpt2
# Evaluate a nanochat model (e.g. d24) using 8 GPUs
torchrun --nproc_per_node=8 -m scripts.base_eval --model-tag d24 --device-batch-size=16
# Quick/approximate evaluation using a single GPU
python -m scripts.base_eval --model-tag d24 --device-batch-size=16 --max-per-task=100 --split-tokens=524288
"""
import os
import csv
import time
import json
import yaml
import shutil
import random
import hashlib
import zipfile
import tempfile
import argparse
from contextlib import nullcontext
from tqdm import tqdm
import torch
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
from nanochat.tokenizer import HuggingFaceTokenizer, get_token_bytes
from nanochat.checkpoint_manager import load_model
from nanochat.core_eval import evaluate_task, prepare_task_data, _forward_all_cached
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
# -----------------------------------------------------------------------------
# HuggingFace loading utilities
class ModelWrapper:
"""Lightweight wrapper to give HuggingFace models a nanochat-compatible interface."""
def __init__(self, model, max_seq_len=None):
self.model = model
self.max_seq_len = max_seq_len
def __call__(self, input_ids, targets=None, loss_reduction='mean'):
logits = self.model(input_ids).logits
if targets is None:
return logits
loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-1,
reduction=loss_reduction
)
return loss
def get_device(self):
return next(self.model.parameters()).device
def load_hf_model(hf_path: str, device):
"""Load a HuggingFace model and tokenizer."""
print0(f"Loading HuggingFace model from: {hf_path}")
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(hf_path)
model.to(device)
model.eval()
max_seq_len = 1024 if "gpt2" in hf_path else None
model = ModelWrapper(model, max_seq_len=max_seq_len)
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
return model, tokenizer
def get_hf_token_bytes(tokenizer, device="cpu"):
"""Compute token_bytes tensor for a HuggingFace tokenizer."""
vocab_size = tokenizer.tokenizer.get_vocab_size()
token_bytes = torch.zeros(vocab_size, dtype=torch.int64, device=device)
for token_id in range(vocab_size):
token_str = tokenizer.tokenizer.decode([token_id])
token_bytes[token_id] = len(token_str.encode('utf-8'))
return token_bytes
# -----------------------------------------------------------------------------
# CORE evaluation
EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip"
def place_eval_bundle(file_path):
"""Unzip eval_bundle.zip and place it in the base directory."""
base_dir = get_base_dir()
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
with tempfile.TemporaryDirectory() as tmpdir:
with zipfile.ZipFile(file_path, 'r') as zip_ref:
zip_ref.extractall(tmpdir)
extracted_bundle_dir = os.path.join(tmpdir, "eval_bundle")
shutil.move(extracted_bundle_dir, eval_bundle_dir)
print0(f"Placed eval_bundle directory at {eval_bundle_dir}")
_eval_data_cache = None # (task_inputs, random_baselines, w_label, w_shot, w_type)
_batch_cache = {} # {label: collated_batches} — cached after first run
_batch_cache_key = None # (max_per_task, max_seq_len) — invalidate if these change
_prev_centered = {} # {label: centered_result} — previous run for delta display
_prev_core = None # previous core_metric
def _get_disk_cache_dir(max_per_task):
"""Get disk cache dir for base-4 collated batches, keyed by tokenizer file hash.
Returns None if no local tokenizer file is found (e.g. HuggingFace models)."""
base_dir = get_base_dir()
for fname in ("tokenizer.pkl", "tokenizer.json"):
path = os.path.join(base_dir, "tokenizer", fname)
if os.path.exists(path):
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
h.update(chunk)
tok_hash = h.hexdigest()[:16]
return os.path.join(base_dir, "core_token_cache", f"{tok_hash}_n{max_per_task}")
return None
def evaluate_model(model, tokenizer, device, max_per_task=-1):
"""
Evaluate a base model on the CORE benchmark.
- max_per_task: crop the data to this many examples per task for testing (-1 = disable)
Collated batches are cached across calls since the tokenizer is fixed.
Second+ runs skip prepare and collation entirely — just GPU forward passes.
"""
global _eval_data_cache, _batch_cache, _batch_cache_key, _prev_centered, _prev_core
from concurrent.futures import ThreadPoolExecutor
max_seq_len = getattr(model, 'max_seq_len', None)
cache_key = (max_per_task, max_seq_len)
# Invalidate batch cache if parameters changed
if cache_key != _batch_cache_key:
_batch_cache.clear()
_batch_cache_key = cache_key
# Load and cache task data + baselines (only read from disk once)
if _eval_data_cache is None:
base_dir = get_base_dir()
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
if not os.path.exists(eval_bundle_dir):
download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle)
config_path = os.path.join(eval_bundle_dir, "core.yaml")
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
tasks = config['icl_tasks']
random_baselines = {}
with open(eval_meta_data, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
random_baselines[row['Eval Task']] = float(row['Random baseline'])
task_inputs = []
for task in tasks:
label = task['label']
task_meta = {
'task_type': task['icl_task_type'],
'dataset_uri': task['dataset_uri'],
'num_fewshot': task['num_fewshot'][0],
'continuation_delimiter': task.get('continuation_delimiter', ' ')
}
data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
with open(data_path, 'r', encoding='utf-8') as f:
data = [json.loads(line.strip()) for line in f]
shuffle_rng = random.Random(1337)
shuffle_rng.shuffle(data)
if max_per_task > 0:
data = data[:max_per_task]
task_inputs.append((label, task_meta, data))
w_label = max(len(t[0]) for t in task_inputs)
w_shot = max(len(f"{t[1]['num_fewshot']}-shot") for t in task_inputs)
w_type = max(len(t[1]['task_type']) for t in task_inputs)
_eval_data_cache = (task_inputs, random_baselines, w_label, w_shot, w_type)
task_inputs, random_baselines, w_label, w_shot, w_type = _eval_data_cache
# First run: eagerly prepare next task while evaluating current, cache collated batches.
# Cached runs: pass collated batches directly — no threads, no prepare, no collation.
results = {}
centered_results = {}
cached_run = all(label in _batch_cache for label, _, _ in task_inputs)
disk_cache_dir = _get_disk_cache_dir(max_per_task)
# Try loading from disk cache if in-memory cache is empty
if not cached_run and disk_cache_dir is not None:
all_on_disk = os.path.isdir(disk_cache_dir) and all(
os.path.exists(os.path.join(disk_cache_dir, f"{label}.pt"))
for label, _, _ in task_inputs
)
if all_on_disk:
for label, _, _ in task_inputs:
d = torch.load(os.path.join(disk_cache_dir, f"{label}.pt"), weights_only=False)
_batch_cache[label] = d['collated']
cached_run = True
print0(" (loaded collated batches from disk cache)")
first_run = not cached_run # track whether we did prepare+collate (for disk save)
import torch.distributed as dist
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
total_examples = sum(len(data) for _, _, data in task_inputs)
task_labels = [label for label, _, _ in task_inputs]
pbar = tqdm(total=total_examples, leave=False, disable=(rank != 0))
def _print_task_result(tidx, accuracy):
"""Print one task's result. Updates results/centered_results dicts."""
label, task_meta, _ = task_inputs[tidx]
random_baseline = random_baselines[label]
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
results[label] = accuracy
centered_results[label] = centered_result
shot_str = f"{task_meta['num_fewshot']}-shot"
prefix = f" {label:<{w_label}} {shot_str:<{w_shot}} {task_meta['task_type']:<{w_type}}"
delta_str = ""
if label in _prev_centered:
d = centered_result - _prev_centered[label]
arrow = "\u2191" if d > 0 else "\u2193" if d < 0 else "="
delta_str = f" {arrow}{d:+.4f}"
if rank == 0:
pbar.write(f"{prefix} acc: {accuracy:.4f} centered: {centered_result:>7.4f}{delta_str}")
def _on_task_done(tidx, correct):
"""Callback for _forward_all_cached: convert tensor to accuracy and print."""
_print_task_result(tidx, correct.mean().item())
if cached_run:
# Continuous pipeline: all tasks in one GPU stream, results printed per-task as they complete.
# Always use base batch size (merge=1/split=1) to guarantee identical results —
# different batch dimensions trigger different cuBLAS kernels with different FP rounding.
t0 = time.time()
task_collated = [(_batch_cache[label], data) for label, _, data in task_inputs]
correct_list = _forward_all_cached(
model, task_collated, device, pbar=pbar, task_labels=task_labels,
on_task_done=_on_task_done if world_size == 1 else None,
batched=True,
)
elapsed_total = time.time() - t0
pbar.close()
# DDP: all_reduce + print (single-GPU already handled by on_task_done above)
if world_size > 1:
for tidx, ((label, task_meta, data), correct) in enumerate(zip(task_inputs, correct_list)):
dist.barrier()
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
_print_task_result(tidx, correct.mean().item())
print0(f" (all tasks: {elapsed_total:.2f}s)")
else:
t0 = time.time()
executor = ThreadPoolExecutor(max_workers=1)
first_uncached = next(i for i, (l, _, _) in enumerate(task_inputs) if l not in _batch_cache)
_, first_meta, first_data = task_inputs[first_uncached]
next_future = executor.submit(prepare_task_data, tokenizer, first_data, first_meta, max_seq_len)
for i, (label, task_meta, data) in enumerate(task_inputs):
pbar.set_description(f"{label:<{w_label}}")
if label in _batch_cache:
accuracy, collated = evaluate_task(model, data, device, collated=_batch_cache[label], pbar=pbar)
else:
prepared = next_future.result()
# Kick off prepare for the next uncached task
for j in range(i + 1, len(task_inputs)):
next_label, next_meta, next_data = task_inputs[j]
if next_label not in _batch_cache:
next_future = executor.submit(prepare_task_data, tokenizer, next_data, next_meta, max_seq_len)
break
accuracy, collated = evaluate_task(model, data, device, prepared=prepared, pbar=pbar)
_batch_cache[label] = collated
_print_task_result(i, accuracy)
elapsed_total = time.time() - t0
pbar.close()
executor.shutdown(wait=False)
print0(f" (all tasks: {elapsed_total:.2f}s)")
# Save collated batches to disk after first run (so bench/future runs skip prepare+collate)
if first_run and disk_cache_dir is not None:
pad_id = tokenizer.get_bos_token_id()
os.makedirs(disk_cache_dir, exist_ok=True)
for label, _, _ in task_inputs:
if label in _batch_cache:
torch.save({'collated': _batch_cache[label], 'pad_token_id': pad_id},
os.path.join(disk_cache_dir, f"{label}.pt"))
print0(f" (saved collated batches to {disk_cache_dir})")
core_metric = sum(centered_results.values()) / len(centered_results)
if _prev_core is not None:
d = core_metric - _prev_core
arrow = "\u2191" if d > 0 else "\u2193" if d < 0 else "="
print0(f"CORE: {core_metric:.4f} {arrow}{d:+.4f}")
else:
print0(f"CORE: {core_metric:.4f}")
_prev_centered = dict(centered_results)
_prev_core = core_metric
out = {
"results": results,
"centered_results": centered_results,
"core_metric": core_metric
}
return out
# -----------------------------------------------------------------------------
# Main
def main():
parser = argparse.ArgumentParser(description="Base model evaluation")
parser.add_argument('--eval', type=str, default='core,bpb,sample', help='Comma-separated evaluations to run: core,bpb,sample (default: all)')
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path (e.g. openai-community/gpt2-xl)')
parser.add_argument('--model-tag', type=str, default=None, help='nanochat model tag to identify the checkpoint directory')
parser.add_argument('--step', type=int, default=None, help='Model step to load (default = last)')
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per CORE task (-1 = all)')
parser.add_argument('--device-batch-size', type=int, default=32, help='Per-device batch size for BPB evaluation')
parser.add_argument('--split-tokens', type=int, default=40*524288, help='Number of tokens to evaluate per split for BPB')
parser.add_argument('--device-type', type=str, default='', help='cuda|cpu|mps (empty = autodetect)')
args = parser.parse_args()
# Parse evaluation modes
eval_modes = set(mode.strip() for mode in args.eval.split(','))
valid_modes = {'core', 'bpb', 'sample'}
invalid = eval_modes - valid_modes
if invalid:
parser.error(f"Invalid eval modes: {invalid}. Valid: {valid_modes}")
# Distributed / precision setup
device_type = autodetect_device_type() if args.device_type == '' else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# Load model and tokenizer
is_hf_model = args.hf_path is not None
if is_hf_model:
model, tokenizer = load_hf_model(args.hf_path, device)
sequence_len = model.max_seq_len or 1024
token_bytes = get_hf_token_bytes(tokenizer, device=device)
model_name = args.hf_path
model_slug = args.hf_path.replace("/", "-")
else:
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.step)
sequence_len = meta["model_config"]["sequence_len"]
token_bytes = get_token_bytes(device=device)
model_name = f"base_model (step {meta['step']})"
model_slug = f"base_model_{meta['step']:06d}"
print0(f"Evaluating model: {model_name}")
print0(f"Eval modes: {', '.join(sorted(eval_modes))}")
# Results to log
core_results = None
bpb_results = {}
samples = []
unconditioned_samples = []
# --- Sampling ---
if 'sample' in eval_modes and not is_hf_model:
print0("\n" + "="*80)
print0("Model Samples")
print0("="*80)
if ddp_rank == 0:
prompts = [
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
engine = Engine(model, tokenizer)
print0("\nConditioned samples:")
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with autocast_ctx:
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
sample_str = tokenizer.decode(sample[0])
print0("-" * 80)
print0(sample_str)
samples.append(sample_str)
print0("\nUnconditioned samples:")
tokens = tokenizer("", prepend="<|bos|>")
with autocast_ctx:
uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
for sample in uncond:
sample_str = tokenizer.decode(sample)
print0("-" * 80)
print0(sample_str)
unconditioned_samples.append(sample_str)
elif 'sample' in eval_modes and is_hf_model:
print0("\nSkipping sampling for HuggingFace models (not supported)")
# --- BPB evaluation ---
if 'bpb' in eval_modes:
print0("\n" + "="*80)
print0("BPB Evaluation")
print0("="*80)
tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size
if args.split_tokens % tokens_per_step != 0:
# Adjust to nearest multiple
args.split_tokens = (args.split_tokens // tokens_per_step) * tokens_per_step
print0(f"Adjusted split_tokens to {args.split_tokens} (must be divisible by {tokens_per_step})")
steps = args.split_tokens // tokens_per_step
for split_name in ["train", "val"]:
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
with autocast_ctx:
bpb = evaluate_bpb(model, loader, steps, token_bytes)
bpb_results[split_name] = bpb
print0(f"{split_name} bpb: {bpb:.6f}")
# --- CORE evaluation ---
if 'core' in eval_modes:
print0("\n" + "="*80)
print0("CORE Evaluation")
print0("="*80)
with autocast_ctx:
core_results = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
# Write CSV output
if ddp_rank == 0:
base_dir = get_base_dir()
output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv")
os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
with open(output_csv_path, 'w', encoding='utf-8', newline='') as f:
f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n")
for label in core_results["results"]:
acc = core_results["results"][label]
centered = core_results["centered_results"][label]
f.write(f"{label:<35}, {acc:<10.6f}, {centered:<10.6f}\n")
f.write(f"{'CORE':<35}, {'':<10}, {core_results['core_metric']:<10.6f}\n")
print0(f"\nResults written to: {output_csv_path}")
print0(f"CORE metric: {core_results['core_metric']:.4f}")
# --- Log to report ---
from nanochat.report import get_report
report_data = [{"model": model_name}]
if core_results:
report_data[0]["CORE metric"] = core_results["core_metric"]
report_data.append(core_results["centered_results"])
if bpb_results:
report_data[0]["train bpb"] = bpb_results.get("train")
report_data[0]["val bpb"] = bpb_results.get("val")
if samples:
report_data.append({f"sample {i}": s for i, s in enumerate(samples)})
if unconditioned_samples:
report_data.append({f"unconditioned {i}": s for i, s in enumerate(unconditioned_samples)})
get_report().log(section="Base model evaluation", data=report_data)
compute_cleanup()
if __name__ == "__main__":
main()