nanochat/tests/test_engine.py
Andrej Karpathy 8f979a8bda fix: sample first token independently for each row in multi-sample generation
Previously, when generating multiple samples (num_samples > 1), the first
token after prefill was sampled once and broadcast to all rows, causing
all samples to start identically. Now the prefill logits are expanded to
num_samples and sampled independently for each row.

Also simplified the generation loop by moving the forward pass to the end
of the loop, eliminating the first_iteration flag and if/else branching.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-28 04:52:13 +00:00

188 lines
7.0 KiB
Python

"""
Test Engine class. Example run:
python -m pytest tests/test_engine.py -v
"""
import torch
from nanochat.engine import KVCache, Engine
from dataclasses import dataclass
# -----------------------------------------------------------------------------
# Mock classes for testing Engine without loading a real model
@dataclass
class MockConfig:
"""Minimal config for Engine tests."""
n_kv_head: int = 4
n_head: int = 4
n_embd: int = 64
n_layer: int = 2
sequence_len: int = 128
class MockModel:
"""
Mock model that returns uniform logits over the vocab.
This ensures that with temperature > 0, different samples should
(with very high probability) produce different tokens.
"""
def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens
self.vocab_size = vocab_size
self.config = MockConfig()
self._device = "cpu"
def get_device(self):
return self._device
def forward(self, ids, kv_cache=None):
"""Return uniform logits so sampling is spread across vocab."""
B, T = ids.shape
# Simulate what a real transformer does: insert k,v into the cache for each layer
if kv_cache is not None:
head_dim = self.config.n_embd // self.config.n_head
for layer_idx in range(self.config.n_layer):
k = torch.zeros(B, self.config.n_kv_head, T, head_dim)
v = torch.zeros(B, self.config.n_kv_head, T, head_dim)
kv_cache.insert_kv(layer_idx, k, v)
# Uniform logits -> equal probability for all tokens
logits = torch.zeros(B, T, self.vocab_size)
return logits
class ByteTokenizer:
"""
Simple byte-level tokenizer for testing.
Tokens 0-255 are raw bytes, 256+ are special tokens.
"""
def __init__(self):
# Special tokens start at 256
self._special_tokens = {
"<|python_start|>": 256,
"<|python_end|>": 257,
"<|output_start|>": 258,
"<|output_end|>": 259,
"<|assistant_end|>": 260,
"<|bos|>": 261,
}
self._bos = 261
def encode_special(self, s):
return self._special_tokens[s]
def get_bos_token_id(self):
return self._bos
def encode(self, s, prepend=None):
tokens = list(s.encode("utf-8")) # bytes 0-255
if prepend is not None:
tokens = [prepend] + tokens
return tokens
def decode(self, tokens):
# Filter out special tokens before decoding
byte_tokens = [t for t in tokens if t < 256]
return bytes(byte_tokens).decode("utf-8", errors="replace")
def test_kv_cache_resize():
"""
The KV cache was not resized correctly, more information here:
https://github.com/karpathy/nanochat/pull/186
This test reproduces the issue and will be merged alongside the fix.
"""
batch_size = 2
num_heads = 3
seq_len = 4
head_dim = 5
num_layers = 6
kv_cache = KVCache(
batch_size=batch_size,
num_heads=num_heads,
seq_len=seq_len,
head_dim=head_dim,
num_layers=num_layers
)
# Insert a single token with a distinct fill value to all layers
def insert_token(token_idx):
for layer_idx in range(num_layers):
k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32)
v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32)
kv_cache.insert_kv(layer_idx, k, v)
# Insert 4 tokens (fills the initial seq_len=4)
for i in range(4):
insert_token(i)
# Record the original state of the cache
original_cache = kv_cache.kv_cache.clone()
original_seq_len = original_cache.shape[4]
# Insert the 5th token, which will trigger a resize
insert_token(4)
# Verify that the cache actually resized
new_seq_len = kv_cache.kv_cache.shape[4]
assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
# Verify that the original 4 tokens are still intact after resize
for layer_idx in range(num_layers):
for token_idx in range(4):
# Check that resized cache matches expected values
expected_k = float(token_idx)
expected_v = float(token_idx * 100)
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :]
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :]
assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
# And that the original cache matches resized cache
original_k = original_cache[layer_idx, 0, :, :, token_idx, :]
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
def test_multi_sample_first_token_diversity():
"""
Test that when generating multiple samples, each sample gets an independently
sampled first token (not a broadcast of the same token to all rows).
Previously, the first token after prefill was sampled once and broadcast to all
rows, causing all samples to start identically. The fix expands the prefill logits
to num_samples and samples independently for each row.
With uniform logits over 262 tokens and 16 samples, the probability that all
samples independently pick the same token is (1/262)^15 ≈ 10^-36. So if they're
all identical, it indicates tokens are being broadcast instead of independently sampled.
"""
model = MockModel(vocab_size=262)
tokenizer = ByteTokenizer()
engine = Engine(model, tokenizer)
# Generate 16 samples with temperature=1.0 (stochastic sampling)
prompt_tokens = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
num_samples = 16
# Collect the first generated token from each sample
first_tokens = []
gen = engine.generate(
prompt_tokens,
num_samples=num_samples,
max_tokens=1, # We only need the first token
temperature=1.0,
seed=42,
)
for token_column, token_masks in gen:
first_tokens = token_column # This is the first (and only) yield
# With uniform distribution and 16 samples, they should NOT all be identical
# If they are all identical, the bug exists (broadcasting instead of sampling)
unique_tokens = set(first_tokens)
assert len(unique_tokens) > 1, (
f"All {num_samples} samples got the same first token ({first_tokens[0]}). "
f"With uniform logits, this is statistically impossible (~10^-36 probability) "
f"unless tokens are being broadcast instead of independently sampled."
)