diff --git a/nanochat/core_eval.py b/nanochat/core_eval.py index a81c810..9231520 100644 --- a/nanochat/core_eval.py +++ b/nanochat/core_eval.py @@ -277,11 +277,37 @@ def prepare_task_data(tokenizer, data, task_meta, max_seq_len=None): return prepared -def _forward_batches(model, collated, data, device): - """Run GPU forward passes on pre-collated batches, return per-example correctness tensor.""" +def _prefetch_to_device(tensor, device): + """Pin and async-transfer a CPU tensor to GPU, overlapping with current GPU work.""" + return tensor.pin_memory().to(device, non_blocking=True) + + +def _forward_batches(model, collated, data, device, pbar=None): + """Run GPU forward passes on pre-collated batches, return per-example correctness tensor. + + Uses double-buffered prefetching on CUDA: while the GPU processes batch N, + batch N+1 is pinned and DMA-transferred asynchronously, keeping the GPU fed. + """ correct = torch.zeros(len(data), dtype=torch.float32, device=device) - for combined_ids, batch_meta in collated: - combined_ids = combined_ids.to(device) + if not collated: + return correct + + use_prefetch = torch.cuda.is_available() and 'cuda' in str(device) + + # Prefetch first batch + if use_prefetch: + next_ids = _prefetch_to_device(collated[0][0], device) + else: + next_ids = collated[0][0].to(device) + + for i, (_, batch_meta) in enumerate(collated): + combined_ids = next_ids + # Start async transfer of next batch while GPU computes on current + if i + 1 < len(collated): + if use_prefetch: + next_ids = _prefetch_to_device(collated[i + 1][0], device) + else: + next_ids = collated[i + 1][0].to(device) losses, predictions = forward_model(model, combined_ids) @@ -294,11 +320,63 @@ def _forward_batches(model, collated, data, device): ) correct[idx] = float(is_correct) offset += n + if pbar is not None: + pbar.update(len(batch_meta)) return correct +def compose_collated(base_collated, target_batch_size, base_batch_size=4, pad_token_id=0): + """Compose base-sized collated batches into target-sized batches. + + Supports both merging (target > base) by concatenating consecutive groups, + and splitting (target < base) by slicing along example boundaries. + Examples are sorted by seq_len within each base batch, so splitting can + trim trailing padding columns for efficiency. + """ + if target_batch_size == base_batch_size: + return base_collated + elif target_batch_size > base_batch_size: + # Merge consecutive base batches + n_merge = target_batch_size // base_batch_size + composed = [] + for i in range(0, len(base_collated), n_merge): + group = base_collated[i:i + n_merge] + if len(group) == 1: + composed.append(group[0]) + continue + max_len = max(ids.shape[1] for ids, _ in group) + parts = [] + merged_meta = [] + for ids, meta in group: + if ids.shape[1] < max_len: + pad = torch.full((ids.shape[0], max_len - ids.shape[1]), pad_token_id, dtype=ids.dtype) + ids = torch.cat([ids, pad], dim=1) + parts.append(ids) + merged_meta.extend(meta) + composed.append((torch.cat(parts, dim=0), merged_meta)) + return composed + else: + # Split base batches into smaller chunks + composed = [] + for combined_ids, batch_meta in base_collated: + for chunk_start in range(0, len(batch_meta), target_batch_size): + chunk_meta = batch_meta[chunk_start:chunk_start + target_batch_size] + row_start = sum(m[1] for m in batch_meta[:chunk_start]) + row_end = row_start + sum(m[1] for m in chunk_meta) + chunk_ids = combined_ids[row_start:row_end] + # Trim trailing padding (examples sorted by seq_len, so chunks + # near the start of a base batch may need fewer columns) + non_pad = (chunk_ids != pad_token_id) + if non_pad.any(): + last_col = non_pad.any(dim=0).nonzero()[-1].item() + 1 + if last_col < chunk_ids.shape[1]: + chunk_ids = chunk_ids[:, :last_col].contiguous() + composed.append((chunk_ids, chunk_meta)) + return composed + + def evaluate_task(model, data, device, batch_size=4, queue_size=2, prepared=None, - collated=None, tokenizer=None, task_meta=None): + collated=None, tokenizer=None, task_meta=None, pbar=None): """ Evaluate one task across many examples with batched GPU forward passes. Examples are sorted by sequence length so similar-length sequences are batched @@ -316,7 +394,7 @@ def evaluate_task(model, data, device, batch_size=4, queue_size=2, prepared=None if collated is not None: # Fast path: just GPU forward passes, no threads - correct = _forward_batches(model, collated, data, device) + correct = _forward_batches(model, collated, data, device, pbar=pbar) else: from queue import Queue from threading import Thread @@ -325,21 +403,34 @@ def evaluate_task(model, data, device, batch_size=4, queue_size=2, prepared=None max_seq_len = getattr(model, 'max_seq_len', None) prepared = prepare_task_data(tokenizer, data, task_meta, max_seq_len) - # Collation thread pipelined with GPU forward passes + # Collation thread pipelined with GPU forward passes. + # Double-buffered: while GPU processes batch N, batch N+1 is + # pin_memory()'d and DMA-transferred asynchronously. queue = Queue(maxsize=queue_size) collator = Thread(target=_collate_batches, args=(prepared, batch_size, queue), daemon=True) collator.start() + use_prefetch = torch.cuda.is_available() and 'cuda' in str(device) + def transfer(tensor): + return _prefetch_to_device(tensor, device) if use_prefetch else tensor.to(device) + collated = [] correct = torch.zeros(len(data), dtype=torch.float32, device=device) - while True: - item = queue.get() - if item is None: - break - collated.append(item) - combined_ids, batch_meta = item - combined_ids = combined_ids.to(device) + # Prime: get first batch and start its transfer + item = queue.get() + if item is not None: + next_ids = transfer(item[0]) + + while item is not None: + collated.append(item) + combined_ids = next_ids + _, batch_meta = item + + # Start async transfer of next batch (overlaps with forward pass below) + item = queue.get() + if item is not None: + next_ids = transfer(item[0]) losses, predictions = forward_model(model, combined_ids) @@ -352,6 +443,8 @@ def evaluate_task(model, data, device, batch_size=4, queue_size=2, prepared=None ) correct[idx] = float(is_correct) offset += n + if pbar is not None: + pbar.update(len(batch_meta)) collator.join() del prepared diff --git a/scripts/base_eval.py b/scripts/base_eval.py index b4aaf56..d1751fc 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -26,6 +26,7 @@ import json import yaml import shutil import random +import hashlib import zipfile import tempfile import argparse @@ -113,6 +114,22 @@ _prev_centered = {} # {label: centered_result} — previous run for delta d _prev_core = None # previous core_metric +def _get_disk_cache_dir(max_per_task): + """Get disk cache dir for base-4 collated batches, keyed by tokenizer file hash. + Returns None if no local tokenizer file is found (e.g. HuggingFace models).""" + base_dir = get_base_dir() + for fname in ("tokenizer.pkl", "tokenizer.json"): + path = os.path.join(base_dir, "tokenizer", fname) + if os.path.exists(path): + h = hashlib.sha256() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + h.update(chunk) + tok_hash = h.hexdigest()[:16] + return os.path.join(base_dir, "core_token_cache", f"{tok_hash}_n{max_per_task}") + return None + + def evaluate_model(model, tokenizer, device, max_per_task=-1): """ Evaluate a base model on the CORE benchmark. @@ -180,6 +197,22 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1): results = {} centered_results = {} cached_run = all(label in _batch_cache for label, _, _ in task_inputs) + disk_cache_dir = _get_disk_cache_dir(max_per_task) + + # Try loading from disk cache if in-memory cache is empty + if not cached_run and disk_cache_dir is not None: + all_on_disk = os.path.isdir(disk_cache_dir) and all( + os.path.exists(os.path.join(disk_cache_dir, f"{label}.pt")) + for label, _, _ in task_inputs + ) + if all_on_disk: + for label, _, _ in task_inputs: + d = torch.load(os.path.join(disk_cache_dir, f"{label}.pt"), weights_only=False) + _batch_cache[label] = d['collated'] + cached_run = True + print0(" (loaded collated batches from disk cache)") + + first_run = not cached_run # track whether we did prepare+collate (for disk save) if not cached_run: executor = ThreadPoolExecutor(max_workers=1) @@ -221,6 +254,16 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1): if not cached_run: executor.shutdown(wait=False) + # Save collated batches to disk after first run (so bench/future runs skip prepare+collate) + if first_run and disk_cache_dir is not None: + pad_id = tokenizer.get_bos_token_id() + os.makedirs(disk_cache_dir, exist_ok=True) + for label, _, _ in task_inputs: + if label in _batch_cache: + torch.save({'collated': _batch_cache[label], 'pad_token_id': pad_id}, + os.path.join(disk_cache_dir, f"{label}.pt")) + print0(f" (saved collated batches to {disk_cache_dir})") + core_metric = sum(centered_results.values()) / len(centered_results) if _prev_core is not None: d = core_metric - _prev_core diff --git a/scripts/bench_core_eval.py b/scripts/bench_core_eval.py index d11bc40..ae7294e 100644 --- a/scripts/bench_core_eval.py +++ b/scripts/bench_core_eval.py @@ -20,10 +20,12 @@ import time import yaml import random import shutil +import hashlib import zipfile import tempfile import argparse from contextlib import nullcontext +from tqdm import tqdm import torch import torch.distributed as dist @@ -34,12 +36,13 @@ from nanochat.checkpoint_manager import load_model from nanochat.core_eval import ( forward_model, prepare_example, check_result, stack_sequences, prepare_task_data, _collate_batches, _forward_batches, evaluate_task, + compose_collated, render_prompts_mc, render_prompts_schema, render_prompts_lm, batch_sequences_mc, batch_sequences_schema, batch_sequences_lm, ) # ---- eval bundle loading (reused from base_eval) ---- - +torch.backends.fp32_precision = "tf32" EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip" def place_eval_bundle(file_path): @@ -79,6 +82,69 @@ def load_tasks(max_per_task=-1): task_inputs.append((label, task_meta, data)) return task_inputs +BASE_BATCH_SIZE = 4 + +def file_checksum(path): + """SHA-256 checksum of a file, truncated to 16 hex chars.""" + h = hashlib.sha256() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + h.update(chunk) + return h.hexdigest()[:16] + + +def collate_prepared(prepared, batch_size): + """Collate prepared examples into batches (non-threaded). Returns (collated, pad_token_id).""" + pad_id = prepared[0][1]['pad_token_id'] + collated = [] + for batch_start in range(0, len(prepared), batch_size): + batch = prepared[batch_start:batch_start + batch_size] + batch_preps = [p for _, p in batch] + max_len = max(p['seq_len'] for p in batch_preps) + total_rows = sum(p['num_options'] for p in batch_preps) + combined_ids = torch.full((total_rows, max_len), pad_id, dtype=torch.long) + batch_meta = [] + offset = 0 + for idx, p in batch: + n, sl = p['num_options'], p['seq_len'] + combined_ids[offset:offset+n, :sl] = p['input_ids'] + batch_meta.append((idx, n, p['start_idxs'], p['end_idxs'], p['gold'], p['task_type'])) + offset += n + collated.append((combined_ids, batch_meta)) + return collated, pad_id + + +def build_or_load_base_collated(tok_hash, tokenizer, task_inputs, max_seq_len, max_per_task): + """Build or load base-4 collated data for all tasks, with disk caching.""" + base_dir = get_base_dir() + cache_dir = os.path.join(base_dir, "core_token_cache", f"{tok_hash}_n{max_per_task}") + + all_cached = os.path.isdir(cache_dir) and all( + os.path.exists(os.path.join(cache_dir, f"{label}.pt")) + for label, _, _ in task_inputs + ) + + base_cache = {} # label -> (collated, pad_token_id) + if all_cached: + print0(f"Loading base-{BASE_BATCH_SIZE} collated cache from {cache_dir}") + for label, _, _ in tqdm(task_inputs, desc="Loading cache", leave=False): + data = torch.load(os.path.join(cache_dir, f"{label}.pt"), weights_only=False) + base_cache[label] = (data['collated'], data['pad_token_id']) + else: + print0(f"Building base-{BASE_BATCH_SIZE} collated cache (saving to {cache_dir})") + os.makedirs(cache_dir, exist_ok=True) + for label, task_meta, data in tqdm(task_inputs, desc="Preparing tasks", leave=False): + prepared = prepare_task_data(tokenizer, data, task_meta, max_seq_len) + collated, pad_id = collate_prepared(prepared, BASE_BATCH_SIZE) + base_cache[label] = (collated, pad_id) + torch.save({'collated': collated, 'pad_token_id': pad_id}, + os.path.join(cache_dir, f"{label}.pt")) + del prepared + print0(f"Saved {len(base_cache)} task caches to {cache_dir}") + + return base_cache + + # ---- old sequential evaluation (baseline) ---- @torch.no_grad() @@ -129,7 +195,7 @@ def evaluate_example_old(idx, model, tokenizer, data, device, task_meta): return check_result(losses, predictions, input_ids, start_idxs, end_idxs, item.get('gold', None), task_type) -def evaluate_task_old(model, tokenizer, data, device, task_meta): +def evaluate_task_old(model, tokenizer, data, device, task_meta, pbar=None): """Original sequential evaluate_task.""" rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 @@ -137,6 +203,8 @@ def evaluate_task_old(model, tokenizer, data, device, task_meta): for idx in range(rank, len(data), world_size): is_correct = evaluate_example_old(idx, model, tokenizer, data, device, task_meta) correct[idx] = float(is_correct) + if pbar is not None: + pbar.update(1) if world_size > 1: dist.barrier() dist.all_reduce(correct, op=dist.ReduceOp.SUM) @@ -170,37 +238,48 @@ def bench_old(model, tokenizer, task_inputs, device): sync_cuda() t0 = time.time() results = {} + total = sum(len(data) for _, _, data in task_inputs) + max_label_len = max(len(label) for label, _, _ in task_inputs) + pbar = tqdm(total=total, desc="Sequential", leave=False) for label, task_meta, data in task_inputs: - acc = evaluate_task_old(model, tokenizer, data, device, task_meta) + pbar.set_description(f"Sequential: {label:<{max_label_len}}") + acc = evaluate_task_old(model, tokenizer, data, device, task_meta, pbar=pbar) results[label] = acc + pbar.close() sync_cuda() return time.time() - t0, results -def bench_new_first(model, tokenizer, task_inputs, device, batch_size, queue_size): +def bench_new_first(model, tokenizer, task_inputs, device, batch_size, queue_size, pbar=None): """Benchmark new batched evaluation (first run, no cache).""" sync_cuda() t0 = time.time() results = {} collated_cache = {} max_seq_len = getattr(model, 'max_seq_len', None) + max_label_len = max(len(label) for label, _, _ in task_inputs) for label, task_meta, data in task_inputs: + if pbar is not None: + pbar.set_description(f"{label:<{max_label_len}}") prepared = prepare_task_data(tokenizer, data, task_meta, max_seq_len) acc, collated = evaluate_task(model, data, device, batch_size=batch_size, - queue_size=queue_size, prepared=prepared) + queue_size=queue_size, prepared=prepared, pbar=pbar) results[label] = acc collated_cache[label] = collated sync_cuda() return time.time() - t0, results, collated_cache -def bench_new_cached(model, task_inputs, device, collated_cache): +def bench_new_cached(model, task_inputs, device, collated_cache, pbar=None): """Benchmark new batched evaluation (cached run, forward only).""" sync_cuda() t0 = time.time() results = {} + max_label_len = max(len(label) for label, _, _ in task_inputs) for label, task_meta, data in task_inputs: - acc, _ = evaluate_task(model, data, device, collated=collated_cache[label]) + if pbar is not None: + pbar.set_description(f"{label:<{max_label_len}}") + acc, _ = evaluate_task(model, data, device, collated=collated_cache[label], pbar=pbar) results[label] = acc sync_cuda() return time.time() - t0, results @@ -232,6 +311,7 @@ def main(): parser.add_argument('--batch-sizes', type=str, default='1,2,4,8,16,32,64', help='Comma-separated batch sizes to sweep') parser.add_argument('--queue-sizes', type=str, default='2,4,8,16', help='Comma-separated queue sizes to sweep') parser.add_argument('--skip-old', action='store_true', help='Skip old sequential baseline (slow)') + parser.add_argument('--skip-first', action='store_true', help='Skip first-run sweep (requires cached collated data)') args = parser.parse_args() batch_sizes = [int(x) for x in args.batch_sizes.split(',')] @@ -272,52 +352,131 @@ def main(): print0("") # ---- 2. Sweep batch_size x queue_size for first run ---- + # Compute tok_hash early — needed for both skip-first check and cache loading + max_seq_len = getattr(model, 'max_seq_len', None) + if args.hf_path is not None: + tok_hash = hashlib.sha256(args.hf_path.encode()).hexdigest()[:16] + else: + base_dir = get_base_dir() + for fname in ("tokenizer.pkl", "tokenizer.json"): + tok_path = os.path.join(base_dir, "tokenizer", fname) + if os.path.exists(tok_path): + tok_hash = file_checksum(tok_path) + break + + # Check if we can skip the first-run sweep + best_time = None + best_params = None + if args.skip_first: + cache_dir = os.path.join(get_base_dir(), "core_token_cache", f"{tok_hash}_n{args.max_per_task}") + has_cache = os.path.isdir(cache_dir) and all( + os.path.exists(os.path.join(cache_dir, f"{label}.pt")) + for label, _, _ in task_inputs + ) + if has_cache: + print0("Skipping first-run sweep (--skip-first, cache exists)") + print0("") + else: + print0(f"--skip-first: cache missing, running single bs={BASE_BATCH_SIZE} eval to create it...") + pbar = tqdm(total=total_examples, desc="Creating cache", leave=False) + with autocast_ctx: + _, _, collated_cache = bench_new_first(model, tokenizer, task_inputs, device, + batch_size=BASE_BATCH_SIZE, queue_size=2, pbar=pbar) + pbar.close() + pad_id = tokenizer.get_bos_token_id() + os.makedirs(cache_dir, exist_ok=True) + for label in collated_cache: + torch.save({'collated': collated_cache[label], 'pad_token_id': pad_id}, + os.path.join(cache_dir, f"{label}.pt")) + print0(f"Cache created ({len(collated_cache)} tasks)") + print0("") + + if not args.skip_first: + print0("=" * 80) + print0("NEW: Batched evaluation — hyperparameter sweep (first run)") + print0("=" * 80) + print0("") + + # Header + qs_header = "".join(f"{'q=' + str(q):>10}" for q in queue_sizes) + print0(f" {'batch_size':>10}{qs_header}") + print0(f" {'':>10}" + "-" * (10 * len(queue_sizes))) + + best_time = float('inf') + best_params = None + best_collated = None + sweep_results = {} + total_combos = len(batch_sizes) * len(queue_sizes) + outer_pbar = tqdm(total=total_combos, desc="First-run sweep", leave=False, position=0) + inner_pbar = tqdm(total=total_examples, desc="", leave=False, position=1) + + for bs in batch_sizes: + row = f" {bs:>10}" + for qs in queue_sizes: + outer_pbar.set_description(f"First-run: bs={bs} qs={qs}") + inner_pbar.reset() + with autocast_ctx: + t, results, collated_cache = bench_new_first(model, tokenizer, task_inputs, device, bs, qs, pbar=inner_pbar) + sweep_results[(bs, qs)] = t + row += f"{t:>9.2f}s" + if t < best_time: + best_time = t + best_params = (bs, qs) + best_collated = collated_cache + best_first_results = results + outer_pbar.update(1) + outer_pbar.write(row) + inner_pbar.close() + outer_pbar.close() + + print0("") + print0(f" Best: batch_size={best_params[0]}, queue_size={best_params[1]} -> {best_time:.2f}s ({total_examples / best_time:.1f} examples/s)") + + # Verify correctness + if old_results is not None: + verify_results(old_results, best_first_results, label="new-first") + print0("") + + # ---- 3. Build/load base-4 collated cache ---- + base_cache = build_or_load_base_collated(tok_hash, tokenizer, task_inputs, max_seq_len, args.max_per_task) + + # ---- 4. Cached run sweep (forward only, composed from base-4) ---- print0("=" * 80) - print0("NEW: Batched evaluation — hyperparameter sweep (first run)") + print0(f"NEW: Cached run (forward only, composed from base-{BASE_BATCH_SIZE})") print0("=" * 80) print0("") - # Header - qs_header = "".join(f"{'q=' + str(q):>10}" for q in queue_sizes) - print0(f" {'batch_size':>10}{qs_header}") - print0(f" {'':>10}" + "-" * (10 * len(queue_sizes))) - - best_time = float('inf') - best_params = None - best_collated = None - sweep_results = {} + best_cached_time = float('inf') + best_cached_params = None + outer_pbar = tqdm(total=len(batch_sizes), desc="Cached sweep", leave=False, position=0) + inner_pbar = tqdm(total=total_examples, desc="", leave=False, position=1) for bs in batch_sizes: - row = f" {bs:>10}" - for qs in queue_sizes: - with autocast_ctx: - t, results, collated_cache = bench_new_first(model, tokenizer, task_inputs, device, bs, qs) - sweep_results[(bs, qs)] = t - row += f"{t:>9.2f}s" - if t < best_time: - best_time = t - best_params = (bs, qs) - best_collated = collated_cache - best_first_results = results - print0(row) + outer_pbar.set_description(f"Cached: bs={bs}") + inner_pbar.reset() + # Compose from base-4 to target batch_size (merge or split) + composed_cache = {} + for label, (collated, pad_id) in base_cache.items(): + composed_cache[label] = compose_collated(collated, bs, BASE_BATCH_SIZE, pad_id) + + with autocast_ctx: + t, cached_results = bench_new_cached(model, task_inputs, device, composed_cache, pbar=inner_pbar) + + outer_pbar.write(f" batch_size={bs:>3}: {t:.2f}s ({total_examples / t:.1f} examples/s)") + outer_pbar.update(1) + + if t < best_cached_time: + best_cached_time = t + best_cached_params = bs + best_cached_results = cached_results + inner_pbar.close() + outer_pbar.close() print0("") - print0(f" Best: batch_size={best_params[0]}, queue_size={best_params[1]} -> {best_time:.2f}s ({total_examples / best_time:.1f} examples/s)") + print0(f" Best: batch_size={best_cached_params} -> {best_cached_time:.2f}s ({total_examples / best_cached_time:.1f} examples/s)") - # Verify correctness if old_results is not None: - verify_results(old_results, best_first_results, label="new-first") - print0("") - - # ---- 3. Cached run (forward only) ---- - print0("=" * 80) - print0("NEW: Cached run (forward only, using best params)") - print0("=" * 80) - with autocast_ctx: - cached_time, cached_results = bench_new_cached(model, task_inputs, device, best_collated) - print0(f" Time: {cached_time:.2f}s ({total_examples / cached_time:.1f} examples/s)") - if old_results is not None: - verify_results(old_results, cached_results, label="new-cached") + verify_results(old_results, best_cached_results, label="new-cached") print0("") # ---- Summary ---- @@ -326,11 +485,13 @@ def main(): print0("=" * 80) if old_results is not None: print0(f" Old (sequential): {old_time:>8.2f}s") - print0(f" New (first run): {best_time:>8.2f}s batch_size={best_params[0]}, queue_size={best_params[1]}") - print0(f" New (cached): {cached_time:>8.2f}s") + if best_time is not None: + print0(f" New (first run): {best_time:>8.2f}s batch_size={best_params[0]}, queue_size={best_params[1]}") + print0(f" New (cached): {best_cached_time:>8.2f}s batch_size={best_cached_params}") if old_results is not None: - print0(f" Speedup (first): {old_time / best_time:>8.2f}x") - print0(f" Speedup (cached): {old_time / cached_time:>8.2f}x") + if best_time is not None: + print0(f" Speedup (first): {old_time / best_time:>8.2f}x") + print0(f" Speedup (cached): {old_time / best_cached_time:>8.2f}x") compute_cleanup()