CORE eval: GPU-resident data, continuous pipeline, per-task progress bars

three independent improvements to the cached CORE evaluation path:

   1. GPU-resident data: all base-4 collated batches (~144MB for full CORE eval)
      are moved to GPU upfront via .to(device). eliminates all CPU→GPU transfers
      from the forward loop. _forward_all_cached replaces double-buffered prefetch
      with a simple upfront bulk transfer — .to() is a no-op when the caller has
      already preloaded tensors to GPU (as bench_core_eval now does).

   2. continuous cross-task pipeline: _forward_all_cached flattens all tasks'
      batches into one stream. the last batch of task N flows directly into the
      first batch of task N+1 with no pipeline restart. GPU-side composition via
      merge (pad+cat for bs > base) and split (row-slice for bs < base) avoids
      the CPU-side compose_collated bottleneck that made bs=8 slower than bs=4.

   3. progress bars + per-task result printing: both cached and first-run paths
      in evaluate_model now show a tqdm progress bar with the current task label.
      on_task_done callback in _forward_all_cached prints each task's accuracy
      as soon as its last batch is processed (single-GPU). DDP falls back to
      printing after all_reduce. both paths print total elapsed time at the end.

   bench_core_eval: preloads ALL base-4 batches to GPU once before the batch-size
   sweep. all sweep iterations compose from GPU-resident tensors with zero
   CPU→GPU transfers in the hot loop.
This commit is contained in:
Unsal Gokdag 2026-02-13 07:54:53 +00:00
parent 7fa30f5ee3
commit c3f234cfca
3 changed files with 220 additions and 54 deletions

View File

@ -325,6 +325,117 @@ def _forward_batches(model, collated, data, device, pbar=None):
return correct
def _forward_all_cached(model, task_collated, device, pbar=None, task_labels=None,
on_task_done=None, merge=1, split=1, pad_token_id=0):
"""Run all tasks' cached batches through the model in one pass.
All batch tensors are moved to device upfront (~144MB for full CORE eval).
If tensors are already on device (caller preloaded), .to() is a no-op.
Composition (merge/split) happens entirely on device:
- merge > 1: pad+cat consecutive base batches on GPU before forwarding.
- split > 1: slice each group into chunks by example boundaries,
forward each chunk separately.
Args:
task_collated: list of (collated_batches, data) per task
pbar: optional progress bar, updated per forward pass (by number of examples)
task_labels: optional list of task names for pbar description updates
on_task_done: optional callback(task_idx, correct_tensor) fired when a task completes
merge: number of consecutive base batches to compose per group (>= 1)
split: number of forward passes to split each group into (>= 1)
pad_token_id: token id used for padding when merging batches of different lengths
Returns:
list of correct tensors (one per task, on device)
"""
# Flatten all batches and move to device upfront (no-op if already there)
flat_stream = [] # (gpu_ids, batch_meta, task_idx)
correct = []
task_batch_counts = []
for task_idx, (collated, data) in enumerate(task_collated):
correct.append(torch.zeros(len(data), dtype=torch.float32, device=device))
task_batch_counts.append(len(collated))
for combined_ids, batch_meta in collated:
flat_stream.append((combined_ids.to(device), batch_meta, task_idx))
if not flat_stream:
return correct
task_batches_remaining = list(task_batch_counts)
current_task = -1
buffer_ids = []
buffer_info = []
for i, (combined_ids, batch_meta, task_idx) in enumerate(flat_stream):
# Update pbar description on task transition
if task_idx != current_task:
current_task = task_idx
if pbar is not None and task_labels is not None:
pbar.set_description(task_labels[task_idx])
buffer_ids.append(combined_ids)
buffer_info.append((batch_meta, task_idx))
# Accumulate until we have `merge` batches (or hit the end)
if len(buffer_ids) < merge and i < len(flat_stream) - 1:
continue
# GPU compose: pad+cat if multiple batches, otherwise use as-is
if len(buffer_ids) == 1:
mega_ids = buffer_ids[0]
else:
max_len = max(t.shape[1] for t in buffer_ids)
parts = []
for t in buffer_ids:
if t.shape[1] < max_len:
pad = torch.full((t.shape[0], max_len - t.shape[1]), pad_token_id,
dtype=t.dtype, device=t.device)
t = torch.cat([t, pad], dim=1)
parts.append(t)
mega_ids = torch.cat(parts, dim=0)
# Flatten examples with row boundaries (for splitting)
examples = []
row_bounds = [0]
for bm, tidx in buffer_info:
for idx, n, start_idxs, end_idxs, gold, task_type in bm:
examples.append((idx, n, start_idxs, end_idxs, gold, task_type, tidx))
row_bounds.append(row_bounds[-1] + n)
# Forward + score (with optional GPU split)
n_ex = len(examples)
chunk_size = -(-n_ex // split) # ceiling division
for cs in range(0, n_ex, chunk_size):
ce = min(cs + chunk_size, n_ex)
chunk = examples[cs:ce]
chunk_ids = mega_ids[row_bounds[cs]:row_bounds[ce]]
losses, predictions = forward_model(model, chunk_ids)
offset = 0
for idx, n, start_idxs, end_idxs, gold, task_type, tidx in chunk:
is_correct = check_result(
losses[offset:offset+n], predictions[offset:offset+n],
chunk_ids[offset:offset+n],
start_idxs, end_idxs, gold, task_type,
)
correct[tidx][idx] = float(is_correct)
offset += n
if pbar is not None:
pbar.update(len(chunk))
# Fire callback for any tasks that just completed all their batches
if on_task_done is not None:
for bm, tidx in buffer_info:
task_batches_remaining[tidx] -= 1
if task_batches_remaining[tidx] == 0:
on_task_done(tidx, correct[tidx])
buffer_ids.clear()
buffer_info.clear()
return correct
def compose_collated(base_collated, target_batch_size, base_batch_size=4, pad_token_id=0):
"""Compose base-sized collated batches into target-sized batches.

View File

@ -31,13 +31,14 @@ import zipfile
import tempfile
import argparse
from contextlib import nullcontext
from tqdm import tqdm
import torch
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
from nanochat.tokenizer import HuggingFaceTokenizer, get_token_bytes
from nanochat.checkpoint_manager import load_model
from nanochat.core_eval import evaluate_task, prepare_task_data
from nanochat.core_eval import evaluate_task, prepare_task_data, _forward_all_cached
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
@ -214,55 +215,91 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
first_run = not cached_run # track whether we did prepare+collate (for disk save)
if not cached_run:
executor = ThreadPoolExecutor(max_workers=1)
first_uncached = next(i for i, (l, _, _) in enumerate(task_inputs) if l not in _batch_cache)
_, first_meta, first_data = task_inputs[first_uncached]
next_future = executor.submit(prepare_task_data, tokenizer, first_data, first_meta, max_seq_len)
import torch.distributed as dist
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
total_examples = sum(len(data) for _, _, data in task_inputs)
task_labels = [label for label, _, _ in task_inputs]
pbar = tqdm(total=total_examples, leave=False, disable=(rank != 0))
for i, (label, task_meta, data) in enumerate(task_inputs):
shot_str = f"{task_meta['num_fewshot']}-shot"
prefix = f" {label:<{w_label}} {shot_str:<{w_shot}} {task_meta['task_type']:<{w_type}}"
print0(f"{prefix} ...", end="", flush=True)
t0 = time.time()
if label in _batch_cache:
accuracy, collated = evaluate_task(model, data, device, collated=_batch_cache[label])
else:
prepared = next_future.result()
# Kick off prepare for the next uncached task
for j in range(i + 1, len(task_inputs)):
next_label, next_meta, next_data = task_inputs[j]
if next_label not in _batch_cache:
next_future = executor.submit(prepare_task_data, tokenizer, next_data, next_meta, max_seq_len)
break
accuracy, collated = evaluate_task(model, data, device, prepared=prepared)
_batch_cache[label] = collated
elapsed = time.time() - t0
def _print_task_result(tidx, accuracy):
"""Print one task's result. Updates results/centered_results dicts."""
label, task_meta, _ = task_inputs[tidx]
random_baseline = random_baselines[label]
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
results[label] = accuracy
centered_results[label] = centered_result
shot_str = f"{task_meta['num_fewshot']}-shot"
prefix = f" {label:<{w_label}} {shot_str:<{w_shot}} {task_meta['task_type']:<{w_type}}"
delta_str = ""
if label in _prev_centered:
d = centered_result - _prev_centered[label]
arrow = "\u2191" if d > 0 else "\u2193" if d < 0 else "="
delta_str = f" {arrow}{d:+.4f}"
print0(f"\r{prefix} acc: {accuracy:.4f} centered: {centered_result:>7.4f}{delta_str} time: {elapsed:.2f}s")
if rank == 0:
pbar.write(f"{prefix} acc: {accuracy:.4f} centered: {centered_result:>7.4f}{delta_str}")
if not cached_run:
executor.shutdown(wait=False)
def _on_task_done(tidx, correct):
"""Callback for _forward_all_cached: convert tensor to accuracy and print."""
_print_task_result(tidx, correct.mean().item())
if cached_run:
# Continuous pipeline: all tasks in one GPU stream, results printed per-task as they complete
t0 = time.time()
task_collated = [(_batch_cache[label], data) for label, _, data in task_inputs]
correct_list = _forward_all_cached(
model, task_collated, device, pbar=pbar, task_labels=task_labels,
on_task_done=_on_task_done if world_size == 1 else None,
)
elapsed_total = time.time() - t0
pbar.close()
# DDP: all_reduce + print (single-GPU already handled by on_task_done above)
if world_size > 1:
for tidx, ((label, task_meta, data), correct) in enumerate(zip(task_inputs, correct_list)):
dist.barrier()
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
_print_task_result(tidx, correct.mean().item())
print0(f" (all tasks: {elapsed_total:.2f}s)")
else:
t0 = time.time()
executor = ThreadPoolExecutor(max_workers=1)
first_uncached = next(i for i, (l, _, _) in enumerate(task_inputs) if l not in _batch_cache)
_, first_meta, first_data = task_inputs[first_uncached]
next_future = executor.submit(prepare_task_data, tokenizer, first_data, first_meta, max_seq_len)
for i, (label, task_meta, data) in enumerate(task_inputs):
pbar.set_description(f"{label:<{w_label}}")
# Save collated batches to disk after first run (so bench/future runs skip prepare+collate)
if first_run and disk_cache_dir is not None:
pad_id = tokenizer.get_bos_token_id()
os.makedirs(disk_cache_dir, exist_ok=True)
for label, _, _ in task_inputs:
if label in _batch_cache:
torch.save({'collated': _batch_cache[label], 'pad_token_id': pad_id},
os.path.join(disk_cache_dir, f"{label}.pt"))
print0(f" (saved collated batches to {disk_cache_dir})")
accuracy, collated = evaluate_task(model, data, device, collated=_batch_cache[label], pbar=pbar)
else:
prepared = next_future.result()
# Kick off prepare for the next uncached task
for j in range(i + 1, len(task_inputs)):
next_label, next_meta, next_data = task_inputs[j]
if next_label not in _batch_cache:
next_future = executor.submit(prepare_task_data, tokenizer, next_data, next_meta, max_seq_len)
break
accuracy, collated = evaluate_task(model, data, device, prepared=prepared, pbar=pbar)
_batch_cache[label] = collated
_print_task_result(i, accuracy)
elapsed_total = time.time() - t0
pbar.close()
executor.shutdown(wait=False)
print0(f" (all tasks: {elapsed_total:.2f}s)")
# Save collated batches to disk after first run (so bench/future runs skip prepare+collate)
if first_run and disk_cache_dir is not None:
pad_id = tokenizer.get_bos_token_id()
os.makedirs(disk_cache_dir, exist_ok=True)
for label, _, _ in task_inputs:
if label in _batch_cache:
torch.save({'collated': _batch_cache[label], 'pad_token_id': pad_id},
os.path.join(disk_cache_dir, f"{label}.pt"))
print0(f" (saved collated batches to {disk_cache_dir})")
core_metric = sum(centered_results.values()) / len(centered_results)
if _prev_core is not None:

View File

@ -35,8 +35,8 @@ from nanochat.tokenizer import HuggingFaceTokenizer
from nanochat.checkpoint_manager import load_model
from nanochat.core_eval import (
forward_model, prepare_example, check_result, stack_sequences,
prepare_task_data, _collate_batches, _forward_batches, evaluate_task,
compose_collated,
prepare_task_data, _collate_batches, _forward_all_cached,
evaluate_task,
render_prompts_mc, render_prompts_schema, render_prompts_lm,
batch_sequences_mc, batch_sequences_schema, batch_sequences_lm,
)
@ -270,19 +270,27 @@ def bench_new_first(model, tokenizer, task_inputs, device, batch_size, queue_siz
return time.time() - t0, results, collated_cache
def bench_new_cached(model, task_inputs, device, collated_cache, pbar=None):
"""Benchmark new batched evaluation (cached run, forward only)."""
def bench_new_cached(model, task_inputs, device, collated_cache, pbar=None,
merge=1, split=1, pad_token_id=0):
"""Benchmark new batched evaluation (cached run, forward only).
Uses continuous pipeline across all tasks to eliminate inter-task stalls.
merge/split control GPU-side composition: merge > 1 cats batches, split > 1 slices them."""
import torch.distributed as dist
world_size = dist.get_world_size() if dist.is_initialized() else 1
sync_cuda()
t0 = time.time()
results = {}
max_label_len = max(len(label) for label, _, _ in task_inputs)
for label, task_meta, data in task_inputs:
if pbar is not None:
pbar.set_description(f"{label:<{max_label_len}}")
acc, _ = evaluate_task(model, data, device, collated=collated_cache[label], pbar=pbar)
results[label] = acc
task_collated = [(collated_cache[label], data) for label, _, data in task_inputs]
correct_list = _forward_all_cached(model, task_collated, device, pbar=pbar,
merge=merge, split=split, pad_token_id=pad_token_id)
sync_cuda()
return time.time() - t0, results
elapsed = time.time() - t0
results = {}
for (label, _, data), correct in zip(task_inputs, correct_list):
if world_size > 1:
dist.barrier()
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
results[label] = correct.mean().item()
return elapsed, results
def verify_results(old_results, new_results, label="new"):
@ -448,19 +456,29 @@ def main():
best_cached_time = float('inf')
best_cached_params = None
pad_id = next(iter(base_cache.values()))[1]
# Preload ALL base-4 batches to GPU once (~144MB for full CORE eval).
# All batch-size sweeps compose from these GPU-resident tensors — zero CPU→GPU transfers.
gpu_collated = {}
for label, (collated, _) in base_cache.items():
gpu_collated[label] = [(ids.to(device), meta) for ids, meta in collated]
outer_pbar = tqdm(total=len(batch_sizes), desc="Cached sweep", leave=False, position=0)
inner_pbar = tqdm(total=total_examples, desc="", leave=False, position=1)
for bs in batch_sizes:
outer_pbar.set_description(f"Cached: bs={bs}")
inner_pbar.reset()
# Compose from base-4 to target batch_size (merge or split)
composed_cache = {}
for label, (collated, pad_id) in base_cache.items():
composed_cache[label] = compose_collated(collated, bs, BASE_BATCH_SIZE, pad_id)
# All composition happens on GPU: merge for bs >= base, split for bs < base
if bs >= BASE_BATCH_SIZE:
merge, split = bs // BASE_BATCH_SIZE, 1
else:
merge, split = 1, BASE_BATCH_SIZE // bs
with autocast_ctx:
t, cached_results = bench_new_cached(model, task_inputs, device, composed_cache, pbar=inner_pbar)
t, cached_results = bench_new_cached(model, task_inputs, device, gpu_collated,
pbar=inner_pbar, merge=merge, split=split,
pad_token_id=pad_id)
outer_pbar.write(f" batch_size={bs:>3}: {t:.2f}s ({total_examples / t:.1f} examples/s)")
outer_pbar.update(1)