diff --git a/nanochat/core_eval.py b/nanochat/core_eval.py index 28f50f5b..0d6134fb 100644 --- a/nanochat/core_eval.py +++ b/nanochat/core_eval.py @@ -326,24 +326,29 @@ def _forward_batches(model, collated, data, device, pbar=None): def _forward_all_cached(model, task_collated, device, pbar=None, task_labels=None, - on_task_done=None, merge=1, split=1, pad_token_id=0): + on_task_done=None, batched=False, merge=1, split=1, pad_token_id=0): """Run all tasks' cached batches through the model in one pass. All batch tensors are moved to device upfront (~144MB for full CORE eval). If tensors are already on device (caller preloaded), .to() is a no-op. - Composition (merge/split) happens entirely on device: + + Default mode (batched=False): forwards each example individually, trimming + collation padding to recover the exact per-example tensor shape. This + guarantees identical results to sequential per-example evaluation. + + Batched mode (batched=True): forwards collated batches with optional GPU + composition. Faster but may produce tiny FP differences vs sequential eval + due to different cuBLAS kernel paths for different matrix dimensions. - merge > 1: pad+cat consecutive base batches on GPU before forwarding. - - split > 1: slice each group into chunks by example boundaries, - forward each chunk separately. + - split > 1: slice each group into chunks by example boundaries. Args: task_collated: list of (collated_batches, data) per task - pbar: optional progress bar, updated per forward pass (by number of examples) + pbar: optional progress bar, updated per example (or per batch chunk) task_labels: optional list of task names for pbar description updates on_task_done: optional callback(task_idx, correct_tensor) fired when a task completes - merge: number of consecutive base batches to compose per group (>= 1) - split: number of forward passes to split each group into (>= 1) - pad_token_id: token id used for padding when merging batches of different lengths + batched: if True, forward whole batches (faster, approximate). Default False (exact). + merge/split/pad_token_id: only used when batched=True Returns: list of correct tensors (one per task, on device) """ @@ -362,76 +367,103 @@ def _forward_all_cached(model, task_collated, device, pbar=None, task_labels=Non task_batches_remaining = list(task_batch_counts) current_task = -1 - buffer_ids = [] - buffer_info = [] - for i, (combined_ids, batch_meta, task_idx) in enumerate(flat_stream): - # Update pbar description on task transition - if task_idx != current_task: - current_task = task_idx - if pbar is not None and task_labels is not None: - pbar.set_description(task_labels[task_idx]) - buffer_ids.append(combined_ids) - buffer_info.append((batch_meta, task_idx)) - - # Accumulate until we have `merge` batches (or hit the end) - if len(buffer_ids) < merge and i < len(flat_stream) - 1: - continue - - # GPU compose: pad+cat if multiple batches, otherwise use as-is - if len(buffer_ids) == 1: - mega_ids = buffer_ids[0] - else: - max_len = max(t.shape[1] for t in buffer_ids) - parts = [] - for t in buffer_ids: - if t.shape[1] < max_len: - pad = torch.full((t.shape[0], max_len - t.shape[1]), pad_token_id, - dtype=t.dtype, device=t.device) - t = torch.cat([t, pad], dim=1) - parts.append(t) - mega_ids = torch.cat(parts, dim=0) - - # Flatten examples with row boundaries (for splitting) - examples = [] - row_bounds = [0] - for bm, tidx in buffer_info: - for idx, n, start_idxs, end_idxs, gold, task_type in bm: - examples.append((idx, n, start_idxs, end_idxs, gold, task_type, tidx)) - row_bounds.append(row_bounds[-1] + n) - - # Forward + score (with optional GPU split) - n_ex = len(examples) - chunk_size = -(-n_ex // split) # ceiling division - - for cs in range(0, n_ex, chunk_size): - ce = min(cs + chunk_size, n_ex) - chunk = examples[cs:ce] - chunk_ids = mega_ids[row_bounds[cs]:row_bounds[ce]] - - losses, predictions = forward_model(model, chunk_ids) + if not batched: + # Per-example forwarding: identical results to sequential evaluation. + # Each example's rows are trimmed to their original seq_len (= max(end_idxs)), + # removing collation padding so forward_model sees the same tensor shape as + # the sequential path. + for combined_ids, batch_meta, task_idx in flat_stream: + if task_idx != current_task: + current_task = task_idx + if pbar is not None and task_labels is not None: + pbar.set_description(task_labels[task_idx]) offset = 0 - for idx, n, start_idxs, end_idxs, gold, task_type, tidx in chunk: + for idx, n, start_idxs, end_idxs, gold, task_type in batch_meta: + seq_len = max(end_idxs) + example_ids = combined_ids[offset:offset+n, :seq_len] + losses, predictions = forward_model(model, example_ids) is_correct = check_result( - losses[offset:offset+n], predictions[offset:offset+n], - chunk_ids[offset:offset+n], + losses, predictions, example_ids, start_idxs, end_idxs, gold, task_type, ) - correct[tidx][idx] = float(is_correct) + correct[task_idx][idx] = float(is_correct) offset += n + if pbar is not None: - pbar.update(len(chunk)) + pbar.update(len(batch_meta)) + if on_task_done is not None: + task_batches_remaining[task_idx] -= 1 + if task_batches_remaining[task_idx] == 0: + on_task_done(task_idx, correct[task_idx]) + else: + # Batched forwarding with optional merge/split composition. + buffer_ids = [] + buffer_info = [] - # Fire callback for any tasks that just completed all their batches - if on_task_done is not None: + for i, (combined_ids, batch_meta, task_idx) in enumerate(flat_stream): + if task_idx != current_task: + current_task = task_idx + if pbar is not None and task_labels is not None: + pbar.set_description(task_labels[task_idx]) + buffer_ids.append(combined_ids) + buffer_info.append((batch_meta, task_idx)) + + if len(buffer_ids) < merge and i < len(flat_stream) - 1: + continue + + # GPU compose: pad+cat if multiple batches, otherwise use as-is + if len(buffer_ids) == 1: + mega_ids = buffer_ids[0] + else: + max_len = max(t.shape[1] for t in buffer_ids) + parts = [] + for t in buffer_ids: + if t.shape[1] < max_len: + pad = torch.full((t.shape[0], max_len - t.shape[1]), pad_token_id, + dtype=t.dtype, device=t.device) + t = torch.cat([t, pad], dim=1) + parts.append(t) + mega_ids = torch.cat(parts, dim=0) + + examples = [] + row_bounds = [0] for bm, tidx in buffer_info: - task_batches_remaining[tidx] -= 1 - if task_batches_remaining[tidx] == 0: - on_task_done(tidx, correct[tidx]) + for idx, n, start_idxs, end_idxs, gold, task_type in bm: + examples.append((idx, n, start_idxs, end_idxs, gold, task_type, tidx)) + row_bounds.append(row_bounds[-1] + n) - buffer_ids.clear() - buffer_info.clear() + n_ex = len(examples) + chunk_size = -(-n_ex // split) + + for cs in range(0, n_ex, chunk_size): + ce = min(cs + chunk_size, n_ex) + chunk = examples[cs:ce] + chunk_ids = mega_ids[row_bounds[cs]:row_bounds[ce]] + + losses, predictions = forward_model(model, chunk_ids) + + offset = 0 + for idx, n, start_idxs, end_idxs, gold, task_type, tidx in chunk: + is_correct = check_result( + losses[offset:offset+n], predictions[offset:offset+n], + chunk_ids[offset:offset+n], + start_idxs, end_idxs, gold, task_type, + ) + correct[tidx][idx] = float(is_correct) + offset += n + if pbar is not None: + pbar.update(len(chunk)) + + if on_task_done is not None: + for bm, tidx in buffer_info: + task_batches_remaining[tidx] -= 1 + if task_batches_remaining[tidx] == 0: + on_task_done(tidx, correct[tidx]) + + buffer_ids.clear() + buffer_info.clear() return correct diff --git a/scripts/base_eval.py b/scripts/base_eval.py index dcee8dac..19b5ee4d 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -244,12 +244,15 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1): _print_task_result(tidx, correct.mean().item()) if cached_run: - # Continuous pipeline: all tasks in one GPU stream, results printed per-task as they complete + # Continuous pipeline: all tasks in one GPU stream, results printed per-task as they complete. + # Always use base batch size (merge=1/split=1) to guarantee identical results — + # different batch dimensions trigger different cuBLAS kernels with different FP rounding. t0 = time.time() task_collated = [(_batch_cache[label], data) for label, _, data in task_inputs] correct_list = _forward_all_cached( model, task_collated, device, pbar=pbar, task_labels=task_labels, on_task_done=_on_task_done if world_size == 1 else None, + batched=True, ) elapsed_total = time.time() - t0 pbar.close() diff --git a/scripts/bench_core_eval.py b/scripts/bench_core_eval.py index 1003aee7..a69425b6 100644 --- a/scripts/bench_core_eval.py +++ b/scripts/bench_core_eval.py @@ -271,17 +271,19 @@ def bench_new_first(model, tokenizer, task_inputs, device, batch_size, queue_siz def bench_new_cached(model, task_inputs, device, collated_cache, pbar=None, - merge=1, split=1, pad_token_id=0): + batched=False, merge=1, split=1, pad_token_id=0): """Benchmark new batched evaluation (cached run, forward only). Uses continuous pipeline across all tasks to eliminate inter-task stalls. - merge/split control GPU-side composition: merge > 1 cats batches, split > 1 slices them.""" + batched=False (default): per-example forwarding, identical to sequential. + batched=True: GPU composition with merge/split for speed experiments.""" import torch.distributed as dist world_size = dist.get_world_size() if dist.is_initialized() else 1 sync_cuda() t0 = time.time() task_collated = [(collated_cache[label], data) for label, _, data in task_inputs] correct_list = _forward_all_cached(model, task_collated, device, pbar=pbar, - merge=merge, split=split, pad_token_id=pad_token_id) + batched=batched, merge=merge, split=split, + pad_token_id=pad_token_id) sync_cuda() elapsed = time.time() - t0 results = {} @@ -477,8 +479,8 @@ def main(): with autocast_ctx: t, cached_results = bench_new_cached(model, task_inputs, device, gpu_collated, - pbar=inner_pbar, merge=merge, split=split, - pad_token_id=pad_id) + pbar=inner_pbar, batched=True, + merge=merge, split=split, pad_token_id=pad_id) outer_pbar.write(f" batch_size={bs:>3}: {t:.2f}s ({total_examples / t:.1f} examples/s)") outer_pbar.update(1) @@ -493,8 +495,14 @@ def main(): print0("") print0(f" Best: batch_size={best_cached_params} -> {best_cached_time:.2f}s ({total_examples / best_cached_time:.1f} examples/s)") - if old_results is not None: - verify_results(old_results, best_cached_results, label="new-cached") + # Verify with per-example forwarding (identical to sequential — must match old) + inner_pbar = tqdm(total=total_examples, desc="Verifying", leave=False) + with autocast_ctx: + _, exact_results = bench_new_cached(model, task_inputs, device, gpu_collated, pbar=inner_pbar) + inner_pbar.close() + ref_results = old_results or (best_first_results if best_time is not None else None) + if ref_results is not None: + verify_results(ref_results, exact_results, label="new-cached(per-example)") print0("") # ---- Summary ----