nanochat/tests/test_engine.py
2026-01-02 18:10:21 +01:00

410 lines
15 KiB
Python

"""
Test Engine class. Example run:
python -m pytest tests/test_engine.py -v
"""
import os
import json
import torch
import pytest
from dataclasses import dataclass
from huggingface_hub import snapshot_download
from nanochat.engine import KVCache, Engine
from nanochat.gpt import GPT, GPTConfig
from nanochat.tokenizer import RustBPETokenizer
from nanochat.checkpoint_manager import find_last_step
# -----------------------------------------------------------------------------
# Ensure deterministic behavior for reproducible tests
# See: https://docs.pytorch.org/docs/stable/notes/randomness.html
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Required for CUDA >= 10.2 determinism
torch.manual_seed(0)
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
# -----------------------------------------------------------------------------
# Mock classes for testing Engine without loading a real model
@dataclass
class MockConfig:
"""Minimal config for Engine tests."""
n_kv_head: int = 4
n_head: int = 4
n_embd: int = 64
n_layer: int = 2
sequence_len: int = 128
class MockModel:
"""
Mock model that returns uniform logits over the vocab.
This ensures that with temperature > 0, different samples should
(with very high probability) produce different tokens.
"""
def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens
self.vocab_size = vocab_size
self.config = MockConfig()
self._device = "cpu"
def get_device(self):
return self._device
def forward(self, ids, kv_cache=None, attention_mask=None):
"""Return uniform logits so sampling is spread across vocab."""
B, T = ids.shape
# Simulate what a real transformer does: insert k,v into the cache for each layer
if kv_cache is not None:
head_dim = self.config.n_embd // self.config.n_head
for layer_idx in range(self.config.n_layer):
k = torch.zeros(B, self.config.n_kv_head, T, head_dim)
v = torch.zeros(B, self.config.n_kv_head, T, head_dim)
kv_cache.insert_kv(layer_idx, k, v)
# Uniform logits -> equal probability for all tokens
logits = torch.zeros(B, T, self.vocab_size)
return logits
class ByteTokenizer:
"""
Simple byte-level tokenizer for testing.
Tokens 0-255 are raw bytes, 256+ are special tokens.
"""
def __init__(self):
# Special tokens start at 256
self._special_tokens = {
"<|python_start|>": 256,
"<|python_end|>": 257,
"<|output_start|>": 258,
"<|output_end|>": 259,
"<|assistant_end|>": 260,
"<|bos|>": 261,
}
self._bos = 261
def encode_special(self, s):
return self._special_tokens[s]
def get_bos_token_id(self):
return self._bos
def encode(self, s, prepend=None):
tokens = list(s.encode("utf-8")) # bytes 0-255
if prepend is not None:
tokens = [prepend] + tokens
return tokens
def decode(self, tokens):
# Filter out special tokens before decoding
byte_tokens = [t for t in tokens if t < 256]
return bytes(byte_tokens).decode("utf-8", errors="replace")
def get_model_and_tokenizer(use_pretrained=False):
"""
Get a model and tokenizer for testing. Requires CUDA.
Args:
use_pretrained: If True, download and load the pretrained nanochat-d34 model.
If False, create a small randomly initialized model.
Returns:
(model, tokenizer) tuple
"""
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for these tests")
device = torch.device("cuda")
if use_pretrained:
# Download the checkpoint
cache_dir = snapshot_download(repo_id="karpathy/nanochat-d34")
# Find the last step
step = find_last_step(cache_dir)
# Load model data
model_path = os.path.join(cache_dir, f"model_{step:06d}.pt")
model_data = torch.load(model_path, map_location=device)
# Fix torch compile key prefix
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
# Convert all tensors to bfloat16 for consistent dtypes (checkpoint has mixed bfloat16/float32)
model_data = {
k: v.bfloat16() if v.is_floating_point() else v
for k, v in model_data.items()
}
# Load metadata
meta_path = os.path.join(cache_dir, f"meta_{step:06d}.json")
with open(meta_path, "r", encoding="utf-8") as f:
meta_data = json.load(f)
# Build model
model_config = GPTConfig(**meta_data["model_config"])
with torch.device("meta"):
model = GPT(model_config)
model.to_empty(device=device)
model.init_weights()
model.load_state_dict(model_data, strict=True, assign=True)
model.eval()
# Load tokenizer from the checkpoint directory
tokenizer = RustBPETokenizer.from_directory(cache_dir)
else:
# Small model for fast testing
config = GPTConfig(
sequence_len=256,
vocab_size=262, # 256 bytes + 6 special tokens
n_layer=2,
n_head=4,
n_kv_head=4,
n_embd=64,
)
model = GPT(config)
model.init_weights()
model = model.to(device)
model.eval()
tokenizer = ByteTokenizer()
return model, tokenizer
# -----------------------------------------------------------------------------
# KVCache tests
def test_kv_cache_resize():
"""
The KV cache was not resized correctly, more information here:
https://github.com/karpathy/nanochat/pull/186
This test reproduces the issue and will be merged alongside the fix.
"""
batch_size = 2
num_heads = 3
seq_len = 4
head_dim = 5
num_layers = 6
kv_cache = KVCache(
batch_size=batch_size,
num_heads=num_heads,
seq_len=seq_len,
head_dim=head_dim,
num_layers=num_layers
)
# Insert a single token with a distinct fill value to all layers
def insert_token(token_idx):
for layer_idx in range(num_layers):
k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32)
v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32)
kv_cache.insert_kv(layer_idx, k, v)
# Insert 4 tokens (fills the initial seq_len=4)
for i in range(4):
insert_token(i)
# Record the original state of the cache
original_cache = kv_cache.kv_cache.clone()
original_seq_len = original_cache.shape[4]
# Insert the 5th token, which will trigger a resize
insert_token(4)
# Verify that the cache actually resized
new_seq_len = kv_cache.kv_cache.shape[4]
assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
# Verify that the original 4 tokens are still intact after resize
for layer_idx in range(num_layers):
for token_idx in range(4):
# Check that resized cache matches expected values
expected_k = float(token_idx)
expected_v = float(token_idx * 100)
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :]
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :]
assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
# And that the original cache matches resized cache
original_k = original_cache[layer_idx, 0, :, :, token_idx, :]
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
def test_multi_sample_first_token_diversity():
"""
Test that when generating multiple samples, each sample gets an independently
sampled first token (not a broadcast of the same token to all rows).
Previously, the first token after prefill was sampled once and broadcast to all
rows, causing all samples to start identically. The fix expands the prefill logits
to num_samples and samples independently for each row.
With uniform logits over 262 tokens and 16 samples, the probability that all
samples independently pick the same token is (1/262)^15 ≈ 10^-36. So if they're
all identical, it indicates tokens are being broadcast instead of independently sampled.
"""
model = MockModel(vocab_size=262)
tokenizer = ByteTokenizer()
engine = Engine(model, tokenizer)
# Generate 16 samples with temperature=1.0 (stochastic sampling)
prompt_tokens = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
num_samples = 16
# Collect the first generated token from each sample
first_tokens = []
gen = engine.generate(
prompt_tokens,
num_samples=num_samples,
max_tokens=1, # We only need the first token
temperature=1.0,
seed=42,
)
for token_column, token_masks in gen:
first_tokens = token_column # This is the first (and only) yield
# With uniform distribution and 16 samples, they should NOT all be identical
# If they are all identical, the bug exists (broadcasting instead of sampling)
unique_tokens = set(first_tokens)
assert len(unique_tokens) > 1, (
f"All {num_samples} samples got the same first token ({first_tokens[0]}). "
f"With uniform logits, this is statistically impossible (~10^-36 probability) "
f"unless tokens are being broadcast instead of independently sampled."
)
# -----------------------------------------------------------------------------
# Batched generation tests
@pytest.mark.parametrize("use_pretrained", [False, True])
def test_batched_generation_consistency(use_pretrained):
"""
Test that batched generation produces the same results as individual generation.
This test:
1. Generates from each prompt individually (existing single-prompt behavior)
2. Generates from all prompts together in a batch (new batched behavior)
3. Asserts that the results match exactly
Uses temperature=0.0 for deterministic outputs.
"""
try:
model, tokenizer = get_model_and_tokenizer(use_pretrained=use_pretrained)
except Exception as e:
if use_pretrained:
pytest.skip(f"Could not load pretrained model: {e}")
raise
engine = Engine(model, tokenizer)
# Define test prompts with different lengths
bos = tokenizer.get_bos_token_id()
prompts = [
tokenizer.encode("hi", prepend=bos),
tokenizer.encode("the capital of France is", prepend=bos),
tokenizer.encode("hello, I'm a", prepend=bos),
]
num_samples = 2
# Deterministic decoding
generation_kwargs = dict(max_tokens=10, temperature=0.0, seed=0)
# 1) Generate individually for each prompt
individual_results = []
individual_masks = []
for prompt in prompts:
results, masks = engine.generate_batch(prompt, num_samples=num_samples, **generation_kwargs)
individual_results.append(results) # results is list[list[int]] of shape (num_samples, seq_len)
individual_masks.append(masks) # masks is list[list[int]] of shape (num_samples, seq_len)
# 2) Generate batched (all prompts together)
batched_results, batched_masks = engine.generate_batch(prompts, num_samples=num_samples, **generation_kwargs)
# 3) Assert results match
assert len(individual_results) == len(batched_results), \
f"Prompt count mismatch: {len(individual_results)} vs {len(batched_results)}"
for prompt_idx, (ind_samples, batch_samples, ind_masks, batch_masks) in enumerate(
zip(individual_results, batched_results, individual_masks, batched_masks)):
assert len(ind_samples) == len(batch_samples), f"Sample count mismatch for prompt {prompt_idx}"
for sample_idx, (ind_result, batch_result, ind_mask, batch_mask) in enumerate(
zip(ind_samples, batch_samples, ind_masks, batch_masks)):
assert ind_result == batch_result, (
f"Mismatch for prompt {prompt_idx}, sample {sample_idx}:\n"
f" Individual: {ind_result}\n"
f" Batched: {batch_result}"
)
assert ind_mask == batch_mask, (
f"Mask mismatch for prompt {prompt_idx}, sample {sample_idx}:\n"
f" Individual: {ind_mask}\n"
f" Batched: {batch_mask}"
)
def test_batched_generation_single_prompt():
"""
Test that batched generation with a single prompt in the batch
produces the same result as non-batched single prompt generation.
"""
model, tokenizer = get_model_and_tokenizer(use_pretrained=False)
engine = Engine(model, tokenizer)
bos = tokenizer.get_bos_token_id()
prompt = tokenizer.encode("the capital of France is", prepend=bos)
num_samples = 3
generation_kwargs = dict(max_tokens=8, temperature=0.0, seed=0)
# Generate non-batched: returns shape (num_samples, seq_len)
single_results, single_masks = engine.generate_batch(prompt, num_samples=num_samples, **generation_kwargs)
# Generate batched with single prompt: returns shape (1, num_samples, seq_len)
batched_results, batched_masks = engine.generate_batch([prompt], num_samples=num_samples, **generation_kwargs)
assert single_results == batched_results[0], (
f"Single vs batched single-prompt mismatch:\n"
f" Single: {single_results}\n"
f" Batched: {batched_results[0]}"
)
assert single_masks == batched_masks[0], (
f"Single vs batched single-prompt mask mismatch:\n"
f" Single: {single_masks}\n"
f" Batched: {batched_masks[0]}"
)
def test_batched_generation_stochastic():
"""
Test that batched generation with temperature > 0 produces diverse outputs.
"""
model, tokenizer = get_model_and_tokenizer(use_pretrained=False)
engine = Engine(model, tokenizer)
bos = tokenizer.get_bos_token_id()
prompts = [
tokenizer.encode("hi", prepend=bos),
tokenizer.encode("the capital of France is", prepend=bos),
]
num_samples = 4
generation_kwargs = dict(max_tokens=64, temperature=1.0, seed=0)
# Generate batched: returns shape (num_prompts, num_samples, seq_len)
results, _ = engine.generate_batch(prompts, num_samples=num_samples, **generation_kwargs)
# Check structure
assert len(results) == len(prompts)
# Check that samples within each prompt are diverse (not all identical)
for prompt_idx, samples in enumerate(results):
assert len(samples) == num_samples
unique_samples = set(tuple(s) for s in samples)
assert len(unique_samples) > 1, (
f"All {num_samples} samples for prompt {prompt_idx} are identical. "
f"With temperature=1.0, samples should differ."
)