mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-03 22:25:27 +00:00
CORE eval: disk-cached tokenized batches, double-buffered GPU transfers, batch composition, benchmark improvements
the main idea: tokenization + collation for CORE eval only needs to happen once per tokenizer.
collated batches at base batch_size=4 are saved to disk (core_token_cache/), keyed by SHA-256
of the tokenizer file. any batch_size can be served from these base-4 batches: larger sizes merge
consecutive batches (right-pad shorter ones, cat along dim=0), smaller sizes split along example
boundaries (trim trailing padding). this means prepare_task_data is truly a one-time cost.
core_eval.py:
- double-buffered CPU->GPU transfers in both forward paths (_forward_batches and evaluate_task's
pipelined path). while GPU runs forward_model on batch N, batch N+1 is pin_memory()'d and
DMA-transferred via non_blocking=True. the DMA engine and GPU compute units are separate
hardware so they overlap. previously GPU idled during every transfer.
- compose_collated(): merge base batches for larger batch_size (cat after right-padding to
max_len), or split for smaller batch_size (slice along row boundaries from batch_meta,
trim trailing padding via vectorized non_pad.any(dim=0)). works because examples are sorted
by seq_len, so consecutive base batches have monotonically increasing lengths.
- evaluate_task and _forward_batches accept optional pbar for progress tracking.
base_eval.py:
- evaluate_model now has 3-tier caching: in-memory (_batch_cache, across calls within same
process), disk load (core_token_cache/, on first call when in-memory is empty), disk save
(after first run's prepare+collate+forward, writes collated batches so future training runs
and the benchmark skip tokenization entirely). keyed by tokenizer file hash + max_per_task.
bench_core_eval.py:
- cached sweep no longer re-runs the full first-run sweep to build collated data (was 2x the
work for no reason). instead loads/builds base-4 cache once, then compose_collated serves
any target batch_size. cached sweep only varies batch_size (no queue_size — no collation thread).
- --skip-first: skip the first-run sweep entirely if disk cache exists. if cache is missing,
runs a single bs=4 eval in minimal time to create it, then proceeds to cached sweep.
- tqdm progress bars everywhere: old sequential baseline (per-example with task name),
first-run sweep (double bar: outer=combo progress, inner=per-example), cache building
(per-task), cached sweep (double bar). task names left-padded to max label length so the
bar doesn't shift.
- tokenizer identity via file_checksum (SHA-256 of tokenizer.pkl/tokenizer.json on disk),
not encode-output hashing. HF models fall back to hashing the repo name.
This commit is contained in:
parent
8695280566
commit
7fa30f5ee3
|
|
@ -277,11 +277,37 @@ def prepare_task_data(tokenizer, data, task_meta, max_seq_len=None):
|
|||
return prepared
|
||||
|
||||
|
||||
def _forward_batches(model, collated, data, device):
|
||||
"""Run GPU forward passes on pre-collated batches, return per-example correctness tensor."""
|
||||
def _prefetch_to_device(tensor, device):
|
||||
"""Pin and async-transfer a CPU tensor to GPU, overlapping with current GPU work."""
|
||||
return tensor.pin_memory().to(device, non_blocking=True)
|
||||
|
||||
|
||||
def _forward_batches(model, collated, data, device, pbar=None):
|
||||
"""Run GPU forward passes on pre-collated batches, return per-example correctness tensor.
|
||||
|
||||
Uses double-buffered prefetching on CUDA: while the GPU processes batch N,
|
||||
batch N+1 is pinned and DMA-transferred asynchronously, keeping the GPU fed.
|
||||
"""
|
||||
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
|
||||
for combined_ids, batch_meta in collated:
|
||||
combined_ids = combined_ids.to(device)
|
||||
if not collated:
|
||||
return correct
|
||||
|
||||
use_prefetch = torch.cuda.is_available() and 'cuda' in str(device)
|
||||
|
||||
# Prefetch first batch
|
||||
if use_prefetch:
|
||||
next_ids = _prefetch_to_device(collated[0][0], device)
|
||||
else:
|
||||
next_ids = collated[0][0].to(device)
|
||||
|
||||
for i, (_, batch_meta) in enumerate(collated):
|
||||
combined_ids = next_ids
|
||||
# Start async transfer of next batch while GPU computes on current
|
||||
if i + 1 < len(collated):
|
||||
if use_prefetch:
|
||||
next_ids = _prefetch_to_device(collated[i + 1][0], device)
|
||||
else:
|
||||
next_ids = collated[i + 1][0].to(device)
|
||||
|
||||
losses, predictions = forward_model(model, combined_ids)
|
||||
|
||||
|
|
@ -294,11 +320,63 @@ def _forward_batches(model, collated, data, device):
|
|||
)
|
||||
correct[idx] = float(is_correct)
|
||||
offset += n
|
||||
if pbar is not None:
|
||||
pbar.update(len(batch_meta))
|
||||
return correct
|
||||
|
||||
|
||||
def compose_collated(base_collated, target_batch_size, base_batch_size=4, pad_token_id=0):
|
||||
"""Compose base-sized collated batches into target-sized batches.
|
||||
|
||||
Supports both merging (target > base) by concatenating consecutive groups,
|
||||
and splitting (target < base) by slicing along example boundaries.
|
||||
Examples are sorted by seq_len within each base batch, so splitting can
|
||||
trim trailing padding columns for efficiency.
|
||||
"""
|
||||
if target_batch_size == base_batch_size:
|
||||
return base_collated
|
||||
elif target_batch_size > base_batch_size:
|
||||
# Merge consecutive base batches
|
||||
n_merge = target_batch_size // base_batch_size
|
||||
composed = []
|
||||
for i in range(0, len(base_collated), n_merge):
|
||||
group = base_collated[i:i + n_merge]
|
||||
if len(group) == 1:
|
||||
composed.append(group[0])
|
||||
continue
|
||||
max_len = max(ids.shape[1] for ids, _ in group)
|
||||
parts = []
|
||||
merged_meta = []
|
||||
for ids, meta in group:
|
||||
if ids.shape[1] < max_len:
|
||||
pad = torch.full((ids.shape[0], max_len - ids.shape[1]), pad_token_id, dtype=ids.dtype)
|
||||
ids = torch.cat([ids, pad], dim=1)
|
||||
parts.append(ids)
|
||||
merged_meta.extend(meta)
|
||||
composed.append((torch.cat(parts, dim=0), merged_meta))
|
||||
return composed
|
||||
else:
|
||||
# Split base batches into smaller chunks
|
||||
composed = []
|
||||
for combined_ids, batch_meta in base_collated:
|
||||
for chunk_start in range(0, len(batch_meta), target_batch_size):
|
||||
chunk_meta = batch_meta[chunk_start:chunk_start + target_batch_size]
|
||||
row_start = sum(m[1] for m in batch_meta[:chunk_start])
|
||||
row_end = row_start + sum(m[1] for m in chunk_meta)
|
||||
chunk_ids = combined_ids[row_start:row_end]
|
||||
# Trim trailing padding (examples sorted by seq_len, so chunks
|
||||
# near the start of a base batch may need fewer columns)
|
||||
non_pad = (chunk_ids != pad_token_id)
|
||||
if non_pad.any():
|
||||
last_col = non_pad.any(dim=0).nonzero()[-1].item() + 1
|
||||
if last_col < chunk_ids.shape[1]:
|
||||
chunk_ids = chunk_ids[:, :last_col].contiguous()
|
||||
composed.append((chunk_ids, chunk_meta))
|
||||
return composed
|
||||
|
||||
|
||||
def evaluate_task(model, data, device, batch_size=4, queue_size=2, prepared=None,
|
||||
collated=None, tokenizer=None, task_meta=None):
|
||||
collated=None, tokenizer=None, task_meta=None, pbar=None):
|
||||
"""
|
||||
Evaluate one task across many examples with batched GPU forward passes.
|
||||
Examples are sorted by sequence length so similar-length sequences are batched
|
||||
|
|
@ -316,7 +394,7 @@ def evaluate_task(model, data, device, batch_size=4, queue_size=2, prepared=None
|
|||
|
||||
if collated is not None:
|
||||
# Fast path: just GPU forward passes, no threads
|
||||
correct = _forward_batches(model, collated, data, device)
|
||||
correct = _forward_batches(model, collated, data, device, pbar=pbar)
|
||||
else:
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
|
|
@ -325,21 +403,34 @@ def evaluate_task(model, data, device, batch_size=4, queue_size=2, prepared=None
|
|||
max_seq_len = getattr(model, 'max_seq_len', None)
|
||||
prepared = prepare_task_data(tokenizer, data, task_meta, max_seq_len)
|
||||
|
||||
# Collation thread pipelined with GPU forward passes
|
||||
# Collation thread pipelined with GPU forward passes.
|
||||
# Double-buffered: while GPU processes batch N, batch N+1 is
|
||||
# pin_memory()'d and DMA-transferred asynchronously.
|
||||
queue = Queue(maxsize=queue_size)
|
||||
collator = Thread(target=_collate_batches, args=(prepared, batch_size, queue), daemon=True)
|
||||
collator.start()
|
||||
|
||||
use_prefetch = torch.cuda.is_available() and 'cuda' in str(device)
|
||||
def transfer(tensor):
|
||||
return _prefetch_to_device(tensor, device) if use_prefetch else tensor.to(device)
|
||||
|
||||
collated = []
|
||||
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
|
||||
while True:
|
||||
item = queue.get()
|
||||
if item is None:
|
||||
break
|
||||
collated.append(item)
|
||||
combined_ids, batch_meta = item
|
||||
|
||||
combined_ids = combined_ids.to(device)
|
||||
# Prime: get first batch and start its transfer
|
||||
item = queue.get()
|
||||
if item is not None:
|
||||
next_ids = transfer(item[0])
|
||||
|
||||
while item is not None:
|
||||
collated.append(item)
|
||||
combined_ids = next_ids
|
||||
_, batch_meta = item
|
||||
|
||||
# Start async transfer of next batch (overlaps with forward pass below)
|
||||
item = queue.get()
|
||||
if item is not None:
|
||||
next_ids = transfer(item[0])
|
||||
|
||||
losses, predictions = forward_model(model, combined_ids)
|
||||
|
||||
|
|
@ -352,6 +443,8 @@ def evaluate_task(model, data, device, batch_size=4, queue_size=2, prepared=None
|
|||
)
|
||||
correct[idx] = float(is_correct)
|
||||
offset += n
|
||||
if pbar is not None:
|
||||
pbar.update(len(batch_meta))
|
||||
|
||||
collator.join()
|
||||
del prepared
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ import json
|
|||
import yaml
|
||||
import shutil
|
||||
import random
|
||||
import hashlib
|
||||
import zipfile
|
||||
import tempfile
|
||||
import argparse
|
||||
|
|
@ -113,6 +114,22 @@ _prev_centered = {} # {label: centered_result} — previous run for delta d
|
|||
_prev_core = None # previous core_metric
|
||||
|
||||
|
||||
def _get_disk_cache_dir(max_per_task):
|
||||
"""Get disk cache dir for base-4 collated batches, keyed by tokenizer file hash.
|
||||
Returns None if no local tokenizer file is found (e.g. HuggingFace models)."""
|
||||
base_dir = get_base_dir()
|
||||
for fname in ("tokenizer.pkl", "tokenizer.json"):
|
||||
path = os.path.join(base_dir, "tokenizer", fname)
|
||||
if os.path.exists(path):
|
||||
h = hashlib.sha256()
|
||||
with open(path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
h.update(chunk)
|
||||
tok_hash = h.hexdigest()[:16]
|
||||
return os.path.join(base_dir, "core_token_cache", f"{tok_hash}_n{max_per_task}")
|
||||
return None
|
||||
|
||||
|
||||
def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
||||
"""
|
||||
Evaluate a base model on the CORE benchmark.
|
||||
|
|
@ -180,6 +197,22 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
|||
results = {}
|
||||
centered_results = {}
|
||||
cached_run = all(label in _batch_cache for label, _, _ in task_inputs)
|
||||
disk_cache_dir = _get_disk_cache_dir(max_per_task)
|
||||
|
||||
# Try loading from disk cache if in-memory cache is empty
|
||||
if not cached_run and disk_cache_dir is not None:
|
||||
all_on_disk = os.path.isdir(disk_cache_dir) and all(
|
||||
os.path.exists(os.path.join(disk_cache_dir, f"{label}.pt"))
|
||||
for label, _, _ in task_inputs
|
||||
)
|
||||
if all_on_disk:
|
||||
for label, _, _ in task_inputs:
|
||||
d = torch.load(os.path.join(disk_cache_dir, f"{label}.pt"), weights_only=False)
|
||||
_batch_cache[label] = d['collated']
|
||||
cached_run = True
|
||||
print0(" (loaded collated batches from disk cache)")
|
||||
|
||||
first_run = not cached_run # track whether we did prepare+collate (for disk save)
|
||||
|
||||
if not cached_run:
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
|
@ -221,6 +254,16 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
|||
if not cached_run:
|
||||
executor.shutdown(wait=False)
|
||||
|
||||
# Save collated batches to disk after first run (so bench/future runs skip prepare+collate)
|
||||
if first_run and disk_cache_dir is not None:
|
||||
pad_id = tokenizer.get_bos_token_id()
|
||||
os.makedirs(disk_cache_dir, exist_ok=True)
|
||||
for label, _, _ in task_inputs:
|
||||
if label in _batch_cache:
|
||||
torch.save({'collated': _batch_cache[label], 'pad_token_id': pad_id},
|
||||
os.path.join(disk_cache_dir, f"{label}.pt"))
|
||||
print0(f" (saved collated batches to {disk_cache_dir})")
|
||||
|
||||
core_metric = sum(centered_results.values()) / len(centered_results)
|
||||
if _prev_core is not None:
|
||||
d = core_metric - _prev_core
|
||||
|
|
|
|||
|
|
@ -20,10 +20,12 @@ import time
|
|||
import yaml
|
||||
import random
|
||||
import shutil
|
||||
import hashlib
|
||||
import zipfile
|
||||
import tempfile
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -34,12 +36,13 @@ from nanochat.checkpoint_manager import load_model
|
|||
from nanochat.core_eval import (
|
||||
forward_model, prepare_example, check_result, stack_sequences,
|
||||
prepare_task_data, _collate_batches, _forward_batches, evaluate_task,
|
||||
compose_collated,
|
||||
render_prompts_mc, render_prompts_schema, render_prompts_lm,
|
||||
batch_sequences_mc, batch_sequences_schema, batch_sequences_lm,
|
||||
)
|
||||
|
||||
# ---- eval bundle loading (reused from base_eval) ----
|
||||
|
||||
torch.backends.fp32_precision = "tf32"
|
||||
EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip"
|
||||
|
||||
def place_eval_bundle(file_path):
|
||||
|
|
@ -79,6 +82,69 @@ def load_tasks(max_per_task=-1):
|
|||
task_inputs.append((label, task_meta, data))
|
||||
return task_inputs
|
||||
|
||||
BASE_BATCH_SIZE = 4
|
||||
|
||||
def file_checksum(path):
|
||||
"""SHA-256 checksum of a file, truncated to 16 hex chars."""
|
||||
h = hashlib.sha256()
|
||||
with open(path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()[:16]
|
||||
|
||||
|
||||
def collate_prepared(prepared, batch_size):
|
||||
"""Collate prepared examples into batches (non-threaded). Returns (collated, pad_token_id)."""
|
||||
pad_id = prepared[0][1]['pad_token_id']
|
||||
collated = []
|
||||
for batch_start in range(0, len(prepared), batch_size):
|
||||
batch = prepared[batch_start:batch_start + batch_size]
|
||||
batch_preps = [p for _, p in batch]
|
||||
max_len = max(p['seq_len'] for p in batch_preps)
|
||||
total_rows = sum(p['num_options'] for p in batch_preps)
|
||||
combined_ids = torch.full((total_rows, max_len), pad_id, dtype=torch.long)
|
||||
batch_meta = []
|
||||
offset = 0
|
||||
for idx, p in batch:
|
||||
n, sl = p['num_options'], p['seq_len']
|
||||
combined_ids[offset:offset+n, :sl] = p['input_ids']
|
||||
batch_meta.append((idx, n, p['start_idxs'], p['end_idxs'], p['gold'], p['task_type']))
|
||||
offset += n
|
||||
collated.append((combined_ids, batch_meta))
|
||||
return collated, pad_id
|
||||
|
||||
|
||||
def build_or_load_base_collated(tok_hash, tokenizer, task_inputs, max_seq_len, max_per_task):
|
||||
"""Build or load base-4 collated data for all tasks, with disk caching."""
|
||||
base_dir = get_base_dir()
|
||||
cache_dir = os.path.join(base_dir, "core_token_cache", f"{tok_hash}_n{max_per_task}")
|
||||
|
||||
all_cached = os.path.isdir(cache_dir) and all(
|
||||
os.path.exists(os.path.join(cache_dir, f"{label}.pt"))
|
||||
for label, _, _ in task_inputs
|
||||
)
|
||||
|
||||
base_cache = {} # label -> (collated, pad_token_id)
|
||||
if all_cached:
|
||||
print0(f"Loading base-{BASE_BATCH_SIZE} collated cache from {cache_dir}")
|
||||
for label, _, _ in tqdm(task_inputs, desc="Loading cache", leave=False):
|
||||
data = torch.load(os.path.join(cache_dir, f"{label}.pt"), weights_only=False)
|
||||
base_cache[label] = (data['collated'], data['pad_token_id'])
|
||||
else:
|
||||
print0(f"Building base-{BASE_BATCH_SIZE} collated cache (saving to {cache_dir})")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
for label, task_meta, data in tqdm(task_inputs, desc="Preparing tasks", leave=False):
|
||||
prepared = prepare_task_data(tokenizer, data, task_meta, max_seq_len)
|
||||
collated, pad_id = collate_prepared(prepared, BASE_BATCH_SIZE)
|
||||
base_cache[label] = (collated, pad_id)
|
||||
torch.save({'collated': collated, 'pad_token_id': pad_id},
|
||||
os.path.join(cache_dir, f"{label}.pt"))
|
||||
del prepared
|
||||
print0(f"Saved {len(base_cache)} task caches to {cache_dir}")
|
||||
|
||||
return base_cache
|
||||
|
||||
|
||||
# ---- old sequential evaluation (baseline) ----
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
@ -129,7 +195,7 @@ def evaluate_example_old(idx, model, tokenizer, data, device, task_meta):
|
|||
return check_result(losses, predictions, input_ids, start_idxs, end_idxs, item.get('gold', None), task_type)
|
||||
|
||||
|
||||
def evaluate_task_old(model, tokenizer, data, device, task_meta):
|
||||
def evaluate_task_old(model, tokenizer, data, device, task_meta, pbar=None):
|
||||
"""Original sequential evaluate_task."""
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
|
|
@ -137,6 +203,8 @@ def evaluate_task_old(model, tokenizer, data, device, task_meta):
|
|||
for idx in range(rank, len(data), world_size):
|
||||
is_correct = evaluate_example_old(idx, model, tokenizer, data, device, task_meta)
|
||||
correct[idx] = float(is_correct)
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
if world_size > 1:
|
||||
dist.barrier()
|
||||
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
|
||||
|
|
@ -170,37 +238,48 @@ def bench_old(model, tokenizer, task_inputs, device):
|
|||
sync_cuda()
|
||||
t0 = time.time()
|
||||
results = {}
|
||||
total = sum(len(data) for _, _, data in task_inputs)
|
||||
max_label_len = max(len(label) for label, _, _ in task_inputs)
|
||||
pbar = tqdm(total=total, desc="Sequential", leave=False)
|
||||
for label, task_meta, data in task_inputs:
|
||||
acc = evaluate_task_old(model, tokenizer, data, device, task_meta)
|
||||
pbar.set_description(f"Sequential: {label:<{max_label_len}}")
|
||||
acc = evaluate_task_old(model, tokenizer, data, device, task_meta, pbar=pbar)
|
||||
results[label] = acc
|
||||
pbar.close()
|
||||
sync_cuda()
|
||||
return time.time() - t0, results
|
||||
|
||||
|
||||
def bench_new_first(model, tokenizer, task_inputs, device, batch_size, queue_size):
|
||||
def bench_new_first(model, tokenizer, task_inputs, device, batch_size, queue_size, pbar=None):
|
||||
"""Benchmark new batched evaluation (first run, no cache)."""
|
||||
sync_cuda()
|
||||
t0 = time.time()
|
||||
results = {}
|
||||
collated_cache = {}
|
||||
max_seq_len = getattr(model, 'max_seq_len', None)
|
||||
max_label_len = max(len(label) for label, _, _ in task_inputs)
|
||||
for label, task_meta, data in task_inputs:
|
||||
if pbar is not None:
|
||||
pbar.set_description(f"{label:<{max_label_len}}")
|
||||
prepared = prepare_task_data(tokenizer, data, task_meta, max_seq_len)
|
||||
acc, collated = evaluate_task(model, data, device, batch_size=batch_size,
|
||||
queue_size=queue_size, prepared=prepared)
|
||||
queue_size=queue_size, prepared=prepared, pbar=pbar)
|
||||
results[label] = acc
|
||||
collated_cache[label] = collated
|
||||
sync_cuda()
|
||||
return time.time() - t0, results, collated_cache
|
||||
|
||||
|
||||
def bench_new_cached(model, task_inputs, device, collated_cache):
|
||||
def bench_new_cached(model, task_inputs, device, collated_cache, pbar=None):
|
||||
"""Benchmark new batched evaluation (cached run, forward only)."""
|
||||
sync_cuda()
|
||||
t0 = time.time()
|
||||
results = {}
|
||||
max_label_len = max(len(label) for label, _, _ in task_inputs)
|
||||
for label, task_meta, data in task_inputs:
|
||||
acc, _ = evaluate_task(model, data, device, collated=collated_cache[label])
|
||||
if pbar is not None:
|
||||
pbar.set_description(f"{label:<{max_label_len}}")
|
||||
acc, _ = evaluate_task(model, data, device, collated=collated_cache[label], pbar=pbar)
|
||||
results[label] = acc
|
||||
sync_cuda()
|
||||
return time.time() - t0, results
|
||||
|
|
@ -232,6 +311,7 @@ def main():
|
|||
parser.add_argument('--batch-sizes', type=str, default='1,2,4,8,16,32,64', help='Comma-separated batch sizes to sweep')
|
||||
parser.add_argument('--queue-sizes', type=str, default='2,4,8,16', help='Comma-separated queue sizes to sweep')
|
||||
parser.add_argument('--skip-old', action='store_true', help='Skip old sequential baseline (slow)')
|
||||
parser.add_argument('--skip-first', action='store_true', help='Skip first-run sweep (requires cached collated data)')
|
||||
args = parser.parse_args()
|
||||
|
||||
batch_sizes = [int(x) for x in args.batch_sizes.split(',')]
|
||||
|
|
@ -272,52 +352,131 @@ def main():
|
|||
print0("")
|
||||
|
||||
# ---- 2. Sweep batch_size x queue_size for first run ----
|
||||
# Compute tok_hash early — needed for both skip-first check and cache loading
|
||||
max_seq_len = getattr(model, 'max_seq_len', None)
|
||||
if args.hf_path is not None:
|
||||
tok_hash = hashlib.sha256(args.hf_path.encode()).hexdigest()[:16]
|
||||
else:
|
||||
base_dir = get_base_dir()
|
||||
for fname in ("tokenizer.pkl", "tokenizer.json"):
|
||||
tok_path = os.path.join(base_dir, "tokenizer", fname)
|
||||
if os.path.exists(tok_path):
|
||||
tok_hash = file_checksum(tok_path)
|
||||
break
|
||||
|
||||
# Check if we can skip the first-run sweep
|
||||
best_time = None
|
||||
best_params = None
|
||||
if args.skip_first:
|
||||
cache_dir = os.path.join(get_base_dir(), "core_token_cache", f"{tok_hash}_n{args.max_per_task}")
|
||||
has_cache = os.path.isdir(cache_dir) and all(
|
||||
os.path.exists(os.path.join(cache_dir, f"{label}.pt"))
|
||||
for label, _, _ in task_inputs
|
||||
)
|
||||
if has_cache:
|
||||
print0("Skipping first-run sweep (--skip-first, cache exists)")
|
||||
print0("")
|
||||
else:
|
||||
print0(f"--skip-first: cache missing, running single bs={BASE_BATCH_SIZE} eval to create it...")
|
||||
pbar = tqdm(total=total_examples, desc="Creating cache", leave=False)
|
||||
with autocast_ctx:
|
||||
_, _, collated_cache = bench_new_first(model, tokenizer, task_inputs, device,
|
||||
batch_size=BASE_BATCH_SIZE, queue_size=2, pbar=pbar)
|
||||
pbar.close()
|
||||
pad_id = tokenizer.get_bos_token_id()
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
for label in collated_cache:
|
||||
torch.save({'collated': collated_cache[label], 'pad_token_id': pad_id},
|
||||
os.path.join(cache_dir, f"{label}.pt"))
|
||||
print0(f"Cache created ({len(collated_cache)} tasks)")
|
||||
print0("")
|
||||
|
||||
if not args.skip_first:
|
||||
print0("=" * 80)
|
||||
print0("NEW: Batched evaluation — hyperparameter sweep (first run)")
|
||||
print0("=" * 80)
|
||||
print0("")
|
||||
|
||||
# Header
|
||||
qs_header = "".join(f"{'q=' + str(q):>10}" for q in queue_sizes)
|
||||
print0(f" {'batch_size':>10}{qs_header}")
|
||||
print0(f" {'':>10}" + "-" * (10 * len(queue_sizes)))
|
||||
|
||||
best_time = float('inf')
|
||||
best_params = None
|
||||
best_collated = None
|
||||
sweep_results = {}
|
||||
total_combos = len(batch_sizes) * len(queue_sizes)
|
||||
outer_pbar = tqdm(total=total_combos, desc="First-run sweep", leave=False, position=0)
|
||||
inner_pbar = tqdm(total=total_examples, desc="", leave=False, position=1)
|
||||
|
||||
for bs in batch_sizes:
|
||||
row = f" {bs:>10}"
|
||||
for qs in queue_sizes:
|
||||
outer_pbar.set_description(f"First-run: bs={bs} qs={qs}")
|
||||
inner_pbar.reset()
|
||||
with autocast_ctx:
|
||||
t, results, collated_cache = bench_new_first(model, tokenizer, task_inputs, device, bs, qs, pbar=inner_pbar)
|
||||
sweep_results[(bs, qs)] = t
|
||||
row += f"{t:>9.2f}s"
|
||||
if t < best_time:
|
||||
best_time = t
|
||||
best_params = (bs, qs)
|
||||
best_collated = collated_cache
|
||||
best_first_results = results
|
||||
outer_pbar.update(1)
|
||||
outer_pbar.write(row)
|
||||
inner_pbar.close()
|
||||
outer_pbar.close()
|
||||
|
||||
print0("")
|
||||
print0(f" Best: batch_size={best_params[0]}, queue_size={best_params[1]} -> {best_time:.2f}s ({total_examples / best_time:.1f} examples/s)")
|
||||
|
||||
# Verify correctness
|
||||
if old_results is not None:
|
||||
verify_results(old_results, best_first_results, label="new-first")
|
||||
print0("")
|
||||
|
||||
# ---- 3. Build/load base-4 collated cache ----
|
||||
base_cache = build_or_load_base_collated(tok_hash, tokenizer, task_inputs, max_seq_len, args.max_per_task)
|
||||
|
||||
# ---- 4. Cached run sweep (forward only, composed from base-4) ----
|
||||
print0("=" * 80)
|
||||
print0("NEW: Batched evaluation — hyperparameter sweep (first run)")
|
||||
print0(f"NEW: Cached run (forward only, composed from base-{BASE_BATCH_SIZE})")
|
||||
print0("=" * 80)
|
||||
print0("")
|
||||
|
||||
# Header
|
||||
qs_header = "".join(f"{'q=' + str(q):>10}" for q in queue_sizes)
|
||||
print0(f" {'batch_size':>10}{qs_header}")
|
||||
print0(f" {'':>10}" + "-" * (10 * len(queue_sizes)))
|
||||
|
||||
best_time = float('inf')
|
||||
best_params = None
|
||||
best_collated = None
|
||||
sweep_results = {}
|
||||
best_cached_time = float('inf')
|
||||
best_cached_params = None
|
||||
outer_pbar = tqdm(total=len(batch_sizes), desc="Cached sweep", leave=False, position=0)
|
||||
inner_pbar = tqdm(total=total_examples, desc="", leave=False, position=1)
|
||||
|
||||
for bs in batch_sizes:
|
||||
row = f" {bs:>10}"
|
||||
for qs in queue_sizes:
|
||||
with autocast_ctx:
|
||||
t, results, collated_cache = bench_new_first(model, tokenizer, task_inputs, device, bs, qs)
|
||||
sweep_results[(bs, qs)] = t
|
||||
row += f"{t:>9.2f}s"
|
||||
if t < best_time:
|
||||
best_time = t
|
||||
best_params = (bs, qs)
|
||||
best_collated = collated_cache
|
||||
best_first_results = results
|
||||
print0(row)
|
||||
outer_pbar.set_description(f"Cached: bs={bs}")
|
||||
inner_pbar.reset()
|
||||
# Compose from base-4 to target batch_size (merge or split)
|
||||
composed_cache = {}
|
||||
for label, (collated, pad_id) in base_cache.items():
|
||||
composed_cache[label] = compose_collated(collated, bs, BASE_BATCH_SIZE, pad_id)
|
||||
|
||||
with autocast_ctx:
|
||||
t, cached_results = bench_new_cached(model, task_inputs, device, composed_cache, pbar=inner_pbar)
|
||||
|
||||
outer_pbar.write(f" batch_size={bs:>3}: {t:.2f}s ({total_examples / t:.1f} examples/s)")
|
||||
outer_pbar.update(1)
|
||||
|
||||
if t < best_cached_time:
|
||||
best_cached_time = t
|
||||
best_cached_params = bs
|
||||
best_cached_results = cached_results
|
||||
inner_pbar.close()
|
||||
outer_pbar.close()
|
||||
|
||||
print0("")
|
||||
print0(f" Best: batch_size={best_params[0]}, queue_size={best_params[1]} -> {best_time:.2f}s ({total_examples / best_time:.1f} examples/s)")
|
||||
print0(f" Best: batch_size={best_cached_params} -> {best_cached_time:.2f}s ({total_examples / best_cached_time:.1f} examples/s)")
|
||||
|
||||
# Verify correctness
|
||||
if old_results is not None:
|
||||
verify_results(old_results, best_first_results, label="new-first")
|
||||
print0("")
|
||||
|
||||
# ---- 3. Cached run (forward only) ----
|
||||
print0("=" * 80)
|
||||
print0("NEW: Cached run (forward only, using best params)")
|
||||
print0("=" * 80)
|
||||
with autocast_ctx:
|
||||
cached_time, cached_results = bench_new_cached(model, task_inputs, device, best_collated)
|
||||
print0(f" Time: {cached_time:.2f}s ({total_examples / cached_time:.1f} examples/s)")
|
||||
if old_results is not None:
|
||||
verify_results(old_results, cached_results, label="new-cached")
|
||||
verify_results(old_results, best_cached_results, label="new-cached")
|
||||
print0("")
|
||||
|
||||
# ---- Summary ----
|
||||
|
|
@ -326,11 +485,13 @@ def main():
|
|||
print0("=" * 80)
|
||||
if old_results is not None:
|
||||
print0(f" Old (sequential): {old_time:>8.2f}s")
|
||||
print0(f" New (first run): {best_time:>8.2f}s batch_size={best_params[0]}, queue_size={best_params[1]}")
|
||||
print0(f" New (cached): {cached_time:>8.2f}s")
|
||||
if best_time is not None:
|
||||
print0(f" New (first run): {best_time:>8.2f}s batch_size={best_params[0]}, queue_size={best_params[1]}")
|
||||
print0(f" New (cached): {best_cached_time:>8.2f}s batch_size={best_cached_params}")
|
||||
if old_results is not None:
|
||||
print0(f" Speedup (first): {old_time / best_time:>8.2f}x")
|
||||
print0(f" Speedup (cached): {old_time / cached_time:>8.2f}x")
|
||||
if best_time is not None:
|
||||
print0(f" Speedup (first): {old_time / best_time:>8.2f}x")
|
||||
print0(f" Speedup (cached): {old_time / best_cached_time:>8.2f}x")
|
||||
|
||||
compute_cleanup()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user