feat: batched generation from multiple prompts

This commit is contained in:
Matěj Kripner 2026-01-02 18:10:21 +01:00
parent 48abd7d85f
commit b7df9f8eaa
3 changed files with 347 additions and 32 deletions

View File

@ -198,12 +198,32 @@ class Engine:
@torch.inference_mode()
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
"""Same as generate, but does single prefill and then clones the KV cache."""
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
"""
Generate tokens from prompt(s). Accepts either list[int] (single prompt) or
list[list[int]] (batched prompts).
Yields:
(token_column, token_masks) tuples where both are nested list[list[int]] of
shape (num_prompts, num_samples) for batched input, or list[int] of shape
(num_samples,) for single prompt. Masks: 1=sampled, 0=forced.
"""
assert isinstance(tokens, list), "tokens must be a list"
# Normalize input: convert single prompt to list of prompts
is_batched = len(tokens) > 0 and isinstance(tokens[0], list)
if is_batched:
prompts = tokens
else:
assert isinstance(tokens[0], int), "expecting list of ints or list of lists of ints"
prompts = [tokens]
device = self.model.get_device()
rng = torch.Generator(device=device)
rng.manual_seed(seed)
num_prompts = len(prompts)
total_rows = num_prompts * num_samples
# Get the special tokens we need to coordinate the tool use state machine
get_special = lambda s: self.tokenizer.encode_special(s)
python_start = get_special("<|python_start|>")
@ -213,33 +233,64 @@ class Engine:
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
# 1) Run a batch 1 prefill of the prompt tokens
# 1) Left-pad all prompts to max length and create attention mask
prompt_lengths = [len(p) for p in prompts]
max_prompt_len = max(prompt_lengths)
padded_prompts = [[0] * (max_prompt_len - len(p)) + p for p in prompts]
# Create attention masks if padding is needed
decode_mask = None
prefill_attn_mask = None
if any(length != max_prompt_len for length in prompt_lengths):
# prompt_mask[b, t] = True if position t is a real token (not padding) for prompt b
prompt_mask = torch.zeros((num_prompts, max_prompt_len), dtype=torch.bool, device=device)
for i, length in enumerate(prompt_lengths):
prompt_mask[i, max_prompt_len - length:] = True
# causal_mask[q, k] = True if query at position q can attend to key at position k
causal_mask = torch.tril(torch.ones((max_prompt_len, max_prompt_len), dtype=torch.bool, device=device))
# prefill_attn_mask combines prompt_mask and causal_mask: attend only to non-padding keys before the query position
# shape: (num_prompts, 1, max_prompt_len, max_prompt_len) - the 1 broadcasts across heads
prefill_attn_mask = (causal_mask.unsqueeze(0) & prompt_mask.unsqueeze(1)).unsqueeze(1)
# decode_mask tracks which positions are valid for each row during generation (will be updated after each step)
decode_mask = prompt_mask.repeat_interleave(num_samples, dim=0)
# 2) Run batched prefill
m = self.model.config
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
kv_cache_prefill = KVCache(
batch_size=1,
seq_len=len(tokens),
batch_size=num_prompts,
seq_len=max_prompt_len,
**kv_model_kwargs,
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size)
# 2) Replicate the KV cache for each sample/row
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
ids = torch.tensor(padded_prompts, dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill, attention_mask=prefill_attn_mask)
logits = logits[:, -1, :] # (num_prompts, vocab_size)
# 3) Expand KV cache for num_samples per prompt
kv_length_hint = (max_prompt_len + max_tokens) if max_tokens is not None else self.model.config.sequence_len
kv_cache_decode = KVCache(
batch_size=num_samples,
batch_size=total_rows,
seq_len=kv_length_hint,
**kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)
del kv_cache_prefill # no need to keep this memory around
# Initialize the decode cache from prefill cache, replicating for each sample
dtype, dev = kv_cache_prefill.kv_cache.dtype, kv_cache_prefill.kv_cache.device
kv_cache_decode.kv_cache = torch.empty(kv_cache_decode.kv_shape, dtype=dtype, device=dev)
for i in range(num_prompts):
src = kv_cache_prefill.kv_cache[:, :, i:i + 1, :, :max_prompt_len, :]
for j in range(num_samples):
kv_cache_decode.kv_cache[:, :, i * num_samples + j:i * num_samples + j + 1, :, :max_prompt_len, :] = src
kv_cache_decode.pos = max_prompt_len
del kv_cache_prefill # no need to keep this memory around
# 3) Initialize states for each sample
row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
# Expand logits for num_samples per prompt
logits = logits.repeat_interleave(num_samples, dim=0) # (total_rows, vocab_size)
# 4) Main generation loop
# 4) Initialize row states and run generation loop
row_states = [RowState(prompt.copy()) for prompt in prompts for _ in range(num_samples)]
num_generated = 0
while True:
# Stop condition: we've reached max tokens
if max_tokens is not None and num_generated >= max_tokens:
@ -284,26 +335,60 @@ class Engine:
elif state.in_python_block:
state.python_expr_tokens.append(next_token)
# Yield the token column
yield token_column, token_masks
if is_batched:
# Yield shape (num_prompts, num_samples)
yield ([token_column[i * num_samples:(i + 1) * num_samples] for i in range(num_prompts)],
[token_masks[i * num_samples:(i + 1) * num_samples] for i in range(num_prompts)])
else:
# Yield shape (num_samples,)
yield token_column, token_masks
num_generated += 1
# Prepare logits for next iteration
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size)
if decode_mask is not None:
# Extend decode_mask with True for the new tokens
decode_mask = torch.cat(
[decode_mask, torch.ones((total_rows, 1), dtype=torch.bool, device=device)], dim=1
)
logits = self.model.forward(
ids,
kv_cache=kv_cache_decode,
attention_mask=decode_mask.unsqueeze(1).unsqueeze(1), # (B, 1, 1, T)
)
else:
logits = self.model.forward(ids, kv_cache=kv_cache_decode)
logits = logits[:, -1, :]
def generate_batch(self, tokens, num_samples=1, **kwargs):
"""
Non-streaming batch generation that just returns the final token sequences.
Returns a list of token sequences (list of lists of ints).
Non-streaming batch generation that returns the final token sequences.
Terminal tokens (assistant_end, bos) are not included in the results.
Returns:
(results, masks): For batched input, both are list[list[list[int]]] of shape
(num_prompts, num_samples, seq_len). For single prompt, both are
list[list[int]] of shape (num_samples, seq_len). Masks: 1=sampled, 0=forced.
"""
assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
bos = self.tokenizer.get_bos_token_id()
results = [tokens.copy() for _ in range(num_samples)]
masks = [[0] * len(tokens) for _ in range(num_samples)]
completed = [False] * num_samples
# Normalize input to list of prompts
is_batched = len(tokens) > 0 and isinstance(tokens[0], list)
prompts = tokens if is_batched else [tokens]
# Work with flat structure internally (prompt0_sample0, prompt0_sample1, ..., prompt1_sample0, ...)
results = [p.copy() for p in prompts for _ in range(num_samples)]
masks = [[0] * len(p) for p in prompts for _ in range(num_samples)]
completed = [False] * len(results)
for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
# Flatten nested output from generate() if batched
if is_batched:
token_column = [t for row in token_column for t in row]
token_masks = [m for row in token_masks for m in row]
for i, (token, mask) in enumerate(zip(token_column, token_masks)):
if not completed[i]:
if token == assistant_end or token == bos:
@ -314,6 +399,11 @@ class Engine:
# Stop if all rows are completed
if all(completed):
break
# Reshape to nested structure for batched output
if is_batched:
results = [results[i * num_samples:(i + 1) * num_samples] for i in range(len(prompts))]
masks = [masks[i * num_samples:(i + 1) * num_samples] for i in range(len(prompts))]
return results, masks

View File

@ -61,7 +61,7 @@ class CausalSelfAttention(nn.Module):
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(self, x, cos_sin, kv_cache):
def forward(self, x, cos_sin, kv_cache, attention_mask=None):
B, T, C = x.size()
# Project the input to get queries, keys, and values
@ -83,7 +83,10 @@ class CausalSelfAttention(nn.Module):
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
if kv_cache is None or Tq == Tk:
if attention_mask is not None:
# Custom attention mask provided (for batched generation with padding)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, enable_gqa=enable_gqa)
elif kv_cache is None or Tq == Tk:
# During training (no KV cache), attend as usual with causal attention
# And even if there is KV cache, we can still use this simple version when Tq == Tk
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
@ -126,8 +129,8 @@ class Block(nn.Module):
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(self, x, cos_sin, kv_cache):
x = x + self.attn(norm(x), cos_sin, kv_cache)
def forward(self, x, cos_sin, kv_cache, attention_mask=None):
x = x + self.attn(norm(x), cos_sin, kv_cache, attention_mask)
x = x + self.mlp(norm(x))
return x
@ -253,7 +256,7 @@ class GPT(nn.Module):
group["initial_lr"] = group["lr"]
return optimizers
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
def forward(self, idx, targets=None, kv_cache=None, attention_mask=None, loss_reduction='mean'):
B, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
@ -268,7 +271,7 @@ class GPT(nn.Module):
x = self.transformer.wte(idx)
x = norm(x)
for block in self.transformer.h:
x = block(x, cos_sin, kv_cache)
x = block(x, cos_sin, kv_cache, attention_mask)
x = norm(x)
# Forward the lm_head (compute logits)

View File

@ -4,10 +4,27 @@ Test Engine class. Example run:
python -m pytest tests/test_engine.py -v
"""
import os
import json
import torch
from nanochat.engine import KVCache, Engine
import pytest
from dataclasses import dataclass
from huggingface_hub import snapshot_download
from nanochat.engine import KVCache, Engine
from nanochat.gpt import GPT, GPTConfig
from nanochat.tokenizer import RustBPETokenizer
from nanochat.checkpoint_manager import find_last_step
# -----------------------------------------------------------------------------
# Ensure deterministic behavior for reproducible tests
# See: https://docs.pytorch.org/docs/stable/notes/randomness.html
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Required for CUDA >= 10.2 determinism
torch.manual_seed(0)
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
# -----------------------------------------------------------------------------
# Mock classes for testing Engine without loading a real model
@ -36,7 +53,7 @@ class MockModel:
def get_device(self):
return self._device
def forward(self, ids, kv_cache=None):
def forward(self, ids, kv_cache=None, attention_mask=None):
"""Return uniform logits so sampling is spread across vocab."""
B, T = ids.shape
# Simulate what a real transformer does: insert k,v into the cache for each layer
@ -85,6 +102,80 @@ class ByteTokenizer:
byte_tokens = [t for t in tokens if t < 256]
return bytes(byte_tokens).decode("utf-8", errors="replace")
def get_model_and_tokenizer(use_pretrained=False):
"""
Get a model and tokenizer for testing. Requires CUDA.
Args:
use_pretrained: If True, download and load the pretrained nanochat-d34 model.
If False, create a small randomly initialized model.
Returns:
(model, tokenizer) tuple
"""
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for these tests")
device = torch.device("cuda")
if use_pretrained:
# Download the checkpoint
cache_dir = snapshot_download(repo_id="karpathy/nanochat-d34")
# Find the last step
step = find_last_step(cache_dir)
# Load model data
model_path = os.path.join(cache_dir, f"model_{step:06d}.pt")
model_data = torch.load(model_path, map_location=device)
# Fix torch compile key prefix
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
# Convert all tensors to bfloat16 for consistent dtypes (checkpoint has mixed bfloat16/float32)
model_data = {
k: v.bfloat16() if v.is_floating_point() else v
for k, v in model_data.items()
}
# Load metadata
meta_path = os.path.join(cache_dir, f"meta_{step:06d}.json")
with open(meta_path, "r", encoding="utf-8") as f:
meta_data = json.load(f)
# Build model
model_config = GPTConfig(**meta_data["model_config"])
with torch.device("meta"):
model = GPT(model_config)
model.to_empty(device=device)
model.init_weights()
model.load_state_dict(model_data, strict=True, assign=True)
model.eval()
# Load tokenizer from the checkpoint directory
tokenizer = RustBPETokenizer.from_directory(cache_dir)
else:
# Small model for fast testing
config = GPTConfig(
sequence_len=256,
vocab_size=262, # 256 bytes + 6 special tokens
n_layer=2,
n_head=4,
n_kv_head=4,
n_embd=64,
)
model = GPT(config)
model.init_weights()
model = model.to(device)
model.eval()
tokenizer = ByteTokenizer()
return model, tokenizer
# -----------------------------------------------------------------------------
# KVCache tests
def test_kv_cache_resize():
"""
The KV cache was not resized correctly, more information here:
@ -185,3 +276,134 @@ def test_multi_sample_first_token_diversity():
f"With uniform logits, this is statistically impossible (~10^-36 probability) "
f"unless tokens are being broadcast instead of independently sampled."
)
# -----------------------------------------------------------------------------
# Batched generation tests
@pytest.mark.parametrize("use_pretrained", [False, True])
def test_batched_generation_consistency(use_pretrained):
"""
Test that batched generation produces the same results as individual generation.
This test:
1. Generates from each prompt individually (existing single-prompt behavior)
2. Generates from all prompts together in a batch (new batched behavior)
3. Asserts that the results match exactly
Uses temperature=0.0 for deterministic outputs.
"""
try:
model, tokenizer = get_model_and_tokenizer(use_pretrained=use_pretrained)
except Exception as e:
if use_pretrained:
pytest.skip(f"Could not load pretrained model: {e}")
raise
engine = Engine(model, tokenizer)
# Define test prompts with different lengths
bos = tokenizer.get_bos_token_id()
prompts = [
tokenizer.encode("hi", prepend=bos),
tokenizer.encode("the capital of France is", prepend=bos),
tokenizer.encode("hello, I'm a", prepend=bos),
]
num_samples = 2
# Deterministic decoding
generation_kwargs = dict(max_tokens=10, temperature=0.0, seed=0)
# 1) Generate individually for each prompt
individual_results = []
individual_masks = []
for prompt in prompts:
results, masks = engine.generate_batch(prompt, num_samples=num_samples, **generation_kwargs)
individual_results.append(results) # results is list[list[int]] of shape (num_samples, seq_len)
individual_masks.append(masks) # masks is list[list[int]] of shape (num_samples, seq_len)
# 2) Generate batched (all prompts together)
batched_results, batched_masks = engine.generate_batch(prompts, num_samples=num_samples, **generation_kwargs)
# 3) Assert results match
assert len(individual_results) == len(batched_results), \
f"Prompt count mismatch: {len(individual_results)} vs {len(batched_results)}"
for prompt_idx, (ind_samples, batch_samples, ind_masks, batch_masks) in enumerate(
zip(individual_results, batched_results, individual_masks, batched_masks)):
assert len(ind_samples) == len(batch_samples), f"Sample count mismatch for prompt {prompt_idx}"
for sample_idx, (ind_result, batch_result, ind_mask, batch_mask) in enumerate(
zip(ind_samples, batch_samples, ind_masks, batch_masks)):
assert ind_result == batch_result, (
f"Mismatch for prompt {prompt_idx}, sample {sample_idx}:\n"
f" Individual: {ind_result}\n"
f" Batched: {batch_result}"
)
assert ind_mask == batch_mask, (
f"Mask mismatch for prompt {prompt_idx}, sample {sample_idx}:\n"
f" Individual: {ind_mask}\n"
f" Batched: {batch_mask}"
)
def test_batched_generation_single_prompt():
"""
Test that batched generation with a single prompt in the batch
produces the same result as non-batched single prompt generation.
"""
model, tokenizer = get_model_and_tokenizer(use_pretrained=False)
engine = Engine(model, tokenizer)
bos = tokenizer.get_bos_token_id()
prompt = tokenizer.encode("the capital of France is", prepend=bos)
num_samples = 3
generation_kwargs = dict(max_tokens=8, temperature=0.0, seed=0)
# Generate non-batched: returns shape (num_samples, seq_len)
single_results, single_masks = engine.generate_batch(prompt, num_samples=num_samples, **generation_kwargs)
# Generate batched with single prompt: returns shape (1, num_samples, seq_len)
batched_results, batched_masks = engine.generate_batch([prompt], num_samples=num_samples, **generation_kwargs)
assert single_results == batched_results[0], (
f"Single vs batched single-prompt mismatch:\n"
f" Single: {single_results}\n"
f" Batched: {batched_results[0]}"
)
assert single_masks == batched_masks[0], (
f"Single vs batched single-prompt mask mismatch:\n"
f" Single: {single_masks}\n"
f" Batched: {batched_masks[0]}"
)
def test_batched_generation_stochastic():
"""
Test that batched generation with temperature > 0 produces diverse outputs.
"""
model, tokenizer = get_model_and_tokenizer(use_pretrained=False)
engine = Engine(model, tokenizer)
bos = tokenizer.get_bos_token_id()
prompts = [
tokenizer.encode("hi", prepend=bos),
tokenizer.encode("the capital of France is", prepend=bos),
]
num_samples = 4
generation_kwargs = dict(max_tokens=64, temperature=1.0, seed=0)
# Generate batched: returns shape (num_prompts, num_samples, seq_len)
results, _ = engine.generate_batch(prompts, num_samples=num_samples, **generation_kwargs)
# Check structure
assert len(results) == len(prompts)
# Check that samples within each prompt are diverse (not all identical)
for prompt_idx, samples in enumerate(results):
assert len(samples) == num_samples
unique_samples = set(tuple(s) for s in samples)
assert len(unique_samples) > 1, (
f"All {num_samples} samples for prompt {prompt_idx} are identical. "
f"With temperature=1.0, samples should differ."
)