mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-03 22:25:27 +00:00
CORE eval: GPU-resident data, continuous pipeline, per-task progress bars
three independent improvements to the cached CORE evaluation path:
1. GPU-resident data: all base-4 collated batches (~144MB for full CORE eval)
are moved to GPU upfront via .to(device). eliminates all CPU→GPU transfers
from the forward loop. _forward_all_cached replaces double-buffered prefetch
with a simple upfront bulk transfer — .to() is a no-op when the caller has
already preloaded tensors to GPU (as bench_core_eval now does).
2. continuous cross-task pipeline: _forward_all_cached flattens all tasks'
batches into one stream. the last batch of task N flows directly into the
first batch of task N+1 with no pipeline restart. GPU-side composition via
merge (pad+cat for bs > base) and split (row-slice for bs < base) avoids
the CPU-side compose_collated bottleneck that made bs=8 slower than bs=4.
3. progress bars + per-task result printing: both cached and first-run paths
in evaluate_model now show a tqdm progress bar with the current task label.
on_task_done callback in _forward_all_cached prints each task's accuracy
as soon as its last batch is processed (single-GPU). DDP falls back to
printing after all_reduce. both paths print total elapsed time at the end.
bench_core_eval: preloads ALL base-4 batches to GPU once before the batch-size
sweep. all sweep iterations compose from GPU-resident tensors with zero
CPU→GPU transfers in the hot loop.
This commit is contained in:
parent
7fa30f5ee3
commit
c3f234cfca
|
|
@ -325,6 +325,117 @@ def _forward_batches(model, collated, data, device, pbar=None):
|
|||
return correct
|
||||
|
||||
|
||||
def _forward_all_cached(model, task_collated, device, pbar=None, task_labels=None,
|
||||
on_task_done=None, merge=1, split=1, pad_token_id=0):
|
||||
"""Run all tasks' cached batches through the model in one pass.
|
||||
|
||||
All batch tensors are moved to device upfront (~144MB for full CORE eval).
|
||||
If tensors are already on device (caller preloaded), .to() is a no-op.
|
||||
Composition (merge/split) happens entirely on device:
|
||||
- merge > 1: pad+cat consecutive base batches on GPU before forwarding.
|
||||
- split > 1: slice each group into chunks by example boundaries,
|
||||
forward each chunk separately.
|
||||
|
||||
Args:
|
||||
task_collated: list of (collated_batches, data) per task
|
||||
pbar: optional progress bar, updated per forward pass (by number of examples)
|
||||
task_labels: optional list of task names for pbar description updates
|
||||
on_task_done: optional callback(task_idx, correct_tensor) fired when a task completes
|
||||
merge: number of consecutive base batches to compose per group (>= 1)
|
||||
split: number of forward passes to split each group into (>= 1)
|
||||
pad_token_id: token id used for padding when merging batches of different lengths
|
||||
Returns:
|
||||
list of correct tensors (one per task, on device)
|
||||
"""
|
||||
# Flatten all batches and move to device upfront (no-op if already there)
|
||||
flat_stream = [] # (gpu_ids, batch_meta, task_idx)
|
||||
correct = []
|
||||
task_batch_counts = []
|
||||
for task_idx, (collated, data) in enumerate(task_collated):
|
||||
correct.append(torch.zeros(len(data), dtype=torch.float32, device=device))
|
||||
task_batch_counts.append(len(collated))
|
||||
for combined_ids, batch_meta in collated:
|
||||
flat_stream.append((combined_ids.to(device), batch_meta, task_idx))
|
||||
|
||||
if not flat_stream:
|
||||
return correct
|
||||
|
||||
task_batches_remaining = list(task_batch_counts)
|
||||
current_task = -1
|
||||
buffer_ids = []
|
||||
buffer_info = []
|
||||
|
||||
for i, (combined_ids, batch_meta, task_idx) in enumerate(flat_stream):
|
||||
# Update pbar description on task transition
|
||||
if task_idx != current_task:
|
||||
current_task = task_idx
|
||||
if pbar is not None and task_labels is not None:
|
||||
pbar.set_description(task_labels[task_idx])
|
||||
buffer_ids.append(combined_ids)
|
||||
buffer_info.append((batch_meta, task_idx))
|
||||
|
||||
# Accumulate until we have `merge` batches (or hit the end)
|
||||
if len(buffer_ids) < merge and i < len(flat_stream) - 1:
|
||||
continue
|
||||
|
||||
# GPU compose: pad+cat if multiple batches, otherwise use as-is
|
||||
if len(buffer_ids) == 1:
|
||||
mega_ids = buffer_ids[0]
|
||||
else:
|
||||
max_len = max(t.shape[1] for t in buffer_ids)
|
||||
parts = []
|
||||
for t in buffer_ids:
|
||||
if t.shape[1] < max_len:
|
||||
pad = torch.full((t.shape[0], max_len - t.shape[1]), pad_token_id,
|
||||
dtype=t.dtype, device=t.device)
|
||||
t = torch.cat([t, pad], dim=1)
|
||||
parts.append(t)
|
||||
mega_ids = torch.cat(parts, dim=0)
|
||||
|
||||
# Flatten examples with row boundaries (for splitting)
|
||||
examples = []
|
||||
row_bounds = [0]
|
||||
for bm, tidx in buffer_info:
|
||||
for idx, n, start_idxs, end_idxs, gold, task_type in bm:
|
||||
examples.append((idx, n, start_idxs, end_idxs, gold, task_type, tidx))
|
||||
row_bounds.append(row_bounds[-1] + n)
|
||||
|
||||
# Forward + score (with optional GPU split)
|
||||
n_ex = len(examples)
|
||||
chunk_size = -(-n_ex // split) # ceiling division
|
||||
|
||||
for cs in range(0, n_ex, chunk_size):
|
||||
ce = min(cs + chunk_size, n_ex)
|
||||
chunk = examples[cs:ce]
|
||||
chunk_ids = mega_ids[row_bounds[cs]:row_bounds[ce]]
|
||||
|
||||
losses, predictions = forward_model(model, chunk_ids)
|
||||
|
||||
offset = 0
|
||||
for idx, n, start_idxs, end_idxs, gold, task_type, tidx in chunk:
|
||||
is_correct = check_result(
|
||||
losses[offset:offset+n], predictions[offset:offset+n],
|
||||
chunk_ids[offset:offset+n],
|
||||
start_idxs, end_idxs, gold, task_type,
|
||||
)
|
||||
correct[tidx][idx] = float(is_correct)
|
||||
offset += n
|
||||
if pbar is not None:
|
||||
pbar.update(len(chunk))
|
||||
|
||||
# Fire callback for any tasks that just completed all their batches
|
||||
if on_task_done is not None:
|
||||
for bm, tidx in buffer_info:
|
||||
task_batches_remaining[tidx] -= 1
|
||||
if task_batches_remaining[tidx] == 0:
|
||||
on_task_done(tidx, correct[tidx])
|
||||
|
||||
buffer_ids.clear()
|
||||
buffer_info.clear()
|
||||
|
||||
return correct
|
||||
|
||||
|
||||
def compose_collated(base_collated, target_batch_size, base_batch_size=4, pad_token_id=0):
|
||||
"""Compose base-sized collated batches into target-sized batches.
|
||||
|
||||
|
|
|
|||
|
|
@ -31,13 +31,14 @@ import zipfile
|
|||
import tempfile
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
|
||||
from nanochat.tokenizer import HuggingFaceTokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.core_eval import evaluate_task, prepare_task_data
|
||||
from nanochat.core_eval import evaluate_task, prepare_task_data, _forward_all_cached
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
|
|
@ -214,55 +215,91 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
|||
|
||||
first_run = not cached_run # track whether we did prepare+collate (for disk save)
|
||||
|
||||
if not cached_run:
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
first_uncached = next(i for i, (l, _, _) in enumerate(task_inputs) if l not in _batch_cache)
|
||||
_, first_meta, first_data = task_inputs[first_uncached]
|
||||
next_future = executor.submit(prepare_task_data, tokenizer, first_data, first_meta, max_seq_len)
|
||||
import torch.distributed as dist
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
total_examples = sum(len(data) for _, _, data in task_inputs)
|
||||
task_labels = [label for label, _, _ in task_inputs]
|
||||
pbar = tqdm(total=total_examples, leave=False, disable=(rank != 0))
|
||||
|
||||
for i, (label, task_meta, data) in enumerate(task_inputs):
|
||||
shot_str = f"{task_meta['num_fewshot']}-shot"
|
||||
prefix = f" {label:<{w_label}} {shot_str:<{w_shot}} {task_meta['task_type']:<{w_type}}"
|
||||
print0(f"{prefix} ...", end="", flush=True)
|
||||
t0 = time.time()
|
||||
|
||||
if label in _batch_cache:
|
||||
accuracy, collated = evaluate_task(model, data, device, collated=_batch_cache[label])
|
||||
else:
|
||||
prepared = next_future.result()
|
||||
# Kick off prepare for the next uncached task
|
||||
for j in range(i + 1, len(task_inputs)):
|
||||
next_label, next_meta, next_data = task_inputs[j]
|
||||
if next_label not in _batch_cache:
|
||||
next_future = executor.submit(prepare_task_data, tokenizer, next_data, next_meta, max_seq_len)
|
||||
break
|
||||
accuracy, collated = evaluate_task(model, data, device, prepared=prepared)
|
||||
_batch_cache[label] = collated
|
||||
|
||||
elapsed = time.time() - t0
|
||||
def _print_task_result(tidx, accuracy):
|
||||
"""Print one task's result. Updates results/centered_results dicts."""
|
||||
label, task_meta, _ = task_inputs[tidx]
|
||||
random_baseline = random_baselines[label]
|
||||
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
|
||||
results[label] = accuracy
|
||||
centered_results[label] = centered_result
|
||||
shot_str = f"{task_meta['num_fewshot']}-shot"
|
||||
prefix = f" {label:<{w_label}} {shot_str:<{w_shot}} {task_meta['task_type']:<{w_type}}"
|
||||
delta_str = ""
|
||||
if label in _prev_centered:
|
||||
d = centered_result - _prev_centered[label]
|
||||
arrow = "\u2191" if d > 0 else "\u2193" if d < 0 else "="
|
||||
delta_str = f" {arrow}{d:+.4f}"
|
||||
print0(f"\r{prefix} acc: {accuracy:.4f} centered: {centered_result:>7.4f}{delta_str} time: {elapsed:.2f}s")
|
||||
if rank == 0:
|
||||
pbar.write(f"{prefix} acc: {accuracy:.4f} centered: {centered_result:>7.4f}{delta_str}")
|
||||
|
||||
if not cached_run:
|
||||
executor.shutdown(wait=False)
|
||||
def _on_task_done(tidx, correct):
|
||||
"""Callback for _forward_all_cached: convert tensor to accuracy and print."""
|
||||
_print_task_result(tidx, correct.mean().item())
|
||||
|
||||
if cached_run:
|
||||
# Continuous pipeline: all tasks in one GPU stream, results printed per-task as they complete
|
||||
t0 = time.time()
|
||||
task_collated = [(_batch_cache[label], data) for label, _, data in task_inputs]
|
||||
correct_list = _forward_all_cached(
|
||||
model, task_collated, device, pbar=pbar, task_labels=task_labels,
|
||||
on_task_done=_on_task_done if world_size == 1 else None,
|
||||
)
|
||||
elapsed_total = time.time() - t0
|
||||
pbar.close()
|
||||
|
||||
# DDP: all_reduce + print (single-GPU already handled by on_task_done above)
|
||||
if world_size > 1:
|
||||
for tidx, ((label, task_meta, data), correct) in enumerate(zip(task_inputs, correct_list)):
|
||||
dist.barrier()
|
||||
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
|
||||
_print_task_result(tidx, correct.mean().item())
|
||||
print0(f" (all tasks: {elapsed_total:.2f}s)")
|
||||
else:
|
||||
t0 = time.time()
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
first_uncached = next(i for i, (l, _, _) in enumerate(task_inputs) if l not in _batch_cache)
|
||||
_, first_meta, first_data = task_inputs[first_uncached]
|
||||
next_future = executor.submit(prepare_task_data, tokenizer, first_data, first_meta, max_seq_len)
|
||||
|
||||
for i, (label, task_meta, data) in enumerate(task_inputs):
|
||||
pbar.set_description(f"{label:<{w_label}}")
|
||||
|
||||
# Save collated batches to disk after first run (so bench/future runs skip prepare+collate)
|
||||
if first_run and disk_cache_dir is not None:
|
||||
pad_id = tokenizer.get_bos_token_id()
|
||||
os.makedirs(disk_cache_dir, exist_ok=True)
|
||||
for label, _, _ in task_inputs:
|
||||
if label in _batch_cache:
|
||||
torch.save({'collated': _batch_cache[label], 'pad_token_id': pad_id},
|
||||
os.path.join(disk_cache_dir, f"{label}.pt"))
|
||||
print0(f" (saved collated batches to {disk_cache_dir})")
|
||||
accuracy, collated = evaluate_task(model, data, device, collated=_batch_cache[label], pbar=pbar)
|
||||
else:
|
||||
prepared = next_future.result()
|
||||
# Kick off prepare for the next uncached task
|
||||
for j in range(i + 1, len(task_inputs)):
|
||||
next_label, next_meta, next_data = task_inputs[j]
|
||||
if next_label not in _batch_cache:
|
||||
next_future = executor.submit(prepare_task_data, tokenizer, next_data, next_meta, max_seq_len)
|
||||
break
|
||||
accuracy, collated = evaluate_task(model, data, device, prepared=prepared, pbar=pbar)
|
||||
_batch_cache[label] = collated
|
||||
|
||||
_print_task_result(i, accuracy)
|
||||
|
||||
elapsed_total = time.time() - t0
|
||||
pbar.close()
|
||||
executor.shutdown(wait=False)
|
||||
print0(f" (all tasks: {elapsed_total:.2f}s)")
|
||||
|
||||
# Save collated batches to disk after first run (so bench/future runs skip prepare+collate)
|
||||
if first_run and disk_cache_dir is not None:
|
||||
pad_id = tokenizer.get_bos_token_id()
|
||||
os.makedirs(disk_cache_dir, exist_ok=True)
|
||||
for label, _, _ in task_inputs:
|
||||
if label in _batch_cache:
|
||||
torch.save({'collated': _batch_cache[label], 'pad_token_id': pad_id},
|
||||
os.path.join(disk_cache_dir, f"{label}.pt"))
|
||||
print0(f" (saved collated batches to {disk_cache_dir})")
|
||||
|
||||
core_metric = sum(centered_results.values()) / len(centered_results)
|
||||
if _prev_core is not None:
|
||||
|
|
|
|||
|
|
@ -35,8 +35,8 @@ from nanochat.tokenizer import HuggingFaceTokenizer
|
|||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.core_eval import (
|
||||
forward_model, prepare_example, check_result, stack_sequences,
|
||||
prepare_task_data, _collate_batches, _forward_batches, evaluate_task,
|
||||
compose_collated,
|
||||
prepare_task_data, _collate_batches, _forward_all_cached,
|
||||
evaluate_task,
|
||||
render_prompts_mc, render_prompts_schema, render_prompts_lm,
|
||||
batch_sequences_mc, batch_sequences_schema, batch_sequences_lm,
|
||||
)
|
||||
|
|
@ -270,19 +270,27 @@ def bench_new_first(model, tokenizer, task_inputs, device, batch_size, queue_siz
|
|||
return time.time() - t0, results, collated_cache
|
||||
|
||||
|
||||
def bench_new_cached(model, task_inputs, device, collated_cache, pbar=None):
|
||||
"""Benchmark new batched evaluation (cached run, forward only)."""
|
||||
def bench_new_cached(model, task_inputs, device, collated_cache, pbar=None,
|
||||
merge=1, split=1, pad_token_id=0):
|
||||
"""Benchmark new batched evaluation (cached run, forward only).
|
||||
Uses continuous pipeline across all tasks to eliminate inter-task stalls.
|
||||
merge/split control GPU-side composition: merge > 1 cats batches, split > 1 slices them."""
|
||||
import torch.distributed as dist
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
sync_cuda()
|
||||
t0 = time.time()
|
||||
results = {}
|
||||
max_label_len = max(len(label) for label, _, _ in task_inputs)
|
||||
for label, task_meta, data in task_inputs:
|
||||
if pbar is not None:
|
||||
pbar.set_description(f"{label:<{max_label_len}}")
|
||||
acc, _ = evaluate_task(model, data, device, collated=collated_cache[label], pbar=pbar)
|
||||
results[label] = acc
|
||||
task_collated = [(collated_cache[label], data) for label, _, data in task_inputs]
|
||||
correct_list = _forward_all_cached(model, task_collated, device, pbar=pbar,
|
||||
merge=merge, split=split, pad_token_id=pad_token_id)
|
||||
sync_cuda()
|
||||
return time.time() - t0, results
|
||||
elapsed = time.time() - t0
|
||||
results = {}
|
||||
for (label, _, data), correct in zip(task_inputs, correct_list):
|
||||
if world_size > 1:
|
||||
dist.barrier()
|
||||
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
|
||||
results[label] = correct.mean().item()
|
||||
return elapsed, results
|
||||
|
||||
|
||||
def verify_results(old_results, new_results, label="new"):
|
||||
|
|
@ -448,19 +456,29 @@ def main():
|
|||
|
||||
best_cached_time = float('inf')
|
||||
best_cached_params = None
|
||||
pad_id = next(iter(base_cache.values()))[1]
|
||||
# Preload ALL base-4 batches to GPU once (~144MB for full CORE eval).
|
||||
# All batch-size sweeps compose from these GPU-resident tensors — zero CPU→GPU transfers.
|
||||
gpu_collated = {}
|
||||
for label, (collated, _) in base_cache.items():
|
||||
gpu_collated[label] = [(ids.to(device), meta) for ids, meta in collated]
|
||||
outer_pbar = tqdm(total=len(batch_sizes), desc="Cached sweep", leave=False, position=0)
|
||||
inner_pbar = tqdm(total=total_examples, desc="", leave=False, position=1)
|
||||
|
||||
for bs in batch_sizes:
|
||||
outer_pbar.set_description(f"Cached: bs={bs}")
|
||||
inner_pbar.reset()
|
||||
# Compose from base-4 to target batch_size (merge or split)
|
||||
composed_cache = {}
|
||||
for label, (collated, pad_id) in base_cache.items():
|
||||
composed_cache[label] = compose_collated(collated, bs, BASE_BATCH_SIZE, pad_id)
|
||||
|
||||
# All composition happens on GPU: merge for bs >= base, split for bs < base
|
||||
if bs >= BASE_BATCH_SIZE:
|
||||
merge, split = bs // BASE_BATCH_SIZE, 1
|
||||
else:
|
||||
merge, split = 1, BASE_BATCH_SIZE // bs
|
||||
|
||||
with autocast_ctx:
|
||||
t, cached_results = bench_new_cached(model, task_inputs, device, composed_cache, pbar=inner_pbar)
|
||||
t, cached_results = bench_new_cached(model, task_inputs, device, gpu_collated,
|
||||
pbar=inner_pbar, merge=merge, split=split,
|
||||
pad_token_id=pad_id)
|
||||
|
||||
outer_pbar.write(f" batch_size={bs:>3}: {t:.2f}s ({total_examples / t:.1f} examples/s)")
|
||||
outer_pbar.update(1)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user