""" Test Engine class. Example run: python -m pytest tests/test_engine.py -v """ import torch from nanochat.engine import KVCache, Engine from dataclasses import dataclass # ----------------------------------------------------------------------------- # Mock classes for testing Engine without loading a real model @dataclass class MockConfig: """Minimal config for Engine tests.""" n_kv_head: int = 4 n_head: int = 4 n_embd: int = 64 n_layer: int = 2 sequence_len: int = 128 class MockModel: """ Mock model that returns uniform logits over the vocab. This ensures that with temperature > 0, different samples should (with very high probability) produce different tokens. """ def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens self.vocab_size = vocab_size self.config = MockConfig() self._device = "cpu" def get_device(self): return self._device def forward(self, ids, kv_cache=None): """Return uniform logits so sampling is spread across vocab.""" B, T = ids.shape # Simulate what a real transformer does: insert k,v into the cache for each layer if kv_cache is not None: head_dim = self.config.n_embd // self.config.n_head for layer_idx in range(self.config.n_layer): k = torch.zeros(B, self.config.n_kv_head, T, head_dim) v = torch.zeros(B, self.config.n_kv_head, T, head_dim) kv_cache.insert_kv(layer_idx, k, v) # Uniform logits -> equal probability for all tokens logits = torch.zeros(B, T, self.vocab_size) return logits class ByteTokenizer: """ Simple byte-level tokenizer for testing. Tokens 0-255 are raw bytes, 256+ are special tokens. """ def __init__(self): # Special tokens start at 256 self._special_tokens = { "<|python_start|>": 256, "<|python_end|>": 257, "<|output_start|>": 258, "<|output_end|>": 259, "<|assistant_end|>": 260, "<|bos|>": 261, } self._bos = 261 def encode_special(self, s): return self._special_tokens[s] def get_bos_token_id(self): return self._bos def encode(self, s, prepend=None): tokens = list(s.encode("utf-8")) # bytes 0-255 if prepend is not None: tokens = [prepend] + tokens return tokens def decode(self, tokens): # Filter out special tokens before decoding byte_tokens = [t for t in tokens if t < 256] return bytes(byte_tokens).decode("utf-8", errors="replace") def test_kv_cache_resize(): """ The KV cache was not resized correctly, more information here: https://github.com/karpathy/nanochat/pull/186 This test reproduces the issue and will be merged alongside the fix. """ batch_size = 2 num_heads = 3 seq_len = 4 head_dim = 5 num_layers = 6 kv_cache = KVCache( batch_size=batch_size, num_heads=num_heads, seq_len=seq_len, head_dim=head_dim, num_layers=num_layers ) # Insert a single token with a distinct fill value to all layers def insert_token(token_idx): for layer_idx in range(num_layers): k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32) v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32) kv_cache.insert_kv(layer_idx, k, v) # Insert 4 tokens (fills the initial seq_len=4) for i in range(4): insert_token(i) # Record the original state of the cache original_cache = kv_cache.kv_cache.clone() original_seq_len = original_cache.shape[4] # Insert the 5th token, which will trigger a resize insert_token(4) # Verify that the cache actually resized new_seq_len = kv_cache.kv_cache.shape[4] assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}" # Verify that the original 4 tokens are still intact after resize for layer_idx in range(num_layers): for token_idx in range(4): # Check that resized cache matches expected values expected_k = float(token_idx) expected_v = float(token_idx * 100) actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :] actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :] assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}" assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}" # And that the original cache matches resized cache original_k = original_cache[layer_idx, 0, :, :, token_idx, :] original_v = original_cache[layer_idx, 1, :, :, token_idx, :] assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original" assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original" def test_multi_sample_first_token_diversity(): """ Test that when generating multiple samples, each sample gets an independently sampled first token (not a broadcast of the same token to all rows). Previously, the first token after prefill was sampled once and broadcast to all rows, causing all samples to start identically. The fix expands the prefill logits to num_samples and samples independently for each row. With uniform logits over 262 tokens and 16 samples, the probability that all samples independently pick the same token is (1/262)^15 ≈ 10^-36. So if they're all identical, it indicates tokens are being broadcast instead of independently sampled. """ model = MockModel(vocab_size=262) tokenizer = ByteTokenizer() engine = Engine(model, tokenizer) # Generate 16 samples with temperature=1.0 (stochastic sampling) prompt_tokens = [261, 72, 101, 108, 108, 111] # + "Hello" num_samples = 16 # Collect the first generated token from each sample first_tokens = [] gen = engine.generate( prompt_tokens, num_samples=num_samples, max_tokens=1, # We only need the first token temperature=1.0, seed=42, ) for token_column, token_masks in gen: first_tokens = token_column # This is the first (and only) yield # With uniform distribution and 16 samples, they should NOT all be identical # If they are all identical, the bug exists (broadcasting instead of sampling) unique_tokens = set(first_tokens) assert len(unique_tokens) > 1, ( f"All {num_samples} samples got the same first token ({first_tokens[0]}). " f"With uniform logits, this is statistically impossible (~10^-36 probability) " f"unless tokens are being broadcast instead of independently sampled." )