fix: sample first token independently for each row in multi-sample generation

Previously, when generating multiple samples (num_samples > 1), the first
token after prefill was sampled once and broadcast to all rows, causing
all samples to start identically. Now the prefill logits are expanded to
num_samples and sampled independently for each row.

Also simplified the generation loop by moving the forward pass to the end
of the loop, eliminating the first_iteration flag and if/else branching.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Andrej Karpathy 2025-12-28 04:52:13 +00:00
parent 2f2d7ab80c
commit 8f979a8bda
2 changed files with 135 additions and 24 deletions

View File

@ -19,7 +19,7 @@ from contextlib import contextmanager
from collections import deque
from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from contextlib import nullcontext
from contextlib import nullcontext
# -----------------------------------------------------------------------------
# Calculator tool helpers
@ -107,23 +107,23 @@ class KVCache:
# 1) validate the shapes
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
# Extract dimensions explicitly
self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape
other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape
# Validate dimensions
assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}"
assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}"
assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}"
assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}"
# Batch size can be expanded (other can be 1, self can be larger)
assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)"
# Sequence length: self must be longer than other
assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}"
# 2) initialize the cache
dtype, device = other.kv_cache.dtype, other.kv_cache.device
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
@ -223,9 +223,7 @@ class Engine:
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
logits = logits[:, -1, :]
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size)
# 2) Replicate the KV cache for each sample/row
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
@ -242,7 +240,6 @@ class Engine:
# 4) Main generation loop
num_generated = 0
first_iteration = True
while True:
# Stop condition: we've reached max tokens
if max_tokens is not None and num_generated >= max_tokens:
@ -251,18 +248,9 @@ class Engine:
if all(state.completed for state in row_states):
break
# Get sampled tokens - either from prefill or from forward pass
if first_iteration:
# Use the tokens we already sampled from prefill
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
# TODO: we should sample a token for each row instead of broadcasting
first_iteration = False
else:
# Forward the model and get the next token for each row
logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size) at last time step
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
# Sample the next token for each row
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
# Process each row: choose the next token, update state, optional tool use
token_column = [] # contains the next token id along each row
@ -299,8 +287,10 @@ class Engine:
# Yield the token column
yield token_column, token_masks
num_generated += 1
# Prepare ids for next iteration
# Prepare logits for next iteration
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size)
def generate_batch(self, tokens, num_samples=1, **kwargs):
"""

View File

@ -5,7 +5,85 @@ python -m pytest tests/test_engine.py -v
"""
import torch
from nanochat.engine import KVCache
from nanochat.engine import KVCache, Engine
from dataclasses import dataclass
# -----------------------------------------------------------------------------
# Mock classes for testing Engine without loading a real model
@dataclass
class MockConfig:
"""Minimal config for Engine tests."""
n_kv_head: int = 4
n_head: int = 4
n_embd: int = 64
n_layer: int = 2
sequence_len: int = 128
class MockModel:
"""
Mock model that returns uniform logits over the vocab.
This ensures that with temperature > 0, different samples should
(with very high probability) produce different tokens.
"""
def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens
self.vocab_size = vocab_size
self.config = MockConfig()
self._device = "cpu"
def get_device(self):
return self._device
def forward(self, ids, kv_cache=None):
"""Return uniform logits so sampling is spread across vocab."""
B, T = ids.shape
# Simulate what a real transformer does: insert k,v into the cache for each layer
if kv_cache is not None:
head_dim = self.config.n_embd // self.config.n_head
for layer_idx in range(self.config.n_layer):
k = torch.zeros(B, self.config.n_kv_head, T, head_dim)
v = torch.zeros(B, self.config.n_kv_head, T, head_dim)
kv_cache.insert_kv(layer_idx, k, v)
# Uniform logits -> equal probability for all tokens
logits = torch.zeros(B, T, self.vocab_size)
return logits
class ByteTokenizer:
"""
Simple byte-level tokenizer for testing.
Tokens 0-255 are raw bytes, 256+ are special tokens.
"""
def __init__(self):
# Special tokens start at 256
self._special_tokens = {
"<|python_start|>": 256,
"<|python_end|>": 257,
"<|output_start|>": 258,
"<|output_end|>": 259,
"<|assistant_end|>": 260,
"<|bos|>": 261,
}
self._bos = 261
def encode_special(self, s):
return self._special_tokens[s]
def get_bos_token_id(self):
return self._bos
def encode(self, s, prepend=None):
tokens = list(s.encode("utf-8")) # bytes 0-255
if prepend is not None:
tokens = [prepend] + tokens
return tokens
def decode(self, tokens):
# Filter out special tokens before decoding
byte_tokens = [t for t in tokens if t < 256]
return bytes(byte_tokens).decode("utf-8", errors="replace")
def test_kv_cache_resize():
"""
@ -64,3 +142,46 @@ def test_kv_cache_resize():
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
def test_multi_sample_first_token_diversity():
"""
Test that when generating multiple samples, each sample gets an independently
sampled first token (not a broadcast of the same token to all rows).
Previously, the first token after prefill was sampled once and broadcast to all
rows, causing all samples to start identically. The fix expands the prefill logits
to num_samples and samples independently for each row.
With uniform logits over 262 tokens and 16 samples, the probability that all
samples independently pick the same token is (1/262)^15 10^-36. So if they're
all identical, it indicates tokens are being broadcast instead of independently sampled.
"""
model = MockModel(vocab_size=262)
tokenizer = ByteTokenizer()
engine = Engine(model, tokenizer)
# Generate 16 samples with temperature=1.0 (stochastic sampling)
prompt_tokens = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
num_samples = 16
# Collect the first generated token from each sample
first_tokens = []
gen = engine.generate(
prompt_tokens,
num_samples=num_samples,
max_tokens=1, # We only need the first token
temperature=1.0,
seed=42,
)
for token_column, token_masks in gen:
first_tokens = token_column # This is the first (and only) yield
# With uniform distribution and 16 samples, they should NOT all be identical
# If they are all identical, the bug exists (broadcasting instead of sampling)
unique_tokens = set(first_tokens)
assert len(unique_tokens) > 1, (
f"All {num_samples} samples got the same first token ({first_tokens[0]}). "
f"With uniform logits, this is statistically impossible (~10^-36 probability) "
f"unless tokens are being broadcast instead of independently sampled."
)