nanochat/tests/test_engine.py
Rimom Costa 44764ffff0 test: add comprehensive test suite with 66 passing tests
Add test coverage for all major components:
- GPT model: architecture, generation, MQA, rotary embeddings (19 tests)
- Inference engine: KV cache, sampling, tool use (17 tests)
- Optimizers: Muon and AdamW functionality (10 tests)
- Checkpoint management: save/load, metadata (5 tests)
- Data loading and utilities (13 tests)

docs: update README with test documentation and learning guide
- Add For Students section with structured learning path
- Document architectural decisions and key concepts
- Add test usage instructions
2025-10-13 19:18:30 +01:00

432 lines
12 KiB
Python

"""
Tests for the inference engine with KV cache and tool use.
Run with:
python -m pytest tests/test_engine.py -v -s --timeout=60
"""
import torch
import pytest
from nanochat.gpt import GPT, GPTConfig
from nanochat.engine import Engine, KVCache, use_calculator, sample_next_token
from nanochat.tokenizer import RustBPETokenizer
@pytest.fixture
def tiny_model():
"""Create a tiny model for testing."""
config = GPTConfig(
sequence_len=128,
vocab_size=256,
n_layer=2,
n_head=4,
n_kv_head=2,
n_embd=64,
)
model = GPT(config)
model.init_weights()
# Prepare for CPU testing
model = model.float()
model.cos = model.cos.bfloat16()
model.sin = model.sin.bfloat16()
model.eval()
return model
@pytest.fixture
def mock_tokenizer():
"""Create a mock tokenizer for testing."""
class MockTokenizer:
def encode(self, text):
"""Simple encode: just return char codes."""
if isinstance(text, str):
return [ord(c) % 256 for c in text]
elif isinstance(text, list):
return [self.encode(t) for t in text]
def decode(self, ids):
"""Simple decode: convert back to chars."""
return ''.join(chr(i) for i in ids)
def encode_special(self, token):
"""Encode special tokens to specific IDs."""
special_map = {
'<|python_start|>': 250,
'<|python_end|>': 251,
'<|output_start|>': 252,
'<|output_end|>': 253,
'<|assistant_end|>': 254,
'<|bos|>': 255,
}
return special_map.get(token, 0)
def get_bos_token_id(self):
return 255
return MockTokenizer()
def test_use_calculator():
"""Test calculator functionality."""
# Basic arithmetic
assert use_calculator("2+2") == 4
assert use_calculator("10-3") == 7
assert use_calculator("4*5") == 20
assert use_calculator("15/3") == 5
# With spaces
assert use_calculator("2 + 2") == 4
# Order of operations
assert use_calculator("2+3*4") == 14
# Parentheses
assert use_calculator("(2+3)*4") == 20
# Decimals
result = use_calculator("10.5+2.5")
assert result == 13.0
# Commas should be removed
result = use_calculator("1,000+500")
assert result == 1500
def test_use_calculator_invalid():
"""Test calculator with invalid inputs."""
# Non-numeric characters should fail
assert use_calculator("abc") is None
assert use_calculator("2+x") is None
assert use_calculator("import os") is None
# Power operator disabled
assert use_calculator("2**10") is None
# Division by zero should fail gracefully
assert use_calculator("1/0") is None
def test_kv_cache_initialization():
"""Test KV cache initialization."""
batch_size, num_heads, seq_len, head_dim, num_layers = 2, 4, 128, 16, 3
cache = KVCache(batch_size, num_heads, seq_len, head_dim, num_layers)
assert cache.pos == 0
assert cache.kv_cache is None # Lazy initialization
assert cache.kv_shape == (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
def test_kv_cache_insert_and_retrieve():
"""Test inserting and retrieving from KV cache."""
batch_size, num_heads, seq_len, head_dim, num_layers = 1, 2, 32, 8, 2
cache = KVCache(batch_size, num_heads, seq_len, head_dim, num_layers)
# Create some keys and values
k = torch.randn(batch_size, num_heads, 4, head_dim)
v = torch.randn(batch_size, num_heads, 4, head_dim)
# Insert for layer 0
k_out, v_out = cache.insert_kv(0, k, v)
# Should return views of size 4 (what we inserted)
assert k_out.shape == (batch_size, num_heads, 4, head_dim)
assert v_out.shape == (batch_size, num_heads, 4, head_dim)
# Values should match
torch.testing.assert_close(k_out, k)
torch.testing.assert_close(v_out, v)
# Position should not advance until last layer
assert cache.pos == 0
# Insert for layer 1 (last layer)
k_out, v_out = cache.insert_kv(1, k, v)
# Now position should advance
assert cache.pos == 4
def test_kv_cache_sequential_inserts():
"""Test sequential token generation with KV cache."""
batch_size, num_heads, seq_len, head_dim, num_layers = 1, 2, 64, 8, 1
cache = KVCache(batch_size, num_heads, seq_len, head_dim, num_layers)
# Insert tokens one at a time
for i in range(5):
k = torch.randn(batch_size, num_heads, 1, head_dim)
v = torch.randn(batch_size, num_heads, 1, head_dim)
k_out, v_out = cache.insert_kv(0, k, v)
# Should return all keys/values so far
assert k_out.shape == (batch_size, num_heads, i + 1, head_dim)
assert v_out.shape == (batch_size, num_heads, i + 1, head_dim)
assert cache.pos == 5
def test_kv_cache_reset():
"""Test resetting KV cache."""
cache = KVCache(1, 2, 32, 8, 1)
k = torch.randn(1, 2, 4, 8)
v = torch.randn(1, 2, 4, 8)
cache.insert_kv(0, k, v)
assert cache.pos == 4
cache.reset()
assert cache.pos == 0
def test_kv_cache_prefill():
"""Test prefilling KV cache from another cache."""
# Create a small cache sized exactly for the data we'll insert
# (This matches the actual usage pattern in engine.py)
num_tokens = 4
small_cache = KVCache(1, 2, num_tokens, 8, 2)
k = torch.randn(1, 2, num_tokens, 8)
v = torch.randn(1, 2, num_tokens, 8)
small_cache.insert_kv(0, k, v)
small_cache.insert_kv(1, k, v)
assert small_cache.pos == num_tokens
# Create a larger cache and prefill from small
large_cache = KVCache(1, 2, 128, 8, 2)
large_cache.prefill(small_cache)
assert large_cache.pos == small_cache.pos
assert large_cache.pos == num_tokens
def test_kv_cache_dynamic_growth():
"""Test that KV cache grows dynamically."""
cache = KVCache(1, 2, 16, 8, 1) # Start with small size
# Insert more tokens than initial capacity
for i in range(20):
k = torch.randn(1, 2, 1, 8)
v = torch.randn(1, 2, 1, 8)
k_out, v_out = cache.insert_kv(0, k, v)
assert k_out.shape[2] == i + 1 # Should have all tokens so far
# Cache should have grown
assert cache.kv_cache.shape[4] >= 20
def test_sample_next_token_greedy():
"""Test greedy sampling (temperature=0)."""
vocab_size = 10
batch_size = 2
logits = torch.randn(batch_size, vocab_size)
rng = torch.Generator()
rng.manual_seed(42)
tokens = sample_next_token(logits, rng, temperature=0.0)
assert tokens.shape == (batch_size, 1)
# Should be argmax
expected = torch.argmax(logits, dim=-1, keepdim=True)
torch.testing.assert_close(tokens, expected)
def test_sample_next_token_with_temperature():
"""Test sampling with temperature."""
vocab_size = 10
batch_size = 2
logits = torch.randn(batch_size, vocab_size)
rng = torch.Generator()
rng.manual_seed(42)
tokens = sample_next_token(logits, rng, temperature=1.0)
assert tokens.shape == (batch_size, 1)
assert torch.all((tokens >= 0) & (tokens < vocab_size))
def test_sample_next_token_top_k():
"""Test top-k sampling."""
vocab_size = 100
batch_size = 1
logits = torch.randn(batch_size, vocab_size)
rng = torch.Generator()
rng.manual_seed(42)
# Sample multiple times and check all are in top-k
top_k = 5
samples = []
for _ in range(20):
rng.manual_seed(42 + _)
token = sample_next_token(logits, rng, temperature=1.0, top_k=top_k)
samples.append(token.item())
# Get the actual top-k indices
_, top_k_indices = torch.topk(logits[0], top_k)
top_k_set = set(top_k_indices.tolist())
# All samples should be in top-k
for sample in samples:
assert sample in top_k_set
def test_engine_initialization(tiny_model, mock_tokenizer):
"""Test Engine initialization."""
engine = Engine(tiny_model, mock_tokenizer)
assert engine.model is tiny_model
assert engine.tokenizer is mock_tokenizer
def test_engine_generate_batch(tiny_model, mock_tokenizer):
"""Test batch generation with Engine."""
engine = Engine(tiny_model, mock_tokenizer)
# Start with some tokens
initial_tokens = [1, 2, 3, 4]
num_samples = 2
max_tokens = 10
results, masks = engine.generate_batch(
initial_tokens,
num_samples=num_samples,
max_tokens=max_tokens,
temperature=1.0,
seed=42
)
# Should return num_samples results
assert len(results) == num_samples
assert len(masks) == num_samples
# Each result should have at least the initial tokens
for result, mask in zip(results, masks):
assert len(result) >= len(initial_tokens)
assert len(result) == len(mask)
# Initial tokens should match
assert result[:len(initial_tokens)] == initial_tokens
def test_engine_generate_deterministic(tiny_model, mock_tokenizer):
"""Test that generation is deterministic with same seed."""
engine = Engine(tiny_model, mock_tokenizer)
initial_tokens = [1, 2, 3]
results1, _ = engine.generate_batch(initial_tokens, num_samples=1, max_tokens=5, seed=42)
results2, _ = engine.generate_batch(initial_tokens, num_samples=1, max_tokens=5, seed=42)
assert results1[0] == results2[0], "Results should be identical with same seed"
def test_engine_generate_streaming(tiny_model, mock_tokenizer):
"""Test streaming generation."""
engine = Engine(tiny_model, mock_tokenizer)
initial_tokens = [1, 2, 3]
max_tokens = 5
token_columns = []
mask_columns = []
for token_col, mask_col in engine.generate(
initial_tokens,
num_samples=2,
max_tokens=max_tokens,
seed=42
):
token_columns.append(token_col)
mask_columns.append(mask_col)
# Should generate max_tokens columns
assert len(token_columns) == max_tokens
# Each column should have 2 tokens (num_samples=2)
for col in token_columns:
assert len(col) == 2
def test_engine_greedy_decode(tiny_model, mock_tokenizer):
"""Test greedy decoding."""
engine = Engine(tiny_model, mock_tokenizer)
initial_tokens = [1, 2, 3]
results, _ = engine.generate_batch(
initial_tokens,
num_samples=1,
max_tokens=5,
temperature=0.0 # Greedy
)
assert len(results) == 1
assert len(results[0]) >= len(initial_tokens)
def test_engine_max_tokens_limit(tiny_model, mock_tokenizer):
"""Test that generation respects max_tokens."""
engine = Engine(tiny_model, mock_tokenizer)
initial_tokens = [1, 2, 3]
max_tokens = 3
results, _ = engine.generate_batch(
initial_tokens,
num_samples=1,
max_tokens=max_tokens,
seed=42
)
# Should not exceed initial + max_tokens
assert len(results[0]) <= len(initial_tokens) + max_tokens
def test_engine_multiple_samples(tiny_model, mock_tokenizer):
"""Test generating multiple samples in parallel."""
engine = Engine(tiny_model, mock_tokenizer)
initial_tokens = [1, 2, 3]
num_samples = 4
results, _ = engine.generate_batch(
initial_tokens,
num_samples=num_samples,
max_tokens=5,
temperature=1.0, # Non-zero temp for diversity
seed=42
)
assert len(results) == num_samples
# All should start with same initial tokens
for result in results:
assert result[:len(initial_tokens)] == initial_tokens
def test_engine_with_kv_cache(tiny_model, mock_tokenizer):
"""Test that KV cache is used during generation."""
engine = Engine(tiny_model, mock_tokenizer)
initial_tokens = [1, 2, 3, 4, 5]
# Generate with a longer prompt to benefit from KV cache
results, _ = engine.generate_batch(
initial_tokens,
num_samples=1,
max_tokens=5,
seed=42
)
# Should successfully generate
assert len(results) == 1
assert len(results[0]) > len(initial_tokens)