From 790f3be65cdf34acf7ced3b2bdd9a259e974c52d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bar=C4=B1=C5=9F=20=C3=96zmen?= Date: Thu, 18 Dec 2025 19:17:59 +0300 Subject: [PATCH 01/10] add rust batch encode as a faster option over encode --- rustbpe/src/lib.rs | 16 +++++++++ tests/test_rustbpe.py | 81 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) diff --git a/rustbpe/src/lib.rs b/rustbpe/src/lib.rs index 273d7f2..f9c8494 100644 --- a/rustbpe/src/lib.rs +++ b/rustbpe/src/lib.rs @@ -465,6 +465,22 @@ impl Tokenizer { all_ids } + + /// Encode multiple texts in parallel using rayon. + /// Returns a list of token ID vectors, one per input text. + #[pyo3(signature = (texts))] + #[pyo3(text_signature = "(self, texts)")] + pub fn batch_encode(&self, py: Python<'_>, texts: Vec) -> PyResult>> { + // Release Python GIL and encode in parallel using rayon + let results = py.allow_threads(|| { + texts + .par_iter() + .map(|text| self.encode(text)) + .collect::>>() + }); + + Ok(results) + } } #[pymodule] diff --git a/tests/test_rustbpe.py b/tests/test_rustbpe.py index aca67fc..482ea20 100644 --- a/tests/test_rustbpe.py +++ b/tests/test_rustbpe.py @@ -633,3 +633,84 @@ def test_interface(enwik8_small): ids_reloaded = tok_reloaded.encode(encode_text) assert ids_reloaded == ids, "Reloaded tokenizer should produce same results" print("✅ Save/load through temporary directory OK") + + +def test_batch_encode_correctness(enwik8_small): + """Quick correctness test for batch_encode()""" + text = enwik8_small + vocab_size = 512 + + tokenizer = rustbpe.Tokenizer() + tokenizer.train_from_iterator([text], vocab_size) + + # Test with various batch sizes and edge cases + test_texts = [ + "Hello world", + "The quick brown fox", + "jumps over the lazy dog", + "", # empty string + "a", # single char + ] + + # Compare batch vs individual encoding + individual = [tokenizer.encode(t) for t in test_texts] + batched = tokenizer.batch_encode(test_texts) + + assert individual == batched, "Batch encoding should match individual encoding" + print("✅ batch_encode() correctness verified") + + +@pytest.mark.slow +def test_batch_encode_performance(enwik8_large): + """ + Benchmark batch_encode() vs sequential encode() loop. + Demonstrates parallelization speedup. + """ + # Setup + text = enwik8_large # 10MB dataset + vocab_size = 2048 + + # Train tokenizer + print("\nTraining tokenizer...") + tokenizer = rustbpe.Tokenizer() + tokenizer.train_from_iterator([text], vocab_size) + + # Create test batch: split text into chunks + chunk_size = 50_000 # ~50KB per chunk + chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)] + chunks = chunks[:20] # Use first 20 chunks (~1MB total) + + print(f"\nBatch encoding benchmark:") + print(f" Number of texts: {len(chunks)}") + print(f" Avg text length: {sum(len(c) for c in chunks) / len(chunks):.0f} chars") + + # Benchmark 1: Sequential encoding (baseline) + print("\n [1/3] Sequential encode() loop...") + sequential_results, sequential_time = time_function( + lambda: [tokenizer.encode(chunk) for chunk in chunks] + ) + print(f" Time: {sequential_time:.4f}s") + + # Benchmark 2: Parallel batch_encode() + print(" [2/3] Parallel batch_encode()...") + batch_results, batch_time = time_function( + tokenizer.batch_encode, chunks + ) + print(f" Time: {batch_time:.4f}s") + + # Verify correctness + print(" [3/3] Verifying correctness...") + assert len(batch_results) == len(sequential_results), "Result count mismatch" + for i, (seq, batch) in enumerate(zip(sequential_results, batch_results)): + assert seq == batch, f"Mismatch at index {i}" + print(" ✓ All results match") + + # Report speedup + speedup = sequential_time / batch_time + print(f"\n Performance Results:") + print(f" Sequential: {sequential_time:.4f}s") + print(f" Batch: {batch_time:.4f}s") + print(f" Speedup: {speedup:.2f}x") + + # Assert meaningful speedup (at least 1.5x on multi-core) + assert speedup > 1.5, f"Expected >1.5x speedup, got {speedup:.2f}x" From 92c6654b9573362777daa44236b2269a75acf9b4 Mon Sep 17 00:00:00 2001 From: duwenjie Date: Sun, 21 Dec 2025 15:07:04 +0800 Subject: [PATCH 02/10] bugfix save and load ckpt from model_tag dir --- scripts/base_eval.py | 21 +++++++++++---------- scripts/chat_sft.py | 4 ++-- scripts/mid_train.py | 2 +- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/scripts/base_eval.py b/scripts/base_eval.py index 3663538..f6070c4 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -27,6 +27,14 @@ from nanochat.tokenizer import HuggingFaceTokenizer from nanochat.checkpoint_manager import load_model from nanochat.core_eval import evaluate_task +# Configuration +hf_path = None # optional HuggingFace model path to evaluate +max_per_task = -1 # max examples per task to evaluate (-1 = disable) +model_tag = None # optional model tag for the output directory name +model_step = None # optional model step for the output directory name +device_type = "" # cuda|cpu|mps (empty => autodetect) +exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file + # ----------------------------------------------------------------------------- # nanochat specific function dealing with I/O etc. @@ -145,34 +153,27 @@ def load_hf_model(hf_path: str, device): # ----------------------------------------------------------------------------- def main(): - import argparse - parser = argparse.ArgumentParser() - parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate') - parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)') - args = parser.parse_args() - # distributed / precision setup device_type = autodetect_device_type() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() # Load model and tokenizer from command line or from file system - if args.hf_path is not None: + if hf_path is not None: # atm assume that if a path is given, it's a huggingface model path - hf_path = args.hf_path print0(f"Loading huggingface model from: {hf_path}") model, tokenizer = load_hf_model(hf_path, device) model_name = hf_path # just for logging model_slug = hf_path.replace("/", "-") # for the output csv file else: # load a local model from the file system - model, tokenizer, meta = load_model("base", device, phase="eval") + model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step) model_name = f"base_model (step {meta['step']})" # just for logging model_slug = f"base_model_{meta['step']:06d}" # for the output csv file # Evaluate the model with autocast_ctx: - out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task) + out = evaluate_model(model, tokenizer, device, max_per_task=max_per_task) # Write out the results to a csv file core_metric = None diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index f93a6e6..bb455a8 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -250,8 +250,8 @@ for step in range(num_iterations): if master_process: base_dir = get_base_dir() depth = model.config.n_layer - model_tag = f"d{depth}" # base the model tag on the depth of the base model - checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag) + output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12 + checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname) model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer save_checkpoint( checkpoint_dir, diff --git a/scripts/mid_train.py b/scripts/mid_train.py index dd0768c..d817a40 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -207,7 +207,7 @@ while True: # save checkpoint at the end of the run (only on master process) if master_process and last_step and not dry_run: - output_dirname = f"d{depth}" # e.g. d12 + output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12 checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname) save_checkpoint( checkpoint_dir, From 78400491899569dfaa82dc8bd6a1c9299abc2bae Mon Sep 17 00:00:00 2001 From: DU Wenjie Date: Fri, 26 Dec 2025 17:29:08 +0800 Subject: [PATCH 03/10] bugfix keep same args style in scripts/base_eval.py --- scripts/base_eval.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/scripts/base_eval.py b/scripts/base_eval.py index f6070c4..1d680a0 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -27,14 +27,6 @@ from nanochat.tokenizer import HuggingFaceTokenizer from nanochat.checkpoint_manager import load_model from nanochat.core_eval import evaluate_task -# Configuration -hf_path = None # optional HuggingFace model path to evaluate -max_per_task = -1 # max examples per task to evaluate (-1 = disable) -model_tag = None # optional model tag for the output directory name -model_step = None # optional model step for the output directory name -device_type = "" # cuda|cpu|mps (empty => autodetect) -exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file - # ----------------------------------------------------------------------------- # nanochat specific function dealing with I/O etc. @@ -153,27 +145,36 @@ def load_hf_model(hf_path: str, device): # ----------------------------------------------------------------------------- def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate') + parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)') + parser.add_argument('--model_tag', type=str, default=None, help='optional model tag for the output directory name') + parser.add_argument('--model_step', type=str, default=None, help='optional model step for the output directory name') + args = parser.parse_args() + # distributed / precision setup device_type = autodetect_device_type() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() # Load model and tokenizer from command line or from file system - if hf_path is not None: + if args.hf_path is not None: # atm assume that if a path is given, it's a huggingface model path + hf_path = args.hf_path print0(f"Loading huggingface model from: {hf_path}") model, tokenizer = load_hf_model(hf_path, device) model_name = hf_path # just for logging model_slug = hf_path.replace("/", "-") # for the output csv file else: # load a local model from the file system - model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step) + model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step) model_name = f"base_model (step {meta['step']})" # just for logging model_slug = f"base_model_{meta['step']:06d}" # for the output csv file # Evaluate the model with autocast_ctx: - out = evaluate_model(model, tokenizer, device, max_per_task=max_per_task) + out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task) # Write out the results to a csv file core_metric = None From ea4229851b6109b5b47aa7cfd467d20815947453 Mon Sep 17 00:00:00 2001 From: DU Wenjie Date: Fri, 26 Dec 2025 17:41:57 +0800 Subject: [PATCH 04/10] bugfix --- scripts/base_eval.py | 6 +++--- scripts/chat_rl.py | 8 +++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/scripts/base_eval.py b/scripts/base_eval.py index 1d680a0..bd83ff3 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -149,8 +149,8 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate') parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)') - parser.add_argument('--model_tag', type=str, default=None, help='optional model tag for the output directory name') - parser.add_argument('--model_step', type=str, default=None, help='optional model step for the output directory name') + parser.add_argument('--model-tag', type=str, default=None, help='optional model tag for the output directory name') + parser.add_argument('--step', type=str, default=None, help='optional model step for the output directory name') args = parser.parse_args() # distributed / precision setup @@ -168,7 +168,7 @@ def main(): model_slug = hf_path.replace("/", "-") # for the output csv file else: # load a local model from the file system - model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step) + model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.step) model_name = f"base_model (step {meta['step']})" # just for logging model_slug = f"base_model_{meta['step']:06d}" # for the output csv file diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index bc78e79..e5c8d3f 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -31,6 +31,8 @@ from tasks.gsm8k import GSM8K # RL hyperparameters run = "dummy" # wandb run name source = "sft" # mid|sft +model_tag = None # model tag to load the model from (base model or midtrained model) +step = None # step to load the model from (base model or midtrained model) dtype = "bfloat16" device_batch_size = 8 # no forward pass will go above this to not OOM examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!) @@ -64,7 +66,7 @@ use_dummy_wandb = run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config) # Init model and tokenizer -model, tokenizer, meta = load_model(source, device, phase="eval") +model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=model_tag, step=step) engine = Engine(model, tokenizer) # for sampling rollouts # ----------------------------------------------------------------------------- @@ -307,8 +309,8 @@ for step in range(num_steps): if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1): base_dir = get_base_dir() depth = model.config.n_layer - model_tag = f"d{depth}" # base the model tag on the depth of the base model - checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", model_tag) + output_dirname = model_tag if model_tag else f"d{depth}" # base the model tag on the depth of the base model + checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname) model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer save_checkpoint( checkpoint_dir, From 49389ecaa88d419c955aff05f725da0e9f70a7fb Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sat, 27 Dec 2025 22:03:06 +0000 Subject: [PATCH 05/10] fix tf32 warning for deprecated api use --- nanochat/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanochat/common.py b/nanochat/common.py index 8f36f94..ad0fb69 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -158,7 +158,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps # Precision if device_type == "cuda": - torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls + torch.backends.cuda.matmul.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() From e1770a3061df8064b95422febb8deff2b75c419a Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sat, 27 Dec 2025 23:07:48 +0000 Subject: [PATCH 06/10] remove spurious cast, gets compiled away anyway but it's confusing people --- nanochat/gpt.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 9a80c7c..69899ee 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -41,12 +41,10 @@ def norm(x): def apply_rotary_emb(x, cos, sin): assert x.ndim == 4 # multihead attention d = x.shape[3] // 2 - x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves + x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves y1 = x1 * cos + x2 * sin # rotate pairs of dims y2 = x1 * (-sin) + x2 * cos - out = torch.cat([y1, y2], 3) # re-assemble - out = out.to(x.dtype) # ensure input/output dtypes match - return out + return torch.cat([y1, y2], 3) class CausalSelfAttention(nn.Module): def __init__(self, config, layer_idx): From 2874eda59accec0cec8cd7da368db6906a0041f7 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 28 Dec 2025 03:32:46 +0000 Subject: [PATCH 07/10] update to new os env var to get rid of deprecation warning --- scripts/base_train.py | 2 +- scripts/chat_sft.py | 2 +- scripts/mid_train.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index 72ee147..afa3b7a 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -12,7 +12,7 @@ python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 - """ import os -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import time from contextlib import nullcontext diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index f93a6e6..1d14187 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -10,7 +10,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft """ import os -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import wandb import torch diff --git a/scripts/mid_train.py b/scripts/mid_train.py index dd0768c..848c7e7 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -11,7 +11,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_ from collections import deque import os -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import time import wandb import torch From 91d76cc690ac35a253651e886a4f0b34d745a232 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 28 Dec 2025 04:10:49 +0000 Subject: [PATCH 08/10] Replace speedup assertion with warning in batch_encode test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Performance varies by machine and load, making hard assertions flaky. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- tests/test_rustbpe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_rustbpe.py b/tests/test_rustbpe.py index 482ea20..437134f 100644 --- a/tests/test_rustbpe.py +++ b/tests/test_rustbpe.py @@ -21,6 +21,7 @@ python -m pytest tests/test_rustbpe.py -v -s import regex as re from collections import Counter, defaultdict import time +import warnings import rustbpe import tiktoken import pytest @@ -712,5 +713,6 @@ def test_batch_encode_performance(enwik8_large): print(f" Batch: {batch_time:.4f}s") print(f" Speedup: {speedup:.2f}x") - # Assert meaningful speedup (at least 1.5x on multi-core) - assert speedup > 1.5, f"Expected >1.5x speedup, got {speedup:.2f}x" + # Warn if speedup is low (can vary by machine/load) + if speedup < 1.5: + warnings.warn(f"batch_encode() speedup was only {speedup:.2f}x (expected >1.5x)") From 2f2d7ab80cd07a8b5fab9feebcb185fd3ca37339 Mon Sep 17 00:00:00 2001 From: Dipesh Babu <59379458+dipeshbabu@users.noreply.github.com> Date: Sat, 27 Dec 2025 23:27:40 -0500 Subject: [PATCH 09/10] fix: safe DDP cleanup (check initialized PG, not just env) (#256) --- nanochat/common.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/nanochat/common.py b/nanochat/common.py index ad0fb69..22559ce 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -113,12 +113,24 @@ def print_banner(): """ print0(banner) -def is_ddp(): - # TODO is there a proper way - return int(os.environ.get('RANK', -1)) != -1 +def is_ddp_requested() -> bool: + """ + True if launched by torchrun (env present), even before init. + Used to decide whether we *should* initialize a PG. + """ + return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE")) + +def is_ddp_initialized() -> bool: + """ + True if torch.distributed is available and the process group is initialized. + Used at cleanup to avoid destroying a non-existent PG. + """ + return dist.is_available() and dist.is_initialized() def get_dist_info(): - if is_ddp(): + if is_ddp_requested(): + # We rely on torchrun's env to decide if we SHOULD init. + # (Initialization itself happens in compute init.) assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) ddp_rank = int(os.environ['RANK']) ddp_local_rank = int(os.environ['LOCAL_RANK']) @@ -161,8 +173,8 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps torch.backends.cuda.matmul.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA - ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() - if ddp and device_type == "cuda": + is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() + if is_ddp_requested and device_type == "cuda": device = torch.device("cuda", ddp_local_rank) torch.cuda.set_device(device) # make "cuda" default to this device dist.init_process_group(backend="nccl", device_id=device) @@ -173,11 +185,11 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps if ddp_rank == 0: logger.info(f"Distributed world size: {ddp_world_size}") - return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device + return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device def compute_cleanup(): """Companion function to compute_init, to clean things up before script exit""" - if is_ddp(): + if is_ddp_initialized(): dist.destroy_process_group() class DummyWandb: From 8f979a8bdab491c4c152ce5c87f90c2ec31d0845 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 28 Dec 2025 04:52:13 +0000 Subject: [PATCH 10/10] fix: sample first token independently for each row in multi-sample generation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, when generating multiple samples (num_samples > 1), the first token after prefill was sampled once and broadcast to all rows, causing all samples to start identically. Now the prefill logits are expanded to num_samples and sampled independently for each row. Also simplified the generation loop by moving the forward pass to the end of the loop, eliminating the first_iteration flag and if/else branching. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- nanochat/engine.py | 36 +++++-------- tests/test_engine.py | 123 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 135 insertions(+), 24 deletions(-) diff --git a/nanochat/engine.py b/nanochat/engine.py index dc43faf..49b10b1 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -19,7 +19,7 @@ from contextlib import contextmanager from collections import deque from nanochat.common import compute_init, autodetect_device_type from nanochat.checkpoint_manager import load_model -from contextlib import nullcontext +from contextlib import nullcontext # ----------------------------------------------------------------------------- # Calculator tool helpers @@ -107,23 +107,23 @@ class KVCache: # 1) validate the shapes assert self.kv_cache is None, "Cannot prefill a non-empty KV cache" assert other.kv_cache is not None, "Cannot prefill with a None KV cache" - + # Extract dimensions explicitly self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape - + # Validate dimensions assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}" assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}" assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}" assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}" - + # Batch size can be expanded (other can be 1, self can be larger) assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)" - + # Sequence length: self must be longer than other assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}" - + # 2) initialize the cache dtype, device = other.kv_cache.dtype, other.kv_cache.device self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device) @@ -223,9 +223,7 @@ class Engine: ) ids = torch.tensor([tokens], dtype=torch.long, device=device) logits = self.model.forward(ids, kv_cache=kv_cache_prefill) - logits = logits[:, -1, :] - next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) - sampled_tokens = next_ids[:, 0].tolist() + logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size) # 2) Replicate the KV cache for each sample/row kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len @@ -242,7 +240,6 @@ class Engine: # 4) Main generation loop num_generated = 0 - first_iteration = True while True: # Stop condition: we've reached max tokens if max_tokens is not None and num_generated >= max_tokens: @@ -251,18 +248,9 @@ class Engine: if all(state.completed for state in row_states): break - # Get sampled tokens - either from prefill or from forward pass - if first_iteration: - # Use the tokens we already sampled from prefill - sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows - # TODO: we should sample a token for each row instead of broadcasting - first_iteration = False - else: - # Forward the model and get the next token for each row - logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size) - logits = logits[:, -1, :] # (B, vocab_size) at last time step - next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) - sampled_tokens = next_ids[:, 0].tolist() + # Sample the next token for each row + next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) + sampled_tokens = next_ids[:, 0].tolist() # Process each row: choose the next token, update state, optional tool use token_column = [] # contains the next token id along each row @@ -299,8 +287,10 @@ class Engine: # Yield the token column yield token_column, token_masks num_generated += 1 - # Prepare ids for next iteration + + # Prepare logits for next iteration ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1) + logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size) def generate_batch(self, tokens, num_samples=1, **kwargs): """ diff --git a/tests/test_engine.py b/tests/test_engine.py index 7403b36..683f89b 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -5,7 +5,85 @@ python -m pytest tests/test_engine.py -v """ import torch -from nanochat.engine import KVCache +from nanochat.engine import KVCache, Engine +from dataclasses import dataclass + + +# ----------------------------------------------------------------------------- +# Mock classes for testing Engine without loading a real model + +@dataclass +class MockConfig: + """Minimal config for Engine tests.""" + n_kv_head: int = 4 + n_head: int = 4 + n_embd: int = 64 + n_layer: int = 2 + sequence_len: int = 128 + + +class MockModel: + """ + Mock model that returns uniform logits over the vocab. + This ensures that with temperature > 0, different samples should + (with very high probability) produce different tokens. + """ + def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens + self.vocab_size = vocab_size + self.config = MockConfig() + self._device = "cpu" + + def get_device(self): + return self._device + + def forward(self, ids, kv_cache=None): + """Return uniform logits so sampling is spread across vocab.""" + B, T = ids.shape + # Simulate what a real transformer does: insert k,v into the cache for each layer + if kv_cache is not None: + head_dim = self.config.n_embd // self.config.n_head + for layer_idx in range(self.config.n_layer): + k = torch.zeros(B, self.config.n_kv_head, T, head_dim) + v = torch.zeros(B, self.config.n_kv_head, T, head_dim) + kv_cache.insert_kv(layer_idx, k, v) + # Uniform logits -> equal probability for all tokens + logits = torch.zeros(B, T, self.vocab_size) + return logits + + +class ByteTokenizer: + """ + Simple byte-level tokenizer for testing. + Tokens 0-255 are raw bytes, 256+ are special tokens. + """ + def __init__(self): + # Special tokens start at 256 + self._special_tokens = { + "<|python_start|>": 256, + "<|python_end|>": 257, + "<|output_start|>": 258, + "<|output_end|>": 259, + "<|assistant_end|>": 260, + "<|bos|>": 261, + } + self._bos = 261 + + def encode_special(self, s): + return self._special_tokens[s] + + def get_bos_token_id(self): + return self._bos + + def encode(self, s, prepend=None): + tokens = list(s.encode("utf-8")) # bytes 0-255 + if prepend is not None: + tokens = [prepend] + tokens + return tokens + + def decode(self, tokens): + # Filter out special tokens before decoding + byte_tokens = [t for t in tokens if t < 256] + return bytes(byte_tokens).decode("utf-8", errors="replace") def test_kv_cache_resize(): """ @@ -64,3 +142,46 @@ def test_kv_cache_resize(): original_v = original_cache[layer_idx, 1, :, :, token_idx, :] assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original" assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original" + + +def test_multi_sample_first_token_diversity(): + """ + Test that when generating multiple samples, each sample gets an independently + sampled first token (not a broadcast of the same token to all rows). + + Previously, the first token after prefill was sampled once and broadcast to all + rows, causing all samples to start identically. The fix expands the prefill logits + to num_samples and samples independently for each row. + + With uniform logits over 262 tokens and 16 samples, the probability that all + samples independently pick the same token is (1/262)^15 ≈ 10^-36. So if they're + all identical, it indicates tokens are being broadcast instead of independently sampled. + """ + model = MockModel(vocab_size=262) + tokenizer = ByteTokenizer() + engine = Engine(model, tokenizer) + + # Generate 16 samples with temperature=1.0 (stochastic sampling) + prompt_tokens = [261, 72, 101, 108, 108, 111] # + "Hello" + num_samples = 16 + + # Collect the first generated token from each sample + first_tokens = [] + gen = engine.generate( + prompt_tokens, + num_samples=num_samples, + max_tokens=1, # We only need the first token + temperature=1.0, + seed=42, + ) + for token_column, token_masks in gen: + first_tokens = token_column # This is the first (and only) yield + + # With uniform distribution and 16 samples, they should NOT all be identical + # If they are all identical, the bug exists (broadcasting instead of sampling) + unique_tokens = set(first_tokens) + assert len(unique_tokens) > 1, ( + f"All {num_samples} samples got the same first token ({first_tokens[0]}). " + f"With uniform logits, this is statistically impossible (~10^-36 probability) " + f"unless tokens are being broadcast instead of independently sampled." + )