nanochat/tests/test_engine.py

198 lines
6.2 KiB
Python

"""
Test Engine class. Example run:
python -m pytest tests/test_engine.py -v
"""
import torch
from nanochat.engine import KVCache, Engine
from dataclasses import dataclass
# -----------------------------------------------------------------------------
# Mock classes for testing Engine without loading a real model
@dataclass
class MockConfig:
"""Minimal config for Engine tests."""
n_kv_head: int = 4
n_head: int = 4
n_embd: int = 64
n_layer: int = 2
sequence_len: int = 128
class MockModel:
"""
Mock model that returns uniform logits over the vocab.
This ensures that with temperature > 0, different samples should
(with very high probability) produce different tokens.
"""
def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens
self.vocab_size = vocab_size
self.config = MockConfig()
self._device = "cpu"
def get_device(self):
return self._device
def forward(self, ids, kv_cache=None):
"""Return uniform logits so sampling is spread across vocab."""
B, T = ids.shape
# With FA3, flash_attn_with_kvcache updates cache in-place and we advance position
if kv_cache is not None:
kv_cache.advance(T)
# Uniform logits -> equal probability for all tokens
logits = torch.zeros(B, T, self.vocab_size)
return logits
class ByteTokenizer:
"""
Simple byte-level tokenizer for testing.
Tokens 0-255 are raw bytes, 256+ are special tokens.
"""
def __init__(self):
# Special tokens start at 256
self._special_tokens = {
"<|python_start|>": 256,
"<|python_end|>": 257,
"<|output_start|>": 258,
"<|output_end|>": 259,
"<|assistant_end|>": 260,
"<|bos|>": 261,
}
self._bos = 261
def encode_special(self, s):
return self._special_tokens[s]
def get_bos_token_id(self):
return self._bos
def encode(self, s, prepend=None):
tokens = list(s.encode("utf-8")) # bytes 0-255
if prepend is not None:
tokens = [prepend] + tokens
return tokens
def decode(self, tokens):
# Filter out special tokens before decoding
byte_tokens = [t for t in tokens if t < 256]
return bytes(byte_tokens).decode("utf-8", errors="replace")
def test_kv_cache_basic():
"""Test basic KVCache functionality for FA3."""
batch_size = 2
num_heads = 3
seq_len = 64
head_dim = 5
num_layers = 6
kv_cache = KVCache(
batch_size=batch_size,
num_heads=num_heads,
seq_len=seq_len,
head_dim=head_dim,
num_layers=num_layers,
device="cpu",
)
# Check initial state
assert kv_cache.get_pos() == 0
assert kv_cache.k_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim)
assert kv_cache.v_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim)
# Test advance
kv_cache.advance(10)
assert kv_cache.get_pos() == 10
kv_cache.advance(5)
assert kv_cache.get_pos() == 15
# Test reset
kv_cache.reset()
assert kv_cache.get_pos() == 0
# Test get_layer_cache returns correct views
k_layer0, v_layer0 = kv_cache.get_layer_cache(0)
assert k_layer0.shape == (batch_size, seq_len, num_heads, head_dim)
assert v_layer0.shape == (batch_size, seq_len, num_heads, head_dim)
def test_kv_cache_prefill():
"""Test KVCache.prefill() copies data correctly."""
batch_size = 1
num_heads = 4
head_dim = 8
num_layers = 2
# Create source cache and advance it
src_cache = KVCache(
batch_size=batch_size, num_heads=num_heads, seq_len=32,
head_dim=head_dim, num_layers=num_layers, device="cpu",
)
# Write some data to source cache
src_cache.k_cache[0, 0, :16, :, :] = 1.0
src_cache.v_cache[0, 0, :16, :, :] = 2.0
src_cache.advance(16)
# Create destination cache with larger seq_len
dst_cache = KVCache(
batch_size=batch_size, num_heads=num_heads, seq_len=64,
head_dim=head_dim, num_layers=num_layers, device="cpu",
)
# Prefill
dst_cache.prefill(src_cache)
# Check position was copied
assert dst_cache.get_pos() == 16
# Check data was copied
assert (dst_cache.k_cache[0, 0, :16, :, :] == 1.0).all()
assert (dst_cache.v_cache[0, 0, :16, :, :] == 2.0).all()
def test_multi_sample_first_token_diversity():
"""
Test that when generating multiple samples, each sample gets an independently
sampled first token (not a broadcast of the same token to all rows).
Previously, the first token after prefill was sampled once and broadcast to all
rows, causing all samples to start identically. The fix expands the prefill logits
to num_samples and samples independently for each row.
With uniform logits over 262 tokens and 16 samples, the probability that all
samples independently pick the same token is (1/262)^15 ≈ 10^-36. So if they're
all identical, it indicates tokens are being broadcast instead of independently sampled.
"""
model = MockModel(vocab_size=262)
tokenizer = ByteTokenizer()
engine = Engine(model, tokenizer)
# Generate 16 samples with temperature=1.0 (stochastic sampling)
prompt_tokens = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
num_samples = 16
# Collect the first generated token from each sample
first_tokens = []
gen = engine.generate(
prompt_tokens,
num_samples=num_samples,
max_tokens=1, # We only need the first token
temperature=1.0,
seed=42,
)
for token_column, token_masks in gen:
first_tokens = token_column # This is the first (and only) yield
# With uniform distribution and 16 samples, they should NOT all be identical
# If they are all identical, the bug exists (broadcasting instead of sampling)
unique_tokens = set(first_tokens)
assert len(unique_tokens) > 1, (
f"All {num_samples} samples got the same first token ({first_tokens[0]}). "
f"With uniform logits, this is statistically impossible (~10^-36 probability) "
f"unless tokens are being broadcast instead of independently sampled."
)