From c3f234cfca4f11a747baaf8d072eab41a132e141 Mon Sep 17 00:00:00 2001 From: Unsal Gokdag Date: Fri, 13 Feb 2026 07:54:53 +0000 Subject: [PATCH] CORE eval: GPU-resident data, continuous pipeline, per-task progress bars MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit three independent improvements to the cached CORE evaluation path: 1. GPU-resident data: all base-4 collated batches (~144MB for full CORE eval) are moved to GPU upfront via .to(device). eliminates all CPU→GPU transfers from the forward loop. _forward_all_cached replaces double-buffered prefetch with a simple upfront bulk transfer — .to() is a no-op when the caller has already preloaded tensors to GPU (as bench_core_eval now does). 2. continuous cross-task pipeline: _forward_all_cached flattens all tasks' batches into one stream. the last batch of task N flows directly into the first batch of task N+1 with no pipeline restart. GPU-side composition via merge (pad+cat for bs > base) and split (row-slice for bs < base) avoids the CPU-side compose_collated bottleneck that made bs=8 slower than bs=4. 3. progress bars + per-task result printing: both cached and first-run paths in evaluate_model now show a tqdm progress bar with the current task label. on_task_done callback in _forward_all_cached prints each task's accuracy as soon as its last batch is processed (single-GPU). DDP falls back to printing after all_reduce. both paths print total elapsed time at the end. bench_core_eval: preloads ALL base-4 batches to GPU once before the batch-size sweep. all sweep iterations compose from GPU-resident tensors with zero CPU→GPU transfers in the hot loop. --- nanochat/core_eval.py | 111 +++++++++++++++++++++++++++++++++++++ scripts/base_eval.py | 111 ++++++++++++++++++++++++------------- scripts/bench_core_eval.py | 52 +++++++++++------ 3 files changed, 220 insertions(+), 54 deletions(-) diff --git a/nanochat/core_eval.py b/nanochat/core_eval.py index 9231520..28f50f5 100644 --- a/nanochat/core_eval.py +++ b/nanochat/core_eval.py @@ -325,6 +325,117 @@ def _forward_batches(model, collated, data, device, pbar=None): return correct +def _forward_all_cached(model, task_collated, device, pbar=None, task_labels=None, + on_task_done=None, merge=1, split=1, pad_token_id=0): + """Run all tasks' cached batches through the model in one pass. + + All batch tensors are moved to device upfront (~144MB for full CORE eval). + If tensors are already on device (caller preloaded), .to() is a no-op. + Composition (merge/split) happens entirely on device: + - merge > 1: pad+cat consecutive base batches on GPU before forwarding. + - split > 1: slice each group into chunks by example boundaries, + forward each chunk separately. + + Args: + task_collated: list of (collated_batches, data) per task + pbar: optional progress bar, updated per forward pass (by number of examples) + task_labels: optional list of task names for pbar description updates + on_task_done: optional callback(task_idx, correct_tensor) fired when a task completes + merge: number of consecutive base batches to compose per group (>= 1) + split: number of forward passes to split each group into (>= 1) + pad_token_id: token id used for padding when merging batches of different lengths + Returns: + list of correct tensors (one per task, on device) + """ + # Flatten all batches and move to device upfront (no-op if already there) + flat_stream = [] # (gpu_ids, batch_meta, task_idx) + correct = [] + task_batch_counts = [] + for task_idx, (collated, data) in enumerate(task_collated): + correct.append(torch.zeros(len(data), dtype=torch.float32, device=device)) + task_batch_counts.append(len(collated)) + for combined_ids, batch_meta in collated: + flat_stream.append((combined_ids.to(device), batch_meta, task_idx)) + + if not flat_stream: + return correct + + task_batches_remaining = list(task_batch_counts) + current_task = -1 + buffer_ids = [] + buffer_info = [] + + for i, (combined_ids, batch_meta, task_idx) in enumerate(flat_stream): + # Update pbar description on task transition + if task_idx != current_task: + current_task = task_idx + if pbar is not None and task_labels is not None: + pbar.set_description(task_labels[task_idx]) + buffer_ids.append(combined_ids) + buffer_info.append((batch_meta, task_idx)) + + # Accumulate until we have `merge` batches (or hit the end) + if len(buffer_ids) < merge and i < len(flat_stream) - 1: + continue + + # GPU compose: pad+cat if multiple batches, otherwise use as-is + if len(buffer_ids) == 1: + mega_ids = buffer_ids[0] + else: + max_len = max(t.shape[1] for t in buffer_ids) + parts = [] + for t in buffer_ids: + if t.shape[1] < max_len: + pad = torch.full((t.shape[0], max_len - t.shape[1]), pad_token_id, + dtype=t.dtype, device=t.device) + t = torch.cat([t, pad], dim=1) + parts.append(t) + mega_ids = torch.cat(parts, dim=0) + + # Flatten examples with row boundaries (for splitting) + examples = [] + row_bounds = [0] + for bm, tidx in buffer_info: + for idx, n, start_idxs, end_idxs, gold, task_type in bm: + examples.append((idx, n, start_idxs, end_idxs, gold, task_type, tidx)) + row_bounds.append(row_bounds[-1] + n) + + # Forward + score (with optional GPU split) + n_ex = len(examples) + chunk_size = -(-n_ex // split) # ceiling division + + for cs in range(0, n_ex, chunk_size): + ce = min(cs + chunk_size, n_ex) + chunk = examples[cs:ce] + chunk_ids = mega_ids[row_bounds[cs]:row_bounds[ce]] + + losses, predictions = forward_model(model, chunk_ids) + + offset = 0 + for idx, n, start_idxs, end_idxs, gold, task_type, tidx in chunk: + is_correct = check_result( + losses[offset:offset+n], predictions[offset:offset+n], + chunk_ids[offset:offset+n], + start_idxs, end_idxs, gold, task_type, + ) + correct[tidx][idx] = float(is_correct) + offset += n + if pbar is not None: + pbar.update(len(chunk)) + + # Fire callback for any tasks that just completed all their batches + if on_task_done is not None: + for bm, tidx in buffer_info: + task_batches_remaining[tidx] -= 1 + if task_batches_remaining[tidx] == 0: + on_task_done(tidx, correct[tidx]) + + buffer_ids.clear() + buffer_info.clear() + + return correct + + def compose_collated(base_collated, target_batch_size, base_batch_size=4, pad_token_id=0): """Compose base-sized collated batches into target-sized batches. diff --git a/scripts/base_eval.py b/scripts/base_eval.py index d1751fc..dcee8da 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -31,13 +31,14 @@ import zipfile import tempfile import argparse from contextlib import nullcontext +from tqdm import tqdm import torch from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock from nanochat.tokenizer import HuggingFaceTokenizer, get_token_bytes from nanochat.checkpoint_manager import load_model -from nanochat.core_eval import evaluate_task, prepare_task_data +from nanochat.core_eval import evaluate_task, prepare_task_data, _forward_all_cached from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine @@ -214,55 +215,91 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1): first_run = not cached_run # track whether we did prepare+collate (for disk save) - if not cached_run: - executor = ThreadPoolExecutor(max_workers=1) - first_uncached = next(i for i, (l, _, _) in enumerate(task_inputs) if l not in _batch_cache) - _, first_meta, first_data = task_inputs[first_uncached] - next_future = executor.submit(prepare_task_data, tokenizer, first_data, first_meta, max_seq_len) + import torch.distributed as dist + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + total_examples = sum(len(data) for _, _, data in task_inputs) + task_labels = [label for label, _, _ in task_inputs] + pbar = tqdm(total=total_examples, leave=False, disable=(rank != 0)) - for i, (label, task_meta, data) in enumerate(task_inputs): - shot_str = f"{task_meta['num_fewshot']}-shot" - prefix = f" {label:<{w_label}} {shot_str:<{w_shot}} {task_meta['task_type']:<{w_type}}" - print0(f"{prefix} ...", end="", flush=True) - t0 = time.time() - - if label in _batch_cache: - accuracy, collated = evaluate_task(model, data, device, collated=_batch_cache[label]) - else: - prepared = next_future.result() - # Kick off prepare for the next uncached task - for j in range(i + 1, len(task_inputs)): - next_label, next_meta, next_data = task_inputs[j] - if next_label not in _batch_cache: - next_future = executor.submit(prepare_task_data, tokenizer, next_data, next_meta, max_seq_len) - break - accuracy, collated = evaluate_task(model, data, device, prepared=prepared) - _batch_cache[label] = collated - - elapsed = time.time() - t0 + def _print_task_result(tidx, accuracy): + """Print one task's result. Updates results/centered_results dicts.""" + label, task_meta, _ = task_inputs[tidx] random_baseline = random_baselines[label] centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline) results[label] = accuracy centered_results[label] = centered_result + shot_str = f"{task_meta['num_fewshot']}-shot" + prefix = f" {label:<{w_label}} {shot_str:<{w_shot}} {task_meta['task_type']:<{w_type}}" delta_str = "" if label in _prev_centered: d = centered_result - _prev_centered[label] arrow = "\u2191" if d > 0 else "\u2193" if d < 0 else "=" delta_str = f" {arrow}{d:+.4f}" - print0(f"\r{prefix} acc: {accuracy:.4f} centered: {centered_result:>7.4f}{delta_str} time: {elapsed:.2f}s") + if rank == 0: + pbar.write(f"{prefix} acc: {accuracy:.4f} centered: {centered_result:>7.4f}{delta_str}") - if not cached_run: - executor.shutdown(wait=False) + def _on_task_done(tidx, correct): + """Callback for _forward_all_cached: convert tensor to accuracy and print.""" + _print_task_result(tidx, correct.mean().item()) + + if cached_run: + # Continuous pipeline: all tasks in one GPU stream, results printed per-task as they complete + t0 = time.time() + task_collated = [(_batch_cache[label], data) for label, _, data in task_inputs] + correct_list = _forward_all_cached( + model, task_collated, device, pbar=pbar, task_labels=task_labels, + on_task_done=_on_task_done if world_size == 1 else None, + ) + elapsed_total = time.time() - t0 + pbar.close() + + # DDP: all_reduce + print (single-GPU already handled by on_task_done above) + if world_size > 1: + for tidx, ((label, task_meta, data), correct) in enumerate(zip(task_inputs, correct_list)): + dist.barrier() + dist.all_reduce(correct, op=dist.ReduceOp.SUM) + _print_task_result(tidx, correct.mean().item()) + print0(f" (all tasks: {elapsed_total:.2f}s)") + else: + t0 = time.time() + executor = ThreadPoolExecutor(max_workers=1) + first_uncached = next(i for i, (l, _, _) in enumerate(task_inputs) if l not in _batch_cache) + _, first_meta, first_data = task_inputs[first_uncached] + next_future = executor.submit(prepare_task_data, tokenizer, first_data, first_meta, max_seq_len) + + for i, (label, task_meta, data) in enumerate(task_inputs): + pbar.set_description(f"{label:<{w_label}}") - # Save collated batches to disk after first run (so bench/future runs skip prepare+collate) - if first_run and disk_cache_dir is not None: - pad_id = tokenizer.get_bos_token_id() - os.makedirs(disk_cache_dir, exist_ok=True) - for label, _, _ in task_inputs: if label in _batch_cache: - torch.save({'collated': _batch_cache[label], 'pad_token_id': pad_id}, - os.path.join(disk_cache_dir, f"{label}.pt")) - print0(f" (saved collated batches to {disk_cache_dir})") + accuracy, collated = evaluate_task(model, data, device, collated=_batch_cache[label], pbar=pbar) + else: + prepared = next_future.result() + # Kick off prepare for the next uncached task + for j in range(i + 1, len(task_inputs)): + next_label, next_meta, next_data = task_inputs[j] + if next_label not in _batch_cache: + next_future = executor.submit(prepare_task_data, tokenizer, next_data, next_meta, max_seq_len) + break + accuracy, collated = evaluate_task(model, data, device, prepared=prepared, pbar=pbar) + _batch_cache[label] = collated + + _print_task_result(i, accuracy) + + elapsed_total = time.time() - t0 + pbar.close() + executor.shutdown(wait=False) + print0(f" (all tasks: {elapsed_total:.2f}s)") + + # Save collated batches to disk after first run (so bench/future runs skip prepare+collate) + if first_run and disk_cache_dir is not None: + pad_id = tokenizer.get_bos_token_id() + os.makedirs(disk_cache_dir, exist_ok=True) + for label, _, _ in task_inputs: + if label in _batch_cache: + torch.save({'collated': _batch_cache[label], 'pad_token_id': pad_id}, + os.path.join(disk_cache_dir, f"{label}.pt")) + print0(f" (saved collated batches to {disk_cache_dir})") core_metric = sum(centered_results.values()) / len(centered_results) if _prev_core is not None: diff --git a/scripts/bench_core_eval.py b/scripts/bench_core_eval.py index ae7294e..1003aee 100644 --- a/scripts/bench_core_eval.py +++ b/scripts/bench_core_eval.py @@ -35,8 +35,8 @@ from nanochat.tokenizer import HuggingFaceTokenizer from nanochat.checkpoint_manager import load_model from nanochat.core_eval import ( forward_model, prepare_example, check_result, stack_sequences, - prepare_task_data, _collate_batches, _forward_batches, evaluate_task, - compose_collated, + prepare_task_data, _collate_batches, _forward_all_cached, + evaluate_task, render_prompts_mc, render_prompts_schema, render_prompts_lm, batch_sequences_mc, batch_sequences_schema, batch_sequences_lm, ) @@ -270,19 +270,27 @@ def bench_new_first(model, tokenizer, task_inputs, device, batch_size, queue_siz return time.time() - t0, results, collated_cache -def bench_new_cached(model, task_inputs, device, collated_cache, pbar=None): - """Benchmark new batched evaluation (cached run, forward only).""" +def bench_new_cached(model, task_inputs, device, collated_cache, pbar=None, + merge=1, split=1, pad_token_id=0): + """Benchmark new batched evaluation (cached run, forward only). + Uses continuous pipeline across all tasks to eliminate inter-task stalls. + merge/split control GPU-side composition: merge > 1 cats batches, split > 1 slices them.""" + import torch.distributed as dist + world_size = dist.get_world_size() if dist.is_initialized() else 1 sync_cuda() t0 = time.time() - results = {} - max_label_len = max(len(label) for label, _, _ in task_inputs) - for label, task_meta, data in task_inputs: - if pbar is not None: - pbar.set_description(f"{label:<{max_label_len}}") - acc, _ = evaluate_task(model, data, device, collated=collated_cache[label], pbar=pbar) - results[label] = acc + task_collated = [(collated_cache[label], data) for label, _, data in task_inputs] + correct_list = _forward_all_cached(model, task_collated, device, pbar=pbar, + merge=merge, split=split, pad_token_id=pad_token_id) sync_cuda() - return time.time() - t0, results + elapsed = time.time() - t0 + results = {} + for (label, _, data), correct in zip(task_inputs, correct_list): + if world_size > 1: + dist.barrier() + dist.all_reduce(correct, op=dist.ReduceOp.SUM) + results[label] = correct.mean().item() + return elapsed, results def verify_results(old_results, new_results, label="new"): @@ -448,19 +456,29 @@ def main(): best_cached_time = float('inf') best_cached_params = None + pad_id = next(iter(base_cache.values()))[1] + # Preload ALL base-4 batches to GPU once (~144MB for full CORE eval). + # All batch-size sweeps compose from these GPU-resident tensors — zero CPU→GPU transfers. + gpu_collated = {} + for label, (collated, _) in base_cache.items(): + gpu_collated[label] = [(ids.to(device), meta) for ids, meta in collated] outer_pbar = tqdm(total=len(batch_sizes), desc="Cached sweep", leave=False, position=0) inner_pbar = tqdm(total=total_examples, desc="", leave=False, position=1) for bs in batch_sizes: outer_pbar.set_description(f"Cached: bs={bs}") inner_pbar.reset() - # Compose from base-4 to target batch_size (merge or split) - composed_cache = {} - for label, (collated, pad_id) in base_cache.items(): - composed_cache[label] = compose_collated(collated, bs, BASE_BATCH_SIZE, pad_id) + + # All composition happens on GPU: merge for bs >= base, split for bs < base + if bs >= BASE_BATCH_SIZE: + merge, split = bs // BASE_BATCH_SIZE, 1 + else: + merge, split = 1, BASE_BATCH_SIZE // bs with autocast_ctx: - t, cached_results = bench_new_cached(model, task_inputs, device, composed_cache, pbar=inner_pbar) + t, cached_results = bench_new_cached(model, task_inputs, device, gpu_collated, + pbar=inner_pbar, merge=merge, split=split, + pad_token_id=pad_id) outer_pbar.write(f" batch_size={bs:>3}: {t:.2f}s ({total_examples / t:.1f} examples/s)") outer_pbar.update(1)