""" Test Engine class. Example run: python -m pytest tests/test_engine.py -v """ import torch from nanochat.engine import KVCache, Engine from dataclasses import dataclass # ----------------------------------------------------------------------------- # Mock classes for testing Engine without loading a real model @dataclass class MockConfig: """Minimal config for Engine tests.""" n_kv_head: int = 4 n_head: int = 4 n_embd: int = 64 n_layer: int = 2 sequence_len: int = 128 class MockModel: """ Mock model that returns uniform logits over the vocab. This ensures that with temperature > 0, different samples should (with very high probability) produce different tokens. """ def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens self.vocab_size = vocab_size self.config = MockConfig() self._device = "cpu" def get_device(self): return self._device def forward(self, ids, kv_cache=None): """Return uniform logits so sampling is spread across vocab.""" B, T = ids.shape # With FA3, flash_attn_with_kvcache updates cache in-place and we advance position if kv_cache is not None: kv_cache.advance(T) # Uniform logits -> equal probability for all tokens logits = torch.zeros(B, T, self.vocab_size) return logits class ByteTokenizer: """ Simple byte-level tokenizer for testing. Tokens 0-255 are raw bytes, 256+ are special tokens. """ def __init__(self): # Special tokens start at 256 self._special_tokens = { "<|python_start|>": 256, "<|python_end|>": 257, "<|output_start|>": 258, "<|output_end|>": 259, "<|assistant_end|>": 260, "<|bos|>": 261, } self._bos = 261 def encode_special(self, s): return self._special_tokens[s] def get_bos_token_id(self): return self._bos def encode(self, s, prepend=None): tokens = list(s.encode("utf-8")) # bytes 0-255 if prepend is not None: tokens = [prepend] + tokens return tokens def decode(self, tokens): # Filter out special tokens before decoding byte_tokens = [t for t in tokens if t < 256] return bytes(byte_tokens).decode("utf-8", errors="replace") def test_kv_cache_basic(): """Test basic KVCache functionality for FA3.""" batch_size = 2 num_heads = 3 seq_len = 64 head_dim = 5 num_layers = 6 kv_cache = KVCache( batch_size=batch_size, num_heads=num_heads, seq_len=seq_len, head_dim=head_dim, num_layers=num_layers, device="cpu", ) # Check initial state assert kv_cache.get_pos() == 0 assert kv_cache.k_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim) assert kv_cache.v_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim) # Test advance kv_cache.advance(10) assert kv_cache.get_pos() == 10 kv_cache.advance(5) assert kv_cache.get_pos() == 15 # Test reset kv_cache.reset() assert kv_cache.get_pos() == 0 # Test get_layer_cache returns correct views k_layer0, v_layer0 = kv_cache.get_layer_cache(0) assert k_layer0.shape == (batch_size, seq_len, num_heads, head_dim) assert v_layer0.shape == (batch_size, seq_len, num_heads, head_dim) def test_kv_cache_prefill(): """Test KVCache.prefill() copies data correctly.""" batch_size = 1 num_heads = 4 head_dim = 8 num_layers = 2 # Create source cache and advance it src_cache = KVCache( batch_size=batch_size, num_heads=num_heads, seq_len=32, head_dim=head_dim, num_layers=num_layers, device="cpu", ) # Write some data to source cache src_cache.k_cache[0, 0, :16, :, :] = 1.0 src_cache.v_cache[0, 0, :16, :, :] = 2.0 src_cache.advance(16) # Create destination cache with larger seq_len dst_cache = KVCache( batch_size=batch_size, num_heads=num_heads, seq_len=64, head_dim=head_dim, num_layers=num_layers, device="cpu", ) # Prefill dst_cache.prefill(src_cache) # Check position was copied assert dst_cache.get_pos() == 16 # Check data was copied assert (dst_cache.k_cache[0, 0, :16, :, :] == 1.0).all() assert (dst_cache.v_cache[0, 0, :16, :, :] == 2.0).all() def test_multi_sample_first_token_diversity(): """ Test that when generating multiple samples, each sample gets an independently sampled first token (not a broadcast of the same token to all rows). Previously, the first token after prefill was sampled once and broadcast to all rows, causing all samples to start identically. The fix expands the prefill logits to num_samples and samples independently for each row. With uniform logits over 262 tokens and 16 samples, the probability that all samples independently pick the same token is (1/262)^15 ≈ 10^-36. So if they're all identical, it indicates tokens are being broadcast instead of independently sampled. """ model = MockModel(vocab_size=262) tokenizer = ByteTokenizer() engine = Engine(model, tokenizer) # Generate 16 samples with temperature=1.0 (stochastic sampling) prompt_tokens = [261, 72, 101, 108, 108, 111] # + "Hello" num_samples = 16 # Collect the first generated token from each sample first_tokens = [] gen = engine.generate( prompt_tokens, num_samples=num_samples, max_tokens=1, # We only need the first token temperature=1.0, seed=42, ) for token_column, token_masks in gen: first_tokens = token_column # This is the first (and only) yield # With uniform distribution and 16 samples, they should NOT all be identical # If they are all identical, the bug exists (broadcasting instead of sampling) unique_tokens = set(first_tokens) assert len(unique_tokens) > 1, ( f"All {num_samples} samples got the same first token ({first_tokens[0]}). " f"With uniform logits, this is statistically impossible (~10^-36 probability) " f"unless tokens are being broadcast instead of independently sampled." ) def test_seed_reproducibility(): """Same seed must produce identical output.""" model = MockModel() engine = Engine(model, ByteTokenizer()) prompt = [261, 72, 101, 108, 108, 111] # + "Hello" for seed in [1, 42, 123, 999]: r1, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed) r2, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed) r3, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed) assert r1 == r2 == r3, "Same seed must produce identical output for the same prompt." def test_temperature_zero_determinism(): """Temperature=0 is deterministic regardless of seed.""" model = MockModel() engine = Engine(model, ByteTokenizer()) prompt = [261, 72, 101, 108, 108, 111] r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=1) r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=42) r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=123) assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed." def test_max_tokens_respected(): """Generation stops at max_tokens limit.""" model = MockModel() engine = Engine(model, ByteTokenizer()) prompt = [261, 72, 101, 108, 108, 111] for max_tokens in [1, 4, 16, 64]: results, _ = engine.generate_batch(prompt, max_tokens=max_tokens) num_generated_tokens = len(results[0]) - len(prompt) assert num_generated_tokens <= max_tokens, f"Generated {num_generated_tokens} tokens, expected max_tokens={max_tokens} or less." def test_num_samples_count(): """num_samples=N produces exactly N sequences.""" model = MockModel() engine = Engine(model, ByteTokenizer()) prompt = [261, 72, 101, 108, 108, 111] for num_samples in [1, 4, 16, 64]: results, _ = engine.generate_batch(prompt, num_samples=num_samples, max_tokens=3) assert len(results) == num_samples, f"Expected {num_samples} sequences from {num_samples} samples, got {len(results)}" def test_different_seeds_introduce_variation_when_temperature_nonzero(): """With temperature > 0, different seeds should introduce sampling variation.""" model = MockModel() engine = Engine(model, ByteTokenizer()) prompt = [261, 72, 101, 108, 108, 111] # + "Hello" outputs = set() for seed in [1, 42, 123, 999, 1000, 1001, 1002, 1003, 1004, 1005]: results, _ = engine.generate_batch( prompt, temperature=1.0, max_tokens=5, seed=seed, ) outputs.add(tuple(results[0])) # Sanity check: sampling actually introduces variation assert len(outputs) > 1, "All seeds produced the same output which is statistically highly improbable."