mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-06 02:29:53 +00:00
feat: batched generation from multiple prompts
This commit is contained in:
parent
48abd7d85f
commit
b7df9f8eaa
|
|
@ -198,12 +198,32 @@ class Engine:
|
|||
|
||||
@torch.inference_mode()
|
||||
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
|
||||
"""Same as generate, but does single prefill and then clones the KV cache."""
|
||||
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
|
||||
"""
|
||||
Generate tokens from prompt(s). Accepts either list[int] (single prompt) or
|
||||
list[list[int]] (batched prompts).
|
||||
|
||||
Yields:
|
||||
(token_column, token_masks) tuples where both are nested list[list[int]] of
|
||||
shape (num_prompts, num_samples) for batched input, or list[int] of shape
|
||||
(num_samples,) for single prompt. Masks: 1=sampled, 0=forced.
|
||||
"""
|
||||
assert isinstance(tokens, list), "tokens must be a list"
|
||||
|
||||
# Normalize input: convert single prompt to list of prompts
|
||||
is_batched = len(tokens) > 0 and isinstance(tokens[0], list)
|
||||
if is_batched:
|
||||
prompts = tokens
|
||||
else:
|
||||
assert isinstance(tokens[0], int), "expecting list of ints or list of lists of ints"
|
||||
prompts = [tokens]
|
||||
|
||||
device = self.model.get_device()
|
||||
rng = torch.Generator(device=device)
|
||||
rng.manual_seed(seed)
|
||||
|
||||
num_prompts = len(prompts)
|
||||
total_rows = num_prompts * num_samples
|
||||
|
||||
# Get the special tokens we need to coordinate the tool use state machine
|
||||
get_special = lambda s: self.tokenizer.encode_special(s)
|
||||
python_start = get_special("<|python_start|>")
|
||||
|
|
@ -213,33 +233,64 @@ class Engine:
|
|||
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
|
||||
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
|
||||
|
||||
# 1) Run a batch 1 prefill of the prompt tokens
|
||||
# 1) Left-pad all prompts to max length and create attention mask
|
||||
prompt_lengths = [len(p) for p in prompts]
|
||||
max_prompt_len = max(prompt_lengths)
|
||||
padded_prompts = [[0] * (max_prompt_len - len(p)) + p for p in prompts]
|
||||
|
||||
# Create attention masks if padding is needed
|
||||
decode_mask = None
|
||||
prefill_attn_mask = None
|
||||
if any(length != max_prompt_len for length in prompt_lengths):
|
||||
# prompt_mask[b, t] = True if position t is a real token (not padding) for prompt b
|
||||
prompt_mask = torch.zeros((num_prompts, max_prompt_len), dtype=torch.bool, device=device)
|
||||
for i, length in enumerate(prompt_lengths):
|
||||
prompt_mask[i, max_prompt_len - length:] = True
|
||||
# causal_mask[q, k] = True if query at position q can attend to key at position k
|
||||
causal_mask = torch.tril(torch.ones((max_prompt_len, max_prompt_len), dtype=torch.bool, device=device))
|
||||
# prefill_attn_mask combines prompt_mask and causal_mask: attend only to non-padding keys before the query position
|
||||
# shape: (num_prompts, 1, max_prompt_len, max_prompt_len) - the 1 broadcasts across heads
|
||||
prefill_attn_mask = (causal_mask.unsqueeze(0) & prompt_mask.unsqueeze(1)).unsqueeze(1)
|
||||
# decode_mask tracks which positions are valid for each row during generation (will be updated after each step)
|
||||
decode_mask = prompt_mask.repeat_interleave(num_samples, dim=0)
|
||||
|
||||
# 2) Run batched prefill
|
||||
m = self.model.config
|
||||
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
|
||||
kv_cache_prefill = KVCache(
|
||||
batch_size=1,
|
||||
seq_len=len(tokens),
|
||||
batch_size=num_prompts,
|
||||
seq_len=max_prompt_len,
|
||||
**kv_model_kwargs,
|
||||
)
|
||||
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
|
||||
logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size)
|
||||
|
||||
# 2) Replicate the KV cache for each sample/row
|
||||
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
|
||||
ids = torch.tensor(padded_prompts, dtype=torch.long, device=device)
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_prefill, attention_mask=prefill_attn_mask)
|
||||
logits = logits[:, -1, :] # (num_prompts, vocab_size)
|
||||
|
||||
# 3) Expand KV cache for num_samples per prompt
|
||||
kv_length_hint = (max_prompt_len + max_tokens) if max_tokens is not None else self.model.config.sequence_len
|
||||
kv_cache_decode = KVCache(
|
||||
batch_size=num_samples,
|
||||
batch_size=total_rows,
|
||||
seq_len=kv_length_hint,
|
||||
**kv_model_kwargs,
|
||||
)
|
||||
kv_cache_decode.prefill(kv_cache_prefill)
|
||||
del kv_cache_prefill # no need to keep this memory around
|
||||
# Initialize the decode cache from prefill cache, replicating for each sample
|
||||
dtype, dev = kv_cache_prefill.kv_cache.dtype, kv_cache_prefill.kv_cache.device
|
||||
kv_cache_decode.kv_cache = torch.empty(kv_cache_decode.kv_shape, dtype=dtype, device=dev)
|
||||
for i in range(num_prompts):
|
||||
src = kv_cache_prefill.kv_cache[:, :, i:i + 1, :, :max_prompt_len, :]
|
||||
for j in range(num_samples):
|
||||
kv_cache_decode.kv_cache[:, :, i * num_samples + j:i * num_samples + j + 1, :, :max_prompt_len, :] = src
|
||||
kv_cache_decode.pos = max_prompt_len
|
||||
del kv_cache_prefill # no need to keep this memory around
|
||||
|
||||
# 3) Initialize states for each sample
|
||||
row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
|
||||
# Expand logits for num_samples per prompt
|
||||
logits = logits.repeat_interleave(num_samples, dim=0) # (total_rows, vocab_size)
|
||||
|
||||
# 4) Main generation loop
|
||||
# 4) Initialize row states and run generation loop
|
||||
row_states = [RowState(prompt.copy()) for prompt in prompts for _ in range(num_samples)]
|
||||
num_generated = 0
|
||||
|
||||
while True:
|
||||
# Stop condition: we've reached max tokens
|
||||
if max_tokens is not None and num_generated >= max_tokens:
|
||||
|
|
@ -284,26 +335,60 @@ class Engine:
|
|||
elif state.in_python_block:
|
||||
state.python_expr_tokens.append(next_token)
|
||||
|
||||
# Yield the token column
|
||||
yield token_column, token_masks
|
||||
if is_batched:
|
||||
# Yield shape (num_prompts, num_samples)
|
||||
yield ([token_column[i * num_samples:(i + 1) * num_samples] for i in range(num_prompts)],
|
||||
[token_masks[i * num_samples:(i + 1) * num_samples] for i in range(num_prompts)])
|
||||
else:
|
||||
# Yield shape (num_samples,)
|
||||
yield token_column, token_masks
|
||||
num_generated += 1
|
||||
|
||||
# Prepare logits for next iteration
|
||||
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size)
|
||||
|
||||
if decode_mask is not None:
|
||||
# Extend decode_mask with True for the new tokens
|
||||
decode_mask = torch.cat(
|
||||
[decode_mask, torch.ones((total_rows, 1), dtype=torch.bool, device=device)], dim=1
|
||||
)
|
||||
logits = self.model.forward(
|
||||
ids,
|
||||
kv_cache=kv_cache_decode,
|
||||
attention_mask=decode_mask.unsqueeze(1).unsqueeze(1), # (B, 1, 1, T)
|
||||
)
|
||||
else:
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_decode)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
def generate_batch(self, tokens, num_samples=1, **kwargs):
|
||||
"""
|
||||
Non-streaming batch generation that just returns the final token sequences.
|
||||
Returns a list of token sequences (list of lists of ints).
|
||||
Non-streaming batch generation that returns the final token sequences.
|
||||
Terminal tokens (assistant_end, bos) are not included in the results.
|
||||
|
||||
Returns:
|
||||
(results, masks): For batched input, both are list[list[list[int]]] of shape
|
||||
(num_prompts, num_samples, seq_len). For single prompt, both are
|
||||
list[list[int]] of shape (num_samples, seq_len). Masks: 1=sampled, 0=forced.
|
||||
"""
|
||||
assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
|
||||
bos = self.tokenizer.get_bos_token_id()
|
||||
results = [tokens.copy() for _ in range(num_samples)]
|
||||
masks = [[0] * len(tokens) for _ in range(num_samples)]
|
||||
completed = [False] * num_samples
|
||||
|
||||
# Normalize input to list of prompts
|
||||
is_batched = len(tokens) > 0 and isinstance(tokens[0], list)
|
||||
prompts = tokens if is_batched else [tokens]
|
||||
|
||||
# Work with flat structure internally (prompt0_sample0, prompt0_sample1, ..., prompt1_sample0, ...)
|
||||
results = [p.copy() for p in prompts for _ in range(num_samples)]
|
||||
masks = [[0] * len(p) for p in prompts for _ in range(num_samples)]
|
||||
completed = [False] * len(results)
|
||||
|
||||
for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
|
||||
# Flatten nested output from generate() if batched
|
||||
if is_batched:
|
||||
token_column = [t for row in token_column for t in row]
|
||||
token_masks = [m for row in token_masks for m in row]
|
||||
|
||||
for i, (token, mask) in enumerate(zip(token_column, token_masks)):
|
||||
if not completed[i]:
|
||||
if token == assistant_end or token == bos:
|
||||
|
|
@ -314,6 +399,11 @@ class Engine:
|
|||
# Stop if all rows are completed
|
||||
if all(completed):
|
||||
break
|
||||
|
||||
# Reshape to nested structure for batched output
|
||||
if is_batched:
|
||||
results = [results[i * num_samples:(i + 1) * num_samples] for i in range(len(prompts))]
|
||||
masks = [masks[i * num_samples:(i + 1) * num_samples] for i in range(len(prompts))]
|
||||
return results, masks
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class CausalSelfAttention(nn.Module):
|
|||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||
|
||||
def forward(self, x, cos_sin, kv_cache):
|
||||
def forward(self, x, cos_sin, kv_cache, attention_mask=None):
|
||||
B, T, C = x.size()
|
||||
|
||||
# Project the input to get queries, keys, and values
|
||||
|
|
@ -83,7 +83,10 @@ class CausalSelfAttention(nn.Module):
|
|||
|
||||
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
|
||||
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
|
||||
if kv_cache is None or Tq == Tk:
|
||||
if attention_mask is not None:
|
||||
# Custom attention mask provided (for batched generation with padding)
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, enable_gqa=enable_gqa)
|
||||
elif kv_cache is None or Tq == Tk:
|
||||
# During training (no KV cache), attend as usual with causal attention
|
||||
# And even if there is KV cache, we can still use this simple version when Tq == Tk
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
||||
|
|
@ -126,8 +129,8 @@ class Block(nn.Module):
|
|||
self.attn = CausalSelfAttention(config, layer_idx)
|
||||
self.mlp = MLP(config)
|
||||
|
||||
def forward(self, x, cos_sin, kv_cache):
|
||||
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
||||
def forward(self, x, cos_sin, kv_cache, attention_mask=None):
|
||||
x = x + self.attn(norm(x), cos_sin, kv_cache, attention_mask)
|
||||
x = x + self.mlp(norm(x))
|
||||
return x
|
||||
|
||||
|
|
@ -253,7 +256,7 @@ class GPT(nn.Module):
|
|||
group["initial_lr"] = group["lr"]
|
||||
return optimizers
|
||||
|
||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||
def forward(self, idx, targets=None, kv_cache=None, attention_mask=None, loss_reduction='mean'):
|
||||
B, T = idx.size()
|
||||
|
||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||
|
|
@ -268,7 +271,7 @@ class GPT(nn.Module):
|
|||
x = self.transformer.wte(idx)
|
||||
x = norm(x)
|
||||
for block in self.transformer.h:
|
||||
x = block(x, cos_sin, kv_cache)
|
||||
x = block(x, cos_sin, kv_cache, attention_mask)
|
||||
x = norm(x)
|
||||
|
||||
# Forward the lm_head (compute logits)
|
||||
|
|
|
|||
|
|
@ -4,10 +4,27 @@ Test Engine class. Example run:
|
|||
python -m pytest tests/test_engine.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
from nanochat.engine import KVCache, Engine
|
||||
import pytest
|
||||
from dataclasses import dataclass
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from nanochat.engine import KVCache, Engine
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.tokenizer import RustBPETokenizer
|
||||
from nanochat.checkpoint_manager import find_last_step
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Ensure deterministic behavior for reproducible tests
|
||||
# See: https://docs.pytorch.org/docs/stable/notes/randomness.html
|
||||
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Required for CUDA >= 10.2 determinism
|
||||
torch.manual_seed(0)
|
||||
torch.use_deterministic_algorithms(True)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Mock classes for testing Engine without loading a real model
|
||||
|
|
@ -36,7 +53,7 @@ class MockModel:
|
|||
def get_device(self):
|
||||
return self._device
|
||||
|
||||
def forward(self, ids, kv_cache=None):
|
||||
def forward(self, ids, kv_cache=None, attention_mask=None):
|
||||
"""Return uniform logits so sampling is spread across vocab."""
|
||||
B, T = ids.shape
|
||||
# Simulate what a real transformer does: insert k,v into the cache for each layer
|
||||
|
|
@ -85,6 +102,80 @@ class ByteTokenizer:
|
|||
byte_tokens = [t for t in tokens if t < 256]
|
||||
return bytes(byte_tokens).decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
def get_model_and_tokenizer(use_pretrained=False):
|
||||
"""
|
||||
Get a model and tokenizer for testing. Requires CUDA.
|
||||
|
||||
Args:
|
||||
use_pretrained: If True, download and load the pretrained nanochat-d34 model.
|
||||
If False, create a small randomly initialized model.
|
||||
|
||||
Returns:
|
||||
(model, tokenizer) tuple
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA is required for these tests")
|
||||
device = torch.device("cuda")
|
||||
|
||||
if use_pretrained:
|
||||
# Download the checkpoint
|
||||
cache_dir = snapshot_download(repo_id="karpathy/nanochat-d34")
|
||||
|
||||
# Find the last step
|
||||
step = find_last_step(cache_dir)
|
||||
|
||||
# Load model data
|
||||
model_path = os.path.join(cache_dir, f"model_{step:06d}.pt")
|
||||
model_data = torch.load(model_path, map_location=device)
|
||||
|
||||
# Fix torch compile key prefix
|
||||
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
||||
|
||||
# Convert all tensors to bfloat16 for consistent dtypes (checkpoint has mixed bfloat16/float32)
|
||||
model_data = {
|
||||
k: v.bfloat16() if v.is_floating_point() else v
|
||||
for k, v in model_data.items()
|
||||
}
|
||||
|
||||
# Load metadata
|
||||
meta_path = os.path.join(cache_dir, f"meta_{step:06d}.json")
|
||||
with open(meta_path, "r", encoding="utf-8") as f:
|
||||
meta_data = json.load(f)
|
||||
|
||||
# Build model
|
||||
model_config = GPTConfig(**meta_data["model_config"])
|
||||
with torch.device("meta"):
|
||||
model = GPT(model_config)
|
||||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
model.load_state_dict(model_data, strict=True, assign=True)
|
||||
model.eval()
|
||||
|
||||
# Load tokenizer from the checkpoint directory
|
||||
tokenizer = RustBPETokenizer.from_directory(cache_dir)
|
||||
else:
|
||||
# Small model for fast testing
|
||||
config = GPTConfig(
|
||||
sequence_len=256,
|
||||
vocab_size=262, # 256 bytes + 6 special tokens
|
||||
n_layer=2,
|
||||
n_head=4,
|
||||
n_kv_head=4,
|
||||
n_embd=64,
|
||||
)
|
||||
model = GPT(config)
|
||||
model.init_weights()
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
tokenizer = ByteTokenizer()
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# KVCache tests
|
||||
|
||||
def test_kv_cache_resize():
|
||||
"""
|
||||
The KV cache was not resized correctly, more information here:
|
||||
|
|
@ -185,3 +276,134 @@ def test_multi_sample_first_token_diversity():
|
|||
f"With uniform logits, this is statistically impossible (~10^-36 probability) "
|
||||
f"unless tokens are being broadcast instead of independently sampled."
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Batched generation tests
|
||||
|
||||
@pytest.mark.parametrize("use_pretrained", [False, True])
|
||||
def test_batched_generation_consistency(use_pretrained):
|
||||
"""
|
||||
Test that batched generation produces the same results as individual generation.
|
||||
|
||||
This test:
|
||||
1. Generates from each prompt individually (existing single-prompt behavior)
|
||||
2. Generates from all prompts together in a batch (new batched behavior)
|
||||
3. Asserts that the results match exactly
|
||||
|
||||
Uses temperature=0.0 for deterministic outputs.
|
||||
"""
|
||||
try:
|
||||
model, tokenizer = get_model_and_tokenizer(use_pretrained=use_pretrained)
|
||||
except Exception as e:
|
||||
if use_pretrained:
|
||||
pytest.skip(f"Could not load pretrained model: {e}")
|
||||
raise
|
||||
|
||||
engine = Engine(model, tokenizer)
|
||||
|
||||
# Define test prompts with different lengths
|
||||
bos = tokenizer.get_bos_token_id()
|
||||
prompts = [
|
||||
tokenizer.encode("hi", prepend=bos),
|
||||
tokenizer.encode("the capital of France is", prepend=bos),
|
||||
tokenizer.encode("hello, I'm a", prepend=bos),
|
||||
]
|
||||
|
||||
num_samples = 2
|
||||
# Deterministic decoding
|
||||
generation_kwargs = dict(max_tokens=10, temperature=0.0, seed=0)
|
||||
|
||||
# 1) Generate individually for each prompt
|
||||
individual_results = []
|
||||
individual_masks = []
|
||||
for prompt in prompts:
|
||||
results, masks = engine.generate_batch(prompt, num_samples=num_samples, **generation_kwargs)
|
||||
individual_results.append(results) # results is list[list[int]] of shape (num_samples, seq_len)
|
||||
individual_masks.append(masks) # masks is list[list[int]] of shape (num_samples, seq_len)
|
||||
|
||||
# 2) Generate batched (all prompts together)
|
||||
batched_results, batched_masks = engine.generate_batch(prompts, num_samples=num_samples, **generation_kwargs)
|
||||
|
||||
# 3) Assert results match
|
||||
assert len(individual_results) == len(batched_results), \
|
||||
f"Prompt count mismatch: {len(individual_results)} vs {len(batched_results)}"
|
||||
|
||||
for prompt_idx, (ind_samples, batch_samples, ind_masks, batch_masks) in enumerate(
|
||||
zip(individual_results, batched_results, individual_masks, batched_masks)):
|
||||
assert len(ind_samples) == len(batch_samples), f"Sample count mismatch for prompt {prompt_idx}"
|
||||
for sample_idx, (ind_result, batch_result, ind_mask, batch_mask) in enumerate(
|
||||
zip(ind_samples, batch_samples, ind_masks, batch_masks)):
|
||||
assert ind_result == batch_result, (
|
||||
f"Mismatch for prompt {prompt_idx}, sample {sample_idx}:\n"
|
||||
f" Individual: {ind_result}\n"
|
||||
f" Batched: {batch_result}"
|
||||
)
|
||||
assert ind_mask == batch_mask, (
|
||||
f"Mask mismatch for prompt {prompt_idx}, sample {sample_idx}:\n"
|
||||
f" Individual: {ind_mask}\n"
|
||||
f" Batched: {batch_mask}"
|
||||
)
|
||||
|
||||
|
||||
def test_batched_generation_single_prompt():
|
||||
"""
|
||||
Test that batched generation with a single prompt in the batch
|
||||
produces the same result as non-batched single prompt generation.
|
||||
"""
|
||||
model, tokenizer = get_model_and_tokenizer(use_pretrained=False)
|
||||
engine = Engine(model, tokenizer)
|
||||
|
||||
bos = tokenizer.get_bos_token_id()
|
||||
prompt = tokenizer.encode("the capital of France is", prepend=bos)
|
||||
num_samples = 3
|
||||
generation_kwargs = dict(max_tokens=8, temperature=0.0, seed=0)
|
||||
|
||||
# Generate non-batched: returns shape (num_samples, seq_len)
|
||||
single_results, single_masks = engine.generate_batch(prompt, num_samples=num_samples, **generation_kwargs)
|
||||
|
||||
# Generate batched with single prompt: returns shape (1, num_samples, seq_len)
|
||||
batched_results, batched_masks = engine.generate_batch([prompt], num_samples=num_samples, **generation_kwargs)
|
||||
|
||||
assert single_results == batched_results[0], (
|
||||
f"Single vs batched single-prompt mismatch:\n"
|
||||
f" Single: {single_results}\n"
|
||||
f" Batched: {batched_results[0]}"
|
||||
)
|
||||
assert single_masks == batched_masks[0], (
|
||||
f"Single vs batched single-prompt mask mismatch:\n"
|
||||
f" Single: {single_masks}\n"
|
||||
f" Batched: {batched_masks[0]}"
|
||||
)
|
||||
|
||||
|
||||
def test_batched_generation_stochastic():
|
||||
"""
|
||||
Test that batched generation with temperature > 0 produces diverse outputs.
|
||||
"""
|
||||
model, tokenizer = get_model_and_tokenizer(use_pretrained=False)
|
||||
engine = Engine(model, tokenizer)
|
||||
|
||||
bos = tokenizer.get_bos_token_id()
|
||||
prompts = [
|
||||
tokenizer.encode("hi", prepend=bos),
|
||||
tokenizer.encode("the capital of France is", prepend=bos),
|
||||
]
|
||||
|
||||
num_samples = 4
|
||||
generation_kwargs = dict(max_tokens=64, temperature=1.0, seed=0)
|
||||
|
||||
# Generate batched: returns shape (num_prompts, num_samples, seq_len)
|
||||
results, _ = engine.generate_batch(prompts, num_samples=num_samples, **generation_kwargs)
|
||||
|
||||
# Check structure
|
||||
assert len(results) == len(prompts)
|
||||
|
||||
# Check that samples within each prompt are diverse (not all identical)
|
||||
for prompt_idx, samples in enumerate(results):
|
||||
assert len(samples) == num_samples
|
||||
unique_samples = set(tuple(s) for s in samples)
|
||||
assert len(unique_samples) > 1, (
|
||||
f"All {num_samples} samples for prompt {prompt_idx} are identical. "
|
||||
f"With temperature=1.0, samples should differ."
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user