mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-20 18:34:14 +00:00
fix: sample first token independently for each row in multi-sample generation
Previously, when generating multiple samples (num_samples > 1), the first token after prefill was sampled once and broadcast to all rows, causing all samples to start identically. Now the prefill logits are expanded to num_samples and sampled independently for each row. Also simplified the generation loop by moving the forward pass to the end of the loop, eliminating the first_iteration flag and if/else branching. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
2f2d7ab80c
commit
8f979a8bda
|
|
@ -19,7 +19,7 @@ from contextlib import contextmanager
|
|||
from collections import deque
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from contextlib import nullcontext
|
||||
from contextlib import nullcontext
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Calculator tool helpers
|
||||
|
|
@ -107,23 +107,23 @@ class KVCache:
|
|||
# 1) validate the shapes
|
||||
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
|
||||
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
|
||||
|
||||
|
||||
# Extract dimensions explicitly
|
||||
self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape
|
||||
other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape
|
||||
|
||||
|
||||
# Validate dimensions
|
||||
assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}"
|
||||
assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}"
|
||||
assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}"
|
||||
assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}"
|
||||
|
||||
|
||||
# Batch size can be expanded (other can be 1, self can be larger)
|
||||
assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)"
|
||||
|
||||
|
||||
# Sequence length: self must be longer than other
|
||||
assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}"
|
||||
|
||||
|
||||
# 2) initialize the cache
|
||||
dtype, device = other.kv_cache.dtype, other.kv_cache.device
|
||||
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
|
||||
|
|
@ -223,9 +223,7 @@ class Engine:
|
|||
)
|
||||
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
|
||||
logits = logits[:, -1, :]
|
||||
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
||||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size)
|
||||
|
||||
# 2) Replicate the KV cache for each sample/row
|
||||
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
|
||||
|
|
@ -242,7 +240,6 @@ class Engine:
|
|||
|
||||
# 4) Main generation loop
|
||||
num_generated = 0
|
||||
first_iteration = True
|
||||
while True:
|
||||
# Stop condition: we've reached max tokens
|
||||
if max_tokens is not None and num_generated >= max_tokens:
|
||||
|
|
@ -251,18 +248,9 @@ class Engine:
|
|||
if all(state.completed for state in row_states):
|
||||
break
|
||||
|
||||
# Get sampled tokens - either from prefill or from forward pass
|
||||
if first_iteration:
|
||||
# Use the tokens we already sampled from prefill
|
||||
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
|
||||
# TODO: we should sample a token for each row instead of broadcasting
|
||||
first_iteration = False
|
||||
else:
|
||||
# Forward the model and get the next token for each row
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
|
||||
logits = logits[:, -1, :] # (B, vocab_size) at last time step
|
||||
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
||||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
# Sample the next token for each row
|
||||
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
||||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
|
||||
# Process each row: choose the next token, update state, optional tool use
|
||||
token_column = [] # contains the next token id along each row
|
||||
|
|
@ -299,8 +287,10 @@ class Engine:
|
|||
# Yield the token column
|
||||
yield token_column, token_masks
|
||||
num_generated += 1
|
||||
# Prepare ids for next iteration
|
||||
|
||||
# Prepare logits for next iteration
|
||||
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size)
|
||||
|
||||
def generate_batch(self, tokens, num_samples=1, **kwargs):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -5,7 +5,85 @@ python -m pytest tests/test_engine.py -v
|
|||
"""
|
||||
|
||||
import torch
|
||||
from nanochat.engine import KVCache
|
||||
from nanochat.engine import KVCache, Engine
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Mock classes for testing Engine without loading a real model
|
||||
|
||||
@dataclass
|
||||
class MockConfig:
|
||||
"""Minimal config for Engine tests."""
|
||||
n_kv_head: int = 4
|
||||
n_head: int = 4
|
||||
n_embd: int = 64
|
||||
n_layer: int = 2
|
||||
sequence_len: int = 128
|
||||
|
||||
|
||||
class MockModel:
|
||||
"""
|
||||
Mock model that returns uniform logits over the vocab.
|
||||
This ensures that with temperature > 0, different samples should
|
||||
(with very high probability) produce different tokens.
|
||||
"""
|
||||
def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens
|
||||
self.vocab_size = vocab_size
|
||||
self.config = MockConfig()
|
||||
self._device = "cpu"
|
||||
|
||||
def get_device(self):
|
||||
return self._device
|
||||
|
||||
def forward(self, ids, kv_cache=None):
|
||||
"""Return uniform logits so sampling is spread across vocab."""
|
||||
B, T = ids.shape
|
||||
# Simulate what a real transformer does: insert k,v into the cache for each layer
|
||||
if kv_cache is not None:
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
for layer_idx in range(self.config.n_layer):
|
||||
k = torch.zeros(B, self.config.n_kv_head, T, head_dim)
|
||||
v = torch.zeros(B, self.config.n_kv_head, T, head_dim)
|
||||
kv_cache.insert_kv(layer_idx, k, v)
|
||||
# Uniform logits -> equal probability for all tokens
|
||||
logits = torch.zeros(B, T, self.vocab_size)
|
||||
return logits
|
||||
|
||||
|
||||
class ByteTokenizer:
|
||||
"""
|
||||
Simple byte-level tokenizer for testing.
|
||||
Tokens 0-255 are raw bytes, 256+ are special tokens.
|
||||
"""
|
||||
def __init__(self):
|
||||
# Special tokens start at 256
|
||||
self._special_tokens = {
|
||||
"<|python_start|>": 256,
|
||||
"<|python_end|>": 257,
|
||||
"<|output_start|>": 258,
|
||||
"<|output_end|>": 259,
|
||||
"<|assistant_end|>": 260,
|
||||
"<|bos|>": 261,
|
||||
}
|
||||
self._bos = 261
|
||||
|
||||
def encode_special(self, s):
|
||||
return self._special_tokens[s]
|
||||
|
||||
def get_bos_token_id(self):
|
||||
return self._bos
|
||||
|
||||
def encode(self, s, prepend=None):
|
||||
tokens = list(s.encode("utf-8")) # bytes 0-255
|
||||
if prepend is not None:
|
||||
tokens = [prepend] + tokens
|
||||
return tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
# Filter out special tokens before decoding
|
||||
byte_tokens = [t for t in tokens if t < 256]
|
||||
return bytes(byte_tokens).decode("utf-8", errors="replace")
|
||||
|
||||
def test_kv_cache_resize():
|
||||
"""
|
||||
|
|
@ -64,3 +142,46 @@ def test_kv_cache_resize():
|
|||
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
|
||||
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
|
||||
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
|
||||
|
||||
|
||||
def test_multi_sample_first_token_diversity():
|
||||
"""
|
||||
Test that when generating multiple samples, each sample gets an independently
|
||||
sampled first token (not a broadcast of the same token to all rows).
|
||||
|
||||
Previously, the first token after prefill was sampled once and broadcast to all
|
||||
rows, causing all samples to start identically. The fix expands the prefill logits
|
||||
to num_samples and samples independently for each row.
|
||||
|
||||
With uniform logits over 262 tokens and 16 samples, the probability that all
|
||||
samples independently pick the same token is (1/262)^15 ≈ 10^-36. So if they're
|
||||
all identical, it indicates tokens are being broadcast instead of independently sampled.
|
||||
"""
|
||||
model = MockModel(vocab_size=262)
|
||||
tokenizer = ByteTokenizer()
|
||||
engine = Engine(model, tokenizer)
|
||||
|
||||
# Generate 16 samples with temperature=1.0 (stochastic sampling)
|
||||
prompt_tokens = [261, 72, 101, 108, 108, 111] # <bos> + "Hello"
|
||||
num_samples = 16
|
||||
|
||||
# Collect the first generated token from each sample
|
||||
first_tokens = []
|
||||
gen = engine.generate(
|
||||
prompt_tokens,
|
||||
num_samples=num_samples,
|
||||
max_tokens=1, # We only need the first token
|
||||
temperature=1.0,
|
||||
seed=42,
|
||||
)
|
||||
for token_column, token_masks in gen:
|
||||
first_tokens = token_column # This is the first (and only) yield
|
||||
|
||||
# With uniform distribution and 16 samples, they should NOT all be identical
|
||||
# If they are all identical, the bug exists (broadcasting instead of sampling)
|
||||
unique_tokens = set(first_tokens)
|
||||
assert len(unique_tokens) > 1, (
|
||||
f"All {num_samples} samples got the same first token ({first_tokens[0]}). "
|
||||
f"With uniform logits, this is statistically impossible (~10^-36 probability) "
|
||||
f"unless tokens are being broadcast instead of independently sampled."
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user