mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-20 18:34:14 +00:00
Previously, when generating multiple samples (num_samples > 1), the first token after prefill was sampled once and broadcast to all rows, causing all samples to start identically. Now the prefill logits are expanded to num_samples and sampled independently for each row. Also simplified the generation loop by moving the forward pass to the end of the loop, eliminating the first_iteration flag and if/else branching. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
188 lines
7.0 KiB
Python
188 lines
7.0 KiB
Python
"""
|
|
Test Engine class. Example run:
|
|
|
|
python -m pytest tests/test_engine.py -v
|
|
"""
|
|
|
|
import torch
|
|
from nanochat.engine import KVCache, Engine
|
|
from dataclasses import dataclass
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Mock classes for testing Engine without loading a real model
|
|
|
|
@dataclass
|
|
class MockConfig:
|
|
"""Minimal config for Engine tests."""
|
|
n_kv_head: int = 4
|
|
n_head: int = 4
|
|
n_embd: int = 64
|
|
n_layer: int = 2
|
|
sequence_len: int = 128
|
|
|
|
|
|
class MockModel:
|
|
"""
|
|
Mock model that returns uniform logits over the vocab.
|
|
This ensures that with temperature > 0, different samples should
|
|
(with very high probability) produce different tokens.
|
|
"""
|
|
def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens
|
|
self.vocab_size = vocab_size
|
|
self.config = MockConfig()
|
|
self._device = "cpu"
|
|
|
|
def get_device(self):
|
|
return self._device
|
|
|
|
def forward(self, ids, kv_cache=None):
|
|
"""Return uniform logits so sampling is spread across vocab."""
|
|
B, T = ids.shape
|
|
# Simulate what a real transformer does: insert k,v into the cache for each layer
|
|
if kv_cache is not None:
|
|
head_dim = self.config.n_embd // self.config.n_head
|
|
for layer_idx in range(self.config.n_layer):
|
|
k = torch.zeros(B, self.config.n_kv_head, T, head_dim)
|
|
v = torch.zeros(B, self.config.n_kv_head, T, head_dim)
|
|
kv_cache.insert_kv(layer_idx, k, v)
|
|
# Uniform logits -> equal probability for all tokens
|
|
logits = torch.zeros(B, T, self.vocab_size)
|
|
return logits
|
|
|
|
|
|
class ByteTokenizer:
|
|
"""
|
|
Simple byte-level tokenizer for testing.
|
|
Tokens 0-255 are raw bytes, 256+ are special tokens.
|
|
"""
|
|
def __init__(self):
|
|
# Special tokens start at 256
|
|
self._special_tokens = {
|
|
"<|python_start|>": 256,
|
|
"<|python_end|>": 257,
|
|
"<|output_start|>": 258,
|
|
"<|output_end|>": 259,
|
|
"<|assistant_end|>": 260,
|
|
"<|bos|>": 261,
|
|
}
|
|
self._bos = 261
|
|
|
|
def encode_special(self, s):
|
|
return self._special_tokens[s]
|
|
|
|
def get_bos_token_id(self):
|
|
return self._bos
|
|
|
|
def encode(self, s, prepend=None):
|
|
tokens = list(s.encode("utf-8")) # bytes 0-255
|
|
if prepend is not None:
|
|
tokens = [prepend] + tokens
|
|
return tokens
|
|
|
|
def decode(self, tokens):
|
|
# Filter out special tokens before decoding
|
|
byte_tokens = [t for t in tokens if t < 256]
|
|
return bytes(byte_tokens).decode("utf-8", errors="replace")
|
|
|
|
def test_kv_cache_resize():
|
|
"""
|
|
The KV cache was not resized correctly, more information here:
|
|
https://github.com/karpathy/nanochat/pull/186
|
|
This test reproduces the issue and will be merged alongside the fix.
|
|
"""
|
|
|
|
batch_size = 2
|
|
num_heads = 3
|
|
seq_len = 4
|
|
head_dim = 5
|
|
num_layers = 6
|
|
|
|
kv_cache = KVCache(
|
|
batch_size=batch_size,
|
|
num_heads=num_heads,
|
|
seq_len=seq_len,
|
|
head_dim=head_dim,
|
|
num_layers=num_layers
|
|
)
|
|
|
|
# Insert a single token with a distinct fill value to all layers
|
|
def insert_token(token_idx):
|
|
for layer_idx in range(num_layers):
|
|
k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32)
|
|
v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32)
|
|
kv_cache.insert_kv(layer_idx, k, v)
|
|
|
|
# Insert 4 tokens (fills the initial seq_len=4)
|
|
for i in range(4):
|
|
insert_token(i)
|
|
|
|
# Record the original state of the cache
|
|
original_cache = kv_cache.kv_cache.clone()
|
|
original_seq_len = original_cache.shape[4]
|
|
|
|
# Insert the 5th token, which will trigger a resize
|
|
insert_token(4)
|
|
# Verify that the cache actually resized
|
|
new_seq_len = kv_cache.kv_cache.shape[4]
|
|
assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
|
|
|
|
# Verify that the original 4 tokens are still intact after resize
|
|
for layer_idx in range(num_layers):
|
|
for token_idx in range(4):
|
|
# Check that resized cache matches expected values
|
|
expected_k = float(token_idx)
|
|
expected_v = float(token_idx * 100)
|
|
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :]
|
|
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :]
|
|
assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
|
|
assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
|
|
# And that the original cache matches resized cache
|
|
original_k = original_cache[layer_idx, 0, :, :, token_idx, :]
|
|
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
|
|
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
|
|
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
|
|
|
|
|
|
def test_multi_sample_first_token_diversity():
|
|
"""
|
|
Test that when generating multiple samples, each sample gets an independently
|
|
sampled first token (not a broadcast of the same token to all rows).
|
|
|
|
Previously, the first token after prefill was sampled once and broadcast to all
|
|
rows, causing all samples to start identically. The fix expands the prefill logits
|
|
to num_samples and samples independently for each row.
|
|
|
|
With uniform logits over 262 tokens and 16 samples, the probability that all
|
|
samples independently pick the same token is (1/262)^15 ≈ 10^-36. So if they're
|
|
all identical, it indicates tokens are being broadcast instead of independently sampled.
|
|
"""
|
|
model = MockModel(vocab_size=262)
|
|
tokenizer = ByteTokenizer()
|
|
engine = Engine(model, tokenizer)
|
|
|
|
# Generate 16 samples with temperature=1.0 (stochastic sampling)
|
|
prompt_tokens = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
|
|
num_samples = 16
|
|
|
|
# Collect the first generated token from each sample
|
|
first_tokens = []
|
|
gen = engine.generate(
|
|
prompt_tokens,
|
|
num_samples=num_samples,
|
|
max_tokens=1, # We only need the first token
|
|
temperature=1.0,
|
|
seed=42,
|
|
)
|
|
for token_column, token_masks in gen:
|
|
first_tokens = token_column # This is the first (and only) yield
|
|
|
|
# With uniform distribution and 16 samples, they should NOT all be identical
|
|
# If they are all identical, the bug exists (broadcasting instead of sampling)
|
|
unique_tokens = set(first_tokens)
|
|
assert len(unique_tokens) > 1, (
|
|
f"All {num_samples} samples got the same first token ({first_tokens[0]}). "
|
|
f"With uniform logits, this is statistically impossible (~10^-36 probability) "
|
|
f"unless tokens are being broadcast instead of independently sampled."
|
|
)
|