From 8695280566f564b53759120f1c4ae89743c44952 Mon Sep 17 00:00:00 2001 From: unsalgokdag Date: Thu, 12 Feb 2026 18:13:56 +0100 Subject: [PATCH 1/4] speed up CORE metric evaluation: batched GPU forward passes, threaded CPU prep, cross-call caching. first eval pipelines tokenization on a background thread while GPU processes the previous batch. second+ evals skip tokenization and collation entirely, only GPU forward passes remain. Also adds a benchmark script to sweep batch_size and queue_size hyperparameters. --- nanochat/core_eval.py | 191 ++++++++++++++++----- scripts/base_eval.py | 155 ++++++++++++----- scripts/base_train.py | 4 +- scripts/bench_core_eval.py | 338 +++++++++++++++++++++++++++++++++++++ 4 files changed, 595 insertions(+), 93 deletions(-) create mode 100644 scripts/bench_core_eval.py diff --git a/nanochat/core_eval.py b/nanochat/core_eval.py index f3c9a9f..a81c810 100644 --- a/nanochat/core_eval.py +++ b/nanochat/core_eval.py @@ -134,14 +134,21 @@ def batch_sequences_lm(tokenizer, prompts): # In LM tasks, we have two prompts: without and with continuation tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) tokens_without, tokens_with = tokens - start_idx, end_idx = len(tokens_without), len(tokens_with) - assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with" - assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with" + end_idx = len(tokens_with) + # Find longest common prefix — greedy trie tokenizers are not always + # prefix-stable, so we can't assume an exact prefix match. + start_idx = 0 + for i in range(min(len(tokens_without), len(tokens_with))): + if tokens_without[i] != tokens_with[i]: + break + start_idx = i + 1 + assert start_idx < end_idx, "continuation must produce additional tokens" # we only need the with continuation prompt in the LM task, i.e. batch size of 1 return [tokens_with], [start_idx], [end_idx] @torch.no_grad() +@torch.compiler.disable def forward_model(model, input_ids): """ Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions. @@ -164,9 +171,8 @@ def forward_model(model, input_ids): return losses, predictions -@torch.no_grad() -def evaluate_example(idx, model, tokenizer, data, device, task_meta): - """Evaluate a single example, return True if correct, False otherwise""" +def prepare_example(idx, tokenizer, data, task_meta, max_seq_len=None): + """CPU-only: render prompts, tokenize, stack into tensors. Returns a dict.""" item = data[idx] task_type = task_meta['task_type'] num_fewshot = task_meta['num_fewshot'] @@ -193,70 +199,165 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta): else: raise ValueError(f"Unsupported task type: {task_type}") - # Some models can't forward sequences beyond a certain length (e.g. GPT-2) - # In these cases, we have to truncate sequences to max length and adjust the indices - if hasattr(model, 'max_seq_len') and model.max_seq_len is not None: - max_tokens = model.max_seq_len + # Truncate sequences for models with a max length (e.g. GPT-2) + if max_seq_len is not None: new_tokens, new_start_idxs, new_end_idxs = [], [], [] for t, s, e in zip(tokens, start_idxs, end_idxs): - if len(t) > max_tokens: - num_to_crop = len(t) - max_tokens - new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens - new_start_idxs.append(s - num_to_crop) # shift the indices down + if len(t) > max_seq_len: + num_to_crop = len(t) - max_seq_len + new_tokens.append(t[-max_seq_len:]) + new_start_idxs.append(s - num_to_crop) new_end_idxs.append(e - num_to_crop) - assert s - num_to_crop >= 0, "this should never happen right?" - assert e - num_to_crop >= 0, "this should never happen right?" else: - new_tokens.append(t) # keep unchanged + new_tokens.append(t) new_start_idxs.append(s) new_end_idxs.append(e) tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs - # Stack up all the sequences into a batch - pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok - input_ids = stack_sequences(tokens, pad_token_id) - input_ids = input_ids.to(device) + pad_token_id = tokenizer.get_bos_token_id() + input_ids = stack_sequences(tokens, pad_token_id) # (num_options, seq_len) - # Forward the model, get the autoregressive loss and argmax prediction at each token - losses, predictions = forward_model(model, input_ids) + return { + 'input_ids': input_ids, + 'start_idxs': start_idxs, + 'end_idxs': end_idxs, + 'gold': item.get('gold', None), + 'task_type': task_type, + 'num_options': input_ids.size(0), + 'seq_len': input_ids.size(1), + 'pad_token_id': pad_token_id, + } - # See if the losses/predictions come out correctly + +def check_result(losses, predictions, input_ids, start_idxs, end_idxs, gold, task_type): + """Analyze forward pass outputs for one example, return True if correct.""" if task_type == 'language_modeling': - # language modeling task is currently always batch size 1 - si = start_idxs[0] - ei = end_idxs[0] - # predictions[i] predict input_ids[i+1] autoregressively + si, ei = start_idxs[0], end_idxs[0] predicted_tokens = predictions[0, si-1:ei-1] actual_tokens = input_ids[0, si:ei] - is_correct = torch.all(predicted_tokens == actual_tokens).item() + return torch.all(predicted_tokens == actual_tokens).item() elif task_type in ['multiple_choice', 'schema']: - # For MC/schema: find the option with lowest average loss mean_losses = [losses[i, si-1:ei-1].mean().item() - for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))] - pred_idx = mean_losses.index(min(mean_losses)) - is_correct = pred_idx == item['gold'] + for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))] + return mean_losses.index(min(mean_losses)) == gold else: raise ValueError(f"Unsupported task type: {task_type}") - return is_correct + +def _collate_batches(prepared, batch_size, queue): + """Background thread: collate batches on CPU and push to queue.""" + for batch_start in range(0, len(prepared), batch_size): + batch = prepared[batch_start:batch_start + batch_size] + batch_preps = [p for _, p in batch] + max_len = max(p['seq_len'] for p in batch_preps) + total_rows = sum(p['num_options'] for p in batch_preps) + pad_id = batch_preps[0]['pad_token_id'] + + combined_ids = torch.full((total_rows, max_len), pad_id, dtype=torch.long) + + batch_meta = [] + offset = 0 + for idx, p in batch: + n, sl = p['num_options'], p['seq_len'] + combined_ids[offset:offset+n, :sl] = p['input_ids'] + batch_meta.append((idx, n, p['start_idxs'], p['end_idxs'], p['gold'], p['task_type'])) + offset += n + + queue.put((combined_ids, batch_meta)) + queue.put(None) # sentinel -def evaluate_task(model, tokenizer, data, device, task_meta): +def prepare_task_data(tokenizer, data, task_meta, max_seq_len=None): + """CPU-only: prepare and sort all examples for a task. Can run on a background thread.""" + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + indices = list(range(rank, len(data), world_size)) + prepared = [(idx, prepare_example(idx, tokenizer, data, task_meta, max_seq_len)) for idx in indices] + prepared.sort(key=lambda x: x[1]['seq_len']) + return prepared + + +def _forward_batches(model, collated, data, device): + """Run GPU forward passes on pre-collated batches, return per-example correctness tensor.""" + correct = torch.zeros(len(data), dtype=torch.float32, device=device) + for combined_ids, batch_meta in collated: + combined_ids = combined_ids.to(device) + + losses, predictions = forward_model(model, combined_ids) + + offset = 0 + for idx, n, start_idxs, end_idxs, gold, task_type in batch_meta: + is_correct = check_result( + losses[offset:offset+n], predictions[offset:offset+n], + combined_ids[offset:offset+n], + start_idxs, end_idxs, gold, task_type, + ) + correct[idx] = float(is_correct) + offset += n + return correct + + +def evaluate_task(model, data, device, batch_size=4, queue_size=2, prepared=None, + collated=None, tokenizer=None, task_meta=None): """ - This function is responsible for evaluating one task across many examples. - It also handles dispatch to all processes if the script is run with torchrun. + Evaluate one task across many examples with batched GPU forward passes. + Examples are sorted by sequence length so similar-length sequences are batched + together, minimizing padding waste and increasing GPU utilization. + + Three modes (checked in order): + - collated: skip prepare + collation, go straight to GPU forward passes. + - prepared: skip prepare, collation runs on a background thread pipelined with GPU. + - neither: full pipeline (prepare + collate + forward). + + Returns (accuracy, collated_batches) so the caller can cache collated batches. """ rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 - correct = torch.zeros(len(data), dtype=torch.float32, device=device) - # stride the examples to each rank - for idx in range(rank, len(data), world_size): - is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta) - correct[idx] = float(is_correct) + + if collated is not None: + # Fast path: just GPU forward passes, no threads + correct = _forward_batches(model, collated, data, device) + else: + from queue import Queue + from threading import Thread + + if prepared is None: + max_seq_len = getattr(model, 'max_seq_len', None) + prepared = prepare_task_data(tokenizer, data, task_meta, max_seq_len) + + # Collation thread pipelined with GPU forward passes + queue = Queue(maxsize=queue_size) + collator = Thread(target=_collate_batches, args=(prepared, batch_size, queue), daemon=True) + collator.start() + + collated = [] + correct = torch.zeros(len(data), dtype=torch.float32, device=device) + while True: + item = queue.get() + if item is None: + break + collated.append(item) + combined_ids, batch_meta = item + + combined_ids = combined_ids.to(device) + + losses, predictions = forward_model(model, combined_ids) + + offset = 0 + for idx, n, start_idxs, end_idxs, gold, task_type in batch_meta: + is_correct = check_result( + losses[offset:offset+n], predictions[offset:offset+n], + combined_ids[offset:offset+n], + start_idxs, end_idxs, gold, task_type, + ) + correct[idx] = float(is_correct) + offset += n + + collator.join() + del prepared + # sync results across all the processes if running distributed if world_size > 1: dist.barrier() dist.all_reduce(correct, op=dist.ReduceOp.SUM) - # compute the mean - mean_correct = correct.mean().item() - return mean_correct + return correct.mean().item(), collated diff --git a/scripts/base_eval.py b/scripts/base_eval.py index e45ae43..b4aaf56 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -36,7 +36,7 @@ import torch from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock from nanochat.tokenizer import HuggingFaceTokenizer, get_token_bytes from nanochat.checkpoint_manager import load_model -from nanochat.core_eval import evaluate_task +from nanochat.core_eval import evaluate_task, prepare_task_data from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine @@ -106,67 +106,130 @@ def place_eval_bundle(file_path): print0(f"Placed eval_bundle directory at {eval_bundle_dir}") -def evaluate_core(model, tokenizer, device, max_per_task=-1): +_eval_data_cache = None # (task_inputs, random_baselines, w_label, w_shot, w_type) +_batch_cache = {} # {label: collated_batches} — cached after first run +_batch_cache_key = None # (max_per_task, max_seq_len) — invalidate if these change +_prev_centered = {} # {label: centered_result} — previous run for delta display +_prev_core = None # previous core_metric + + +def evaluate_model(model, tokenizer, device, max_per_task=-1): """ Evaluate a base model on the CORE benchmark. - Returns dict with results, centered_results, and core_metric. + - max_per_task: crop the data to this many examples per task for testing (-1 = disable) + Collated batches are cached across calls since the tokenizer is fixed. + Second+ runs skip prepare and collation entirely — just GPU forward passes. """ - base_dir = get_base_dir() - eval_bundle_dir = os.path.join(base_dir, "eval_bundle") - # Download the eval bundle if needed - if not os.path.exists(eval_bundle_dir): - download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle) + global _eval_data_cache, _batch_cache, _batch_cache_key, _prev_centered, _prev_core + from concurrent.futures import ThreadPoolExecutor - config_path = os.path.join(eval_bundle_dir, "core.yaml") - data_base_path = os.path.join(eval_bundle_dir, "eval_data") - eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv") + max_seq_len = getattr(model, 'max_seq_len', None) + cache_key = (max_per_task, max_seq_len) - with open(config_path, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - tasks = config['icl_tasks'] + # Invalidate batch cache if parameters changed + if cache_key != _batch_cache_key: + _batch_cache.clear() + _batch_cache_key = cache_key - # Load random baseline values - random_baselines = {} - with open(eval_meta_data, 'r', encoding='utf-8') as f: - reader = csv.DictReader(f) - for row in reader: - task_name = row['Eval Task'] - random_baseline = row['Random baseline'] - random_baselines[task_name] = float(random_baseline) + # Load and cache task data + baselines (only read from disk once) + if _eval_data_cache is None: + base_dir = get_base_dir() + eval_bundle_dir = os.path.join(base_dir, "eval_bundle") + if not os.path.exists(eval_bundle_dir): + download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle) + config_path = os.path.join(eval_bundle_dir, "core.yaml") + data_base_path = os.path.join(eval_bundle_dir, "eval_data") + eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv") + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + tasks = config['icl_tasks'] - # Evaluate each task + random_baselines = {} + with open(eval_meta_data, 'r', encoding='utf-8') as f: + reader = csv.DictReader(f) + for row in reader: + random_baselines[row['Eval Task']] = float(row['Random baseline']) + + task_inputs = [] + for task in tasks: + label = task['label'] + task_meta = { + 'task_type': task['icl_task_type'], + 'dataset_uri': task['dataset_uri'], + 'num_fewshot': task['num_fewshot'][0], + 'continuation_delimiter': task.get('continuation_delimiter', ' ') + } + data_path = os.path.join(data_base_path, task_meta['dataset_uri']) + with open(data_path, 'r', encoding='utf-8') as f: + data = [json.loads(line.strip()) for line in f] + shuffle_rng = random.Random(1337) + shuffle_rng.shuffle(data) + if max_per_task > 0: + data = data[:max_per_task] + task_inputs.append((label, task_meta, data)) + + w_label = max(len(t[0]) for t in task_inputs) + w_shot = max(len(f"{t[1]['num_fewshot']}-shot") for t in task_inputs) + w_type = max(len(t[1]['task_type']) for t in task_inputs) + _eval_data_cache = (task_inputs, random_baselines, w_label, w_shot, w_type) + + task_inputs, random_baselines, w_label, w_shot, w_type = _eval_data_cache + + # First run: eagerly prepare next task while evaluating current, cache collated batches. + # Cached runs: pass collated batches directly — no threads, no prepare, no collation. results = {} centered_results = {} - for task in tasks: - start_time = time.time() - label = task['label'] - task_meta = { - 'task_type': task['icl_task_type'], - 'dataset_uri': task['dataset_uri'], - 'num_fewshot': task['num_fewshot'][0], - 'continuation_delimiter': task.get('continuation_delimiter', ' ') - } - print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='') + cached_run = all(label in _batch_cache for label, _, _ in task_inputs) - data_path = os.path.join(data_base_path, task_meta['dataset_uri']) - with open(data_path, 'r', encoding='utf-8') as f: - data = [json.loads(line.strip()) for line in f] + if not cached_run: + executor = ThreadPoolExecutor(max_workers=1) + first_uncached = next(i for i, (l, _, _) in enumerate(task_inputs) if l not in _batch_cache) + _, first_meta, first_data = task_inputs[first_uncached] + next_future = executor.submit(prepare_task_data, tokenizer, first_data, first_meta, max_seq_len) - # Shuffle for consistent subsampling when using max_per_task - shuffle_rng = random.Random(1337) - shuffle_rng.shuffle(data) - if max_per_task > 0: - data = data[:max_per_task] + for i, (label, task_meta, data) in enumerate(task_inputs): + shot_str = f"{task_meta['num_fewshot']}-shot" + prefix = f" {label:<{w_label}} {shot_str:<{w_shot}} {task_meta['task_type']:<{w_type}}" + print0(f"{prefix} ...", end="", flush=True) + t0 = time.time() - accuracy = evaluate_task(model, tokenizer, data, device, task_meta) - results[label] = accuracy + if label in _batch_cache: + accuracy, collated = evaluate_task(model, data, device, collated=_batch_cache[label]) + else: + prepared = next_future.result() + # Kick off prepare for the next uncached task + for j in range(i + 1, len(task_inputs)): + next_label, next_meta, next_data = task_inputs[j] + if next_label not in _batch_cache: + next_future = executor.submit(prepare_task_data, tokenizer, next_data, next_meta, max_seq_len) + break + accuracy, collated = evaluate_task(model, data, device, prepared=prepared) + _batch_cache[label] = collated + + elapsed = time.time() - t0 random_baseline = random_baselines[label] centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline) + results[label] = accuracy centered_results[label] = centered_result - elapsed = time.time() - start_time - print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {elapsed:.2f}s") + delta_str = "" + if label in _prev_centered: + d = centered_result - _prev_centered[label] + arrow = "\u2191" if d > 0 else "\u2193" if d < 0 else "=" + delta_str = f" {arrow}{d:+.4f}" + print0(f"\r{prefix} acc: {accuracy:.4f} centered: {centered_result:>7.4f}{delta_str} time: {elapsed:.2f}s") + + if not cached_run: + executor.shutdown(wait=False) core_metric = sum(centered_results.values()) / len(centered_results) + if _prev_core is not None: + d = core_metric - _prev_core + arrow = "\u2191" if d > 0 else "\u2193" if d < 0 else "=" + print0(f"CORE: {core_metric:.4f} {arrow}{d:+.4f}") + else: + print0(f"CORE: {core_metric:.4f}") + _prev_centered = dict(centered_results) + _prev_core = core_metric out = { "results": results, "centered_results": centered_results, @@ -288,7 +351,7 @@ def main(): print0("CORE Evaluation") print0("="*80) with autocast_ctx: - core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task) + core_results = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task) # Write CSV output if ddp_rank == 0: diff --git a/scripts/base_train.py b/scripts/base_train.py index 996b2ba..813196b 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -32,7 +32,7 @@ from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine from nanochat.flash_attention import HAS_FA3 -from scripts.base_eval import evaluate_core +from scripts.base_eval import evaluate_model print_banner() # ----------------------------------------------------------------------------- @@ -423,7 +423,7 @@ while True: if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)): model.eval() with disable_fp8(orig_model), autocast_ctx: - results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task) + results = evaluate_model(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task) print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") wandb_run.log({ "step": step, diff --git a/scripts/bench_core_eval.py b/scripts/bench_core_eval.py new file mode 100644 index 0000000..d11bc40 --- /dev/null +++ b/scripts/bench_core_eval.py @@ -0,0 +1,338 @@ +""" +Benchmark the CORE evaluation pipeline. + +Compares three modes: + 1. Old sequential (per-example) evaluation + 2. New batched evaluation (first run — includes prepare + collate + forward) + 3. New batched evaluation (cached run — forward only) + +Also sweeps batch_size and queue_size to find optimal hyperparameters. + +Usage: + python -m scripts.bench_core_eval + python -m scripts.bench_core_eval --max-per-task 100 # quick test + python -m scripts.bench_core_eval --hf-path openai-community/gpt2 +""" +import os +import csv +import json +import time +import yaml +import random +import shutil +import zipfile +import tempfile +import argparse +from contextlib import nullcontext + +import torch +import torch.distributed as dist + +from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock +from nanochat.tokenizer import HuggingFaceTokenizer +from nanochat.checkpoint_manager import load_model +from nanochat.core_eval import ( + forward_model, prepare_example, check_result, stack_sequences, + prepare_task_data, _collate_batches, _forward_batches, evaluate_task, + render_prompts_mc, render_prompts_schema, render_prompts_lm, + batch_sequences_mc, batch_sequences_schema, batch_sequences_lm, +) + +# ---- eval bundle loading (reused from base_eval) ---- + +EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip" + +def place_eval_bundle(file_path): + base_dir = get_base_dir() + eval_bundle_dir = os.path.join(base_dir, "eval_bundle") + with tempfile.TemporaryDirectory() as tmpdir: + with zipfile.ZipFile(file_path, 'r') as zip_ref: + zip_ref.extractall(tmpdir) + shutil.move(os.path.join(tmpdir, "eval_bundle"), eval_bundle_dir) + print0(f"Placed eval_bundle at {eval_bundle_dir}") + +def load_tasks(max_per_task=-1): + base_dir = get_base_dir() + eval_bundle_dir = os.path.join(base_dir, "eval_bundle") + if not os.path.exists(eval_bundle_dir): + download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle) + config_path = os.path.join(eval_bundle_dir, "core.yaml") + data_base_path = os.path.join(eval_bundle_dir, "eval_data") + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + task_inputs = [] + for task in config['icl_tasks']: + label = task['label'] + task_meta = { + 'task_type': task['icl_task_type'], + 'dataset_uri': task['dataset_uri'], + 'num_fewshot': task['num_fewshot'][0], + 'continuation_delimiter': task.get('continuation_delimiter', ' ') + } + data_path = os.path.join(data_base_path, task_meta['dataset_uri']) + with open(data_path, 'r', encoding='utf-8') as f: + data = [json.loads(line.strip()) for line in f] + shuffle_rng = random.Random(1337) + shuffle_rng.shuffle(data) + if max_per_task > 0: + data = data[:max_per_task] + task_inputs.append((label, task_meta, data)) + return task_inputs + +# ---- old sequential evaluation (baseline) ---- + +@torch.no_grad() +def evaluate_example_old(idx, model, tokenizer, data, device, task_meta): + """Original per-example sequential evaluation (the old code).""" + item = data[idx] + task_type = task_meta['task_type'] + num_fewshot = task_meta['num_fewshot'] + continuation_delimiter = task_meta['continuation_delimiter'] + + fewshot_examples = [] + if num_fewshot > 0: + rng = random.Random(1234 + idx) + available_indices = [i for i in range(len(data)) if i != idx] + fewshot_indices = rng.sample(available_indices, num_fewshot) + fewshot_examples = [data[i] for i in fewshot_indices] + + if task_type == 'multiple_choice': + prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples) + tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts) + elif task_type == 'schema': + prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples) + tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts) + elif task_type == 'language_modeling': + prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples) + tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts) + else: + raise ValueError(f"Unsupported task type: {task_type}") + + if hasattr(model, 'max_seq_len') and model.max_seq_len is not None: + max_tokens = model.max_seq_len + new_tokens, new_start_idxs, new_end_idxs = [], [], [] + for t, s, e in zip(tokens, start_idxs, end_idxs): + if len(t) > max_tokens: + num_to_crop = len(t) - max_tokens + new_tokens.append(t[-max_tokens:]) + new_start_idxs.append(s - num_to_crop) + new_end_idxs.append(e - num_to_crop) + else: + new_tokens.append(t) + new_start_idxs.append(s) + new_end_idxs.append(e) + tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs + + pad_token_id = tokenizer.get_bos_token_id() + input_ids = stack_sequences(tokens, pad_token_id).to(device) + losses, predictions = forward_model(model, input_ids) + return check_result(losses, predictions, input_ids, start_idxs, end_idxs, item.get('gold', None), task_type) + + +def evaluate_task_old(model, tokenizer, data, device, task_meta): + """Original sequential evaluate_task.""" + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + correct = torch.zeros(len(data), dtype=torch.float32, device=device) + for idx in range(rank, len(data), world_size): + is_correct = evaluate_example_old(idx, model, tokenizer, data, device, task_meta) + correct[idx] = float(is_correct) + if world_size > 1: + dist.barrier() + dist.all_reduce(correct, op=dist.ReduceOp.SUM) + return correct.mean().item() + +# ---- HuggingFace model wrapper ---- + +class ModelWrapper: + def __init__(self, model, max_seq_len=None): + self.model = model + self.max_seq_len = max_seq_len + def __call__(self, input_ids): + return self.model(input_ids).logits + +def load_hf_model(hf_path, device): + from transformers import AutoModelForCausalLM + model = AutoModelForCausalLM.from_pretrained(hf_path) + model.to(device) + model.eval() + max_seq_len = 1024 if "gpt2" in hf_path else None + return ModelWrapper(model, max_seq_len=max_seq_len), HuggingFaceTokenizer.from_pretrained(hf_path) + +# ---- benchmark helpers ---- + +def sync_cuda(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + +def bench_old(model, tokenizer, task_inputs, device): + """Benchmark old sequential evaluation across all tasks.""" + sync_cuda() + t0 = time.time() + results = {} + for label, task_meta, data in task_inputs: + acc = evaluate_task_old(model, tokenizer, data, device, task_meta) + results[label] = acc + sync_cuda() + return time.time() - t0, results + + +def bench_new_first(model, tokenizer, task_inputs, device, batch_size, queue_size): + """Benchmark new batched evaluation (first run, no cache).""" + sync_cuda() + t0 = time.time() + results = {} + collated_cache = {} + max_seq_len = getattr(model, 'max_seq_len', None) + for label, task_meta, data in task_inputs: + prepared = prepare_task_data(tokenizer, data, task_meta, max_seq_len) + acc, collated = evaluate_task(model, data, device, batch_size=batch_size, + queue_size=queue_size, prepared=prepared) + results[label] = acc + collated_cache[label] = collated + sync_cuda() + return time.time() - t0, results, collated_cache + + +def bench_new_cached(model, task_inputs, device, collated_cache): + """Benchmark new batched evaluation (cached run, forward only).""" + sync_cuda() + t0 = time.time() + results = {} + for label, task_meta, data in task_inputs: + acc, _ = evaluate_task(model, data, device, collated=collated_cache[label]) + results[label] = acc + sync_cuda() + return time.time() - t0, results + + +def verify_results(old_results, new_results, label="new"): + """Check that old and new produce the same accuracies.""" + mismatches = [] + for task in old_results: + if task in new_results and abs(old_results[task] - new_results[task]) > 1e-6: + mismatches.append((task, old_results[task], new_results[task])) + if mismatches: + print0(f" WARNING: {label} mismatches vs old:") + for task, old, new in mismatches: + print0(f" {task}: old={old:.6f} {label}={new:.6f}") + else: + print0(f" {label} results match old (verified)") + + +# ---- main ---- + +def main(): + parser = argparse.ArgumentParser(description="Benchmark CORE eval pipeline") + parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path') + parser.add_argument('--model-tag', type=str, default=None, help='nanochat model tag') + parser.add_argument('--step', type=int, default=None, help='Model step to load') + parser.add_argument('--max-per-task', type=int, default=500, help='Max examples per task') + parser.add_argument('--device-type', type=str, default='', help='cuda|cpu|mps') + parser.add_argument('--batch-sizes', type=str, default='1,2,4,8,16,32,64', help='Comma-separated batch sizes to sweep') + parser.add_argument('--queue-sizes', type=str, default='2,4,8,16', help='Comma-separated queue sizes to sweep') + parser.add_argument('--skip-old', action='store_true', help='Skip old sequential baseline (slow)') + args = parser.parse_args() + + batch_sizes = [int(x) for x in args.batch_sizes.split(',')] + queue_sizes = [int(x) for x in args.queue_sizes.split(',')] + + device_type = autodetect_device_type() if args.device_type == '' else args.device_type + ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) + autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() + + # Load model + if args.hf_path is not None: + model, tokenizer = load_hf_model(args.hf_path, device) + model_name = args.hf_path + else: + model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.step) + model_name = f"base_model (step {meta['step']})" + + print0(f"Model: {model_name}") + print0(f"Max per task: {args.max_per_task}") + print0(f"Device: {device}") + print0("") + + # Load tasks + task_inputs = load_tasks(max_per_task=args.max_per_task) + total_examples = sum(len(data) for _, _, data in task_inputs) + print0(f"Loaded {len(task_inputs)} tasks, {total_examples} total examples") + print0("") + + # ---- 1. Old sequential baseline ---- + old_results = None + if not args.skip_old: + print0("=" * 80) + print0("OLD: Sequential per-example evaluation") + print0("=" * 80) + with autocast_ctx: + old_time, old_results = bench_old(model, tokenizer, task_inputs, device) + print0(f" Time: {old_time:.2f}s ({total_examples / old_time:.1f} examples/s)") + print0("") + + # ---- 2. Sweep batch_size x queue_size for first run ---- + print0("=" * 80) + print0("NEW: Batched evaluation — hyperparameter sweep (first run)") + print0("=" * 80) + print0("") + + # Header + qs_header = "".join(f"{'q=' + str(q):>10}" for q in queue_sizes) + print0(f" {'batch_size':>10}{qs_header}") + print0(f" {'':>10}" + "-" * (10 * len(queue_sizes))) + + best_time = float('inf') + best_params = None + best_collated = None + sweep_results = {} + + for bs in batch_sizes: + row = f" {bs:>10}" + for qs in queue_sizes: + with autocast_ctx: + t, results, collated_cache = bench_new_first(model, tokenizer, task_inputs, device, bs, qs) + sweep_results[(bs, qs)] = t + row += f"{t:>9.2f}s" + if t < best_time: + best_time = t + best_params = (bs, qs) + best_collated = collated_cache + best_first_results = results + print0(row) + + print0("") + print0(f" Best: batch_size={best_params[0]}, queue_size={best_params[1]} -> {best_time:.2f}s ({total_examples / best_time:.1f} examples/s)") + + # Verify correctness + if old_results is not None: + verify_results(old_results, best_first_results, label="new-first") + print0("") + + # ---- 3. Cached run (forward only) ---- + print0("=" * 80) + print0("NEW: Cached run (forward only, using best params)") + print0("=" * 80) + with autocast_ctx: + cached_time, cached_results = bench_new_cached(model, task_inputs, device, best_collated) + print0(f" Time: {cached_time:.2f}s ({total_examples / cached_time:.1f} examples/s)") + if old_results is not None: + verify_results(old_results, cached_results, label="new-cached") + print0("") + + # ---- Summary ---- + print0("=" * 80) + print0("SUMMARY") + print0("=" * 80) + if old_results is not None: + print0(f" Old (sequential): {old_time:>8.2f}s") + print0(f" New (first run): {best_time:>8.2f}s batch_size={best_params[0]}, queue_size={best_params[1]}") + print0(f" New (cached): {cached_time:>8.2f}s") + if old_results is not None: + print0(f" Speedup (first): {old_time / best_time:>8.2f}x") + print0(f" Speedup (cached): {old_time / cached_time:>8.2f}x") + + compute_cleanup() + +if __name__ == "__main__": + main() From 7fa30f5ee3944d71e9922c8614e612436cac0159 Mon Sep 17 00:00:00 2001 From: Unsal Gokdag Date: Thu, 12 Feb 2026 22:34:23 +0000 Subject: [PATCH 2/4] CORE eval: disk-cached tokenized batches, double-buffered GPU transfers, batch composition, benchmark improvements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit the main idea: tokenization + collation for CORE eval only needs to happen once per tokenizer. collated batches at base batch_size=4 are saved to disk (core_token_cache/), keyed by SHA-256 of the tokenizer file. any batch_size can be served from these base-4 batches: larger sizes merge consecutive batches (right-pad shorter ones, cat along dim=0), smaller sizes split along example boundaries (trim trailing padding). this means prepare_task_data is truly a one-time cost. core_eval.py: - double-buffered CPU->GPU transfers in both forward paths (_forward_batches and evaluate_task's pipelined path). while GPU runs forward_model on batch N, batch N+1 is pin_memory()'d and DMA-transferred via non_blocking=True. the DMA engine and GPU compute units are separate hardware so they overlap. previously GPU idled during every transfer. - compose_collated(): merge base batches for larger batch_size (cat after right-padding to max_len), or split for smaller batch_size (slice along row boundaries from batch_meta, trim trailing padding via vectorized non_pad.any(dim=0)). works because examples are sorted by seq_len, so consecutive base batches have monotonically increasing lengths. - evaluate_task and _forward_batches accept optional pbar for progress tracking. base_eval.py: - evaluate_model now has 3-tier caching: in-memory (_batch_cache, across calls within same process), disk load (core_token_cache/, on first call when in-memory is empty), disk save (after first run's prepare+collate+forward, writes collated batches so future training runs and the benchmark skip tokenization entirely). keyed by tokenizer file hash + max_per_task. bench_core_eval.py: - cached sweep no longer re-runs the full first-run sweep to build collated data (was 2x the work for no reason). instead loads/builds base-4 cache once, then compose_collated serves any target batch_size. cached sweep only varies batch_size (no queue_size — no collation thread). - --skip-first: skip the first-run sweep entirely if disk cache exists. if cache is missing, runs a single bs=4 eval in minimal time to create it, then proceeds to cached sweep. - tqdm progress bars everywhere: old sequential baseline (per-example with task name), first-run sweep (double bar: outer=combo progress, inner=per-example), cache building (per-task), cached sweep (double bar). task names left-padded to max label length so the bar doesn't shift. - tokenizer identity via file_checksum (SHA-256 of tokenizer.pkl/tokenizer.json on disk), not encode-output hashing. HF models fall back to hashing the repo name. --- nanochat/core_eval.py | 121 ++++++++++++++++-- scripts/base_eval.py | 43 +++++++ scripts/bench_core_eval.py | 255 ++++++++++++++++++++++++++++++------- 3 files changed, 358 insertions(+), 61 deletions(-) diff --git a/nanochat/core_eval.py b/nanochat/core_eval.py index a81c810..9231520 100644 --- a/nanochat/core_eval.py +++ b/nanochat/core_eval.py @@ -277,11 +277,37 @@ def prepare_task_data(tokenizer, data, task_meta, max_seq_len=None): return prepared -def _forward_batches(model, collated, data, device): - """Run GPU forward passes on pre-collated batches, return per-example correctness tensor.""" +def _prefetch_to_device(tensor, device): + """Pin and async-transfer a CPU tensor to GPU, overlapping with current GPU work.""" + return tensor.pin_memory().to(device, non_blocking=True) + + +def _forward_batches(model, collated, data, device, pbar=None): + """Run GPU forward passes on pre-collated batches, return per-example correctness tensor. + + Uses double-buffered prefetching on CUDA: while the GPU processes batch N, + batch N+1 is pinned and DMA-transferred asynchronously, keeping the GPU fed. + """ correct = torch.zeros(len(data), dtype=torch.float32, device=device) - for combined_ids, batch_meta in collated: - combined_ids = combined_ids.to(device) + if not collated: + return correct + + use_prefetch = torch.cuda.is_available() and 'cuda' in str(device) + + # Prefetch first batch + if use_prefetch: + next_ids = _prefetch_to_device(collated[0][0], device) + else: + next_ids = collated[0][0].to(device) + + for i, (_, batch_meta) in enumerate(collated): + combined_ids = next_ids + # Start async transfer of next batch while GPU computes on current + if i + 1 < len(collated): + if use_prefetch: + next_ids = _prefetch_to_device(collated[i + 1][0], device) + else: + next_ids = collated[i + 1][0].to(device) losses, predictions = forward_model(model, combined_ids) @@ -294,11 +320,63 @@ def _forward_batches(model, collated, data, device): ) correct[idx] = float(is_correct) offset += n + if pbar is not None: + pbar.update(len(batch_meta)) return correct +def compose_collated(base_collated, target_batch_size, base_batch_size=4, pad_token_id=0): + """Compose base-sized collated batches into target-sized batches. + + Supports both merging (target > base) by concatenating consecutive groups, + and splitting (target < base) by slicing along example boundaries. + Examples are sorted by seq_len within each base batch, so splitting can + trim trailing padding columns for efficiency. + """ + if target_batch_size == base_batch_size: + return base_collated + elif target_batch_size > base_batch_size: + # Merge consecutive base batches + n_merge = target_batch_size // base_batch_size + composed = [] + for i in range(0, len(base_collated), n_merge): + group = base_collated[i:i + n_merge] + if len(group) == 1: + composed.append(group[0]) + continue + max_len = max(ids.shape[1] for ids, _ in group) + parts = [] + merged_meta = [] + for ids, meta in group: + if ids.shape[1] < max_len: + pad = torch.full((ids.shape[0], max_len - ids.shape[1]), pad_token_id, dtype=ids.dtype) + ids = torch.cat([ids, pad], dim=1) + parts.append(ids) + merged_meta.extend(meta) + composed.append((torch.cat(parts, dim=0), merged_meta)) + return composed + else: + # Split base batches into smaller chunks + composed = [] + for combined_ids, batch_meta in base_collated: + for chunk_start in range(0, len(batch_meta), target_batch_size): + chunk_meta = batch_meta[chunk_start:chunk_start + target_batch_size] + row_start = sum(m[1] for m in batch_meta[:chunk_start]) + row_end = row_start + sum(m[1] for m in chunk_meta) + chunk_ids = combined_ids[row_start:row_end] + # Trim trailing padding (examples sorted by seq_len, so chunks + # near the start of a base batch may need fewer columns) + non_pad = (chunk_ids != pad_token_id) + if non_pad.any(): + last_col = non_pad.any(dim=0).nonzero()[-1].item() + 1 + if last_col < chunk_ids.shape[1]: + chunk_ids = chunk_ids[:, :last_col].contiguous() + composed.append((chunk_ids, chunk_meta)) + return composed + + def evaluate_task(model, data, device, batch_size=4, queue_size=2, prepared=None, - collated=None, tokenizer=None, task_meta=None): + collated=None, tokenizer=None, task_meta=None, pbar=None): """ Evaluate one task across many examples with batched GPU forward passes. Examples are sorted by sequence length so similar-length sequences are batched @@ -316,7 +394,7 @@ def evaluate_task(model, data, device, batch_size=4, queue_size=2, prepared=None if collated is not None: # Fast path: just GPU forward passes, no threads - correct = _forward_batches(model, collated, data, device) + correct = _forward_batches(model, collated, data, device, pbar=pbar) else: from queue import Queue from threading import Thread @@ -325,21 +403,34 @@ def evaluate_task(model, data, device, batch_size=4, queue_size=2, prepared=None max_seq_len = getattr(model, 'max_seq_len', None) prepared = prepare_task_data(tokenizer, data, task_meta, max_seq_len) - # Collation thread pipelined with GPU forward passes + # Collation thread pipelined with GPU forward passes. + # Double-buffered: while GPU processes batch N, batch N+1 is + # pin_memory()'d and DMA-transferred asynchronously. queue = Queue(maxsize=queue_size) collator = Thread(target=_collate_batches, args=(prepared, batch_size, queue), daemon=True) collator.start() + use_prefetch = torch.cuda.is_available() and 'cuda' in str(device) + def transfer(tensor): + return _prefetch_to_device(tensor, device) if use_prefetch else tensor.to(device) + collated = [] correct = torch.zeros(len(data), dtype=torch.float32, device=device) - while True: - item = queue.get() - if item is None: - break - collated.append(item) - combined_ids, batch_meta = item - combined_ids = combined_ids.to(device) + # Prime: get first batch and start its transfer + item = queue.get() + if item is not None: + next_ids = transfer(item[0]) + + while item is not None: + collated.append(item) + combined_ids = next_ids + _, batch_meta = item + + # Start async transfer of next batch (overlaps with forward pass below) + item = queue.get() + if item is not None: + next_ids = transfer(item[0]) losses, predictions = forward_model(model, combined_ids) @@ -352,6 +443,8 @@ def evaluate_task(model, data, device, batch_size=4, queue_size=2, prepared=None ) correct[idx] = float(is_correct) offset += n + if pbar is not None: + pbar.update(len(batch_meta)) collator.join() del prepared diff --git a/scripts/base_eval.py b/scripts/base_eval.py index b4aaf56..d1751fc 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -26,6 +26,7 @@ import json import yaml import shutil import random +import hashlib import zipfile import tempfile import argparse @@ -113,6 +114,22 @@ _prev_centered = {} # {label: centered_result} — previous run for delta d _prev_core = None # previous core_metric +def _get_disk_cache_dir(max_per_task): + """Get disk cache dir for base-4 collated batches, keyed by tokenizer file hash. + Returns None if no local tokenizer file is found (e.g. HuggingFace models).""" + base_dir = get_base_dir() + for fname in ("tokenizer.pkl", "tokenizer.json"): + path = os.path.join(base_dir, "tokenizer", fname) + if os.path.exists(path): + h = hashlib.sha256() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + h.update(chunk) + tok_hash = h.hexdigest()[:16] + return os.path.join(base_dir, "core_token_cache", f"{tok_hash}_n{max_per_task}") + return None + + def evaluate_model(model, tokenizer, device, max_per_task=-1): """ Evaluate a base model on the CORE benchmark. @@ -180,6 +197,22 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1): results = {} centered_results = {} cached_run = all(label in _batch_cache for label, _, _ in task_inputs) + disk_cache_dir = _get_disk_cache_dir(max_per_task) + + # Try loading from disk cache if in-memory cache is empty + if not cached_run and disk_cache_dir is not None: + all_on_disk = os.path.isdir(disk_cache_dir) and all( + os.path.exists(os.path.join(disk_cache_dir, f"{label}.pt")) + for label, _, _ in task_inputs + ) + if all_on_disk: + for label, _, _ in task_inputs: + d = torch.load(os.path.join(disk_cache_dir, f"{label}.pt"), weights_only=False) + _batch_cache[label] = d['collated'] + cached_run = True + print0(" (loaded collated batches from disk cache)") + + first_run = not cached_run # track whether we did prepare+collate (for disk save) if not cached_run: executor = ThreadPoolExecutor(max_workers=1) @@ -221,6 +254,16 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1): if not cached_run: executor.shutdown(wait=False) + # Save collated batches to disk after first run (so bench/future runs skip prepare+collate) + if first_run and disk_cache_dir is not None: + pad_id = tokenizer.get_bos_token_id() + os.makedirs(disk_cache_dir, exist_ok=True) + for label, _, _ in task_inputs: + if label in _batch_cache: + torch.save({'collated': _batch_cache[label], 'pad_token_id': pad_id}, + os.path.join(disk_cache_dir, f"{label}.pt")) + print0(f" (saved collated batches to {disk_cache_dir})") + core_metric = sum(centered_results.values()) / len(centered_results) if _prev_core is not None: d = core_metric - _prev_core diff --git a/scripts/bench_core_eval.py b/scripts/bench_core_eval.py index d11bc40..ae7294e 100644 --- a/scripts/bench_core_eval.py +++ b/scripts/bench_core_eval.py @@ -20,10 +20,12 @@ import time import yaml import random import shutil +import hashlib import zipfile import tempfile import argparse from contextlib import nullcontext +from tqdm import tqdm import torch import torch.distributed as dist @@ -34,12 +36,13 @@ from nanochat.checkpoint_manager import load_model from nanochat.core_eval import ( forward_model, prepare_example, check_result, stack_sequences, prepare_task_data, _collate_batches, _forward_batches, evaluate_task, + compose_collated, render_prompts_mc, render_prompts_schema, render_prompts_lm, batch_sequences_mc, batch_sequences_schema, batch_sequences_lm, ) # ---- eval bundle loading (reused from base_eval) ---- - +torch.backends.fp32_precision = "tf32" EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip" def place_eval_bundle(file_path): @@ -79,6 +82,69 @@ def load_tasks(max_per_task=-1): task_inputs.append((label, task_meta, data)) return task_inputs +BASE_BATCH_SIZE = 4 + +def file_checksum(path): + """SHA-256 checksum of a file, truncated to 16 hex chars.""" + h = hashlib.sha256() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + h.update(chunk) + return h.hexdigest()[:16] + + +def collate_prepared(prepared, batch_size): + """Collate prepared examples into batches (non-threaded). Returns (collated, pad_token_id).""" + pad_id = prepared[0][1]['pad_token_id'] + collated = [] + for batch_start in range(0, len(prepared), batch_size): + batch = prepared[batch_start:batch_start + batch_size] + batch_preps = [p for _, p in batch] + max_len = max(p['seq_len'] for p in batch_preps) + total_rows = sum(p['num_options'] for p in batch_preps) + combined_ids = torch.full((total_rows, max_len), pad_id, dtype=torch.long) + batch_meta = [] + offset = 0 + for idx, p in batch: + n, sl = p['num_options'], p['seq_len'] + combined_ids[offset:offset+n, :sl] = p['input_ids'] + batch_meta.append((idx, n, p['start_idxs'], p['end_idxs'], p['gold'], p['task_type'])) + offset += n + collated.append((combined_ids, batch_meta)) + return collated, pad_id + + +def build_or_load_base_collated(tok_hash, tokenizer, task_inputs, max_seq_len, max_per_task): + """Build or load base-4 collated data for all tasks, with disk caching.""" + base_dir = get_base_dir() + cache_dir = os.path.join(base_dir, "core_token_cache", f"{tok_hash}_n{max_per_task}") + + all_cached = os.path.isdir(cache_dir) and all( + os.path.exists(os.path.join(cache_dir, f"{label}.pt")) + for label, _, _ in task_inputs + ) + + base_cache = {} # label -> (collated, pad_token_id) + if all_cached: + print0(f"Loading base-{BASE_BATCH_SIZE} collated cache from {cache_dir}") + for label, _, _ in tqdm(task_inputs, desc="Loading cache", leave=False): + data = torch.load(os.path.join(cache_dir, f"{label}.pt"), weights_only=False) + base_cache[label] = (data['collated'], data['pad_token_id']) + else: + print0(f"Building base-{BASE_BATCH_SIZE} collated cache (saving to {cache_dir})") + os.makedirs(cache_dir, exist_ok=True) + for label, task_meta, data in tqdm(task_inputs, desc="Preparing tasks", leave=False): + prepared = prepare_task_data(tokenizer, data, task_meta, max_seq_len) + collated, pad_id = collate_prepared(prepared, BASE_BATCH_SIZE) + base_cache[label] = (collated, pad_id) + torch.save({'collated': collated, 'pad_token_id': pad_id}, + os.path.join(cache_dir, f"{label}.pt")) + del prepared + print0(f"Saved {len(base_cache)} task caches to {cache_dir}") + + return base_cache + + # ---- old sequential evaluation (baseline) ---- @torch.no_grad() @@ -129,7 +195,7 @@ def evaluate_example_old(idx, model, tokenizer, data, device, task_meta): return check_result(losses, predictions, input_ids, start_idxs, end_idxs, item.get('gold', None), task_type) -def evaluate_task_old(model, tokenizer, data, device, task_meta): +def evaluate_task_old(model, tokenizer, data, device, task_meta, pbar=None): """Original sequential evaluate_task.""" rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 @@ -137,6 +203,8 @@ def evaluate_task_old(model, tokenizer, data, device, task_meta): for idx in range(rank, len(data), world_size): is_correct = evaluate_example_old(idx, model, tokenizer, data, device, task_meta) correct[idx] = float(is_correct) + if pbar is not None: + pbar.update(1) if world_size > 1: dist.barrier() dist.all_reduce(correct, op=dist.ReduceOp.SUM) @@ -170,37 +238,48 @@ def bench_old(model, tokenizer, task_inputs, device): sync_cuda() t0 = time.time() results = {} + total = sum(len(data) for _, _, data in task_inputs) + max_label_len = max(len(label) for label, _, _ in task_inputs) + pbar = tqdm(total=total, desc="Sequential", leave=False) for label, task_meta, data in task_inputs: - acc = evaluate_task_old(model, tokenizer, data, device, task_meta) + pbar.set_description(f"Sequential: {label:<{max_label_len}}") + acc = evaluate_task_old(model, tokenizer, data, device, task_meta, pbar=pbar) results[label] = acc + pbar.close() sync_cuda() return time.time() - t0, results -def bench_new_first(model, tokenizer, task_inputs, device, batch_size, queue_size): +def bench_new_first(model, tokenizer, task_inputs, device, batch_size, queue_size, pbar=None): """Benchmark new batched evaluation (first run, no cache).""" sync_cuda() t0 = time.time() results = {} collated_cache = {} max_seq_len = getattr(model, 'max_seq_len', None) + max_label_len = max(len(label) for label, _, _ in task_inputs) for label, task_meta, data in task_inputs: + if pbar is not None: + pbar.set_description(f"{label:<{max_label_len}}") prepared = prepare_task_data(tokenizer, data, task_meta, max_seq_len) acc, collated = evaluate_task(model, data, device, batch_size=batch_size, - queue_size=queue_size, prepared=prepared) + queue_size=queue_size, prepared=prepared, pbar=pbar) results[label] = acc collated_cache[label] = collated sync_cuda() return time.time() - t0, results, collated_cache -def bench_new_cached(model, task_inputs, device, collated_cache): +def bench_new_cached(model, task_inputs, device, collated_cache, pbar=None): """Benchmark new batched evaluation (cached run, forward only).""" sync_cuda() t0 = time.time() results = {} + max_label_len = max(len(label) for label, _, _ in task_inputs) for label, task_meta, data in task_inputs: - acc, _ = evaluate_task(model, data, device, collated=collated_cache[label]) + if pbar is not None: + pbar.set_description(f"{label:<{max_label_len}}") + acc, _ = evaluate_task(model, data, device, collated=collated_cache[label], pbar=pbar) results[label] = acc sync_cuda() return time.time() - t0, results @@ -232,6 +311,7 @@ def main(): parser.add_argument('--batch-sizes', type=str, default='1,2,4,8,16,32,64', help='Comma-separated batch sizes to sweep') parser.add_argument('--queue-sizes', type=str, default='2,4,8,16', help='Comma-separated queue sizes to sweep') parser.add_argument('--skip-old', action='store_true', help='Skip old sequential baseline (slow)') + parser.add_argument('--skip-first', action='store_true', help='Skip first-run sweep (requires cached collated data)') args = parser.parse_args() batch_sizes = [int(x) for x in args.batch_sizes.split(',')] @@ -272,52 +352,131 @@ def main(): print0("") # ---- 2. Sweep batch_size x queue_size for first run ---- + # Compute tok_hash early — needed for both skip-first check and cache loading + max_seq_len = getattr(model, 'max_seq_len', None) + if args.hf_path is not None: + tok_hash = hashlib.sha256(args.hf_path.encode()).hexdigest()[:16] + else: + base_dir = get_base_dir() + for fname in ("tokenizer.pkl", "tokenizer.json"): + tok_path = os.path.join(base_dir, "tokenizer", fname) + if os.path.exists(tok_path): + tok_hash = file_checksum(tok_path) + break + + # Check if we can skip the first-run sweep + best_time = None + best_params = None + if args.skip_first: + cache_dir = os.path.join(get_base_dir(), "core_token_cache", f"{tok_hash}_n{args.max_per_task}") + has_cache = os.path.isdir(cache_dir) and all( + os.path.exists(os.path.join(cache_dir, f"{label}.pt")) + for label, _, _ in task_inputs + ) + if has_cache: + print0("Skipping first-run sweep (--skip-first, cache exists)") + print0("") + else: + print0(f"--skip-first: cache missing, running single bs={BASE_BATCH_SIZE} eval to create it...") + pbar = tqdm(total=total_examples, desc="Creating cache", leave=False) + with autocast_ctx: + _, _, collated_cache = bench_new_first(model, tokenizer, task_inputs, device, + batch_size=BASE_BATCH_SIZE, queue_size=2, pbar=pbar) + pbar.close() + pad_id = tokenizer.get_bos_token_id() + os.makedirs(cache_dir, exist_ok=True) + for label in collated_cache: + torch.save({'collated': collated_cache[label], 'pad_token_id': pad_id}, + os.path.join(cache_dir, f"{label}.pt")) + print0(f"Cache created ({len(collated_cache)} tasks)") + print0("") + + if not args.skip_first: + print0("=" * 80) + print0("NEW: Batched evaluation — hyperparameter sweep (first run)") + print0("=" * 80) + print0("") + + # Header + qs_header = "".join(f"{'q=' + str(q):>10}" for q in queue_sizes) + print0(f" {'batch_size':>10}{qs_header}") + print0(f" {'':>10}" + "-" * (10 * len(queue_sizes))) + + best_time = float('inf') + best_params = None + best_collated = None + sweep_results = {} + total_combos = len(batch_sizes) * len(queue_sizes) + outer_pbar = tqdm(total=total_combos, desc="First-run sweep", leave=False, position=0) + inner_pbar = tqdm(total=total_examples, desc="", leave=False, position=1) + + for bs in batch_sizes: + row = f" {bs:>10}" + for qs in queue_sizes: + outer_pbar.set_description(f"First-run: bs={bs} qs={qs}") + inner_pbar.reset() + with autocast_ctx: + t, results, collated_cache = bench_new_first(model, tokenizer, task_inputs, device, bs, qs, pbar=inner_pbar) + sweep_results[(bs, qs)] = t + row += f"{t:>9.2f}s" + if t < best_time: + best_time = t + best_params = (bs, qs) + best_collated = collated_cache + best_first_results = results + outer_pbar.update(1) + outer_pbar.write(row) + inner_pbar.close() + outer_pbar.close() + + print0("") + print0(f" Best: batch_size={best_params[0]}, queue_size={best_params[1]} -> {best_time:.2f}s ({total_examples / best_time:.1f} examples/s)") + + # Verify correctness + if old_results is not None: + verify_results(old_results, best_first_results, label="new-first") + print0("") + + # ---- 3. Build/load base-4 collated cache ---- + base_cache = build_or_load_base_collated(tok_hash, tokenizer, task_inputs, max_seq_len, args.max_per_task) + + # ---- 4. Cached run sweep (forward only, composed from base-4) ---- print0("=" * 80) - print0("NEW: Batched evaluation — hyperparameter sweep (first run)") + print0(f"NEW: Cached run (forward only, composed from base-{BASE_BATCH_SIZE})") print0("=" * 80) print0("") - # Header - qs_header = "".join(f"{'q=' + str(q):>10}" for q in queue_sizes) - print0(f" {'batch_size':>10}{qs_header}") - print0(f" {'':>10}" + "-" * (10 * len(queue_sizes))) - - best_time = float('inf') - best_params = None - best_collated = None - sweep_results = {} + best_cached_time = float('inf') + best_cached_params = None + outer_pbar = tqdm(total=len(batch_sizes), desc="Cached sweep", leave=False, position=0) + inner_pbar = tqdm(total=total_examples, desc="", leave=False, position=1) for bs in batch_sizes: - row = f" {bs:>10}" - for qs in queue_sizes: - with autocast_ctx: - t, results, collated_cache = bench_new_first(model, tokenizer, task_inputs, device, bs, qs) - sweep_results[(bs, qs)] = t - row += f"{t:>9.2f}s" - if t < best_time: - best_time = t - best_params = (bs, qs) - best_collated = collated_cache - best_first_results = results - print0(row) + outer_pbar.set_description(f"Cached: bs={bs}") + inner_pbar.reset() + # Compose from base-4 to target batch_size (merge or split) + composed_cache = {} + for label, (collated, pad_id) in base_cache.items(): + composed_cache[label] = compose_collated(collated, bs, BASE_BATCH_SIZE, pad_id) + + with autocast_ctx: + t, cached_results = bench_new_cached(model, task_inputs, device, composed_cache, pbar=inner_pbar) + + outer_pbar.write(f" batch_size={bs:>3}: {t:.2f}s ({total_examples / t:.1f} examples/s)") + outer_pbar.update(1) + + if t < best_cached_time: + best_cached_time = t + best_cached_params = bs + best_cached_results = cached_results + inner_pbar.close() + outer_pbar.close() print0("") - print0(f" Best: batch_size={best_params[0]}, queue_size={best_params[1]} -> {best_time:.2f}s ({total_examples / best_time:.1f} examples/s)") + print0(f" Best: batch_size={best_cached_params} -> {best_cached_time:.2f}s ({total_examples / best_cached_time:.1f} examples/s)") - # Verify correctness if old_results is not None: - verify_results(old_results, best_first_results, label="new-first") - print0("") - - # ---- 3. Cached run (forward only) ---- - print0("=" * 80) - print0("NEW: Cached run (forward only, using best params)") - print0("=" * 80) - with autocast_ctx: - cached_time, cached_results = bench_new_cached(model, task_inputs, device, best_collated) - print0(f" Time: {cached_time:.2f}s ({total_examples / cached_time:.1f} examples/s)") - if old_results is not None: - verify_results(old_results, cached_results, label="new-cached") + verify_results(old_results, best_cached_results, label="new-cached") print0("") # ---- Summary ---- @@ -326,11 +485,13 @@ def main(): print0("=" * 80) if old_results is not None: print0(f" Old (sequential): {old_time:>8.2f}s") - print0(f" New (first run): {best_time:>8.2f}s batch_size={best_params[0]}, queue_size={best_params[1]}") - print0(f" New (cached): {cached_time:>8.2f}s") + if best_time is not None: + print0(f" New (first run): {best_time:>8.2f}s batch_size={best_params[0]}, queue_size={best_params[1]}") + print0(f" New (cached): {best_cached_time:>8.2f}s batch_size={best_cached_params}") if old_results is not None: - print0(f" Speedup (first): {old_time / best_time:>8.2f}x") - print0(f" Speedup (cached): {old_time / cached_time:>8.2f}x") + if best_time is not None: + print0(f" Speedup (first): {old_time / best_time:>8.2f}x") + print0(f" Speedup (cached): {old_time / best_cached_time:>8.2f}x") compute_cleanup() From c3f234cfca4f11a747baaf8d072eab41a132e141 Mon Sep 17 00:00:00 2001 From: Unsal Gokdag Date: Fri, 13 Feb 2026 07:54:53 +0000 Subject: [PATCH 3/4] CORE eval: GPU-resident data, continuous pipeline, per-task progress bars MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit three independent improvements to the cached CORE evaluation path: 1. GPU-resident data: all base-4 collated batches (~144MB for full CORE eval) are moved to GPU upfront via .to(device). eliminates all CPU→GPU transfers from the forward loop. _forward_all_cached replaces double-buffered prefetch with a simple upfront bulk transfer — .to() is a no-op when the caller has already preloaded tensors to GPU (as bench_core_eval now does). 2. continuous cross-task pipeline: _forward_all_cached flattens all tasks' batches into one stream. the last batch of task N flows directly into the first batch of task N+1 with no pipeline restart. GPU-side composition via merge (pad+cat for bs > base) and split (row-slice for bs < base) avoids the CPU-side compose_collated bottleneck that made bs=8 slower than bs=4. 3. progress bars + per-task result printing: both cached and first-run paths in evaluate_model now show a tqdm progress bar with the current task label. on_task_done callback in _forward_all_cached prints each task's accuracy as soon as its last batch is processed (single-GPU). DDP falls back to printing after all_reduce. both paths print total elapsed time at the end. bench_core_eval: preloads ALL base-4 batches to GPU once before the batch-size sweep. all sweep iterations compose from GPU-resident tensors with zero CPU→GPU transfers in the hot loop. --- nanochat/core_eval.py | 111 +++++++++++++++++++++++++++++++++++++ scripts/base_eval.py | 111 ++++++++++++++++++++++++------------- scripts/bench_core_eval.py | 52 +++++++++++------ 3 files changed, 220 insertions(+), 54 deletions(-) diff --git a/nanochat/core_eval.py b/nanochat/core_eval.py index 9231520..28f50f5 100644 --- a/nanochat/core_eval.py +++ b/nanochat/core_eval.py @@ -325,6 +325,117 @@ def _forward_batches(model, collated, data, device, pbar=None): return correct +def _forward_all_cached(model, task_collated, device, pbar=None, task_labels=None, + on_task_done=None, merge=1, split=1, pad_token_id=0): + """Run all tasks' cached batches through the model in one pass. + + All batch tensors are moved to device upfront (~144MB for full CORE eval). + If tensors are already on device (caller preloaded), .to() is a no-op. + Composition (merge/split) happens entirely on device: + - merge > 1: pad+cat consecutive base batches on GPU before forwarding. + - split > 1: slice each group into chunks by example boundaries, + forward each chunk separately. + + Args: + task_collated: list of (collated_batches, data) per task + pbar: optional progress bar, updated per forward pass (by number of examples) + task_labels: optional list of task names for pbar description updates + on_task_done: optional callback(task_idx, correct_tensor) fired when a task completes + merge: number of consecutive base batches to compose per group (>= 1) + split: number of forward passes to split each group into (>= 1) + pad_token_id: token id used for padding when merging batches of different lengths + Returns: + list of correct tensors (one per task, on device) + """ + # Flatten all batches and move to device upfront (no-op if already there) + flat_stream = [] # (gpu_ids, batch_meta, task_idx) + correct = [] + task_batch_counts = [] + for task_idx, (collated, data) in enumerate(task_collated): + correct.append(torch.zeros(len(data), dtype=torch.float32, device=device)) + task_batch_counts.append(len(collated)) + for combined_ids, batch_meta in collated: + flat_stream.append((combined_ids.to(device), batch_meta, task_idx)) + + if not flat_stream: + return correct + + task_batches_remaining = list(task_batch_counts) + current_task = -1 + buffer_ids = [] + buffer_info = [] + + for i, (combined_ids, batch_meta, task_idx) in enumerate(flat_stream): + # Update pbar description on task transition + if task_idx != current_task: + current_task = task_idx + if pbar is not None and task_labels is not None: + pbar.set_description(task_labels[task_idx]) + buffer_ids.append(combined_ids) + buffer_info.append((batch_meta, task_idx)) + + # Accumulate until we have `merge` batches (or hit the end) + if len(buffer_ids) < merge and i < len(flat_stream) - 1: + continue + + # GPU compose: pad+cat if multiple batches, otherwise use as-is + if len(buffer_ids) == 1: + mega_ids = buffer_ids[0] + else: + max_len = max(t.shape[1] for t in buffer_ids) + parts = [] + for t in buffer_ids: + if t.shape[1] < max_len: + pad = torch.full((t.shape[0], max_len - t.shape[1]), pad_token_id, + dtype=t.dtype, device=t.device) + t = torch.cat([t, pad], dim=1) + parts.append(t) + mega_ids = torch.cat(parts, dim=0) + + # Flatten examples with row boundaries (for splitting) + examples = [] + row_bounds = [0] + for bm, tidx in buffer_info: + for idx, n, start_idxs, end_idxs, gold, task_type in bm: + examples.append((idx, n, start_idxs, end_idxs, gold, task_type, tidx)) + row_bounds.append(row_bounds[-1] + n) + + # Forward + score (with optional GPU split) + n_ex = len(examples) + chunk_size = -(-n_ex // split) # ceiling division + + for cs in range(0, n_ex, chunk_size): + ce = min(cs + chunk_size, n_ex) + chunk = examples[cs:ce] + chunk_ids = mega_ids[row_bounds[cs]:row_bounds[ce]] + + losses, predictions = forward_model(model, chunk_ids) + + offset = 0 + for idx, n, start_idxs, end_idxs, gold, task_type, tidx in chunk: + is_correct = check_result( + losses[offset:offset+n], predictions[offset:offset+n], + chunk_ids[offset:offset+n], + start_idxs, end_idxs, gold, task_type, + ) + correct[tidx][idx] = float(is_correct) + offset += n + if pbar is not None: + pbar.update(len(chunk)) + + # Fire callback for any tasks that just completed all their batches + if on_task_done is not None: + for bm, tidx in buffer_info: + task_batches_remaining[tidx] -= 1 + if task_batches_remaining[tidx] == 0: + on_task_done(tidx, correct[tidx]) + + buffer_ids.clear() + buffer_info.clear() + + return correct + + def compose_collated(base_collated, target_batch_size, base_batch_size=4, pad_token_id=0): """Compose base-sized collated batches into target-sized batches. diff --git a/scripts/base_eval.py b/scripts/base_eval.py index d1751fc..dcee8da 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -31,13 +31,14 @@ import zipfile import tempfile import argparse from contextlib import nullcontext +from tqdm import tqdm import torch from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock from nanochat.tokenizer import HuggingFaceTokenizer, get_token_bytes from nanochat.checkpoint_manager import load_model -from nanochat.core_eval import evaluate_task, prepare_task_data +from nanochat.core_eval import evaluate_task, prepare_task_data, _forward_all_cached from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine @@ -214,55 +215,91 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1): first_run = not cached_run # track whether we did prepare+collate (for disk save) - if not cached_run: - executor = ThreadPoolExecutor(max_workers=1) - first_uncached = next(i for i, (l, _, _) in enumerate(task_inputs) if l not in _batch_cache) - _, first_meta, first_data = task_inputs[first_uncached] - next_future = executor.submit(prepare_task_data, tokenizer, first_data, first_meta, max_seq_len) + import torch.distributed as dist + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + total_examples = sum(len(data) for _, _, data in task_inputs) + task_labels = [label for label, _, _ in task_inputs] + pbar = tqdm(total=total_examples, leave=False, disable=(rank != 0)) - for i, (label, task_meta, data) in enumerate(task_inputs): - shot_str = f"{task_meta['num_fewshot']}-shot" - prefix = f" {label:<{w_label}} {shot_str:<{w_shot}} {task_meta['task_type']:<{w_type}}" - print0(f"{prefix} ...", end="", flush=True) - t0 = time.time() - - if label in _batch_cache: - accuracy, collated = evaluate_task(model, data, device, collated=_batch_cache[label]) - else: - prepared = next_future.result() - # Kick off prepare for the next uncached task - for j in range(i + 1, len(task_inputs)): - next_label, next_meta, next_data = task_inputs[j] - if next_label not in _batch_cache: - next_future = executor.submit(prepare_task_data, tokenizer, next_data, next_meta, max_seq_len) - break - accuracy, collated = evaluate_task(model, data, device, prepared=prepared) - _batch_cache[label] = collated - - elapsed = time.time() - t0 + def _print_task_result(tidx, accuracy): + """Print one task's result. Updates results/centered_results dicts.""" + label, task_meta, _ = task_inputs[tidx] random_baseline = random_baselines[label] centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline) results[label] = accuracy centered_results[label] = centered_result + shot_str = f"{task_meta['num_fewshot']}-shot" + prefix = f" {label:<{w_label}} {shot_str:<{w_shot}} {task_meta['task_type']:<{w_type}}" delta_str = "" if label in _prev_centered: d = centered_result - _prev_centered[label] arrow = "\u2191" if d > 0 else "\u2193" if d < 0 else "=" delta_str = f" {arrow}{d:+.4f}" - print0(f"\r{prefix} acc: {accuracy:.4f} centered: {centered_result:>7.4f}{delta_str} time: {elapsed:.2f}s") + if rank == 0: + pbar.write(f"{prefix} acc: {accuracy:.4f} centered: {centered_result:>7.4f}{delta_str}") - if not cached_run: - executor.shutdown(wait=False) + def _on_task_done(tidx, correct): + """Callback for _forward_all_cached: convert tensor to accuracy and print.""" + _print_task_result(tidx, correct.mean().item()) + + if cached_run: + # Continuous pipeline: all tasks in one GPU stream, results printed per-task as they complete + t0 = time.time() + task_collated = [(_batch_cache[label], data) for label, _, data in task_inputs] + correct_list = _forward_all_cached( + model, task_collated, device, pbar=pbar, task_labels=task_labels, + on_task_done=_on_task_done if world_size == 1 else None, + ) + elapsed_total = time.time() - t0 + pbar.close() + + # DDP: all_reduce + print (single-GPU already handled by on_task_done above) + if world_size > 1: + for tidx, ((label, task_meta, data), correct) in enumerate(zip(task_inputs, correct_list)): + dist.barrier() + dist.all_reduce(correct, op=dist.ReduceOp.SUM) + _print_task_result(tidx, correct.mean().item()) + print0(f" (all tasks: {elapsed_total:.2f}s)") + else: + t0 = time.time() + executor = ThreadPoolExecutor(max_workers=1) + first_uncached = next(i for i, (l, _, _) in enumerate(task_inputs) if l not in _batch_cache) + _, first_meta, first_data = task_inputs[first_uncached] + next_future = executor.submit(prepare_task_data, tokenizer, first_data, first_meta, max_seq_len) + + for i, (label, task_meta, data) in enumerate(task_inputs): + pbar.set_description(f"{label:<{w_label}}") - # Save collated batches to disk after first run (so bench/future runs skip prepare+collate) - if first_run and disk_cache_dir is not None: - pad_id = tokenizer.get_bos_token_id() - os.makedirs(disk_cache_dir, exist_ok=True) - for label, _, _ in task_inputs: if label in _batch_cache: - torch.save({'collated': _batch_cache[label], 'pad_token_id': pad_id}, - os.path.join(disk_cache_dir, f"{label}.pt")) - print0(f" (saved collated batches to {disk_cache_dir})") + accuracy, collated = evaluate_task(model, data, device, collated=_batch_cache[label], pbar=pbar) + else: + prepared = next_future.result() + # Kick off prepare for the next uncached task + for j in range(i + 1, len(task_inputs)): + next_label, next_meta, next_data = task_inputs[j] + if next_label not in _batch_cache: + next_future = executor.submit(prepare_task_data, tokenizer, next_data, next_meta, max_seq_len) + break + accuracy, collated = evaluate_task(model, data, device, prepared=prepared, pbar=pbar) + _batch_cache[label] = collated + + _print_task_result(i, accuracy) + + elapsed_total = time.time() - t0 + pbar.close() + executor.shutdown(wait=False) + print0(f" (all tasks: {elapsed_total:.2f}s)") + + # Save collated batches to disk after first run (so bench/future runs skip prepare+collate) + if first_run and disk_cache_dir is not None: + pad_id = tokenizer.get_bos_token_id() + os.makedirs(disk_cache_dir, exist_ok=True) + for label, _, _ in task_inputs: + if label in _batch_cache: + torch.save({'collated': _batch_cache[label], 'pad_token_id': pad_id}, + os.path.join(disk_cache_dir, f"{label}.pt")) + print0(f" (saved collated batches to {disk_cache_dir})") core_metric = sum(centered_results.values()) / len(centered_results) if _prev_core is not None: diff --git a/scripts/bench_core_eval.py b/scripts/bench_core_eval.py index ae7294e..1003aee 100644 --- a/scripts/bench_core_eval.py +++ b/scripts/bench_core_eval.py @@ -35,8 +35,8 @@ from nanochat.tokenizer import HuggingFaceTokenizer from nanochat.checkpoint_manager import load_model from nanochat.core_eval import ( forward_model, prepare_example, check_result, stack_sequences, - prepare_task_data, _collate_batches, _forward_batches, evaluate_task, - compose_collated, + prepare_task_data, _collate_batches, _forward_all_cached, + evaluate_task, render_prompts_mc, render_prompts_schema, render_prompts_lm, batch_sequences_mc, batch_sequences_schema, batch_sequences_lm, ) @@ -270,19 +270,27 @@ def bench_new_first(model, tokenizer, task_inputs, device, batch_size, queue_siz return time.time() - t0, results, collated_cache -def bench_new_cached(model, task_inputs, device, collated_cache, pbar=None): - """Benchmark new batched evaluation (cached run, forward only).""" +def bench_new_cached(model, task_inputs, device, collated_cache, pbar=None, + merge=1, split=1, pad_token_id=0): + """Benchmark new batched evaluation (cached run, forward only). + Uses continuous pipeline across all tasks to eliminate inter-task stalls. + merge/split control GPU-side composition: merge > 1 cats batches, split > 1 slices them.""" + import torch.distributed as dist + world_size = dist.get_world_size() if dist.is_initialized() else 1 sync_cuda() t0 = time.time() - results = {} - max_label_len = max(len(label) for label, _, _ in task_inputs) - for label, task_meta, data in task_inputs: - if pbar is not None: - pbar.set_description(f"{label:<{max_label_len}}") - acc, _ = evaluate_task(model, data, device, collated=collated_cache[label], pbar=pbar) - results[label] = acc + task_collated = [(collated_cache[label], data) for label, _, data in task_inputs] + correct_list = _forward_all_cached(model, task_collated, device, pbar=pbar, + merge=merge, split=split, pad_token_id=pad_token_id) sync_cuda() - return time.time() - t0, results + elapsed = time.time() - t0 + results = {} + for (label, _, data), correct in zip(task_inputs, correct_list): + if world_size > 1: + dist.barrier() + dist.all_reduce(correct, op=dist.ReduceOp.SUM) + results[label] = correct.mean().item() + return elapsed, results def verify_results(old_results, new_results, label="new"): @@ -448,19 +456,29 @@ def main(): best_cached_time = float('inf') best_cached_params = None + pad_id = next(iter(base_cache.values()))[1] + # Preload ALL base-4 batches to GPU once (~144MB for full CORE eval). + # All batch-size sweeps compose from these GPU-resident tensors — zero CPU→GPU transfers. + gpu_collated = {} + for label, (collated, _) in base_cache.items(): + gpu_collated[label] = [(ids.to(device), meta) for ids, meta in collated] outer_pbar = tqdm(total=len(batch_sizes), desc="Cached sweep", leave=False, position=0) inner_pbar = tqdm(total=total_examples, desc="", leave=False, position=1) for bs in batch_sizes: outer_pbar.set_description(f"Cached: bs={bs}") inner_pbar.reset() - # Compose from base-4 to target batch_size (merge or split) - composed_cache = {} - for label, (collated, pad_id) in base_cache.items(): - composed_cache[label] = compose_collated(collated, bs, BASE_BATCH_SIZE, pad_id) + + # All composition happens on GPU: merge for bs >= base, split for bs < base + if bs >= BASE_BATCH_SIZE: + merge, split = bs // BASE_BATCH_SIZE, 1 + else: + merge, split = 1, BASE_BATCH_SIZE // bs with autocast_ctx: - t, cached_results = bench_new_cached(model, task_inputs, device, composed_cache, pbar=inner_pbar) + t, cached_results = bench_new_cached(model, task_inputs, device, gpu_collated, + pbar=inner_pbar, merge=merge, split=split, + pad_token_id=pad_id) outer_pbar.write(f" batch_size={bs:>3}: {t:.2f}s ({total_examples / t:.1f} examples/s)") outer_pbar.update(1) From 4f79e750e7b727860f0270a8435ca4320750c932 Mon Sep 17 00:00:00 2001 From: Unsal Gokdag Date: Fri, 13 Feb 2026 08:42:45 +0000 Subject: [PATCH 4/4] CORE eval: batched forwarding by default, per-example mode for verification Switch cached eval path to batched=True (forwards full collated batches) for ~5-7x speedup over sequential per-example evaluation. Add per-example forwarding mode (batched=False) that trims collation padding to recover exact per-example tensor shapes, guaranteeing identical results to the old sequential path. Bench script uses batched=True for speed sweeps and per-example mode for correctness verification against old. --- nanochat/core_eval.py | 166 ++++++++++++++++++++++--------------- scripts/base_eval.py | 5 +- scripts/bench_core_eval.py | 22 +++-- 3 files changed, 118 insertions(+), 75 deletions(-) diff --git a/nanochat/core_eval.py b/nanochat/core_eval.py index 28f50f5..0d6134f 100644 --- a/nanochat/core_eval.py +++ b/nanochat/core_eval.py @@ -326,24 +326,29 @@ def _forward_batches(model, collated, data, device, pbar=None): def _forward_all_cached(model, task_collated, device, pbar=None, task_labels=None, - on_task_done=None, merge=1, split=1, pad_token_id=0): + on_task_done=None, batched=False, merge=1, split=1, pad_token_id=0): """Run all tasks' cached batches through the model in one pass. All batch tensors are moved to device upfront (~144MB for full CORE eval). If tensors are already on device (caller preloaded), .to() is a no-op. - Composition (merge/split) happens entirely on device: + + Default mode (batched=False): forwards each example individually, trimming + collation padding to recover the exact per-example tensor shape. This + guarantees identical results to sequential per-example evaluation. + + Batched mode (batched=True): forwards collated batches with optional GPU + composition. Faster but may produce tiny FP differences vs sequential eval + due to different cuBLAS kernel paths for different matrix dimensions. - merge > 1: pad+cat consecutive base batches on GPU before forwarding. - - split > 1: slice each group into chunks by example boundaries, - forward each chunk separately. + - split > 1: slice each group into chunks by example boundaries. Args: task_collated: list of (collated_batches, data) per task - pbar: optional progress bar, updated per forward pass (by number of examples) + pbar: optional progress bar, updated per example (or per batch chunk) task_labels: optional list of task names for pbar description updates on_task_done: optional callback(task_idx, correct_tensor) fired when a task completes - merge: number of consecutive base batches to compose per group (>= 1) - split: number of forward passes to split each group into (>= 1) - pad_token_id: token id used for padding when merging batches of different lengths + batched: if True, forward whole batches (faster, approximate). Default False (exact). + merge/split/pad_token_id: only used when batched=True Returns: list of correct tensors (one per task, on device) """ @@ -362,76 +367,103 @@ def _forward_all_cached(model, task_collated, device, pbar=None, task_labels=Non task_batches_remaining = list(task_batch_counts) current_task = -1 - buffer_ids = [] - buffer_info = [] - for i, (combined_ids, batch_meta, task_idx) in enumerate(flat_stream): - # Update pbar description on task transition - if task_idx != current_task: - current_task = task_idx - if pbar is not None and task_labels is not None: - pbar.set_description(task_labels[task_idx]) - buffer_ids.append(combined_ids) - buffer_info.append((batch_meta, task_idx)) - - # Accumulate until we have `merge` batches (or hit the end) - if len(buffer_ids) < merge and i < len(flat_stream) - 1: - continue - - # GPU compose: pad+cat if multiple batches, otherwise use as-is - if len(buffer_ids) == 1: - mega_ids = buffer_ids[0] - else: - max_len = max(t.shape[1] for t in buffer_ids) - parts = [] - for t in buffer_ids: - if t.shape[1] < max_len: - pad = torch.full((t.shape[0], max_len - t.shape[1]), pad_token_id, - dtype=t.dtype, device=t.device) - t = torch.cat([t, pad], dim=1) - parts.append(t) - mega_ids = torch.cat(parts, dim=0) - - # Flatten examples with row boundaries (for splitting) - examples = [] - row_bounds = [0] - for bm, tidx in buffer_info: - for idx, n, start_idxs, end_idxs, gold, task_type in bm: - examples.append((idx, n, start_idxs, end_idxs, gold, task_type, tidx)) - row_bounds.append(row_bounds[-1] + n) - - # Forward + score (with optional GPU split) - n_ex = len(examples) - chunk_size = -(-n_ex // split) # ceiling division - - for cs in range(0, n_ex, chunk_size): - ce = min(cs + chunk_size, n_ex) - chunk = examples[cs:ce] - chunk_ids = mega_ids[row_bounds[cs]:row_bounds[ce]] - - losses, predictions = forward_model(model, chunk_ids) + if not batched: + # Per-example forwarding: identical results to sequential evaluation. + # Each example's rows are trimmed to their original seq_len (= max(end_idxs)), + # removing collation padding so forward_model sees the same tensor shape as + # the sequential path. + for combined_ids, batch_meta, task_idx in flat_stream: + if task_idx != current_task: + current_task = task_idx + if pbar is not None and task_labels is not None: + pbar.set_description(task_labels[task_idx]) offset = 0 - for idx, n, start_idxs, end_idxs, gold, task_type, tidx in chunk: + for idx, n, start_idxs, end_idxs, gold, task_type in batch_meta: + seq_len = max(end_idxs) + example_ids = combined_ids[offset:offset+n, :seq_len] + losses, predictions = forward_model(model, example_ids) is_correct = check_result( - losses[offset:offset+n], predictions[offset:offset+n], - chunk_ids[offset:offset+n], + losses, predictions, example_ids, start_idxs, end_idxs, gold, task_type, ) - correct[tidx][idx] = float(is_correct) + correct[task_idx][idx] = float(is_correct) offset += n + if pbar is not None: - pbar.update(len(chunk)) + pbar.update(len(batch_meta)) + if on_task_done is not None: + task_batches_remaining[task_idx] -= 1 + if task_batches_remaining[task_idx] == 0: + on_task_done(task_idx, correct[task_idx]) + else: + # Batched forwarding with optional merge/split composition. + buffer_ids = [] + buffer_info = [] - # Fire callback for any tasks that just completed all their batches - if on_task_done is not None: + for i, (combined_ids, batch_meta, task_idx) in enumerate(flat_stream): + if task_idx != current_task: + current_task = task_idx + if pbar is not None and task_labels is not None: + pbar.set_description(task_labels[task_idx]) + buffer_ids.append(combined_ids) + buffer_info.append((batch_meta, task_idx)) + + if len(buffer_ids) < merge and i < len(flat_stream) - 1: + continue + + # GPU compose: pad+cat if multiple batches, otherwise use as-is + if len(buffer_ids) == 1: + mega_ids = buffer_ids[0] + else: + max_len = max(t.shape[1] for t in buffer_ids) + parts = [] + for t in buffer_ids: + if t.shape[1] < max_len: + pad = torch.full((t.shape[0], max_len - t.shape[1]), pad_token_id, + dtype=t.dtype, device=t.device) + t = torch.cat([t, pad], dim=1) + parts.append(t) + mega_ids = torch.cat(parts, dim=0) + + examples = [] + row_bounds = [0] for bm, tidx in buffer_info: - task_batches_remaining[tidx] -= 1 - if task_batches_remaining[tidx] == 0: - on_task_done(tidx, correct[tidx]) + for idx, n, start_idxs, end_idxs, gold, task_type in bm: + examples.append((idx, n, start_idxs, end_idxs, gold, task_type, tidx)) + row_bounds.append(row_bounds[-1] + n) - buffer_ids.clear() - buffer_info.clear() + n_ex = len(examples) + chunk_size = -(-n_ex // split) + + for cs in range(0, n_ex, chunk_size): + ce = min(cs + chunk_size, n_ex) + chunk = examples[cs:ce] + chunk_ids = mega_ids[row_bounds[cs]:row_bounds[ce]] + + losses, predictions = forward_model(model, chunk_ids) + + offset = 0 + for idx, n, start_idxs, end_idxs, gold, task_type, tidx in chunk: + is_correct = check_result( + losses[offset:offset+n], predictions[offset:offset+n], + chunk_ids[offset:offset+n], + start_idxs, end_idxs, gold, task_type, + ) + correct[tidx][idx] = float(is_correct) + offset += n + if pbar is not None: + pbar.update(len(chunk)) + + if on_task_done is not None: + for bm, tidx in buffer_info: + task_batches_remaining[tidx] -= 1 + if task_batches_remaining[tidx] == 0: + on_task_done(tidx, correct[tidx]) + + buffer_ids.clear() + buffer_info.clear() return correct diff --git a/scripts/base_eval.py b/scripts/base_eval.py index dcee8da..19b5ee4 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -244,12 +244,15 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1): _print_task_result(tidx, correct.mean().item()) if cached_run: - # Continuous pipeline: all tasks in one GPU stream, results printed per-task as they complete + # Continuous pipeline: all tasks in one GPU stream, results printed per-task as they complete. + # Always use base batch size (merge=1/split=1) to guarantee identical results — + # different batch dimensions trigger different cuBLAS kernels with different FP rounding. t0 = time.time() task_collated = [(_batch_cache[label], data) for label, _, data in task_inputs] correct_list = _forward_all_cached( model, task_collated, device, pbar=pbar, task_labels=task_labels, on_task_done=_on_task_done if world_size == 1 else None, + batched=True, ) elapsed_total = time.time() - t0 pbar.close() diff --git a/scripts/bench_core_eval.py b/scripts/bench_core_eval.py index 1003aee..a69425b 100644 --- a/scripts/bench_core_eval.py +++ b/scripts/bench_core_eval.py @@ -271,17 +271,19 @@ def bench_new_first(model, tokenizer, task_inputs, device, batch_size, queue_siz def bench_new_cached(model, task_inputs, device, collated_cache, pbar=None, - merge=1, split=1, pad_token_id=0): + batched=False, merge=1, split=1, pad_token_id=0): """Benchmark new batched evaluation (cached run, forward only). Uses continuous pipeline across all tasks to eliminate inter-task stalls. - merge/split control GPU-side composition: merge > 1 cats batches, split > 1 slices them.""" + batched=False (default): per-example forwarding, identical to sequential. + batched=True: GPU composition with merge/split for speed experiments.""" import torch.distributed as dist world_size = dist.get_world_size() if dist.is_initialized() else 1 sync_cuda() t0 = time.time() task_collated = [(collated_cache[label], data) for label, _, data in task_inputs] correct_list = _forward_all_cached(model, task_collated, device, pbar=pbar, - merge=merge, split=split, pad_token_id=pad_token_id) + batched=batched, merge=merge, split=split, + pad_token_id=pad_token_id) sync_cuda() elapsed = time.time() - t0 results = {} @@ -477,8 +479,8 @@ def main(): with autocast_ctx: t, cached_results = bench_new_cached(model, task_inputs, device, gpu_collated, - pbar=inner_pbar, merge=merge, split=split, - pad_token_id=pad_id) + pbar=inner_pbar, batched=True, + merge=merge, split=split, pad_token_id=pad_id) outer_pbar.write(f" batch_size={bs:>3}: {t:.2f}s ({total_examples / t:.1f} examples/s)") outer_pbar.update(1) @@ -493,8 +495,14 @@ def main(): print0("") print0(f" Best: batch_size={best_cached_params} -> {best_cached_time:.2f}s ({total_examples / best_cached_time:.1f} examples/s)") - if old_results is not None: - verify_results(old_results, best_cached_results, label="new-cached") + # Verify with per-example forwarding (identical to sequential — must match old) + inner_pbar = tqdm(total=total_examples, desc="Verifying", leave=False) + with autocast_ctx: + _, exact_results = bench_new_cached(model, task_inputs, device, gpu_collated, pbar=inner_pbar) + inner_pbar.close() + ref_results = old_results or (best_first_results if best_time is not None else None) + if ref_results is not None: + verify_results(ref_results, exact_results, label="new-cached(per-example)") print0("") # ---- Summary ----