mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 13:45:21 +00:00
Add test coverage for all major components: - GPT model: architecture, generation, MQA, rotary embeddings (19 tests) - Inference engine: KV cache, sampling, tool use (17 tests) - Optimizers: Muon and AdamW functionality (10 tests) - Checkpoint management: save/load, metadata (5 tests) - Data loading and utilities (13 tests) docs: update README with test documentation and learning guide - Add For Students section with structured learning path - Document architectural decisions and key concepts - Add test usage instructions
432 lines
12 KiB
Python
432 lines
12 KiB
Python
"""
|
|
Tests for the inference engine with KV cache and tool use.
|
|
|
|
Run with:
|
|
python -m pytest tests/test_engine.py -v -s --timeout=60
|
|
"""
|
|
|
|
import torch
|
|
import pytest
|
|
from nanochat.gpt import GPT, GPTConfig
|
|
from nanochat.engine import Engine, KVCache, use_calculator, sample_next_token
|
|
from nanochat.tokenizer import RustBPETokenizer
|
|
|
|
|
|
@pytest.fixture
|
|
def tiny_model():
|
|
"""Create a tiny model for testing."""
|
|
config = GPTConfig(
|
|
sequence_len=128,
|
|
vocab_size=256,
|
|
n_layer=2,
|
|
n_head=4,
|
|
n_kv_head=2,
|
|
n_embd=64,
|
|
)
|
|
model = GPT(config)
|
|
model.init_weights()
|
|
# Prepare for CPU testing
|
|
model = model.float()
|
|
model.cos = model.cos.bfloat16()
|
|
model.sin = model.sin.bfloat16()
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_tokenizer():
|
|
"""Create a mock tokenizer for testing."""
|
|
class MockTokenizer:
|
|
def encode(self, text):
|
|
"""Simple encode: just return char codes."""
|
|
if isinstance(text, str):
|
|
return [ord(c) % 256 for c in text]
|
|
elif isinstance(text, list):
|
|
return [self.encode(t) for t in text]
|
|
|
|
def decode(self, ids):
|
|
"""Simple decode: convert back to chars."""
|
|
return ''.join(chr(i) for i in ids)
|
|
|
|
def encode_special(self, token):
|
|
"""Encode special tokens to specific IDs."""
|
|
special_map = {
|
|
'<|python_start|>': 250,
|
|
'<|python_end|>': 251,
|
|
'<|output_start|>': 252,
|
|
'<|output_end|>': 253,
|
|
'<|assistant_end|>': 254,
|
|
'<|bos|>': 255,
|
|
}
|
|
return special_map.get(token, 0)
|
|
|
|
def get_bos_token_id(self):
|
|
return 255
|
|
|
|
return MockTokenizer()
|
|
|
|
|
|
def test_use_calculator():
|
|
"""Test calculator functionality."""
|
|
# Basic arithmetic
|
|
assert use_calculator("2+2") == 4
|
|
assert use_calculator("10-3") == 7
|
|
assert use_calculator("4*5") == 20
|
|
assert use_calculator("15/3") == 5
|
|
|
|
# With spaces
|
|
assert use_calculator("2 + 2") == 4
|
|
|
|
# Order of operations
|
|
assert use_calculator("2+3*4") == 14
|
|
|
|
# Parentheses
|
|
assert use_calculator("(2+3)*4") == 20
|
|
|
|
# Decimals
|
|
result = use_calculator("10.5+2.5")
|
|
assert result == 13.0
|
|
|
|
# Commas should be removed
|
|
result = use_calculator("1,000+500")
|
|
assert result == 1500
|
|
|
|
|
|
def test_use_calculator_invalid():
|
|
"""Test calculator with invalid inputs."""
|
|
# Non-numeric characters should fail
|
|
assert use_calculator("abc") is None
|
|
assert use_calculator("2+x") is None
|
|
assert use_calculator("import os") is None
|
|
|
|
# Power operator disabled
|
|
assert use_calculator("2**10") is None
|
|
|
|
# Division by zero should fail gracefully
|
|
assert use_calculator("1/0") is None
|
|
|
|
|
|
def test_kv_cache_initialization():
|
|
"""Test KV cache initialization."""
|
|
batch_size, num_heads, seq_len, head_dim, num_layers = 2, 4, 128, 16, 3
|
|
|
|
cache = KVCache(batch_size, num_heads, seq_len, head_dim, num_layers)
|
|
|
|
assert cache.pos == 0
|
|
assert cache.kv_cache is None # Lazy initialization
|
|
assert cache.kv_shape == (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
|
|
|
|
|
|
def test_kv_cache_insert_and_retrieve():
|
|
"""Test inserting and retrieving from KV cache."""
|
|
batch_size, num_heads, seq_len, head_dim, num_layers = 1, 2, 32, 8, 2
|
|
|
|
cache = KVCache(batch_size, num_heads, seq_len, head_dim, num_layers)
|
|
|
|
# Create some keys and values
|
|
k = torch.randn(batch_size, num_heads, 4, head_dim)
|
|
v = torch.randn(batch_size, num_heads, 4, head_dim)
|
|
|
|
# Insert for layer 0
|
|
k_out, v_out = cache.insert_kv(0, k, v)
|
|
|
|
# Should return views of size 4 (what we inserted)
|
|
assert k_out.shape == (batch_size, num_heads, 4, head_dim)
|
|
assert v_out.shape == (batch_size, num_heads, 4, head_dim)
|
|
|
|
# Values should match
|
|
torch.testing.assert_close(k_out, k)
|
|
torch.testing.assert_close(v_out, v)
|
|
|
|
# Position should not advance until last layer
|
|
assert cache.pos == 0
|
|
|
|
# Insert for layer 1 (last layer)
|
|
k_out, v_out = cache.insert_kv(1, k, v)
|
|
|
|
# Now position should advance
|
|
assert cache.pos == 4
|
|
|
|
|
|
def test_kv_cache_sequential_inserts():
|
|
"""Test sequential token generation with KV cache."""
|
|
batch_size, num_heads, seq_len, head_dim, num_layers = 1, 2, 64, 8, 1
|
|
|
|
cache = KVCache(batch_size, num_heads, seq_len, head_dim, num_layers)
|
|
|
|
# Insert tokens one at a time
|
|
for i in range(5):
|
|
k = torch.randn(batch_size, num_heads, 1, head_dim)
|
|
v = torch.randn(batch_size, num_heads, 1, head_dim)
|
|
|
|
k_out, v_out = cache.insert_kv(0, k, v)
|
|
|
|
# Should return all keys/values so far
|
|
assert k_out.shape == (batch_size, num_heads, i + 1, head_dim)
|
|
assert v_out.shape == (batch_size, num_heads, i + 1, head_dim)
|
|
|
|
assert cache.pos == 5
|
|
|
|
|
|
def test_kv_cache_reset():
|
|
"""Test resetting KV cache."""
|
|
cache = KVCache(1, 2, 32, 8, 1)
|
|
|
|
k = torch.randn(1, 2, 4, 8)
|
|
v = torch.randn(1, 2, 4, 8)
|
|
cache.insert_kv(0, k, v)
|
|
|
|
assert cache.pos == 4
|
|
|
|
cache.reset()
|
|
assert cache.pos == 0
|
|
|
|
|
|
def test_kv_cache_prefill():
|
|
"""Test prefilling KV cache from another cache."""
|
|
# Create a small cache sized exactly for the data we'll insert
|
|
# (This matches the actual usage pattern in engine.py)
|
|
num_tokens = 4
|
|
small_cache = KVCache(1, 2, num_tokens, 8, 2)
|
|
k = torch.randn(1, 2, num_tokens, 8)
|
|
v = torch.randn(1, 2, num_tokens, 8)
|
|
small_cache.insert_kv(0, k, v)
|
|
small_cache.insert_kv(1, k, v)
|
|
|
|
assert small_cache.pos == num_tokens
|
|
|
|
# Create a larger cache and prefill from small
|
|
large_cache = KVCache(1, 2, 128, 8, 2)
|
|
large_cache.prefill(small_cache)
|
|
|
|
assert large_cache.pos == small_cache.pos
|
|
assert large_cache.pos == num_tokens
|
|
|
|
|
|
def test_kv_cache_dynamic_growth():
|
|
"""Test that KV cache grows dynamically."""
|
|
cache = KVCache(1, 2, 16, 8, 1) # Start with small size
|
|
|
|
# Insert more tokens than initial capacity
|
|
for i in range(20):
|
|
k = torch.randn(1, 2, 1, 8)
|
|
v = torch.randn(1, 2, 1, 8)
|
|
k_out, v_out = cache.insert_kv(0, k, v)
|
|
|
|
assert k_out.shape[2] == i + 1 # Should have all tokens so far
|
|
|
|
# Cache should have grown
|
|
assert cache.kv_cache.shape[4] >= 20
|
|
|
|
|
|
def test_sample_next_token_greedy():
|
|
"""Test greedy sampling (temperature=0)."""
|
|
vocab_size = 10
|
|
batch_size = 2
|
|
|
|
logits = torch.randn(batch_size, vocab_size)
|
|
rng = torch.Generator()
|
|
rng.manual_seed(42)
|
|
|
|
tokens = sample_next_token(logits, rng, temperature=0.0)
|
|
|
|
assert tokens.shape == (batch_size, 1)
|
|
|
|
# Should be argmax
|
|
expected = torch.argmax(logits, dim=-1, keepdim=True)
|
|
torch.testing.assert_close(tokens, expected)
|
|
|
|
|
|
def test_sample_next_token_with_temperature():
|
|
"""Test sampling with temperature."""
|
|
vocab_size = 10
|
|
batch_size = 2
|
|
|
|
logits = torch.randn(batch_size, vocab_size)
|
|
rng = torch.Generator()
|
|
rng.manual_seed(42)
|
|
|
|
tokens = sample_next_token(logits, rng, temperature=1.0)
|
|
|
|
assert tokens.shape == (batch_size, 1)
|
|
assert torch.all((tokens >= 0) & (tokens < vocab_size))
|
|
|
|
|
|
def test_sample_next_token_top_k():
|
|
"""Test top-k sampling."""
|
|
vocab_size = 100
|
|
batch_size = 1
|
|
|
|
logits = torch.randn(batch_size, vocab_size)
|
|
rng = torch.Generator()
|
|
rng.manual_seed(42)
|
|
|
|
# Sample multiple times and check all are in top-k
|
|
top_k = 5
|
|
samples = []
|
|
for _ in range(20):
|
|
rng.manual_seed(42 + _)
|
|
token = sample_next_token(logits, rng, temperature=1.0, top_k=top_k)
|
|
samples.append(token.item())
|
|
|
|
# Get the actual top-k indices
|
|
_, top_k_indices = torch.topk(logits[0], top_k)
|
|
top_k_set = set(top_k_indices.tolist())
|
|
|
|
# All samples should be in top-k
|
|
for sample in samples:
|
|
assert sample in top_k_set
|
|
|
|
|
|
def test_engine_initialization(tiny_model, mock_tokenizer):
|
|
"""Test Engine initialization."""
|
|
engine = Engine(tiny_model, mock_tokenizer)
|
|
|
|
assert engine.model is tiny_model
|
|
assert engine.tokenizer is mock_tokenizer
|
|
|
|
|
|
def test_engine_generate_batch(tiny_model, mock_tokenizer):
|
|
"""Test batch generation with Engine."""
|
|
engine = Engine(tiny_model, mock_tokenizer)
|
|
|
|
# Start with some tokens
|
|
initial_tokens = [1, 2, 3, 4]
|
|
num_samples = 2
|
|
max_tokens = 10
|
|
|
|
results, masks = engine.generate_batch(
|
|
initial_tokens,
|
|
num_samples=num_samples,
|
|
max_tokens=max_tokens,
|
|
temperature=1.0,
|
|
seed=42
|
|
)
|
|
|
|
# Should return num_samples results
|
|
assert len(results) == num_samples
|
|
assert len(masks) == num_samples
|
|
|
|
# Each result should have at least the initial tokens
|
|
for result, mask in zip(results, masks):
|
|
assert len(result) >= len(initial_tokens)
|
|
assert len(result) == len(mask)
|
|
# Initial tokens should match
|
|
assert result[:len(initial_tokens)] == initial_tokens
|
|
|
|
|
|
def test_engine_generate_deterministic(tiny_model, mock_tokenizer):
|
|
"""Test that generation is deterministic with same seed."""
|
|
engine = Engine(tiny_model, mock_tokenizer)
|
|
|
|
initial_tokens = [1, 2, 3]
|
|
|
|
results1, _ = engine.generate_batch(initial_tokens, num_samples=1, max_tokens=5, seed=42)
|
|
results2, _ = engine.generate_batch(initial_tokens, num_samples=1, max_tokens=5, seed=42)
|
|
|
|
assert results1[0] == results2[0], "Results should be identical with same seed"
|
|
|
|
|
|
def test_engine_generate_streaming(tiny_model, mock_tokenizer):
|
|
"""Test streaming generation."""
|
|
engine = Engine(tiny_model, mock_tokenizer)
|
|
|
|
initial_tokens = [1, 2, 3]
|
|
max_tokens = 5
|
|
|
|
token_columns = []
|
|
mask_columns = []
|
|
|
|
for token_col, mask_col in engine.generate(
|
|
initial_tokens,
|
|
num_samples=2,
|
|
max_tokens=max_tokens,
|
|
seed=42
|
|
):
|
|
token_columns.append(token_col)
|
|
mask_columns.append(mask_col)
|
|
|
|
# Should generate max_tokens columns
|
|
assert len(token_columns) == max_tokens
|
|
|
|
# Each column should have 2 tokens (num_samples=2)
|
|
for col in token_columns:
|
|
assert len(col) == 2
|
|
|
|
|
|
def test_engine_greedy_decode(tiny_model, mock_tokenizer):
|
|
"""Test greedy decoding."""
|
|
engine = Engine(tiny_model, mock_tokenizer)
|
|
|
|
initial_tokens = [1, 2, 3]
|
|
|
|
results, _ = engine.generate_batch(
|
|
initial_tokens,
|
|
num_samples=1,
|
|
max_tokens=5,
|
|
temperature=0.0 # Greedy
|
|
)
|
|
|
|
assert len(results) == 1
|
|
assert len(results[0]) >= len(initial_tokens)
|
|
|
|
|
|
def test_engine_max_tokens_limit(tiny_model, mock_tokenizer):
|
|
"""Test that generation respects max_tokens."""
|
|
engine = Engine(tiny_model, mock_tokenizer)
|
|
|
|
initial_tokens = [1, 2, 3]
|
|
max_tokens = 3
|
|
|
|
results, _ = engine.generate_batch(
|
|
initial_tokens,
|
|
num_samples=1,
|
|
max_tokens=max_tokens,
|
|
seed=42
|
|
)
|
|
|
|
# Should not exceed initial + max_tokens
|
|
assert len(results[0]) <= len(initial_tokens) + max_tokens
|
|
|
|
|
|
def test_engine_multiple_samples(tiny_model, mock_tokenizer):
|
|
"""Test generating multiple samples in parallel."""
|
|
engine = Engine(tiny_model, mock_tokenizer)
|
|
|
|
initial_tokens = [1, 2, 3]
|
|
num_samples = 4
|
|
|
|
results, _ = engine.generate_batch(
|
|
initial_tokens,
|
|
num_samples=num_samples,
|
|
max_tokens=5,
|
|
temperature=1.0, # Non-zero temp for diversity
|
|
seed=42
|
|
)
|
|
|
|
assert len(results) == num_samples
|
|
|
|
# All should start with same initial tokens
|
|
for result in results:
|
|
assert result[:len(initial_tokens)] == initial_tokens
|
|
|
|
|
|
def test_engine_with_kv_cache(tiny_model, mock_tokenizer):
|
|
"""Test that KV cache is used during generation."""
|
|
engine = Engine(tiny_model, mock_tokenizer)
|
|
|
|
initial_tokens = [1, 2, 3, 4, 5]
|
|
|
|
# Generate with a longer prompt to benefit from KV cache
|
|
results, _ = engine.generate_batch(
|
|
initial_tokens,
|
|
num_samples=1,
|
|
max_tokens=5,
|
|
seed=42
|
|
)
|
|
|
|
# Should successfully generate
|
|
assert len(results) == 1
|
|
assert len(results[0]) > len(initial_tokens)
|
|
|