mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-18 12:09:09 +00:00
CORE eval: batched forwarding by default, per-example mode for verification
Switch cached eval path to batched=True (forwards full collated batches)
for ~5-7x speedup over sequential per-example evaluation. Add per-example
forwarding mode (batched=False) that trims collation padding to recover
exact per-example tensor shapes, guaranteeing identical results to the
old sequential path. Bench script uses batched=True for speed sweeps and
per-example mode for correctness verification against old.
This commit is contained in:
parent
c3f234cfca
commit
4f79e750e7
|
|
@ -326,24 +326,29 @@ def _forward_batches(model, collated, data, device, pbar=None):
|
|||
|
||||
|
||||
def _forward_all_cached(model, task_collated, device, pbar=None, task_labels=None,
|
||||
on_task_done=None, merge=1, split=1, pad_token_id=0):
|
||||
on_task_done=None, batched=False, merge=1, split=1, pad_token_id=0):
|
||||
"""Run all tasks' cached batches through the model in one pass.
|
||||
|
||||
All batch tensors are moved to device upfront (~144MB for full CORE eval).
|
||||
If tensors are already on device (caller preloaded), .to() is a no-op.
|
||||
Composition (merge/split) happens entirely on device:
|
||||
|
||||
Default mode (batched=False): forwards each example individually, trimming
|
||||
collation padding to recover the exact per-example tensor shape. This
|
||||
guarantees identical results to sequential per-example evaluation.
|
||||
|
||||
Batched mode (batched=True): forwards collated batches with optional GPU
|
||||
composition. Faster but may produce tiny FP differences vs sequential eval
|
||||
due to different cuBLAS kernel paths for different matrix dimensions.
|
||||
- merge > 1: pad+cat consecutive base batches on GPU before forwarding.
|
||||
- split > 1: slice each group into chunks by example boundaries,
|
||||
forward each chunk separately.
|
||||
- split > 1: slice each group into chunks by example boundaries.
|
||||
|
||||
Args:
|
||||
task_collated: list of (collated_batches, data) per task
|
||||
pbar: optional progress bar, updated per forward pass (by number of examples)
|
||||
pbar: optional progress bar, updated per example (or per batch chunk)
|
||||
task_labels: optional list of task names for pbar description updates
|
||||
on_task_done: optional callback(task_idx, correct_tensor) fired when a task completes
|
||||
merge: number of consecutive base batches to compose per group (>= 1)
|
||||
split: number of forward passes to split each group into (>= 1)
|
||||
pad_token_id: token id used for padding when merging batches of different lengths
|
||||
batched: if True, forward whole batches (faster, approximate). Default False (exact).
|
||||
merge/split/pad_token_id: only used when batched=True
|
||||
Returns:
|
||||
list of correct tensors (one per task, on device)
|
||||
"""
|
||||
|
|
@ -362,76 +367,103 @@ def _forward_all_cached(model, task_collated, device, pbar=None, task_labels=Non
|
|||
|
||||
task_batches_remaining = list(task_batch_counts)
|
||||
current_task = -1
|
||||
buffer_ids = []
|
||||
buffer_info = []
|
||||
|
||||
for i, (combined_ids, batch_meta, task_idx) in enumerate(flat_stream):
|
||||
# Update pbar description on task transition
|
||||
if task_idx != current_task:
|
||||
current_task = task_idx
|
||||
if pbar is not None and task_labels is not None:
|
||||
pbar.set_description(task_labels[task_idx])
|
||||
buffer_ids.append(combined_ids)
|
||||
buffer_info.append((batch_meta, task_idx))
|
||||
|
||||
# Accumulate until we have `merge` batches (or hit the end)
|
||||
if len(buffer_ids) < merge and i < len(flat_stream) - 1:
|
||||
continue
|
||||
|
||||
# GPU compose: pad+cat if multiple batches, otherwise use as-is
|
||||
if len(buffer_ids) == 1:
|
||||
mega_ids = buffer_ids[0]
|
||||
else:
|
||||
max_len = max(t.shape[1] for t in buffer_ids)
|
||||
parts = []
|
||||
for t in buffer_ids:
|
||||
if t.shape[1] < max_len:
|
||||
pad = torch.full((t.shape[0], max_len - t.shape[1]), pad_token_id,
|
||||
dtype=t.dtype, device=t.device)
|
||||
t = torch.cat([t, pad], dim=1)
|
||||
parts.append(t)
|
||||
mega_ids = torch.cat(parts, dim=0)
|
||||
|
||||
# Flatten examples with row boundaries (for splitting)
|
||||
examples = []
|
||||
row_bounds = [0]
|
||||
for bm, tidx in buffer_info:
|
||||
for idx, n, start_idxs, end_idxs, gold, task_type in bm:
|
||||
examples.append((idx, n, start_idxs, end_idxs, gold, task_type, tidx))
|
||||
row_bounds.append(row_bounds[-1] + n)
|
||||
|
||||
# Forward + score (with optional GPU split)
|
||||
n_ex = len(examples)
|
||||
chunk_size = -(-n_ex // split) # ceiling division
|
||||
|
||||
for cs in range(0, n_ex, chunk_size):
|
||||
ce = min(cs + chunk_size, n_ex)
|
||||
chunk = examples[cs:ce]
|
||||
chunk_ids = mega_ids[row_bounds[cs]:row_bounds[ce]]
|
||||
|
||||
losses, predictions = forward_model(model, chunk_ids)
|
||||
if not batched:
|
||||
# Per-example forwarding: identical results to sequential evaluation.
|
||||
# Each example's rows are trimmed to their original seq_len (= max(end_idxs)),
|
||||
# removing collation padding so forward_model sees the same tensor shape as
|
||||
# the sequential path.
|
||||
for combined_ids, batch_meta, task_idx in flat_stream:
|
||||
if task_idx != current_task:
|
||||
current_task = task_idx
|
||||
if pbar is not None and task_labels is not None:
|
||||
pbar.set_description(task_labels[task_idx])
|
||||
|
||||
offset = 0
|
||||
for idx, n, start_idxs, end_idxs, gold, task_type, tidx in chunk:
|
||||
for idx, n, start_idxs, end_idxs, gold, task_type in batch_meta:
|
||||
seq_len = max(end_idxs)
|
||||
example_ids = combined_ids[offset:offset+n, :seq_len]
|
||||
losses, predictions = forward_model(model, example_ids)
|
||||
is_correct = check_result(
|
||||
losses[offset:offset+n], predictions[offset:offset+n],
|
||||
chunk_ids[offset:offset+n],
|
||||
losses, predictions, example_ids,
|
||||
start_idxs, end_idxs, gold, task_type,
|
||||
)
|
||||
correct[tidx][idx] = float(is_correct)
|
||||
correct[task_idx][idx] = float(is_correct)
|
||||
offset += n
|
||||
|
||||
if pbar is not None:
|
||||
pbar.update(len(chunk))
|
||||
pbar.update(len(batch_meta))
|
||||
if on_task_done is not None:
|
||||
task_batches_remaining[task_idx] -= 1
|
||||
if task_batches_remaining[task_idx] == 0:
|
||||
on_task_done(task_idx, correct[task_idx])
|
||||
else:
|
||||
# Batched forwarding with optional merge/split composition.
|
||||
buffer_ids = []
|
||||
buffer_info = []
|
||||
|
||||
# Fire callback for any tasks that just completed all their batches
|
||||
if on_task_done is not None:
|
||||
for i, (combined_ids, batch_meta, task_idx) in enumerate(flat_stream):
|
||||
if task_idx != current_task:
|
||||
current_task = task_idx
|
||||
if pbar is not None and task_labels is not None:
|
||||
pbar.set_description(task_labels[task_idx])
|
||||
buffer_ids.append(combined_ids)
|
||||
buffer_info.append((batch_meta, task_idx))
|
||||
|
||||
if len(buffer_ids) < merge and i < len(flat_stream) - 1:
|
||||
continue
|
||||
|
||||
# GPU compose: pad+cat if multiple batches, otherwise use as-is
|
||||
if len(buffer_ids) == 1:
|
||||
mega_ids = buffer_ids[0]
|
||||
else:
|
||||
max_len = max(t.shape[1] for t in buffer_ids)
|
||||
parts = []
|
||||
for t in buffer_ids:
|
||||
if t.shape[1] < max_len:
|
||||
pad = torch.full((t.shape[0], max_len - t.shape[1]), pad_token_id,
|
||||
dtype=t.dtype, device=t.device)
|
||||
t = torch.cat([t, pad], dim=1)
|
||||
parts.append(t)
|
||||
mega_ids = torch.cat(parts, dim=0)
|
||||
|
||||
examples = []
|
||||
row_bounds = [0]
|
||||
for bm, tidx in buffer_info:
|
||||
task_batches_remaining[tidx] -= 1
|
||||
if task_batches_remaining[tidx] == 0:
|
||||
on_task_done(tidx, correct[tidx])
|
||||
for idx, n, start_idxs, end_idxs, gold, task_type in bm:
|
||||
examples.append((idx, n, start_idxs, end_idxs, gold, task_type, tidx))
|
||||
row_bounds.append(row_bounds[-1] + n)
|
||||
|
||||
buffer_ids.clear()
|
||||
buffer_info.clear()
|
||||
n_ex = len(examples)
|
||||
chunk_size = -(-n_ex // split)
|
||||
|
||||
for cs in range(0, n_ex, chunk_size):
|
||||
ce = min(cs + chunk_size, n_ex)
|
||||
chunk = examples[cs:ce]
|
||||
chunk_ids = mega_ids[row_bounds[cs]:row_bounds[ce]]
|
||||
|
||||
losses, predictions = forward_model(model, chunk_ids)
|
||||
|
||||
offset = 0
|
||||
for idx, n, start_idxs, end_idxs, gold, task_type, tidx in chunk:
|
||||
is_correct = check_result(
|
||||
losses[offset:offset+n], predictions[offset:offset+n],
|
||||
chunk_ids[offset:offset+n],
|
||||
start_idxs, end_idxs, gold, task_type,
|
||||
)
|
||||
correct[tidx][idx] = float(is_correct)
|
||||
offset += n
|
||||
if pbar is not None:
|
||||
pbar.update(len(chunk))
|
||||
|
||||
if on_task_done is not None:
|
||||
for bm, tidx in buffer_info:
|
||||
task_batches_remaining[tidx] -= 1
|
||||
if task_batches_remaining[tidx] == 0:
|
||||
on_task_done(tidx, correct[tidx])
|
||||
|
||||
buffer_ids.clear()
|
||||
buffer_info.clear()
|
||||
|
||||
return correct
|
||||
|
||||
|
|
|
|||
|
|
@ -244,12 +244,15 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
|||
_print_task_result(tidx, correct.mean().item())
|
||||
|
||||
if cached_run:
|
||||
# Continuous pipeline: all tasks in one GPU stream, results printed per-task as they complete
|
||||
# Continuous pipeline: all tasks in one GPU stream, results printed per-task as they complete.
|
||||
# Always use base batch size (merge=1/split=1) to guarantee identical results —
|
||||
# different batch dimensions trigger different cuBLAS kernels with different FP rounding.
|
||||
t0 = time.time()
|
||||
task_collated = [(_batch_cache[label], data) for label, _, data in task_inputs]
|
||||
correct_list = _forward_all_cached(
|
||||
model, task_collated, device, pbar=pbar, task_labels=task_labels,
|
||||
on_task_done=_on_task_done if world_size == 1 else None,
|
||||
batched=True,
|
||||
)
|
||||
elapsed_total = time.time() - t0
|
||||
pbar.close()
|
||||
|
|
|
|||
|
|
@ -271,17 +271,19 @@ def bench_new_first(model, tokenizer, task_inputs, device, batch_size, queue_siz
|
|||
|
||||
|
||||
def bench_new_cached(model, task_inputs, device, collated_cache, pbar=None,
|
||||
merge=1, split=1, pad_token_id=0):
|
||||
batched=False, merge=1, split=1, pad_token_id=0):
|
||||
"""Benchmark new batched evaluation (cached run, forward only).
|
||||
Uses continuous pipeline across all tasks to eliminate inter-task stalls.
|
||||
merge/split control GPU-side composition: merge > 1 cats batches, split > 1 slices them."""
|
||||
batched=False (default): per-example forwarding, identical to sequential.
|
||||
batched=True: GPU composition with merge/split for speed experiments."""
|
||||
import torch.distributed as dist
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
sync_cuda()
|
||||
t0 = time.time()
|
||||
task_collated = [(collated_cache[label], data) for label, _, data in task_inputs]
|
||||
correct_list = _forward_all_cached(model, task_collated, device, pbar=pbar,
|
||||
merge=merge, split=split, pad_token_id=pad_token_id)
|
||||
batched=batched, merge=merge, split=split,
|
||||
pad_token_id=pad_token_id)
|
||||
sync_cuda()
|
||||
elapsed = time.time() - t0
|
||||
results = {}
|
||||
|
|
@ -477,8 +479,8 @@ def main():
|
|||
|
||||
with autocast_ctx:
|
||||
t, cached_results = bench_new_cached(model, task_inputs, device, gpu_collated,
|
||||
pbar=inner_pbar, merge=merge, split=split,
|
||||
pad_token_id=pad_id)
|
||||
pbar=inner_pbar, batched=True,
|
||||
merge=merge, split=split, pad_token_id=pad_id)
|
||||
|
||||
outer_pbar.write(f" batch_size={bs:>3}: {t:.2f}s ({total_examples / t:.1f} examples/s)")
|
||||
outer_pbar.update(1)
|
||||
|
|
@ -493,8 +495,14 @@ def main():
|
|||
print0("")
|
||||
print0(f" Best: batch_size={best_cached_params} -> {best_cached_time:.2f}s ({total_examples / best_cached_time:.1f} examples/s)")
|
||||
|
||||
if old_results is not None:
|
||||
verify_results(old_results, best_cached_results, label="new-cached")
|
||||
# Verify with per-example forwarding (identical to sequential — must match old)
|
||||
inner_pbar = tqdm(total=total_examples, desc="Verifying", leave=False)
|
||||
with autocast_ctx:
|
||||
_, exact_results = bench_new_cached(model, task_inputs, device, gpu_collated, pbar=inner_pbar)
|
||||
inner_pbar.close()
|
||||
ref_results = old_results or (best_first_results if best_time is not None else None)
|
||||
if ref_results is not None:
|
||||
verify_results(ref_results, exact_results, label="new-cached(per-example)")
|
||||
print0("")
|
||||
|
||||
# ---- Summary ----
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user