This commit is contained in:
Unsal 2026-02-24 00:13:26 -05:00 committed by GitHub
commit 122d408c15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 1101 additions and 93 deletions

View File

@ -134,14 +134,21 @@ def batch_sequences_lm(tokenizer, prompts):
# In LM tasks, we have two prompts: without and with continuation
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
tokens_without, tokens_with = tokens
start_idx, end_idx = len(tokens_without), len(tokens_with)
assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
end_idx = len(tokens_with)
# Find longest common prefix — greedy trie tokenizers are not always
# prefix-stable, so we can't assume an exact prefix match.
start_idx = 0
for i in range(min(len(tokens_without), len(tokens_with))):
if tokens_without[i] != tokens_with[i]:
break
start_idx = i + 1
assert start_idx < end_idx, "continuation must produce additional tokens"
# we only need the with continuation prompt in the LM task, i.e. batch size of 1
return [tokens_with], [start_idx], [end_idx]
@torch.no_grad()
@torch.compiler.disable
def forward_model(model, input_ids):
"""
Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
@ -164,9 +171,8 @@ def forward_model(model, input_ids):
return losses, predictions
@torch.no_grad()
def evaluate_example(idx, model, tokenizer, data, device, task_meta):
"""Evaluate a single example, return True if correct, False otherwise"""
def prepare_example(idx, tokenizer, data, task_meta, max_seq_len=None):
"""CPU-only: render prompts, tokenize, stack into tensors. Returns a dict."""
item = data[idx]
task_type = task_meta['task_type']
num_fewshot = task_meta['num_fewshot']
@ -193,70 +199,401 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
else:
raise ValueError(f"Unsupported task type: {task_type}")
# Some models can't forward sequences beyond a certain length (e.g. GPT-2)
# In these cases, we have to truncate sequences to max length and adjust the indices
if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
max_tokens = model.max_seq_len
# Truncate sequences for models with a max length (e.g. GPT-2)
if max_seq_len is not None:
new_tokens, new_start_idxs, new_end_idxs = [], [], []
for t, s, e in zip(tokens, start_idxs, end_idxs):
if len(t) > max_tokens:
num_to_crop = len(t) - max_tokens
new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
new_start_idxs.append(s - num_to_crop) # shift the indices down
if len(t) > max_seq_len:
num_to_crop = len(t) - max_seq_len
new_tokens.append(t[-max_seq_len:])
new_start_idxs.append(s - num_to_crop)
new_end_idxs.append(e - num_to_crop)
assert s - num_to_crop >= 0, "this should never happen right?"
assert e - num_to_crop >= 0, "this should never happen right?"
else:
new_tokens.append(t) # keep unchanged
new_tokens.append(t)
new_start_idxs.append(s)
new_end_idxs.append(e)
tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
# Stack up all the sequences into a batch
pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
input_ids = stack_sequences(tokens, pad_token_id)
input_ids = input_ids.to(device)
pad_token_id = tokenizer.get_bos_token_id()
input_ids = stack_sequences(tokens, pad_token_id) # (num_options, seq_len)
# Forward the model, get the autoregressive loss and argmax prediction at each token
losses, predictions = forward_model(model, input_ids)
return {
'input_ids': input_ids,
'start_idxs': start_idxs,
'end_idxs': end_idxs,
'gold': item.get('gold', None),
'task_type': task_type,
'num_options': input_ids.size(0),
'seq_len': input_ids.size(1),
'pad_token_id': pad_token_id,
}
# See if the losses/predictions come out correctly
def check_result(losses, predictions, input_ids, start_idxs, end_idxs, gold, task_type):
"""Analyze forward pass outputs for one example, return True if correct."""
if task_type == 'language_modeling':
# language modeling task is currently always batch size 1
si = start_idxs[0]
ei = end_idxs[0]
# predictions[i] predict input_ids[i+1] autoregressively
si, ei = start_idxs[0], end_idxs[0]
predicted_tokens = predictions[0, si-1:ei-1]
actual_tokens = input_ids[0, si:ei]
is_correct = torch.all(predicted_tokens == actual_tokens).item()
return torch.all(predicted_tokens == actual_tokens).item()
elif task_type in ['multiple_choice', 'schema']:
# For MC/schema: find the option with lowest average loss
mean_losses = [losses[i, si-1:ei-1].mean().item()
for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
pred_idx = mean_losses.index(min(mean_losses))
is_correct = pred_idx == item['gold']
for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
return mean_losses.index(min(mean_losses)) == gold
else:
raise ValueError(f"Unsupported task type: {task_type}")
return is_correct
def _collate_batches(prepared, batch_size, queue):
"""Background thread: collate batches on CPU and push to queue."""
for batch_start in range(0, len(prepared), batch_size):
batch = prepared[batch_start:batch_start + batch_size]
batch_preps = [p for _, p in batch]
max_len = max(p['seq_len'] for p in batch_preps)
total_rows = sum(p['num_options'] for p in batch_preps)
pad_id = batch_preps[0]['pad_token_id']
combined_ids = torch.full((total_rows, max_len), pad_id, dtype=torch.long)
batch_meta = []
offset = 0
for idx, p in batch:
n, sl = p['num_options'], p['seq_len']
combined_ids[offset:offset+n, :sl] = p['input_ids']
batch_meta.append((idx, n, p['start_idxs'], p['end_idxs'], p['gold'], p['task_type']))
offset += n
queue.put((combined_ids, batch_meta))
queue.put(None) # sentinel
def evaluate_task(model, tokenizer, data, device, task_meta):
def prepare_task_data(tokenizer, data, task_meta, max_seq_len=None):
"""CPU-only: prepare and sort all examples for a task. Can run on a background thread."""
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
indices = list(range(rank, len(data), world_size))
prepared = [(idx, prepare_example(idx, tokenizer, data, task_meta, max_seq_len)) for idx in indices]
prepared.sort(key=lambda x: x[1]['seq_len'])
return prepared
def _prefetch_to_device(tensor, device):
"""Pin and async-transfer a CPU tensor to GPU, overlapping with current GPU work."""
return tensor.pin_memory().to(device, non_blocking=True)
def _forward_batches(model, collated, data, device, pbar=None):
"""Run GPU forward passes on pre-collated batches, return per-example correctness tensor.
Uses double-buffered prefetching on CUDA: while the GPU processes batch N,
batch N+1 is pinned and DMA-transferred asynchronously, keeping the GPU fed.
"""
This function is responsible for evaluating one task across many examples.
It also handles dispatch to all processes if the script is run with torchrun.
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
if not collated:
return correct
use_prefetch = torch.cuda.is_available() and 'cuda' in str(device)
# Prefetch first batch
if use_prefetch:
next_ids = _prefetch_to_device(collated[0][0], device)
else:
next_ids = collated[0][0].to(device)
for i, (_, batch_meta) in enumerate(collated):
combined_ids = next_ids
# Start async transfer of next batch while GPU computes on current
if i + 1 < len(collated):
if use_prefetch:
next_ids = _prefetch_to_device(collated[i + 1][0], device)
else:
next_ids = collated[i + 1][0].to(device)
losses, predictions = forward_model(model, combined_ids)
offset = 0
for idx, n, start_idxs, end_idxs, gold, task_type in batch_meta:
is_correct = check_result(
losses[offset:offset+n], predictions[offset:offset+n],
combined_ids[offset:offset+n],
start_idxs, end_idxs, gold, task_type,
)
correct[idx] = float(is_correct)
offset += n
if pbar is not None:
pbar.update(len(batch_meta))
return correct
def _forward_all_cached(model, task_collated, device, pbar=None, task_labels=None,
on_task_done=None, batched=False, merge=1, split=1, pad_token_id=0):
"""Run all tasks' cached batches through the model in one pass.
All batch tensors are moved to device upfront (~144MB for full CORE eval).
If tensors are already on device (caller preloaded), .to() is a no-op.
Default mode (batched=False): forwards each example individually, trimming
collation padding to recover the exact per-example tensor shape. This
guarantees identical results to sequential per-example evaluation.
Batched mode (batched=True): forwards collated batches with optional GPU
composition. Faster but may produce tiny FP differences vs sequential eval
due to different cuBLAS kernel paths for different matrix dimensions.
- merge > 1: pad+cat consecutive base batches on GPU before forwarding.
- split > 1: slice each group into chunks by example boundaries.
Args:
task_collated: list of (collated_batches, data) per task
pbar: optional progress bar, updated per example (or per batch chunk)
task_labels: optional list of task names for pbar description updates
on_task_done: optional callback(task_idx, correct_tensor) fired when a task completes
batched: if True, forward whole batches (faster, approximate). Default False (exact).
merge/split/pad_token_id: only used when batched=True
Returns:
list of correct tensors (one per task, on device)
"""
# Flatten all batches and move to device upfront (no-op if already there)
flat_stream = [] # (gpu_ids, batch_meta, task_idx)
correct = []
task_batch_counts = []
for task_idx, (collated, data) in enumerate(task_collated):
correct.append(torch.zeros(len(data), dtype=torch.float32, device=device))
task_batch_counts.append(len(collated))
for combined_ids, batch_meta in collated:
flat_stream.append((combined_ids.to(device), batch_meta, task_idx))
if not flat_stream:
return correct
task_batches_remaining = list(task_batch_counts)
current_task = -1
if not batched:
# Per-example forwarding: identical results to sequential evaluation.
# Each example's rows are trimmed to their original seq_len (= max(end_idxs)),
# removing collation padding so forward_model sees the same tensor shape as
# the sequential path.
for combined_ids, batch_meta, task_idx in flat_stream:
if task_idx != current_task:
current_task = task_idx
if pbar is not None and task_labels is not None:
pbar.set_description(task_labels[task_idx])
offset = 0
for idx, n, start_idxs, end_idxs, gold, task_type in batch_meta:
seq_len = max(end_idxs)
example_ids = combined_ids[offset:offset+n, :seq_len]
losses, predictions = forward_model(model, example_ids)
is_correct = check_result(
losses, predictions, example_ids,
start_idxs, end_idxs, gold, task_type,
)
correct[task_idx][idx] = float(is_correct)
offset += n
if pbar is not None:
pbar.update(len(batch_meta))
if on_task_done is not None:
task_batches_remaining[task_idx] -= 1
if task_batches_remaining[task_idx] == 0:
on_task_done(task_idx, correct[task_idx])
else:
# Batched forwarding with optional merge/split composition.
buffer_ids = []
buffer_info = []
for i, (combined_ids, batch_meta, task_idx) in enumerate(flat_stream):
if task_idx != current_task:
current_task = task_idx
if pbar is not None and task_labels is not None:
pbar.set_description(task_labels[task_idx])
buffer_ids.append(combined_ids)
buffer_info.append((batch_meta, task_idx))
if len(buffer_ids) < merge and i < len(flat_stream) - 1:
continue
# GPU compose: pad+cat if multiple batches, otherwise use as-is
if len(buffer_ids) == 1:
mega_ids = buffer_ids[0]
else:
max_len = max(t.shape[1] for t in buffer_ids)
parts = []
for t in buffer_ids:
if t.shape[1] < max_len:
pad = torch.full((t.shape[0], max_len - t.shape[1]), pad_token_id,
dtype=t.dtype, device=t.device)
t = torch.cat([t, pad], dim=1)
parts.append(t)
mega_ids = torch.cat(parts, dim=0)
examples = []
row_bounds = [0]
for bm, tidx in buffer_info:
for idx, n, start_idxs, end_idxs, gold, task_type in bm:
examples.append((idx, n, start_idxs, end_idxs, gold, task_type, tidx))
row_bounds.append(row_bounds[-1] + n)
n_ex = len(examples)
chunk_size = -(-n_ex // split)
for cs in range(0, n_ex, chunk_size):
ce = min(cs + chunk_size, n_ex)
chunk = examples[cs:ce]
chunk_ids = mega_ids[row_bounds[cs]:row_bounds[ce]]
losses, predictions = forward_model(model, chunk_ids)
offset = 0
for idx, n, start_idxs, end_idxs, gold, task_type, tidx in chunk:
is_correct = check_result(
losses[offset:offset+n], predictions[offset:offset+n],
chunk_ids[offset:offset+n],
start_idxs, end_idxs, gold, task_type,
)
correct[tidx][idx] = float(is_correct)
offset += n
if pbar is not None:
pbar.update(len(chunk))
if on_task_done is not None:
for bm, tidx in buffer_info:
task_batches_remaining[tidx] -= 1
if task_batches_remaining[tidx] == 0:
on_task_done(tidx, correct[tidx])
buffer_ids.clear()
buffer_info.clear()
return correct
def compose_collated(base_collated, target_batch_size, base_batch_size=4, pad_token_id=0):
"""Compose base-sized collated batches into target-sized batches.
Supports both merging (target > base) by concatenating consecutive groups,
and splitting (target < base) by slicing along example boundaries.
Examples are sorted by seq_len within each base batch, so splitting can
trim trailing padding columns for efficiency.
"""
if target_batch_size == base_batch_size:
return base_collated
elif target_batch_size > base_batch_size:
# Merge consecutive base batches
n_merge = target_batch_size // base_batch_size
composed = []
for i in range(0, len(base_collated), n_merge):
group = base_collated[i:i + n_merge]
if len(group) == 1:
composed.append(group[0])
continue
max_len = max(ids.shape[1] for ids, _ in group)
parts = []
merged_meta = []
for ids, meta in group:
if ids.shape[1] < max_len:
pad = torch.full((ids.shape[0], max_len - ids.shape[1]), pad_token_id, dtype=ids.dtype)
ids = torch.cat([ids, pad], dim=1)
parts.append(ids)
merged_meta.extend(meta)
composed.append((torch.cat(parts, dim=0), merged_meta))
return composed
else:
# Split base batches into smaller chunks
composed = []
for combined_ids, batch_meta in base_collated:
for chunk_start in range(0, len(batch_meta), target_batch_size):
chunk_meta = batch_meta[chunk_start:chunk_start + target_batch_size]
row_start = sum(m[1] for m in batch_meta[:chunk_start])
row_end = row_start + sum(m[1] for m in chunk_meta)
chunk_ids = combined_ids[row_start:row_end]
# Trim trailing padding (examples sorted by seq_len, so chunks
# near the start of a base batch may need fewer columns)
non_pad = (chunk_ids != pad_token_id)
if non_pad.any():
last_col = non_pad.any(dim=0).nonzero()[-1].item() + 1
if last_col < chunk_ids.shape[1]:
chunk_ids = chunk_ids[:, :last_col].contiguous()
composed.append((chunk_ids, chunk_meta))
return composed
def evaluate_task(model, data, device, batch_size=4, queue_size=2, prepared=None,
collated=None, tokenizer=None, task_meta=None, pbar=None):
"""
Evaluate one task across many examples with batched GPU forward passes.
Examples are sorted by sequence length so similar-length sequences are batched
together, minimizing padding waste and increasing GPU utilization.
Three modes (checked in order):
- collated: skip prepare + collation, go straight to GPU forward passes.
- prepared: skip prepare, collation runs on a background thread pipelined with GPU.
- neither: full pipeline (prepare + collate + forward).
Returns (accuracy, collated_batches) so the caller can cache collated batches.
"""
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
# stride the examples to each rank
for idx in range(rank, len(data), world_size):
is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
correct[idx] = float(is_correct)
if collated is not None:
# Fast path: just GPU forward passes, no threads
correct = _forward_batches(model, collated, data, device, pbar=pbar)
else:
from queue import Queue
from threading import Thread
if prepared is None:
max_seq_len = getattr(model, 'max_seq_len', None)
prepared = prepare_task_data(tokenizer, data, task_meta, max_seq_len)
# Collation thread pipelined with GPU forward passes.
# Double-buffered: while GPU processes batch N, batch N+1 is
# pin_memory()'d and DMA-transferred asynchronously.
queue = Queue(maxsize=queue_size)
collator = Thread(target=_collate_batches, args=(prepared, batch_size, queue), daemon=True)
collator.start()
use_prefetch = torch.cuda.is_available() and 'cuda' in str(device)
def transfer(tensor):
return _prefetch_to_device(tensor, device) if use_prefetch else tensor.to(device)
collated = []
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
# Prime: get first batch and start its transfer
item = queue.get()
if item is not None:
next_ids = transfer(item[0])
while item is not None:
collated.append(item)
combined_ids = next_ids
_, batch_meta = item
# Start async transfer of next batch (overlaps with forward pass below)
item = queue.get()
if item is not None:
next_ids = transfer(item[0])
losses, predictions = forward_model(model, combined_ids)
offset = 0
for idx, n, start_idxs, end_idxs, gold, task_type in batch_meta:
is_correct = check_result(
losses[offset:offset+n], predictions[offset:offset+n],
combined_ids[offset:offset+n],
start_idxs, end_idxs, gold, task_type,
)
correct[idx] = float(is_correct)
offset += n
if pbar is not None:
pbar.update(len(batch_meta))
collator.join()
del prepared
# sync results across all the processes if running distributed
if world_size > 1:
dist.barrier()
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
# compute the mean
mean_correct = correct.mean().item()
return mean_correct
return correct.mean().item(), collated

View File

@ -26,17 +26,19 @@ import json
import yaml
import shutil
import random
import hashlib
import zipfile
import tempfile
import argparse
from contextlib import nullcontext
from tqdm import tqdm
import torch
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
from nanochat.tokenizer import HuggingFaceTokenizer, get_token_bytes
from nanochat.checkpoint_manager import load_model
from nanochat.core_eval import evaluate_task
from nanochat.core_eval import evaluate_task, prepare_task_data, _forward_all_cached
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
@ -106,67 +108,211 @@ def place_eval_bundle(file_path):
print0(f"Placed eval_bundle directory at {eval_bundle_dir}")
def evaluate_core(model, tokenizer, device, max_per_task=-1):
_eval_data_cache = None # (task_inputs, random_baselines, w_label, w_shot, w_type)
_batch_cache = {} # {label: collated_batches} — cached after first run
_batch_cache_key = None # (max_per_task, max_seq_len) — invalidate if these change
_prev_centered = {} # {label: centered_result} — previous run for delta display
_prev_core = None # previous core_metric
def _get_disk_cache_dir(max_per_task):
"""Get disk cache dir for base-4 collated batches, keyed by tokenizer file hash.
Returns None if no local tokenizer file is found (e.g. HuggingFace models)."""
base_dir = get_base_dir()
for fname in ("tokenizer.pkl", "tokenizer.json"):
path = os.path.join(base_dir, "tokenizer", fname)
if os.path.exists(path):
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
h.update(chunk)
tok_hash = h.hexdigest()[:16]
return os.path.join(base_dir, "core_token_cache", f"{tok_hash}_n{max_per_task}")
return None
def evaluate_model(model, tokenizer, device, max_per_task=-1):
"""
Evaluate a base model on the CORE benchmark.
Returns dict with results, centered_results, and core_metric.
- max_per_task: crop the data to this many examples per task for testing (-1 = disable)
Collated batches are cached across calls since the tokenizer is fixed.
Second+ runs skip prepare and collation entirely just GPU forward passes.
"""
base_dir = get_base_dir()
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
# Download the eval bundle if needed
if not os.path.exists(eval_bundle_dir):
download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle)
global _eval_data_cache, _batch_cache, _batch_cache_key, _prev_centered, _prev_core
from concurrent.futures import ThreadPoolExecutor
config_path = os.path.join(eval_bundle_dir, "core.yaml")
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
max_seq_len = getattr(model, 'max_seq_len', None)
cache_key = (max_per_task, max_seq_len)
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
tasks = config['icl_tasks']
# Invalidate batch cache if parameters changed
if cache_key != _batch_cache_key:
_batch_cache.clear()
_batch_cache_key = cache_key
# Load random baseline values
random_baselines = {}
with open(eval_meta_data, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
task_name = row['Eval Task']
random_baseline = row['Random baseline']
random_baselines[task_name] = float(random_baseline)
# Load and cache task data + baselines (only read from disk once)
if _eval_data_cache is None:
base_dir = get_base_dir()
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
if not os.path.exists(eval_bundle_dir):
download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle)
config_path = os.path.join(eval_bundle_dir, "core.yaml")
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
tasks = config['icl_tasks']
# Evaluate each task
random_baselines = {}
with open(eval_meta_data, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
random_baselines[row['Eval Task']] = float(row['Random baseline'])
task_inputs = []
for task in tasks:
label = task['label']
task_meta = {
'task_type': task['icl_task_type'],
'dataset_uri': task['dataset_uri'],
'num_fewshot': task['num_fewshot'][0],
'continuation_delimiter': task.get('continuation_delimiter', ' ')
}
data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
with open(data_path, 'r', encoding='utf-8') as f:
data = [json.loads(line.strip()) for line in f]
shuffle_rng = random.Random(1337)
shuffle_rng.shuffle(data)
if max_per_task > 0:
data = data[:max_per_task]
task_inputs.append((label, task_meta, data))
w_label = max(len(t[0]) for t in task_inputs)
w_shot = max(len(f"{t[1]['num_fewshot']}-shot") for t in task_inputs)
w_type = max(len(t[1]['task_type']) for t in task_inputs)
_eval_data_cache = (task_inputs, random_baselines, w_label, w_shot, w_type)
task_inputs, random_baselines, w_label, w_shot, w_type = _eval_data_cache
# First run: eagerly prepare next task while evaluating current, cache collated batches.
# Cached runs: pass collated batches directly — no threads, no prepare, no collation.
results = {}
centered_results = {}
for task in tasks:
start_time = time.time()
label = task['label']
task_meta = {
'task_type': task['icl_task_type'],
'dataset_uri': task['dataset_uri'],
'num_fewshot': task['num_fewshot'][0],
'continuation_delimiter': task.get('continuation_delimiter', ' ')
}
print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='')
cached_run = all(label in _batch_cache for label, _, _ in task_inputs)
disk_cache_dir = _get_disk_cache_dir(max_per_task)
data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
with open(data_path, 'r', encoding='utf-8') as f:
data = [json.loads(line.strip()) for line in f]
# Try loading from disk cache if in-memory cache is empty
if not cached_run and disk_cache_dir is not None:
all_on_disk = os.path.isdir(disk_cache_dir) and all(
os.path.exists(os.path.join(disk_cache_dir, f"{label}.pt"))
for label, _, _ in task_inputs
)
if all_on_disk:
for label, _, _ in task_inputs:
d = torch.load(os.path.join(disk_cache_dir, f"{label}.pt"), weights_only=False)
_batch_cache[label] = d['collated']
cached_run = True
print0(" (loaded collated batches from disk cache)")
# Shuffle for consistent subsampling when using max_per_task
shuffle_rng = random.Random(1337)
shuffle_rng.shuffle(data)
if max_per_task > 0:
data = data[:max_per_task]
first_run = not cached_run # track whether we did prepare+collate (for disk save)
accuracy = evaluate_task(model, tokenizer, data, device, task_meta)
results[label] = accuracy
import torch.distributed as dist
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
total_examples = sum(len(data) for _, _, data in task_inputs)
task_labels = [label for label, _, _ in task_inputs]
pbar = tqdm(total=total_examples, leave=False, disable=(rank != 0))
def _print_task_result(tidx, accuracy):
"""Print one task's result. Updates results/centered_results dicts."""
label, task_meta, _ = task_inputs[tidx]
random_baseline = random_baselines[label]
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
results[label] = accuracy
centered_results[label] = centered_result
elapsed = time.time() - start_time
print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {elapsed:.2f}s")
shot_str = f"{task_meta['num_fewshot']}-shot"
prefix = f" {label:<{w_label}} {shot_str:<{w_shot}} {task_meta['task_type']:<{w_type}}"
delta_str = ""
if label in _prev_centered:
d = centered_result - _prev_centered[label]
arrow = "\u2191" if d > 0 else "\u2193" if d < 0 else "="
delta_str = f" {arrow}{d:+.4f}"
if rank == 0:
pbar.write(f"{prefix} acc: {accuracy:.4f} centered: {centered_result:>7.4f}{delta_str}")
def _on_task_done(tidx, correct):
"""Callback for _forward_all_cached: convert tensor to accuracy and print."""
_print_task_result(tidx, correct.mean().item())
if cached_run:
# Continuous pipeline: all tasks in one GPU stream, results printed per-task as they complete.
# Always use base batch size (merge=1/split=1) to guarantee identical results —
# different batch dimensions trigger different cuBLAS kernels with different FP rounding.
t0 = time.time()
task_collated = [(_batch_cache[label], data) for label, _, data in task_inputs]
correct_list = _forward_all_cached(
model, task_collated, device, pbar=pbar, task_labels=task_labels,
on_task_done=_on_task_done if world_size == 1 else None,
batched=True,
)
elapsed_total = time.time() - t0
pbar.close()
# DDP: all_reduce + print (single-GPU already handled by on_task_done above)
if world_size > 1:
for tidx, ((label, task_meta, data), correct) in enumerate(zip(task_inputs, correct_list)):
dist.barrier()
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
_print_task_result(tidx, correct.mean().item())
print0(f" (all tasks: {elapsed_total:.2f}s)")
else:
t0 = time.time()
executor = ThreadPoolExecutor(max_workers=1)
first_uncached = next(i for i, (l, _, _) in enumerate(task_inputs) if l not in _batch_cache)
_, first_meta, first_data = task_inputs[first_uncached]
next_future = executor.submit(prepare_task_data, tokenizer, first_data, first_meta, max_seq_len)
for i, (label, task_meta, data) in enumerate(task_inputs):
pbar.set_description(f"{label:<{w_label}}")
if label in _batch_cache:
accuracy, collated = evaluate_task(model, data, device, collated=_batch_cache[label], pbar=pbar)
else:
prepared = next_future.result()
# Kick off prepare for the next uncached task
for j in range(i + 1, len(task_inputs)):
next_label, next_meta, next_data = task_inputs[j]
if next_label not in _batch_cache:
next_future = executor.submit(prepare_task_data, tokenizer, next_data, next_meta, max_seq_len)
break
accuracy, collated = evaluate_task(model, data, device, prepared=prepared, pbar=pbar)
_batch_cache[label] = collated
_print_task_result(i, accuracy)
elapsed_total = time.time() - t0
pbar.close()
executor.shutdown(wait=False)
print0(f" (all tasks: {elapsed_total:.2f}s)")
# Save collated batches to disk after first run (so bench/future runs skip prepare+collate)
if first_run and disk_cache_dir is not None:
pad_id = tokenizer.get_bos_token_id()
os.makedirs(disk_cache_dir, exist_ok=True)
for label, _, _ in task_inputs:
if label in _batch_cache:
torch.save({'collated': _batch_cache[label], 'pad_token_id': pad_id},
os.path.join(disk_cache_dir, f"{label}.pt"))
print0(f" (saved collated batches to {disk_cache_dir})")
core_metric = sum(centered_results.values()) / len(centered_results)
if _prev_core is not None:
d = core_metric - _prev_core
arrow = "\u2191" if d > 0 else "\u2193" if d < 0 else "="
print0(f"CORE: {core_metric:.4f} {arrow}{d:+.4f}")
else:
print0(f"CORE: {core_metric:.4f}")
_prev_centered = dict(centered_results)
_prev_core = core_metric
out = {
"results": results,
"centered_results": centered_results,
@ -288,7 +434,7 @@ def main():
print0("CORE Evaluation")
print0("="*80)
with autocast_ctx:
core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task)
core_results = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
# Write CSV output
if ddp_rank == 0:

View File

@ -32,7 +32,7 @@ from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.flash_attention import HAS_FA3
from scripts.base_eval import evaluate_core
from scripts.base_eval import evaluate_model
print_banner()
# -----------------------------------------------------------------------------
@ -425,7 +425,7 @@ while True:
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
model.eval()
with disable_fp8(orig_model), autocast_ctx:
results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
results = evaluate_model(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
wandb_run.log({
"step": step,

525
scripts/bench_core_eval.py Normal file
View File

@ -0,0 +1,525 @@
"""
Benchmark the CORE evaluation pipeline.
Compares three modes:
1. Old sequential (per-example) evaluation
2. New batched evaluation (first run includes prepare + collate + forward)
3. New batched evaluation (cached run forward only)
Also sweeps batch_size and queue_size to find optimal hyperparameters.
Usage:
python -m scripts.bench_core_eval
python -m scripts.bench_core_eval --max-per-task 100 # quick test
python -m scripts.bench_core_eval --hf-path openai-community/gpt2
"""
import os
import csv
import json
import time
import yaml
import random
import shutil
import hashlib
import zipfile
import tempfile
import argparse
from contextlib import nullcontext
from tqdm import tqdm
import torch
import torch.distributed as dist
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
from nanochat.tokenizer import HuggingFaceTokenizer
from nanochat.checkpoint_manager import load_model
from nanochat.core_eval import (
forward_model, prepare_example, check_result, stack_sequences,
prepare_task_data, _collate_batches, _forward_all_cached,
evaluate_task,
render_prompts_mc, render_prompts_schema, render_prompts_lm,
batch_sequences_mc, batch_sequences_schema, batch_sequences_lm,
)
# ---- eval bundle loading (reused from base_eval) ----
torch.backends.fp32_precision = "tf32"
EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip"
def place_eval_bundle(file_path):
base_dir = get_base_dir()
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
with tempfile.TemporaryDirectory() as tmpdir:
with zipfile.ZipFile(file_path, 'r') as zip_ref:
zip_ref.extractall(tmpdir)
shutil.move(os.path.join(tmpdir, "eval_bundle"), eval_bundle_dir)
print0(f"Placed eval_bundle at {eval_bundle_dir}")
def load_tasks(max_per_task=-1):
base_dir = get_base_dir()
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
if not os.path.exists(eval_bundle_dir):
download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle)
config_path = os.path.join(eval_bundle_dir, "core.yaml")
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
task_inputs = []
for task in config['icl_tasks']:
label = task['label']
task_meta = {
'task_type': task['icl_task_type'],
'dataset_uri': task['dataset_uri'],
'num_fewshot': task['num_fewshot'][0],
'continuation_delimiter': task.get('continuation_delimiter', ' ')
}
data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
with open(data_path, 'r', encoding='utf-8') as f:
data = [json.loads(line.strip()) for line in f]
shuffle_rng = random.Random(1337)
shuffle_rng.shuffle(data)
if max_per_task > 0:
data = data[:max_per_task]
task_inputs.append((label, task_meta, data))
return task_inputs
BASE_BATCH_SIZE = 4
def file_checksum(path):
"""SHA-256 checksum of a file, truncated to 16 hex chars."""
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
h.update(chunk)
return h.hexdigest()[:16]
def collate_prepared(prepared, batch_size):
"""Collate prepared examples into batches (non-threaded). Returns (collated, pad_token_id)."""
pad_id = prepared[0][1]['pad_token_id']
collated = []
for batch_start in range(0, len(prepared), batch_size):
batch = prepared[batch_start:batch_start + batch_size]
batch_preps = [p for _, p in batch]
max_len = max(p['seq_len'] for p in batch_preps)
total_rows = sum(p['num_options'] for p in batch_preps)
combined_ids = torch.full((total_rows, max_len), pad_id, dtype=torch.long)
batch_meta = []
offset = 0
for idx, p in batch:
n, sl = p['num_options'], p['seq_len']
combined_ids[offset:offset+n, :sl] = p['input_ids']
batch_meta.append((idx, n, p['start_idxs'], p['end_idxs'], p['gold'], p['task_type']))
offset += n
collated.append((combined_ids, batch_meta))
return collated, pad_id
def build_or_load_base_collated(tok_hash, tokenizer, task_inputs, max_seq_len, max_per_task):
"""Build or load base-4 collated data for all tasks, with disk caching."""
base_dir = get_base_dir()
cache_dir = os.path.join(base_dir, "core_token_cache", f"{tok_hash}_n{max_per_task}")
all_cached = os.path.isdir(cache_dir) and all(
os.path.exists(os.path.join(cache_dir, f"{label}.pt"))
for label, _, _ in task_inputs
)
base_cache = {} # label -> (collated, pad_token_id)
if all_cached:
print0(f"Loading base-{BASE_BATCH_SIZE} collated cache from {cache_dir}")
for label, _, _ in tqdm(task_inputs, desc="Loading cache", leave=False):
data = torch.load(os.path.join(cache_dir, f"{label}.pt"), weights_only=False)
base_cache[label] = (data['collated'], data['pad_token_id'])
else:
print0(f"Building base-{BASE_BATCH_SIZE} collated cache (saving to {cache_dir})")
os.makedirs(cache_dir, exist_ok=True)
for label, task_meta, data in tqdm(task_inputs, desc="Preparing tasks", leave=False):
prepared = prepare_task_data(tokenizer, data, task_meta, max_seq_len)
collated, pad_id = collate_prepared(prepared, BASE_BATCH_SIZE)
base_cache[label] = (collated, pad_id)
torch.save({'collated': collated, 'pad_token_id': pad_id},
os.path.join(cache_dir, f"{label}.pt"))
del prepared
print0(f"Saved {len(base_cache)} task caches to {cache_dir}")
return base_cache
# ---- old sequential evaluation (baseline) ----
@torch.no_grad()
def evaluate_example_old(idx, model, tokenizer, data, device, task_meta):
"""Original per-example sequential evaluation (the old code)."""
item = data[idx]
task_type = task_meta['task_type']
num_fewshot = task_meta['num_fewshot']
continuation_delimiter = task_meta['continuation_delimiter']
fewshot_examples = []
if num_fewshot > 0:
rng = random.Random(1234 + idx)
available_indices = [i for i in range(len(data)) if i != idx]
fewshot_indices = rng.sample(available_indices, num_fewshot)
fewshot_examples = [data[i] for i in fewshot_indices]
if task_type == 'multiple_choice':
prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
elif task_type == 'schema':
prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
elif task_type == 'language_modeling':
prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
else:
raise ValueError(f"Unsupported task type: {task_type}")
if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
max_tokens = model.max_seq_len
new_tokens, new_start_idxs, new_end_idxs = [], [], []
for t, s, e in zip(tokens, start_idxs, end_idxs):
if len(t) > max_tokens:
num_to_crop = len(t) - max_tokens
new_tokens.append(t[-max_tokens:])
new_start_idxs.append(s - num_to_crop)
new_end_idxs.append(e - num_to_crop)
else:
new_tokens.append(t)
new_start_idxs.append(s)
new_end_idxs.append(e)
tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
pad_token_id = tokenizer.get_bos_token_id()
input_ids = stack_sequences(tokens, pad_token_id).to(device)
losses, predictions = forward_model(model, input_ids)
return check_result(losses, predictions, input_ids, start_idxs, end_idxs, item.get('gold', None), task_type)
def evaluate_task_old(model, tokenizer, data, device, task_meta, pbar=None):
"""Original sequential evaluate_task."""
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
for idx in range(rank, len(data), world_size):
is_correct = evaluate_example_old(idx, model, tokenizer, data, device, task_meta)
correct[idx] = float(is_correct)
if pbar is not None:
pbar.update(1)
if world_size > 1:
dist.barrier()
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
return correct.mean().item()
# ---- HuggingFace model wrapper ----
class ModelWrapper:
def __init__(self, model, max_seq_len=None):
self.model = model
self.max_seq_len = max_seq_len
def __call__(self, input_ids):
return self.model(input_ids).logits
def load_hf_model(hf_path, device):
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(hf_path)
model.to(device)
model.eval()
max_seq_len = 1024 if "gpt2" in hf_path else None
return ModelWrapper(model, max_seq_len=max_seq_len), HuggingFaceTokenizer.from_pretrained(hf_path)
# ---- benchmark helpers ----
def sync_cuda():
if torch.cuda.is_available():
torch.cuda.synchronize()
def bench_old(model, tokenizer, task_inputs, device):
"""Benchmark old sequential evaluation across all tasks."""
sync_cuda()
t0 = time.time()
results = {}
total = sum(len(data) for _, _, data in task_inputs)
max_label_len = max(len(label) for label, _, _ in task_inputs)
pbar = tqdm(total=total, desc="Sequential", leave=False)
for label, task_meta, data in task_inputs:
pbar.set_description(f"Sequential: {label:<{max_label_len}}")
acc = evaluate_task_old(model, tokenizer, data, device, task_meta, pbar=pbar)
results[label] = acc
pbar.close()
sync_cuda()
return time.time() - t0, results
def bench_new_first(model, tokenizer, task_inputs, device, batch_size, queue_size, pbar=None):
"""Benchmark new batched evaluation (first run, no cache)."""
sync_cuda()
t0 = time.time()
results = {}
collated_cache = {}
max_seq_len = getattr(model, 'max_seq_len', None)
max_label_len = max(len(label) for label, _, _ in task_inputs)
for label, task_meta, data in task_inputs:
if pbar is not None:
pbar.set_description(f"{label:<{max_label_len}}")
prepared = prepare_task_data(tokenizer, data, task_meta, max_seq_len)
acc, collated = evaluate_task(model, data, device, batch_size=batch_size,
queue_size=queue_size, prepared=prepared, pbar=pbar)
results[label] = acc
collated_cache[label] = collated
sync_cuda()
return time.time() - t0, results, collated_cache
def bench_new_cached(model, task_inputs, device, collated_cache, pbar=None,
batched=False, merge=1, split=1, pad_token_id=0):
"""Benchmark new batched evaluation (cached run, forward only).
Uses continuous pipeline across all tasks to eliminate inter-task stalls.
batched=False (default): per-example forwarding, identical to sequential.
batched=True: GPU composition with merge/split for speed experiments."""
import torch.distributed as dist
world_size = dist.get_world_size() if dist.is_initialized() else 1
sync_cuda()
t0 = time.time()
task_collated = [(collated_cache[label], data) for label, _, data in task_inputs]
correct_list = _forward_all_cached(model, task_collated, device, pbar=pbar,
batched=batched, merge=merge, split=split,
pad_token_id=pad_token_id)
sync_cuda()
elapsed = time.time() - t0
results = {}
for (label, _, data), correct in zip(task_inputs, correct_list):
if world_size > 1:
dist.barrier()
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
results[label] = correct.mean().item()
return elapsed, results
def verify_results(old_results, new_results, label="new"):
"""Check that old and new produce the same accuracies."""
mismatches = []
for task in old_results:
if task in new_results and abs(old_results[task] - new_results[task]) > 1e-6:
mismatches.append((task, old_results[task], new_results[task]))
if mismatches:
print0(f" WARNING: {label} mismatches vs old:")
for task, old, new in mismatches:
print0(f" {task}: old={old:.6f} {label}={new:.6f}")
else:
print0(f" {label} results match old (verified)")
# ---- main ----
def main():
parser = argparse.ArgumentParser(description="Benchmark CORE eval pipeline")
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path')
parser.add_argument('--model-tag', type=str, default=None, help='nanochat model tag')
parser.add_argument('--step', type=int, default=None, help='Model step to load')
parser.add_argument('--max-per-task', type=int, default=500, help='Max examples per task')
parser.add_argument('--device-type', type=str, default='', help='cuda|cpu|mps')
parser.add_argument('--batch-sizes', type=str, default='1,2,4,8,16,32,64', help='Comma-separated batch sizes to sweep')
parser.add_argument('--queue-sizes', type=str, default='2,4,8,16', help='Comma-separated queue sizes to sweep')
parser.add_argument('--skip-old', action='store_true', help='Skip old sequential baseline (slow)')
parser.add_argument('--skip-first', action='store_true', help='Skip first-run sweep (requires cached collated data)')
args = parser.parse_args()
batch_sizes = [int(x) for x in args.batch_sizes.split(',')]
queue_sizes = [int(x) for x in args.queue_sizes.split(',')]
device_type = autodetect_device_type() if args.device_type == '' else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# Load model
if args.hf_path is not None:
model, tokenizer = load_hf_model(args.hf_path, device)
model_name = args.hf_path
else:
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.step)
model_name = f"base_model (step {meta['step']})"
print0(f"Model: {model_name}")
print0(f"Max per task: {args.max_per_task}")
print0(f"Device: {device}")
print0("")
# Load tasks
task_inputs = load_tasks(max_per_task=args.max_per_task)
total_examples = sum(len(data) for _, _, data in task_inputs)
print0(f"Loaded {len(task_inputs)} tasks, {total_examples} total examples")
print0("")
# ---- 1. Old sequential baseline ----
old_results = None
if not args.skip_old:
print0("=" * 80)
print0("OLD: Sequential per-example evaluation")
print0("=" * 80)
with autocast_ctx:
old_time, old_results = bench_old(model, tokenizer, task_inputs, device)
print0(f" Time: {old_time:.2f}s ({total_examples / old_time:.1f} examples/s)")
print0("")
# ---- 2. Sweep batch_size x queue_size for first run ----
# Compute tok_hash early — needed for both skip-first check and cache loading
max_seq_len = getattr(model, 'max_seq_len', None)
if args.hf_path is not None:
tok_hash = hashlib.sha256(args.hf_path.encode()).hexdigest()[:16]
else:
base_dir = get_base_dir()
for fname in ("tokenizer.pkl", "tokenizer.json"):
tok_path = os.path.join(base_dir, "tokenizer", fname)
if os.path.exists(tok_path):
tok_hash = file_checksum(tok_path)
break
# Check if we can skip the first-run sweep
best_time = None
best_params = None
if args.skip_first:
cache_dir = os.path.join(get_base_dir(), "core_token_cache", f"{tok_hash}_n{args.max_per_task}")
has_cache = os.path.isdir(cache_dir) and all(
os.path.exists(os.path.join(cache_dir, f"{label}.pt"))
for label, _, _ in task_inputs
)
if has_cache:
print0("Skipping first-run sweep (--skip-first, cache exists)")
print0("")
else:
print0(f"--skip-first: cache missing, running single bs={BASE_BATCH_SIZE} eval to create it...")
pbar = tqdm(total=total_examples, desc="Creating cache", leave=False)
with autocast_ctx:
_, _, collated_cache = bench_new_first(model, tokenizer, task_inputs, device,
batch_size=BASE_BATCH_SIZE, queue_size=2, pbar=pbar)
pbar.close()
pad_id = tokenizer.get_bos_token_id()
os.makedirs(cache_dir, exist_ok=True)
for label in collated_cache:
torch.save({'collated': collated_cache[label], 'pad_token_id': pad_id},
os.path.join(cache_dir, f"{label}.pt"))
print0(f"Cache created ({len(collated_cache)} tasks)")
print0("")
if not args.skip_first:
print0("=" * 80)
print0("NEW: Batched evaluation — hyperparameter sweep (first run)")
print0("=" * 80)
print0("")
# Header
qs_header = "".join(f"{'q=' + str(q):>10}" for q in queue_sizes)
print0(f" {'batch_size':>10}{qs_header}")
print0(f" {'':>10}" + "-" * (10 * len(queue_sizes)))
best_time = float('inf')
best_params = None
best_collated = None
sweep_results = {}
total_combos = len(batch_sizes) * len(queue_sizes)
outer_pbar = tqdm(total=total_combos, desc="First-run sweep", leave=False, position=0)
inner_pbar = tqdm(total=total_examples, desc="", leave=False, position=1)
for bs in batch_sizes:
row = f" {bs:>10}"
for qs in queue_sizes:
outer_pbar.set_description(f"First-run: bs={bs} qs={qs}")
inner_pbar.reset()
with autocast_ctx:
t, results, collated_cache = bench_new_first(model, tokenizer, task_inputs, device, bs, qs, pbar=inner_pbar)
sweep_results[(bs, qs)] = t
row += f"{t:>9.2f}s"
if t < best_time:
best_time = t
best_params = (bs, qs)
best_collated = collated_cache
best_first_results = results
outer_pbar.update(1)
outer_pbar.write(row)
inner_pbar.close()
outer_pbar.close()
print0("")
print0(f" Best: batch_size={best_params[0]}, queue_size={best_params[1]} -> {best_time:.2f}s ({total_examples / best_time:.1f} examples/s)")
# Verify correctness
if old_results is not None:
verify_results(old_results, best_first_results, label="new-first")
print0("")
# ---- 3. Build/load base-4 collated cache ----
base_cache = build_or_load_base_collated(tok_hash, tokenizer, task_inputs, max_seq_len, args.max_per_task)
# ---- 4. Cached run sweep (forward only, composed from base-4) ----
print0("=" * 80)
print0(f"NEW: Cached run (forward only, composed from base-{BASE_BATCH_SIZE})")
print0("=" * 80)
print0("")
best_cached_time = float('inf')
best_cached_params = None
pad_id = next(iter(base_cache.values()))[1]
# Preload ALL base-4 batches to GPU once (~144MB for full CORE eval).
# All batch-size sweeps compose from these GPU-resident tensors — zero CPU→GPU transfers.
gpu_collated = {}
for label, (collated, _) in base_cache.items():
gpu_collated[label] = [(ids.to(device), meta) for ids, meta in collated]
outer_pbar = tqdm(total=len(batch_sizes), desc="Cached sweep", leave=False, position=0)
inner_pbar = tqdm(total=total_examples, desc="", leave=False, position=1)
for bs in batch_sizes:
outer_pbar.set_description(f"Cached: bs={bs}")
inner_pbar.reset()
# All composition happens on GPU: merge for bs >= base, split for bs < base
if bs >= BASE_BATCH_SIZE:
merge, split = bs // BASE_BATCH_SIZE, 1
else:
merge, split = 1, BASE_BATCH_SIZE // bs
with autocast_ctx:
t, cached_results = bench_new_cached(model, task_inputs, device, gpu_collated,
pbar=inner_pbar, batched=True,
merge=merge, split=split, pad_token_id=pad_id)
outer_pbar.write(f" batch_size={bs:>3}: {t:.2f}s ({total_examples / t:.1f} examples/s)")
outer_pbar.update(1)
if t < best_cached_time:
best_cached_time = t
best_cached_params = bs
best_cached_results = cached_results
inner_pbar.close()
outer_pbar.close()
print0("")
print0(f" Best: batch_size={best_cached_params} -> {best_cached_time:.2f}s ({total_examples / best_cached_time:.1f} examples/s)")
# Verify with per-example forwarding (identical to sequential — must match old)
inner_pbar = tqdm(total=total_examples, desc="Verifying", leave=False)
with autocast_ctx:
_, exact_results = bench_new_cached(model, task_inputs, device, gpu_collated, pbar=inner_pbar)
inner_pbar.close()
ref_results = old_results or (best_first_results if best_time is not None else None)
if ref_results is not None:
verify_results(ref_results, exact_results, label="new-cached(per-example)")
print0("")
# ---- Summary ----
print0("=" * 80)
print0("SUMMARY")
print0("=" * 80)
if old_results is not None:
print0(f" Old (sequential): {old_time:>8.2f}s")
if best_time is not None:
print0(f" New (first run): {best_time:>8.2f}s batch_size={best_params[0]}, queue_size={best_params[1]}")
print0(f" New (cached): {best_cached_time:>8.2f}s batch_size={best_cached_params}")
if old_results is not None:
if best_time is not None:
print0(f" Speedup (first): {old_time / best_time:>8.2f}x")
print0(f" Speedup (cached): {old_time / best_cached_time:>8.2f}x")
compute_cleanup()
if __name__ == "__main__":
main()