diff --git a/nanochat/engine.py b/nanochat/engine.py index dc43faf..49b10b1 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -19,7 +19,7 @@ from contextlib import contextmanager from collections import deque from nanochat.common import compute_init, autodetect_device_type from nanochat.checkpoint_manager import load_model -from contextlib import nullcontext +from contextlib import nullcontext # ----------------------------------------------------------------------------- # Calculator tool helpers @@ -107,23 +107,23 @@ class KVCache: # 1) validate the shapes assert self.kv_cache is None, "Cannot prefill a non-empty KV cache" assert other.kv_cache is not None, "Cannot prefill with a None KV cache" - + # Extract dimensions explicitly self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape - + # Validate dimensions assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}" assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}" assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}" assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}" - + # Batch size can be expanded (other can be 1, self can be larger) assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)" - + # Sequence length: self must be longer than other assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}" - + # 2) initialize the cache dtype, device = other.kv_cache.dtype, other.kv_cache.device self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device) @@ -223,9 +223,7 @@ class Engine: ) ids = torch.tensor([tokens], dtype=torch.long, device=device) logits = self.model.forward(ids, kv_cache=kv_cache_prefill) - logits = logits[:, -1, :] - next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) - sampled_tokens = next_ids[:, 0].tolist() + logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size) # 2) Replicate the KV cache for each sample/row kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len @@ -242,7 +240,6 @@ class Engine: # 4) Main generation loop num_generated = 0 - first_iteration = True while True: # Stop condition: we've reached max tokens if max_tokens is not None and num_generated >= max_tokens: @@ -251,18 +248,9 @@ class Engine: if all(state.completed for state in row_states): break - # Get sampled tokens - either from prefill or from forward pass - if first_iteration: - # Use the tokens we already sampled from prefill - sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows - # TODO: we should sample a token for each row instead of broadcasting - first_iteration = False - else: - # Forward the model and get the next token for each row - logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size) - logits = logits[:, -1, :] # (B, vocab_size) at last time step - next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) - sampled_tokens = next_ids[:, 0].tolist() + # Sample the next token for each row + next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) + sampled_tokens = next_ids[:, 0].tolist() # Process each row: choose the next token, update state, optional tool use token_column = [] # contains the next token id along each row @@ -299,8 +287,10 @@ class Engine: # Yield the token column yield token_column, token_masks num_generated += 1 - # Prepare ids for next iteration + + # Prepare logits for next iteration ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1) + logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size) def generate_batch(self, tokens, num_samples=1, **kwargs): """ diff --git a/tests/test_engine.py b/tests/test_engine.py index 7403b36..683f89b 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -5,7 +5,85 @@ python -m pytest tests/test_engine.py -v """ import torch -from nanochat.engine import KVCache +from nanochat.engine import KVCache, Engine +from dataclasses import dataclass + + +# ----------------------------------------------------------------------------- +# Mock classes for testing Engine without loading a real model + +@dataclass +class MockConfig: + """Minimal config for Engine tests.""" + n_kv_head: int = 4 + n_head: int = 4 + n_embd: int = 64 + n_layer: int = 2 + sequence_len: int = 128 + + +class MockModel: + """ + Mock model that returns uniform logits over the vocab. + This ensures that with temperature > 0, different samples should + (with very high probability) produce different tokens. + """ + def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens + self.vocab_size = vocab_size + self.config = MockConfig() + self._device = "cpu" + + def get_device(self): + return self._device + + def forward(self, ids, kv_cache=None): + """Return uniform logits so sampling is spread across vocab.""" + B, T = ids.shape + # Simulate what a real transformer does: insert k,v into the cache for each layer + if kv_cache is not None: + head_dim = self.config.n_embd // self.config.n_head + for layer_idx in range(self.config.n_layer): + k = torch.zeros(B, self.config.n_kv_head, T, head_dim) + v = torch.zeros(B, self.config.n_kv_head, T, head_dim) + kv_cache.insert_kv(layer_idx, k, v) + # Uniform logits -> equal probability for all tokens + logits = torch.zeros(B, T, self.vocab_size) + return logits + + +class ByteTokenizer: + """ + Simple byte-level tokenizer for testing. + Tokens 0-255 are raw bytes, 256+ are special tokens. + """ + def __init__(self): + # Special tokens start at 256 + self._special_tokens = { + "<|python_start|>": 256, + "<|python_end|>": 257, + "<|output_start|>": 258, + "<|output_end|>": 259, + "<|assistant_end|>": 260, + "<|bos|>": 261, + } + self._bos = 261 + + def encode_special(self, s): + return self._special_tokens[s] + + def get_bos_token_id(self): + return self._bos + + def encode(self, s, prepend=None): + tokens = list(s.encode("utf-8")) # bytes 0-255 + if prepend is not None: + tokens = [prepend] + tokens + return tokens + + def decode(self, tokens): + # Filter out special tokens before decoding + byte_tokens = [t for t in tokens if t < 256] + return bytes(byte_tokens).decode("utf-8", errors="replace") def test_kv_cache_resize(): """ @@ -64,3 +142,46 @@ def test_kv_cache_resize(): original_v = original_cache[layer_idx, 1, :, :, token_idx, :] assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original" assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original" + + +def test_multi_sample_first_token_diversity(): + """ + Test that when generating multiple samples, each sample gets an independently + sampled first token (not a broadcast of the same token to all rows). + + Previously, the first token after prefill was sampled once and broadcast to all + rows, causing all samples to start identically. The fix expands the prefill logits + to num_samples and samples independently for each row. + + With uniform logits over 262 tokens and 16 samples, the probability that all + samples independently pick the same token is (1/262)^15 ≈ 10^-36. So if they're + all identical, it indicates tokens are being broadcast instead of independently sampled. + """ + model = MockModel(vocab_size=262) + tokenizer = ByteTokenizer() + engine = Engine(model, tokenizer) + + # Generate 16 samples with temperature=1.0 (stochastic sampling) + prompt_tokens = [261, 72, 101, 108, 108, 111] # + "Hello" + num_samples = 16 + + # Collect the first generated token from each sample + first_tokens = [] + gen = engine.generate( + prompt_tokens, + num_samples=num_samples, + max_tokens=1, # We only need the first token + temperature=1.0, + seed=42, + ) + for token_column, token_masks in gen: + first_tokens = token_column # This is the first (and only) yield + + # With uniform distribution and 16 samples, they should NOT all be identical + # If they are all identical, the bug exists (broadcasting instead of sampling) + unique_tokens = set(first_tokens) + assert len(unique_tokens) > 1, ( + f"All {num_samples} samples got the same first token ({first_tokens[0]}). " + f"With uniform logits, this is statistically impossible (~10^-36 probability) " + f"unless tokens are being broadcast instead of independently sampled." + )