mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-03 14:15:26 +00:00
Merge 4f79e750e7 into 83dccc20ae
This commit is contained in:
commit
a168cb52fd
|
|
@ -134,14 +134,21 @@ def batch_sequences_lm(tokenizer, prompts):
|
|||
# In LM tasks, we have two prompts: without and with continuation
|
||||
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
|
||||
tokens_without, tokens_with = tokens
|
||||
start_idx, end_idx = len(tokens_without), len(tokens_with)
|
||||
assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
|
||||
assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
|
||||
end_idx = len(tokens_with)
|
||||
# Find longest common prefix — greedy trie tokenizers are not always
|
||||
# prefix-stable, so we can't assume an exact prefix match.
|
||||
start_idx = 0
|
||||
for i in range(min(len(tokens_without), len(tokens_with))):
|
||||
if tokens_without[i] != tokens_with[i]:
|
||||
break
|
||||
start_idx = i + 1
|
||||
assert start_idx < end_idx, "continuation must produce additional tokens"
|
||||
# we only need the with continuation prompt in the LM task, i.e. batch size of 1
|
||||
return [tokens_with], [start_idx], [end_idx]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.compiler.disable
|
||||
def forward_model(model, input_ids):
|
||||
"""
|
||||
Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
|
||||
|
|
@ -164,9 +171,8 @@ def forward_model(model, input_ids):
|
|||
return losses, predictions
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_example(idx, model, tokenizer, data, device, task_meta):
|
||||
"""Evaluate a single example, return True if correct, False otherwise"""
|
||||
def prepare_example(idx, tokenizer, data, task_meta, max_seq_len=None):
|
||||
"""CPU-only: render prompts, tokenize, stack into tensors. Returns a dict."""
|
||||
item = data[idx]
|
||||
task_type = task_meta['task_type']
|
||||
num_fewshot = task_meta['num_fewshot']
|
||||
|
|
@ -193,70 +199,401 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
|
|||
else:
|
||||
raise ValueError(f"Unsupported task type: {task_type}")
|
||||
|
||||
# Some models can't forward sequences beyond a certain length (e.g. GPT-2)
|
||||
# In these cases, we have to truncate sequences to max length and adjust the indices
|
||||
if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
|
||||
max_tokens = model.max_seq_len
|
||||
# Truncate sequences for models with a max length (e.g. GPT-2)
|
||||
if max_seq_len is not None:
|
||||
new_tokens, new_start_idxs, new_end_idxs = [], [], []
|
||||
for t, s, e in zip(tokens, start_idxs, end_idxs):
|
||||
if len(t) > max_tokens:
|
||||
num_to_crop = len(t) - max_tokens
|
||||
new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
|
||||
new_start_idxs.append(s - num_to_crop) # shift the indices down
|
||||
if len(t) > max_seq_len:
|
||||
num_to_crop = len(t) - max_seq_len
|
||||
new_tokens.append(t[-max_seq_len:])
|
||||
new_start_idxs.append(s - num_to_crop)
|
||||
new_end_idxs.append(e - num_to_crop)
|
||||
assert s - num_to_crop >= 0, "this should never happen right?"
|
||||
assert e - num_to_crop >= 0, "this should never happen right?"
|
||||
else:
|
||||
new_tokens.append(t) # keep unchanged
|
||||
new_tokens.append(t)
|
||||
new_start_idxs.append(s)
|
||||
new_end_idxs.append(e)
|
||||
tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
|
||||
|
||||
# Stack up all the sequences into a batch
|
||||
pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
|
||||
input_ids = stack_sequences(tokens, pad_token_id)
|
||||
input_ids = input_ids.to(device)
|
||||
pad_token_id = tokenizer.get_bos_token_id()
|
||||
input_ids = stack_sequences(tokens, pad_token_id) # (num_options, seq_len)
|
||||
|
||||
# Forward the model, get the autoregressive loss and argmax prediction at each token
|
||||
losses, predictions = forward_model(model, input_ids)
|
||||
return {
|
||||
'input_ids': input_ids,
|
||||
'start_idxs': start_idxs,
|
||||
'end_idxs': end_idxs,
|
||||
'gold': item.get('gold', None),
|
||||
'task_type': task_type,
|
||||
'num_options': input_ids.size(0),
|
||||
'seq_len': input_ids.size(1),
|
||||
'pad_token_id': pad_token_id,
|
||||
}
|
||||
|
||||
# See if the losses/predictions come out correctly
|
||||
|
||||
def check_result(losses, predictions, input_ids, start_idxs, end_idxs, gold, task_type):
|
||||
"""Analyze forward pass outputs for one example, return True if correct."""
|
||||
if task_type == 'language_modeling':
|
||||
# language modeling task is currently always batch size 1
|
||||
si = start_idxs[0]
|
||||
ei = end_idxs[0]
|
||||
# predictions[i] predict input_ids[i+1] autoregressively
|
||||
si, ei = start_idxs[0], end_idxs[0]
|
||||
predicted_tokens = predictions[0, si-1:ei-1]
|
||||
actual_tokens = input_ids[0, si:ei]
|
||||
is_correct = torch.all(predicted_tokens == actual_tokens).item()
|
||||
return torch.all(predicted_tokens == actual_tokens).item()
|
||||
elif task_type in ['multiple_choice', 'schema']:
|
||||
# For MC/schema: find the option with lowest average loss
|
||||
mean_losses = [losses[i, si-1:ei-1].mean().item()
|
||||
for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
|
||||
pred_idx = mean_losses.index(min(mean_losses))
|
||||
is_correct = pred_idx == item['gold']
|
||||
for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
|
||||
return mean_losses.index(min(mean_losses)) == gold
|
||||
else:
|
||||
raise ValueError(f"Unsupported task type: {task_type}")
|
||||
|
||||
return is_correct
|
||||
|
||||
def _collate_batches(prepared, batch_size, queue):
|
||||
"""Background thread: collate batches on CPU and push to queue."""
|
||||
for batch_start in range(0, len(prepared), batch_size):
|
||||
batch = prepared[batch_start:batch_start + batch_size]
|
||||
batch_preps = [p for _, p in batch]
|
||||
max_len = max(p['seq_len'] for p in batch_preps)
|
||||
total_rows = sum(p['num_options'] for p in batch_preps)
|
||||
pad_id = batch_preps[0]['pad_token_id']
|
||||
|
||||
combined_ids = torch.full((total_rows, max_len), pad_id, dtype=torch.long)
|
||||
|
||||
batch_meta = []
|
||||
offset = 0
|
||||
for idx, p in batch:
|
||||
n, sl = p['num_options'], p['seq_len']
|
||||
combined_ids[offset:offset+n, :sl] = p['input_ids']
|
||||
batch_meta.append((idx, n, p['start_idxs'], p['end_idxs'], p['gold'], p['task_type']))
|
||||
offset += n
|
||||
|
||||
queue.put((combined_ids, batch_meta))
|
||||
queue.put(None) # sentinel
|
||||
|
||||
|
||||
def evaluate_task(model, tokenizer, data, device, task_meta):
|
||||
def prepare_task_data(tokenizer, data, task_meta, max_seq_len=None):
|
||||
"""CPU-only: prepare and sort all examples for a task. Can run on a background thread."""
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
indices = list(range(rank, len(data), world_size))
|
||||
prepared = [(idx, prepare_example(idx, tokenizer, data, task_meta, max_seq_len)) for idx in indices]
|
||||
prepared.sort(key=lambda x: x[1]['seq_len'])
|
||||
return prepared
|
||||
|
||||
|
||||
def _prefetch_to_device(tensor, device):
|
||||
"""Pin and async-transfer a CPU tensor to GPU, overlapping with current GPU work."""
|
||||
return tensor.pin_memory().to(device, non_blocking=True)
|
||||
|
||||
|
||||
def _forward_batches(model, collated, data, device, pbar=None):
|
||||
"""Run GPU forward passes on pre-collated batches, return per-example correctness tensor.
|
||||
|
||||
Uses double-buffered prefetching on CUDA: while the GPU processes batch N,
|
||||
batch N+1 is pinned and DMA-transferred asynchronously, keeping the GPU fed.
|
||||
"""
|
||||
This function is responsible for evaluating one task across many examples.
|
||||
It also handles dispatch to all processes if the script is run with torchrun.
|
||||
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
|
||||
if not collated:
|
||||
return correct
|
||||
|
||||
use_prefetch = torch.cuda.is_available() and 'cuda' in str(device)
|
||||
|
||||
# Prefetch first batch
|
||||
if use_prefetch:
|
||||
next_ids = _prefetch_to_device(collated[0][0], device)
|
||||
else:
|
||||
next_ids = collated[0][0].to(device)
|
||||
|
||||
for i, (_, batch_meta) in enumerate(collated):
|
||||
combined_ids = next_ids
|
||||
# Start async transfer of next batch while GPU computes on current
|
||||
if i + 1 < len(collated):
|
||||
if use_prefetch:
|
||||
next_ids = _prefetch_to_device(collated[i + 1][0], device)
|
||||
else:
|
||||
next_ids = collated[i + 1][0].to(device)
|
||||
|
||||
losses, predictions = forward_model(model, combined_ids)
|
||||
|
||||
offset = 0
|
||||
for idx, n, start_idxs, end_idxs, gold, task_type in batch_meta:
|
||||
is_correct = check_result(
|
||||
losses[offset:offset+n], predictions[offset:offset+n],
|
||||
combined_ids[offset:offset+n],
|
||||
start_idxs, end_idxs, gold, task_type,
|
||||
)
|
||||
correct[idx] = float(is_correct)
|
||||
offset += n
|
||||
if pbar is not None:
|
||||
pbar.update(len(batch_meta))
|
||||
return correct
|
||||
|
||||
|
||||
def _forward_all_cached(model, task_collated, device, pbar=None, task_labels=None,
|
||||
on_task_done=None, batched=False, merge=1, split=1, pad_token_id=0):
|
||||
"""Run all tasks' cached batches through the model in one pass.
|
||||
|
||||
All batch tensors are moved to device upfront (~144MB for full CORE eval).
|
||||
If tensors are already on device (caller preloaded), .to() is a no-op.
|
||||
|
||||
Default mode (batched=False): forwards each example individually, trimming
|
||||
collation padding to recover the exact per-example tensor shape. This
|
||||
guarantees identical results to sequential per-example evaluation.
|
||||
|
||||
Batched mode (batched=True): forwards collated batches with optional GPU
|
||||
composition. Faster but may produce tiny FP differences vs sequential eval
|
||||
due to different cuBLAS kernel paths for different matrix dimensions.
|
||||
- merge > 1: pad+cat consecutive base batches on GPU before forwarding.
|
||||
- split > 1: slice each group into chunks by example boundaries.
|
||||
|
||||
Args:
|
||||
task_collated: list of (collated_batches, data) per task
|
||||
pbar: optional progress bar, updated per example (or per batch chunk)
|
||||
task_labels: optional list of task names for pbar description updates
|
||||
on_task_done: optional callback(task_idx, correct_tensor) fired when a task completes
|
||||
batched: if True, forward whole batches (faster, approximate). Default False (exact).
|
||||
merge/split/pad_token_id: only used when batched=True
|
||||
Returns:
|
||||
list of correct tensors (one per task, on device)
|
||||
"""
|
||||
# Flatten all batches and move to device upfront (no-op if already there)
|
||||
flat_stream = [] # (gpu_ids, batch_meta, task_idx)
|
||||
correct = []
|
||||
task_batch_counts = []
|
||||
for task_idx, (collated, data) in enumerate(task_collated):
|
||||
correct.append(torch.zeros(len(data), dtype=torch.float32, device=device))
|
||||
task_batch_counts.append(len(collated))
|
||||
for combined_ids, batch_meta in collated:
|
||||
flat_stream.append((combined_ids.to(device), batch_meta, task_idx))
|
||||
|
||||
if not flat_stream:
|
||||
return correct
|
||||
|
||||
task_batches_remaining = list(task_batch_counts)
|
||||
current_task = -1
|
||||
|
||||
if not batched:
|
||||
# Per-example forwarding: identical results to sequential evaluation.
|
||||
# Each example's rows are trimmed to their original seq_len (= max(end_idxs)),
|
||||
# removing collation padding so forward_model sees the same tensor shape as
|
||||
# the sequential path.
|
||||
for combined_ids, batch_meta, task_idx in flat_stream:
|
||||
if task_idx != current_task:
|
||||
current_task = task_idx
|
||||
if pbar is not None and task_labels is not None:
|
||||
pbar.set_description(task_labels[task_idx])
|
||||
|
||||
offset = 0
|
||||
for idx, n, start_idxs, end_idxs, gold, task_type in batch_meta:
|
||||
seq_len = max(end_idxs)
|
||||
example_ids = combined_ids[offset:offset+n, :seq_len]
|
||||
losses, predictions = forward_model(model, example_ids)
|
||||
is_correct = check_result(
|
||||
losses, predictions, example_ids,
|
||||
start_idxs, end_idxs, gold, task_type,
|
||||
)
|
||||
correct[task_idx][idx] = float(is_correct)
|
||||
offset += n
|
||||
|
||||
if pbar is not None:
|
||||
pbar.update(len(batch_meta))
|
||||
if on_task_done is not None:
|
||||
task_batches_remaining[task_idx] -= 1
|
||||
if task_batches_remaining[task_idx] == 0:
|
||||
on_task_done(task_idx, correct[task_idx])
|
||||
else:
|
||||
# Batched forwarding with optional merge/split composition.
|
||||
buffer_ids = []
|
||||
buffer_info = []
|
||||
|
||||
for i, (combined_ids, batch_meta, task_idx) in enumerate(flat_stream):
|
||||
if task_idx != current_task:
|
||||
current_task = task_idx
|
||||
if pbar is not None and task_labels is not None:
|
||||
pbar.set_description(task_labels[task_idx])
|
||||
buffer_ids.append(combined_ids)
|
||||
buffer_info.append((batch_meta, task_idx))
|
||||
|
||||
if len(buffer_ids) < merge and i < len(flat_stream) - 1:
|
||||
continue
|
||||
|
||||
# GPU compose: pad+cat if multiple batches, otherwise use as-is
|
||||
if len(buffer_ids) == 1:
|
||||
mega_ids = buffer_ids[0]
|
||||
else:
|
||||
max_len = max(t.shape[1] for t in buffer_ids)
|
||||
parts = []
|
||||
for t in buffer_ids:
|
||||
if t.shape[1] < max_len:
|
||||
pad = torch.full((t.shape[0], max_len - t.shape[1]), pad_token_id,
|
||||
dtype=t.dtype, device=t.device)
|
||||
t = torch.cat([t, pad], dim=1)
|
||||
parts.append(t)
|
||||
mega_ids = torch.cat(parts, dim=0)
|
||||
|
||||
examples = []
|
||||
row_bounds = [0]
|
||||
for bm, tidx in buffer_info:
|
||||
for idx, n, start_idxs, end_idxs, gold, task_type in bm:
|
||||
examples.append((idx, n, start_idxs, end_idxs, gold, task_type, tidx))
|
||||
row_bounds.append(row_bounds[-1] + n)
|
||||
|
||||
n_ex = len(examples)
|
||||
chunk_size = -(-n_ex // split)
|
||||
|
||||
for cs in range(0, n_ex, chunk_size):
|
||||
ce = min(cs + chunk_size, n_ex)
|
||||
chunk = examples[cs:ce]
|
||||
chunk_ids = mega_ids[row_bounds[cs]:row_bounds[ce]]
|
||||
|
||||
losses, predictions = forward_model(model, chunk_ids)
|
||||
|
||||
offset = 0
|
||||
for idx, n, start_idxs, end_idxs, gold, task_type, tidx in chunk:
|
||||
is_correct = check_result(
|
||||
losses[offset:offset+n], predictions[offset:offset+n],
|
||||
chunk_ids[offset:offset+n],
|
||||
start_idxs, end_idxs, gold, task_type,
|
||||
)
|
||||
correct[tidx][idx] = float(is_correct)
|
||||
offset += n
|
||||
if pbar is not None:
|
||||
pbar.update(len(chunk))
|
||||
|
||||
if on_task_done is not None:
|
||||
for bm, tidx in buffer_info:
|
||||
task_batches_remaining[tidx] -= 1
|
||||
if task_batches_remaining[tidx] == 0:
|
||||
on_task_done(tidx, correct[tidx])
|
||||
|
||||
buffer_ids.clear()
|
||||
buffer_info.clear()
|
||||
|
||||
return correct
|
||||
|
||||
|
||||
def compose_collated(base_collated, target_batch_size, base_batch_size=4, pad_token_id=0):
|
||||
"""Compose base-sized collated batches into target-sized batches.
|
||||
|
||||
Supports both merging (target > base) by concatenating consecutive groups,
|
||||
and splitting (target < base) by slicing along example boundaries.
|
||||
Examples are sorted by seq_len within each base batch, so splitting can
|
||||
trim trailing padding columns for efficiency.
|
||||
"""
|
||||
if target_batch_size == base_batch_size:
|
||||
return base_collated
|
||||
elif target_batch_size > base_batch_size:
|
||||
# Merge consecutive base batches
|
||||
n_merge = target_batch_size // base_batch_size
|
||||
composed = []
|
||||
for i in range(0, len(base_collated), n_merge):
|
||||
group = base_collated[i:i + n_merge]
|
||||
if len(group) == 1:
|
||||
composed.append(group[0])
|
||||
continue
|
||||
max_len = max(ids.shape[1] for ids, _ in group)
|
||||
parts = []
|
||||
merged_meta = []
|
||||
for ids, meta in group:
|
||||
if ids.shape[1] < max_len:
|
||||
pad = torch.full((ids.shape[0], max_len - ids.shape[1]), pad_token_id, dtype=ids.dtype)
|
||||
ids = torch.cat([ids, pad], dim=1)
|
||||
parts.append(ids)
|
||||
merged_meta.extend(meta)
|
||||
composed.append((torch.cat(parts, dim=0), merged_meta))
|
||||
return composed
|
||||
else:
|
||||
# Split base batches into smaller chunks
|
||||
composed = []
|
||||
for combined_ids, batch_meta in base_collated:
|
||||
for chunk_start in range(0, len(batch_meta), target_batch_size):
|
||||
chunk_meta = batch_meta[chunk_start:chunk_start + target_batch_size]
|
||||
row_start = sum(m[1] for m in batch_meta[:chunk_start])
|
||||
row_end = row_start + sum(m[1] for m in chunk_meta)
|
||||
chunk_ids = combined_ids[row_start:row_end]
|
||||
# Trim trailing padding (examples sorted by seq_len, so chunks
|
||||
# near the start of a base batch may need fewer columns)
|
||||
non_pad = (chunk_ids != pad_token_id)
|
||||
if non_pad.any():
|
||||
last_col = non_pad.any(dim=0).nonzero()[-1].item() + 1
|
||||
if last_col < chunk_ids.shape[1]:
|
||||
chunk_ids = chunk_ids[:, :last_col].contiguous()
|
||||
composed.append((chunk_ids, chunk_meta))
|
||||
return composed
|
||||
|
||||
|
||||
def evaluate_task(model, data, device, batch_size=4, queue_size=2, prepared=None,
|
||||
collated=None, tokenizer=None, task_meta=None, pbar=None):
|
||||
"""
|
||||
Evaluate one task across many examples with batched GPU forward passes.
|
||||
Examples are sorted by sequence length so similar-length sequences are batched
|
||||
together, minimizing padding waste and increasing GPU utilization.
|
||||
|
||||
Three modes (checked in order):
|
||||
- collated: skip prepare + collation, go straight to GPU forward passes.
|
||||
- prepared: skip prepare, collation runs on a background thread pipelined with GPU.
|
||||
- neither: full pipeline (prepare + collate + forward).
|
||||
|
||||
Returns (accuracy, collated_batches) so the caller can cache collated batches.
|
||||
"""
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
|
||||
# stride the examples to each rank
|
||||
for idx in range(rank, len(data), world_size):
|
||||
is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
|
||||
correct[idx] = float(is_correct)
|
||||
|
||||
if collated is not None:
|
||||
# Fast path: just GPU forward passes, no threads
|
||||
correct = _forward_batches(model, collated, data, device, pbar=pbar)
|
||||
else:
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
|
||||
if prepared is None:
|
||||
max_seq_len = getattr(model, 'max_seq_len', None)
|
||||
prepared = prepare_task_data(tokenizer, data, task_meta, max_seq_len)
|
||||
|
||||
# Collation thread pipelined with GPU forward passes.
|
||||
# Double-buffered: while GPU processes batch N, batch N+1 is
|
||||
# pin_memory()'d and DMA-transferred asynchronously.
|
||||
queue = Queue(maxsize=queue_size)
|
||||
collator = Thread(target=_collate_batches, args=(prepared, batch_size, queue), daemon=True)
|
||||
collator.start()
|
||||
|
||||
use_prefetch = torch.cuda.is_available() and 'cuda' in str(device)
|
||||
def transfer(tensor):
|
||||
return _prefetch_to_device(tensor, device) if use_prefetch else tensor.to(device)
|
||||
|
||||
collated = []
|
||||
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
|
||||
|
||||
# Prime: get first batch and start its transfer
|
||||
item = queue.get()
|
||||
if item is not None:
|
||||
next_ids = transfer(item[0])
|
||||
|
||||
while item is not None:
|
||||
collated.append(item)
|
||||
combined_ids = next_ids
|
||||
_, batch_meta = item
|
||||
|
||||
# Start async transfer of next batch (overlaps with forward pass below)
|
||||
item = queue.get()
|
||||
if item is not None:
|
||||
next_ids = transfer(item[0])
|
||||
|
||||
losses, predictions = forward_model(model, combined_ids)
|
||||
|
||||
offset = 0
|
||||
for idx, n, start_idxs, end_idxs, gold, task_type in batch_meta:
|
||||
is_correct = check_result(
|
||||
losses[offset:offset+n], predictions[offset:offset+n],
|
||||
combined_ids[offset:offset+n],
|
||||
start_idxs, end_idxs, gold, task_type,
|
||||
)
|
||||
correct[idx] = float(is_correct)
|
||||
offset += n
|
||||
if pbar is not None:
|
||||
pbar.update(len(batch_meta))
|
||||
|
||||
collator.join()
|
||||
del prepared
|
||||
|
||||
# sync results across all the processes if running distributed
|
||||
if world_size > 1:
|
||||
dist.barrier()
|
||||
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
|
||||
# compute the mean
|
||||
mean_correct = correct.mean().item()
|
||||
return mean_correct
|
||||
return correct.mean().item(), collated
|
||||
|
|
|
|||
|
|
@ -26,17 +26,19 @@ import json
|
|||
import yaml
|
||||
import shutil
|
||||
import random
|
||||
import hashlib
|
||||
import zipfile
|
||||
import tempfile
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
|
||||
from nanochat.tokenizer import HuggingFaceTokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.core_eval import evaluate_task
|
||||
from nanochat.core_eval import evaluate_task, prepare_task_data, _forward_all_cached
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
|
|
@ -106,67 +108,211 @@ def place_eval_bundle(file_path):
|
|||
print0(f"Placed eval_bundle directory at {eval_bundle_dir}")
|
||||
|
||||
|
||||
def evaluate_core(model, tokenizer, device, max_per_task=-1):
|
||||
_eval_data_cache = None # (task_inputs, random_baselines, w_label, w_shot, w_type)
|
||||
_batch_cache = {} # {label: collated_batches} — cached after first run
|
||||
_batch_cache_key = None # (max_per_task, max_seq_len) — invalidate if these change
|
||||
_prev_centered = {} # {label: centered_result} — previous run for delta display
|
||||
_prev_core = None # previous core_metric
|
||||
|
||||
|
||||
def _get_disk_cache_dir(max_per_task):
|
||||
"""Get disk cache dir for base-4 collated batches, keyed by tokenizer file hash.
|
||||
Returns None if no local tokenizer file is found (e.g. HuggingFace models)."""
|
||||
base_dir = get_base_dir()
|
||||
for fname in ("tokenizer.pkl", "tokenizer.json"):
|
||||
path = os.path.join(base_dir, "tokenizer", fname)
|
||||
if os.path.exists(path):
|
||||
h = hashlib.sha256()
|
||||
with open(path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
h.update(chunk)
|
||||
tok_hash = h.hexdigest()[:16]
|
||||
return os.path.join(base_dir, "core_token_cache", f"{tok_hash}_n{max_per_task}")
|
||||
return None
|
||||
|
||||
|
||||
def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
||||
"""
|
||||
Evaluate a base model on the CORE benchmark.
|
||||
Returns dict with results, centered_results, and core_metric.
|
||||
- max_per_task: crop the data to this many examples per task for testing (-1 = disable)
|
||||
Collated batches are cached across calls since the tokenizer is fixed.
|
||||
Second+ runs skip prepare and collation entirely — just GPU forward passes.
|
||||
"""
|
||||
base_dir = get_base_dir()
|
||||
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
|
||||
# Download the eval bundle if needed
|
||||
if not os.path.exists(eval_bundle_dir):
|
||||
download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle)
|
||||
global _eval_data_cache, _batch_cache, _batch_cache_key, _prev_centered, _prev_core
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
config_path = os.path.join(eval_bundle_dir, "core.yaml")
|
||||
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
|
||||
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
|
||||
max_seq_len = getattr(model, 'max_seq_len', None)
|
||||
cache_key = (max_per_task, max_seq_len)
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
tasks = config['icl_tasks']
|
||||
# Invalidate batch cache if parameters changed
|
||||
if cache_key != _batch_cache_key:
|
||||
_batch_cache.clear()
|
||||
_batch_cache_key = cache_key
|
||||
|
||||
# Load random baseline values
|
||||
random_baselines = {}
|
||||
with open(eval_meta_data, 'r', encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
task_name = row['Eval Task']
|
||||
random_baseline = row['Random baseline']
|
||||
random_baselines[task_name] = float(random_baseline)
|
||||
# Load and cache task data + baselines (only read from disk once)
|
||||
if _eval_data_cache is None:
|
||||
base_dir = get_base_dir()
|
||||
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
|
||||
if not os.path.exists(eval_bundle_dir):
|
||||
download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle)
|
||||
config_path = os.path.join(eval_bundle_dir, "core.yaml")
|
||||
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
|
||||
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
tasks = config['icl_tasks']
|
||||
|
||||
# Evaluate each task
|
||||
random_baselines = {}
|
||||
with open(eval_meta_data, 'r', encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
random_baselines[row['Eval Task']] = float(row['Random baseline'])
|
||||
|
||||
task_inputs = []
|
||||
for task in tasks:
|
||||
label = task['label']
|
||||
task_meta = {
|
||||
'task_type': task['icl_task_type'],
|
||||
'dataset_uri': task['dataset_uri'],
|
||||
'num_fewshot': task['num_fewshot'][0],
|
||||
'continuation_delimiter': task.get('continuation_delimiter', ' ')
|
||||
}
|
||||
data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
data = [json.loads(line.strip()) for line in f]
|
||||
shuffle_rng = random.Random(1337)
|
||||
shuffle_rng.shuffle(data)
|
||||
if max_per_task > 0:
|
||||
data = data[:max_per_task]
|
||||
task_inputs.append((label, task_meta, data))
|
||||
|
||||
w_label = max(len(t[0]) for t in task_inputs)
|
||||
w_shot = max(len(f"{t[1]['num_fewshot']}-shot") for t in task_inputs)
|
||||
w_type = max(len(t[1]['task_type']) for t in task_inputs)
|
||||
_eval_data_cache = (task_inputs, random_baselines, w_label, w_shot, w_type)
|
||||
|
||||
task_inputs, random_baselines, w_label, w_shot, w_type = _eval_data_cache
|
||||
|
||||
# First run: eagerly prepare next task while evaluating current, cache collated batches.
|
||||
# Cached runs: pass collated batches directly — no threads, no prepare, no collation.
|
||||
results = {}
|
||||
centered_results = {}
|
||||
for task in tasks:
|
||||
start_time = time.time()
|
||||
label = task['label']
|
||||
task_meta = {
|
||||
'task_type': task['icl_task_type'],
|
||||
'dataset_uri': task['dataset_uri'],
|
||||
'num_fewshot': task['num_fewshot'][0],
|
||||
'continuation_delimiter': task.get('continuation_delimiter', ' ')
|
||||
}
|
||||
print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='')
|
||||
cached_run = all(label in _batch_cache for label, _, _ in task_inputs)
|
||||
disk_cache_dir = _get_disk_cache_dir(max_per_task)
|
||||
|
||||
data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
data = [json.loads(line.strip()) for line in f]
|
||||
# Try loading from disk cache if in-memory cache is empty
|
||||
if not cached_run and disk_cache_dir is not None:
|
||||
all_on_disk = os.path.isdir(disk_cache_dir) and all(
|
||||
os.path.exists(os.path.join(disk_cache_dir, f"{label}.pt"))
|
||||
for label, _, _ in task_inputs
|
||||
)
|
||||
if all_on_disk:
|
||||
for label, _, _ in task_inputs:
|
||||
d = torch.load(os.path.join(disk_cache_dir, f"{label}.pt"), weights_only=False)
|
||||
_batch_cache[label] = d['collated']
|
||||
cached_run = True
|
||||
print0(" (loaded collated batches from disk cache)")
|
||||
|
||||
# Shuffle for consistent subsampling when using max_per_task
|
||||
shuffle_rng = random.Random(1337)
|
||||
shuffle_rng.shuffle(data)
|
||||
if max_per_task > 0:
|
||||
data = data[:max_per_task]
|
||||
first_run = not cached_run # track whether we did prepare+collate (for disk save)
|
||||
|
||||
accuracy = evaluate_task(model, tokenizer, data, device, task_meta)
|
||||
results[label] = accuracy
|
||||
import torch.distributed as dist
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
total_examples = sum(len(data) for _, _, data in task_inputs)
|
||||
task_labels = [label for label, _, _ in task_inputs]
|
||||
pbar = tqdm(total=total_examples, leave=False, disable=(rank != 0))
|
||||
|
||||
def _print_task_result(tidx, accuracy):
|
||||
"""Print one task's result. Updates results/centered_results dicts."""
|
||||
label, task_meta, _ = task_inputs[tidx]
|
||||
random_baseline = random_baselines[label]
|
||||
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
|
||||
results[label] = accuracy
|
||||
centered_results[label] = centered_result
|
||||
elapsed = time.time() - start_time
|
||||
print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {elapsed:.2f}s")
|
||||
shot_str = f"{task_meta['num_fewshot']}-shot"
|
||||
prefix = f" {label:<{w_label}} {shot_str:<{w_shot}} {task_meta['task_type']:<{w_type}}"
|
||||
delta_str = ""
|
||||
if label in _prev_centered:
|
||||
d = centered_result - _prev_centered[label]
|
||||
arrow = "\u2191" if d > 0 else "\u2193" if d < 0 else "="
|
||||
delta_str = f" {arrow}{d:+.4f}"
|
||||
if rank == 0:
|
||||
pbar.write(f"{prefix} acc: {accuracy:.4f} centered: {centered_result:>7.4f}{delta_str}")
|
||||
|
||||
def _on_task_done(tidx, correct):
|
||||
"""Callback for _forward_all_cached: convert tensor to accuracy and print."""
|
||||
_print_task_result(tidx, correct.mean().item())
|
||||
|
||||
if cached_run:
|
||||
# Continuous pipeline: all tasks in one GPU stream, results printed per-task as they complete.
|
||||
# Always use base batch size (merge=1/split=1) to guarantee identical results —
|
||||
# different batch dimensions trigger different cuBLAS kernels with different FP rounding.
|
||||
t0 = time.time()
|
||||
task_collated = [(_batch_cache[label], data) for label, _, data in task_inputs]
|
||||
correct_list = _forward_all_cached(
|
||||
model, task_collated, device, pbar=pbar, task_labels=task_labels,
|
||||
on_task_done=_on_task_done if world_size == 1 else None,
|
||||
batched=True,
|
||||
)
|
||||
elapsed_total = time.time() - t0
|
||||
pbar.close()
|
||||
|
||||
# DDP: all_reduce + print (single-GPU already handled by on_task_done above)
|
||||
if world_size > 1:
|
||||
for tidx, ((label, task_meta, data), correct) in enumerate(zip(task_inputs, correct_list)):
|
||||
dist.barrier()
|
||||
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
|
||||
_print_task_result(tidx, correct.mean().item())
|
||||
print0(f" (all tasks: {elapsed_total:.2f}s)")
|
||||
else:
|
||||
t0 = time.time()
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
first_uncached = next(i for i, (l, _, _) in enumerate(task_inputs) if l not in _batch_cache)
|
||||
_, first_meta, first_data = task_inputs[first_uncached]
|
||||
next_future = executor.submit(prepare_task_data, tokenizer, first_data, first_meta, max_seq_len)
|
||||
|
||||
for i, (label, task_meta, data) in enumerate(task_inputs):
|
||||
pbar.set_description(f"{label:<{w_label}}")
|
||||
|
||||
if label in _batch_cache:
|
||||
accuracy, collated = evaluate_task(model, data, device, collated=_batch_cache[label], pbar=pbar)
|
||||
else:
|
||||
prepared = next_future.result()
|
||||
# Kick off prepare for the next uncached task
|
||||
for j in range(i + 1, len(task_inputs)):
|
||||
next_label, next_meta, next_data = task_inputs[j]
|
||||
if next_label not in _batch_cache:
|
||||
next_future = executor.submit(prepare_task_data, tokenizer, next_data, next_meta, max_seq_len)
|
||||
break
|
||||
accuracy, collated = evaluate_task(model, data, device, prepared=prepared, pbar=pbar)
|
||||
_batch_cache[label] = collated
|
||||
|
||||
_print_task_result(i, accuracy)
|
||||
|
||||
elapsed_total = time.time() - t0
|
||||
pbar.close()
|
||||
executor.shutdown(wait=False)
|
||||
print0(f" (all tasks: {elapsed_total:.2f}s)")
|
||||
|
||||
# Save collated batches to disk after first run (so bench/future runs skip prepare+collate)
|
||||
if first_run and disk_cache_dir is not None:
|
||||
pad_id = tokenizer.get_bos_token_id()
|
||||
os.makedirs(disk_cache_dir, exist_ok=True)
|
||||
for label, _, _ in task_inputs:
|
||||
if label in _batch_cache:
|
||||
torch.save({'collated': _batch_cache[label], 'pad_token_id': pad_id},
|
||||
os.path.join(disk_cache_dir, f"{label}.pt"))
|
||||
print0(f" (saved collated batches to {disk_cache_dir})")
|
||||
|
||||
core_metric = sum(centered_results.values()) / len(centered_results)
|
||||
if _prev_core is not None:
|
||||
d = core_metric - _prev_core
|
||||
arrow = "\u2191" if d > 0 else "\u2193" if d < 0 else "="
|
||||
print0(f"CORE: {core_metric:.4f} {arrow}{d:+.4f}")
|
||||
else:
|
||||
print0(f"CORE: {core_metric:.4f}")
|
||||
_prev_centered = dict(centered_results)
|
||||
_prev_core = core_metric
|
||||
out = {
|
||||
"results": results,
|
||||
"centered_results": centered_results,
|
||||
|
|
@ -288,7 +434,7 @@ def main():
|
|||
print0("CORE Evaluation")
|
||||
print0("="*80)
|
||||
with autocast_ctx:
|
||||
core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task)
|
||||
core_results = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
|
||||
|
||||
# Write CSV output
|
||||
if ddp_rank == 0:
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
|||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.flash_attention import HAS_FA3
|
||||
from scripts.base_eval import evaluate_core
|
||||
from scripts.base_eval import evaluate_model
|
||||
print_banner()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -425,7 +425,7 @@ while True:
|
|||
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
|
||||
model.eval()
|
||||
with disable_fp8(orig_model), autocast_ctx:
|
||||
results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
|
||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
|
||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
|
|
|
|||
525
scripts/bench_core_eval.py
Normal file
525
scripts/bench_core_eval.py
Normal file
|
|
@ -0,0 +1,525 @@
|
|||
"""
|
||||
Benchmark the CORE evaluation pipeline.
|
||||
|
||||
Compares three modes:
|
||||
1. Old sequential (per-example) evaluation
|
||||
2. New batched evaluation (first run — includes prepare + collate + forward)
|
||||
3. New batched evaluation (cached run — forward only)
|
||||
|
||||
Also sweeps batch_size and queue_size to find optimal hyperparameters.
|
||||
|
||||
Usage:
|
||||
python -m scripts.bench_core_eval
|
||||
python -m scripts.bench_core_eval --max-per-task 100 # quick test
|
||||
python -m scripts.bench_core_eval --hf-path openai-community/gpt2
|
||||
"""
|
||||
import os
|
||||
import csv
|
||||
import json
|
||||
import time
|
||||
import yaml
|
||||
import random
|
||||
import shutil
|
||||
import hashlib
|
||||
import zipfile
|
||||
import tempfile
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
|
||||
from nanochat.tokenizer import HuggingFaceTokenizer
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.core_eval import (
|
||||
forward_model, prepare_example, check_result, stack_sequences,
|
||||
prepare_task_data, _collate_batches, _forward_all_cached,
|
||||
evaluate_task,
|
||||
render_prompts_mc, render_prompts_schema, render_prompts_lm,
|
||||
batch_sequences_mc, batch_sequences_schema, batch_sequences_lm,
|
||||
)
|
||||
|
||||
# ---- eval bundle loading (reused from base_eval) ----
|
||||
torch.backends.fp32_precision = "tf32"
|
||||
EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip"
|
||||
|
||||
def place_eval_bundle(file_path):
|
||||
base_dir = get_base_dir()
|
||||
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with zipfile.ZipFile(file_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(tmpdir)
|
||||
shutil.move(os.path.join(tmpdir, "eval_bundle"), eval_bundle_dir)
|
||||
print0(f"Placed eval_bundle at {eval_bundle_dir}")
|
||||
|
||||
def load_tasks(max_per_task=-1):
|
||||
base_dir = get_base_dir()
|
||||
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
|
||||
if not os.path.exists(eval_bundle_dir):
|
||||
download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle)
|
||||
config_path = os.path.join(eval_bundle_dir, "core.yaml")
|
||||
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
task_inputs = []
|
||||
for task in config['icl_tasks']:
|
||||
label = task['label']
|
||||
task_meta = {
|
||||
'task_type': task['icl_task_type'],
|
||||
'dataset_uri': task['dataset_uri'],
|
||||
'num_fewshot': task['num_fewshot'][0],
|
||||
'continuation_delimiter': task.get('continuation_delimiter', ' ')
|
||||
}
|
||||
data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
data = [json.loads(line.strip()) for line in f]
|
||||
shuffle_rng = random.Random(1337)
|
||||
shuffle_rng.shuffle(data)
|
||||
if max_per_task > 0:
|
||||
data = data[:max_per_task]
|
||||
task_inputs.append((label, task_meta, data))
|
||||
return task_inputs
|
||||
|
||||
BASE_BATCH_SIZE = 4
|
||||
|
||||
def file_checksum(path):
|
||||
"""SHA-256 checksum of a file, truncated to 16 hex chars."""
|
||||
h = hashlib.sha256()
|
||||
with open(path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()[:16]
|
||||
|
||||
|
||||
def collate_prepared(prepared, batch_size):
|
||||
"""Collate prepared examples into batches (non-threaded). Returns (collated, pad_token_id)."""
|
||||
pad_id = prepared[0][1]['pad_token_id']
|
||||
collated = []
|
||||
for batch_start in range(0, len(prepared), batch_size):
|
||||
batch = prepared[batch_start:batch_start + batch_size]
|
||||
batch_preps = [p for _, p in batch]
|
||||
max_len = max(p['seq_len'] for p in batch_preps)
|
||||
total_rows = sum(p['num_options'] for p in batch_preps)
|
||||
combined_ids = torch.full((total_rows, max_len), pad_id, dtype=torch.long)
|
||||
batch_meta = []
|
||||
offset = 0
|
||||
for idx, p in batch:
|
||||
n, sl = p['num_options'], p['seq_len']
|
||||
combined_ids[offset:offset+n, :sl] = p['input_ids']
|
||||
batch_meta.append((idx, n, p['start_idxs'], p['end_idxs'], p['gold'], p['task_type']))
|
||||
offset += n
|
||||
collated.append((combined_ids, batch_meta))
|
||||
return collated, pad_id
|
||||
|
||||
|
||||
def build_or_load_base_collated(tok_hash, tokenizer, task_inputs, max_seq_len, max_per_task):
|
||||
"""Build or load base-4 collated data for all tasks, with disk caching."""
|
||||
base_dir = get_base_dir()
|
||||
cache_dir = os.path.join(base_dir, "core_token_cache", f"{tok_hash}_n{max_per_task}")
|
||||
|
||||
all_cached = os.path.isdir(cache_dir) and all(
|
||||
os.path.exists(os.path.join(cache_dir, f"{label}.pt"))
|
||||
for label, _, _ in task_inputs
|
||||
)
|
||||
|
||||
base_cache = {} # label -> (collated, pad_token_id)
|
||||
if all_cached:
|
||||
print0(f"Loading base-{BASE_BATCH_SIZE} collated cache from {cache_dir}")
|
||||
for label, _, _ in tqdm(task_inputs, desc="Loading cache", leave=False):
|
||||
data = torch.load(os.path.join(cache_dir, f"{label}.pt"), weights_only=False)
|
||||
base_cache[label] = (data['collated'], data['pad_token_id'])
|
||||
else:
|
||||
print0(f"Building base-{BASE_BATCH_SIZE} collated cache (saving to {cache_dir})")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
for label, task_meta, data in tqdm(task_inputs, desc="Preparing tasks", leave=False):
|
||||
prepared = prepare_task_data(tokenizer, data, task_meta, max_seq_len)
|
||||
collated, pad_id = collate_prepared(prepared, BASE_BATCH_SIZE)
|
||||
base_cache[label] = (collated, pad_id)
|
||||
torch.save({'collated': collated, 'pad_token_id': pad_id},
|
||||
os.path.join(cache_dir, f"{label}.pt"))
|
||||
del prepared
|
||||
print0(f"Saved {len(base_cache)} task caches to {cache_dir}")
|
||||
|
||||
return base_cache
|
||||
|
||||
|
||||
# ---- old sequential evaluation (baseline) ----
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_example_old(idx, model, tokenizer, data, device, task_meta):
|
||||
"""Original per-example sequential evaluation (the old code)."""
|
||||
item = data[idx]
|
||||
task_type = task_meta['task_type']
|
||||
num_fewshot = task_meta['num_fewshot']
|
||||
continuation_delimiter = task_meta['continuation_delimiter']
|
||||
|
||||
fewshot_examples = []
|
||||
if num_fewshot > 0:
|
||||
rng = random.Random(1234 + idx)
|
||||
available_indices = [i for i in range(len(data)) if i != idx]
|
||||
fewshot_indices = rng.sample(available_indices, num_fewshot)
|
||||
fewshot_examples = [data[i] for i in fewshot_indices]
|
||||
|
||||
if task_type == 'multiple_choice':
|
||||
prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
|
||||
tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
|
||||
elif task_type == 'schema':
|
||||
prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
|
||||
tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
|
||||
elif task_type == 'language_modeling':
|
||||
prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
|
||||
tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
|
||||
else:
|
||||
raise ValueError(f"Unsupported task type: {task_type}")
|
||||
|
||||
if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
|
||||
max_tokens = model.max_seq_len
|
||||
new_tokens, new_start_idxs, new_end_idxs = [], [], []
|
||||
for t, s, e in zip(tokens, start_idxs, end_idxs):
|
||||
if len(t) > max_tokens:
|
||||
num_to_crop = len(t) - max_tokens
|
||||
new_tokens.append(t[-max_tokens:])
|
||||
new_start_idxs.append(s - num_to_crop)
|
||||
new_end_idxs.append(e - num_to_crop)
|
||||
else:
|
||||
new_tokens.append(t)
|
||||
new_start_idxs.append(s)
|
||||
new_end_idxs.append(e)
|
||||
tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
|
||||
|
||||
pad_token_id = tokenizer.get_bos_token_id()
|
||||
input_ids = stack_sequences(tokens, pad_token_id).to(device)
|
||||
losses, predictions = forward_model(model, input_ids)
|
||||
return check_result(losses, predictions, input_ids, start_idxs, end_idxs, item.get('gold', None), task_type)
|
||||
|
||||
|
||||
def evaluate_task_old(model, tokenizer, data, device, task_meta, pbar=None):
|
||||
"""Original sequential evaluate_task."""
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
|
||||
for idx in range(rank, len(data), world_size):
|
||||
is_correct = evaluate_example_old(idx, model, tokenizer, data, device, task_meta)
|
||||
correct[idx] = float(is_correct)
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
if world_size > 1:
|
||||
dist.barrier()
|
||||
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
|
||||
return correct.mean().item()
|
||||
|
||||
# ---- HuggingFace model wrapper ----
|
||||
|
||||
class ModelWrapper:
|
||||
def __init__(self, model, max_seq_len=None):
|
||||
self.model = model
|
||||
self.max_seq_len = max_seq_len
|
||||
def __call__(self, input_ids):
|
||||
return self.model(input_ids).logits
|
||||
|
||||
def load_hf_model(hf_path, device):
|
||||
from transformers import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained(hf_path)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
max_seq_len = 1024 if "gpt2" in hf_path else None
|
||||
return ModelWrapper(model, max_seq_len=max_seq_len), HuggingFaceTokenizer.from_pretrained(hf_path)
|
||||
|
||||
# ---- benchmark helpers ----
|
||||
|
||||
def sync_cuda():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def bench_old(model, tokenizer, task_inputs, device):
|
||||
"""Benchmark old sequential evaluation across all tasks."""
|
||||
sync_cuda()
|
||||
t0 = time.time()
|
||||
results = {}
|
||||
total = sum(len(data) for _, _, data in task_inputs)
|
||||
max_label_len = max(len(label) for label, _, _ in task_inputs)
|
||||
pbar = tqdm(total=total, desc="Sequential", leave=False)
|
||||
for label, task_meta, data in task_inputs:
|
||||
pbar.set_description(f"Sequential: {label:<{max_label_len}}")
|
||||
acc = evaluate_task_old(model, tokenizer, data, device, task_meta, pbar=pbar)
|
||||
results[label] = acc
|
||||
pbar.close()
|
||||
sync_cuda()
|
||||
return time.time() - t0, results
|
||||
|
||||
|
||||
def bench_new_first(model, tokenizer, task_inputs, device, batch_size, queue_size, pbar=None):
|
||||
"""Benchmark new batched evaluation (first run, no cache)."""
|
||||
sync_cuda()
|
||||
t0 = time.time()
|
||||
results = {}
|
||||
collated_cache = {}
|
||||
max_seq_len = getattr(model, 'max_seq_len', None)
|
||||
max_label_len = max(len(label) for label, _, _ in task_inputs)
|
||||
for label, task_meta, data in task_inputs:
|
||||
if pbar is not None:
|
||||
pbar.set_description(f"{label:<{max_label_len}}")
|
||||
prepared = prepare_task_data(tokenizer, data, task_meta, max_seq_len)
|
||||
acc, collated = evaluate_task(model, data, device, batch_size=batch_size,
|
||||
queue_size=queue_size, prepared=prepared, pbar=pbar)
|
||||
results[label] = acc
|
||||
collated_cache[label] = collated
|
||||
sync_cuda()
|
||||
return time.time() - t0, results, collated_cache
|
||||
|
||||
|
||||
def bench_new_cached(model, task_inputs, device, collated_cache, pbar=None,
|
||||
batched=False, merge=1, split=1, pad_token_id=0):
|
||||
"""Benchmark new batched evaluation (cached run, forward only).
|
||||
Uses continuous pipeline across all tasks to eliminate inter-task stalls.
|
||||
batched=False (default): per-example forwarding, identical to sequential.
|
||||
batched=True: GPU composition with merge/split for speed experiments."""
|
||||
import torch.distributed as dist
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
sync_cuda()
|
||||
t0 = time.time()
|
||||
task_collated = [(collated_cache[label], data) for label, _, data in task_inputs]
|
||||
correct_list = _forward_all_cached(model, task_collated, device, pbar=pbar,
|
||||
batched=batched, merge=merge, split=split,
|
||||
pad_token_id=pad_token_id)
|
||||
sync_cuda()
|
||||
elapsed = time.time() - t0
|
||||
results = {}
|
||||
for (label, _, data), correct in zip(task_inputs, correct_list):
|
||||
if world_size > 1:
|
||||
dist.barrier()
|
||||
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
|
||||
results[label] = correct.mean().item()
|
||||
return elapsed, results
|
||||
|
||||
|
||||
def verify_results(old_results, new_results, label="new"):
|
||||
"""Check that old and new produce the same accuracies."""
|
||||
mismatches = []
|
||||
for task in old_results:
|
||||
if task in new_results and abs(old_results[task] - new_results[task]) > 1e-6:
|
||||
mismatches.append((task, old_results[task], new_results[task]))
|
||||
if mismatches:
|
||||
print0(f" WARNING: {label} mismatches vs old:")
|
||||
for task, old, new in mismatches:
|
||||
print0(f" {task}: old={old:.6f} {label}={new:.6f}")
|
||||
else:
|
||||
print0(f" {label} results match old (verified)")
|
||||
|
||||
|
||||
# ---- main ----
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Benchmark CORE eval pipeline")
|
||||
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path')
|
||||
parser.add_argument('--model-tag', type=str, default=None, help='nanochat model tag')
|
||||
parser.add_argument('--step', type=int, default=None, help='Model step to load')
|
||||
parser.add_argument('--max-per-task', type=int, default=500, help='Max examples per task')
|
||||
parser.add_argument('--device-type', type=str, default='', help='cuda|cpu|mps')
|
||||
parser.add_argument('--batch-sizes', type=str, default='1,2,4,8,16,32,64', help='Comma-separated batch sizes to sweep')
|
||||
parser.add_argument('--queue-sizes', type=str, default='2,4,8,16', help='Comma-separated queue sizes to sweep')
|
||||
parser.add_argument('--skip-old', action='store_true', help='Skip old sequential baseline (slow)')
|
||||
parser.add_argument('--skip-first', action='store_true', help='Skip first-run sweep (requires cached collated data)')
|
||||
args = parser.parse_args()
|
||||
|
||||
batch_sizes = [int(x) for x in args.batch_sizes.split(',')]
|
||||
queue_sizes = [int(x) for x in args.queue_sizes.split(',')]
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == '' else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# Load model
|
||||
if args.hf_path is not None:
|
||||
model, tokenizer = load_hf_model(args.hf_path, device)
|
||||
model_name = args.hf_path
|
||||
else:
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
model_name = f"base_model (step {meta['step']})"
|
||||
|
||||
print0(f"Model: {model_name}")
|
||||
print0(f"Max per task: {args.max_per_task}")
|
||||
print0(f"Device: {device}")
|
||||
print0("")
|
||||
|
||||
# Load tasks
|
||||
task_inputs = load_tasks(max_per_task=args.max_per_task)
|
||||
total_examples = sum(len(data) for _, _, data in task_inputs)
|
||||
print0(f"Loaded {len(task_inputs)} tasks, {total_examples} total examples")
|
||||
print0("")
|
||||
|
||||
# ---- 1. Old sequential baseline ----
|
||||
old_results = None
|
||||
if not args.skip_old:
|
||||
print0("=" * 80)
|
||||
print0("OLD: Sequential per-example evaluation")
|
||||
print0("=" * 80)
|
||||
with autocast_ctx:
|
||||
old_time, old_results = bench_old(model, tokenizer, task_inputs, device)
|
||||
print0(f" Time: {old_time:.2f}s ({total_examples / old_time:.1f} examples/s)")
|
||||
print0("")
|
||||
|
||||
# ---- 2. Sweep batch_size x queue_size for first run ----
|
||||
# Compute tok_hash early — needed for both skip-first check and cache loading
|
||||
max_seq_len = getattr(model, 'max_seq_len', None)
|
||||
if args.hf_path is not None:
|
||||
tok_hash = hashlib.sha256(args.hf_path.encode()).hexdigest()[:16]
|
||||
else:
|
||||
base_dir = get_base_dir()
|
||||
for fname in ("tokenizer.pkl", "tokenizer.json"):
|
||||
tok_path = os.path.join(base_dir, "tokenizer", fname)
|
||||
if os.path.exists(tok_path):
|
||||
tok_hash = file_checksum(tok_path)
|
||||
break
|
||||
|
||||
# Check if we can skip the first-run sweep
|
||||
best_time = None
|
||||
best_params = None
|
||||
if args.skip_first:
|
||||
cache_dir = os.path.join(get_base_dir(), "core_token_cache", f"{tok_hash}_n{args.max_per_task}")
|
||||
has_cache = os.path.isdir(cache_dir) and all(
|
||||
os.path.exists(os.path.join(cache_dir, f"{label}.pt"))
|
||||
for label, _, _ in task_inputs
|
||||
)
|
||||
if has_cache:
|
||||
print0("Skipping first-run sweep (--skip-first, cache exists)")
|
||||
print0("")
|
||||
else:
|
||||
print0(f"--skip-first: cache missing, running single bs={BASE_BATCH_SIZE} eval to create it...")
|
||||
pbar = tqdm(total=total_examples, desc="Creating cache", leave=False)
|
||||
with autocast_ctx:
|
||||
_, _, collated_cache = bench_new_first(model, tokenizer, task_inputs, device,
|
||||
batch_size=BASE_BATCH_SIZE, queue_size=2, pbar=pbar)
|
||||
pbar.close()
|
||||
pad_id = tokenizer.get_bos_token_id()
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
for label in collated_cache:
|
||||
torch.save({'collated': collated_cache[label], 'pad_token_id': pad_id},
|
||||
os.path.join(cache_dir, f"{label}.pt"))
|
||||
print0(f"Cache created ({len(collated_cache)} tasks)")
|
||||
print0("")
|
||||
|
||||
if not args.skip_first:
|
||||
print0("=" * 80)
|
||||
print0("NEW: Batched evaluation — hyperparameter sweep (first run)")
|
||||
print0("=" * 80)
|
||||
print0("")
|
||||
|
||||
# Header
|
||||
qs_header = "".join(f"{'q=' + str(q):>10}" for q in queue_sizes)
|
||||
print0(f" {'batch_size':>10}{qs_header}")
|
||||
print0(f" {'':>10}" + "-" * (10 * len(queue_sizes)))
|
||||
|
||||
best_time = float('inf')
|
||||
best_params = None
|
||||
best_collated = None
|
||||
sweep_results = {}
|
||||
total_combos = len(batch_sizes) * len(queue_sizes)
|
||||
outer_pbar = tqdm(total=total_combos, desc="First-run sweep", leave=False, position=0)
|
||||
inner_pbar = tqdm(total=total_examples, desc="", leave=False, position=1)
|
||||
|
||||
for bs in batch_sizes:
|
||||
row = f" {bs:>10}"
|
||||
for qs in queue_sizes:
|
||||
outer_pbar.set_description(f"First-run: bs={bs} qs={qs}")
|
||||
inner_pbar.reset()
|
||||
with autocast_ctx:
|
||||
t, results, collated_cache = bench_new_first(model, tokenizer, task_inputs, device, bs, qs, pbar=inner_pbar)
|
||||
sweep_results[(bs, qs)] = t
|
||||
row += f"{t:>9.2f}s"
|
||||
if t < best_time:
|
||||
best_time = t
|
||||
best_params = (bs, qs)
|
||||
best_collated = collated_cache
|
||||
best_first_results = results
|
||||
outer_pbar.update(1)
|
||||
outer_pbar.write(row)
|
||||
inner_pbar.close()
|
||||
outer_pbar.close()
|
||||
|
||||
print0("")
|
||||
print0(f" Best: batch_size={best_params[0]}, queue_size={best_params[1]} -> {best_time:.2f}s ({total_examples / best_time:.1f} examples/s)")
|
||||
|
||||
# Verify correctness
|
||||
if old_results is not None:
|
||||
verify_results(old_results, best_first_results, label="new-first")
|
||||
print0("")
|
||||
|
||||
# ---- 3. Build/load base-4 collated cache ----
|
||||
base_cache = build_or_load_base_collated(tok_hash, tokenizer, task_inputs, max_seq_len, args.max_per_task)
|
||||
|
||||
# ---- 4. Cached run sweep (forward only, composed from base-4) ----
|
||||
print0("=" * 80)
|
||||
print0(f"NEW: Cached run (forward only, composed from base-{BASE_BATCH_SIZE})")
|
||||
print0("=" * 80)
|
||||
print0("")
|
||||
|
||||
best_cached_time = float('inf')
|
||||
best_cached_params = None
|
||||
pad_id = next(iter(base_cache.values()))[1]
|
||||
# Preload ALL base-4 batches to GPU once (~144MB for full CORE eval).
|
||||
# All batch-size sweeps compose from these GPU-resident tensors — zero CPU→GPU transfers.
|
||||
gpu_collated = {}
|
||||
for label, (collated, _) in base_cache.items():
|
||||
gpu_collated[label] = [(ids.to(device), meta) for ids, meta in collated]
|
||||
outer_pbar = tqdm(total=len(batch_sizes), desc="Cached sweep", leave=False, position=0)
|
||||
inner_pbar = tqdm(total=total_examples, desc="", leave=False, position=1)
|
||||
|
||||
for bs in batch_sizes:
|
||||
outer_pbar.set_description(f"Cached: bs={bs}")
|
||||
inner_pbar.reset()
|
||||
|
||||
# All composition happens on GPU: merge for bs >= base, split for bs < base
|
||||
if bs >= BASE_BATCH_SIZE:
|
||||
merge, split = bs // BASE_BATCH_SIZE, 1
|
||||
else:
|
||||
merge, split = 1, BASE_BATCH_SIZE // bs
|
||||
|
||||
with autocast_ctx:
|
||||
t, cached_results = bench_new_cached(model, task_inputs, device, gpu_collated,
|
||||
pbar=inner_pbar, batched=True,
|
||||
merge=merge, split=split, pad_token_id=pad_id)
|
||||
|
||||
outer_pbar.write(f" batch_size={bs:>3}: {t:.2f}s ({total_examples / t:.1f} examples/s)")
|
||||
outer_pbar.update(1)
|
||||
|
||||
if t < best_cached_time:
|
||||
best_cached_time = t
|
||||
best_cached_params = bs
|
||||
best_cached_results = cached_results
|
||||
inner_pbar.close()
|
||||
outer_pbar.close()
|
||||
|
||||
print0("")
|
||||
print0(f" Best: batch_size={best_cached_params} -> {best_cached_time:.2f}s ({total_examples / best_cached_time:.1f} examples/s)")
|
||||
|
||||
# Verify with per-example forwarding (identical to sequential — must match old)
|
||||
inner_pbar = tqdm(total=total_examples, desc="Verifying", leave=False)
|
||||
with autocast_ctx:
|
||||
_, exact_results = bench_new_cached(model, task_inputs, device, gpu_collated, pbar=inner_pbar)
|
||||
inner_pbar.close()
|
||||
ref_results = old_results or (best_first_results if best_time is not None else None)
|
||||
if ref_results is not None:
|
||||
verify_results(ref_results, exact_results, label="new-cached(per-example)")
|
||||
print0("")
|
||||
|
||||
# ---- Summary ----
|
||||
print0("=" * 80)
|
||||
print0("SUMMARY")
|
||||
print0("=" * 80)
|
||||
if old_results is not None:
|
||||
print0(f" Old (sequential): {old_time:>8.2f}s")
|
||||
if best_time is not None:
|
||||
print0(f" New (first run): {best_time:>8.2f}s batch_size={best_params[0]}, queue_size={best_params[1]}")
|
||||
print0(f" New (cached): {best_cached_time:>8.2f}s batch_size={best_cached_params}")
|
||||
if old_results is not None:
|
||||
if best_time is not None:
|
||||
print0(f" Speedup (first): {old_time / best_time:>8.2f}x")
|
||||
print0(f" Speedup (cached): {old_time / best_cached_time:>8.2f}x")
|
||||
|
||||
compute_cleanup()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user