diff --git a/nanochat/engine.py b/nanochat/engine.py index 49b10b1..cc207e8 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -198,12 +198,32 @@ class Engine: @torch.inference_mode() def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42): - """Same as generate, but does single prefill and then clones the KV cache.""" - assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints" + """ + Generate tokens from prompt(s). Accepts either list[int] (single prompt) or + list[list[int]] (batched prompts). + + Yields: + (token_column, token_masks) tuples where both are nested list[list[int]] of + shape (num_prompts, num_samples) for batched input, or list[int] of shape + (num_samples,) for single prompt. Masks: 1=sampled, 0=forced. + """ + assert isinstance(tokens, list), "tokens must be a list" + + # Normalize input: convert single prompt to list of prompts + is_batched = len(tokens) > 0 and isinstance(tokens[0], list) + if is_batched: + prompts = tokens + else: + assert isinstance(tokens[0], int), "expecting list of ints or list of lists of ints" + prompts = [tokens] + device = self.model.get_device() rng = torch.Generator(device=device) rng.manual_seed(seed) + num_prompts = len(prompts) + total_rows = num_prompts * num_samples + # Get the special tokens we need to coordinate the tool use state machine get_special = lambda s: self.tokenizer.encode_special(s) python_start = get_special("<|python_start|>") @@ -213,33 +233,64 @@ class Engine: assistant_end = get_special("<|assistant_end|>") # if sampled, ends row bos = self.tokenizer.get_bos_token_id() # if sampled, ends row - # 1) Run a batch 1 prefill of the prompt tokens + # 1) Left-pad all prompts to max length and create attention mask + prompt_lengths = [len(p) for p in prompts] + max_prompt_len = max(prompt_lengths) + padded_prompts = [[0] * (max_prompt_len - len(p)) + p for p in prompts] + + # Create attention masks if padding is needed + decode_mask = None + prefill_attn_mask = None + if any(length != max_prompt_len for length in prompt_lengths): + # prompt_mask[b, t] = True if position t is a real token (not padding) for prompt b + prompt_mask = torch.zeros((num_prompts, max_prompt_len), dtype=torch.bool, device=device) + for i, length in enumerate(prompt_lengths): + prompt_mask[i, max_prompt_len - length:] = True + # causal_mask[q, k] = True if query at position q can attend to key at position k + causal_mask = torch.tril(torch.ones((max_prompt_len, max_prompt_len), dtype=torch.bool, device=device)) + # prefill_attn_mask combines prompt_mask and causal_mask: attend only to non-padding keys before the query position + # shape: (num_prompts, 1, max_prompt_len, max_prompt_len) - the 1 broadcasts across heads + prefill_attn_mask = (causal_mask.unsqueeze(0) & prompt_mask.unsqueeze(1)).unsqueeze(1) + # decode_mask tracks which positions are valid for each row during generation (will be updated after each step) + decode_mask = prompt_mask.repeat_interleave(num_samples, dim=0) + + # 2) Run batched prefill m = self.model.config kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer} kv_cache_prefill = KVCache( - batch_size=1, - seq_len=len(tokens), + batch_size=num_prompts, + seq_len=max_prompt_len, **kv_model_kwargs, ) - ids = torch.tensor([tokens], dtype=torch.long, device=device) - logits = self.model.forward(ids, kv_cache=kv_cache_prefill) - logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size) - # 2) Replicate the KV cache for each sample/row - kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len + ids = torch.tensor(padded_prompts, dtype=torch.long, device=device) + logits = self.model.forward(ids, kv_cache=kv_cache_prefill, attention_mask=prefill_attn_mask) + logits = logits[:, -1, :] # (num_prompts, vocab_size) + + # 3) Expand KV cache for num_samples per prompt + kv_length_hint = (max_prompt_len + max_tokens) if max_tokens is not None else self.model.config.sequence_len kv_cache_decode = KVCache( - batch_size=num_samples, + batch_size=total_rows, seq_len=kv_length_hint, **kv_model_kwargs, ) - kv_cache_decode.prefill(kv_cache_prefill) - del kv_cache_prefill # no need to keep this memory around + # Initialize the decode cache from prefill cache, replicating for each sample + dtype, dev = kv_cache_prefill.kv_cache.dtype, kv_cache_prefill.kv_cache.device + kv_cache_decode.kv_cache = torch.empty(kv_cache_decode.kv_shape, dtype=dtype, device=dev) + for i in range(num_prompts): + src = kv_cache_prefill.kv_cache[:, :, i:i + 1, :, :max_prompt_len, :] + for j in range(num_samples): + kv_cache_decode.kv_cache[:, :, i * num_samples + j:i * num_samples + j + 1, :, :max_prompt_len, :] = src + kv_cache_decode.pos = max_prompt_len + del kv_cache_prefill # no need to keep this memory around - # 3) Initialize states for each sample - row_states = [RowState(tokens.copy()) for _ in range(num_samples)] + # Expand logits for num_samples per prompt + logits = logits.repeat_interleave(num_samples, dim=0) # (total_rows, vocab_size) - # 4) Main generation loop + # 4) Initialize row states and run generation loop + row_states = [RowState(prompt.copy()) for prompt in prompts for _ in range(num_samples)] num_generated = 0 + while True: # Stop condition: we've reached max tokens if max_tokens is not None and num_generated >= max_tokens: @@ -284,26 +335,60 @@ class Engine: elif state.in_python_block: state.python_expr_tokens.append(next_token) - # Yield the token column - yield token_column, token_masks + if is_batched: + # Yield shape (num_prompts, num_samples) + yield ([token_column[i * num_samples:(i + 1) * num_samples] for i in range(num_prompts)], + [token_masks[i * num_samples:(i + 1) * num_samples] for i in range(num_prompts)]) + else: + # Yield shape (num_samples,) + yield token_column, token_masks num_generated += 1 # Prepare logits for next iteration ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1) - logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size) + + if decode_mask is not None: + # Extend decode_mask with True for the new tokens + decode_mask = torch.cat( + [decode_mask, torch.ones((total_rows, 1), dtype=torch.bool, device=device)], dim=1 + ) + logits = self.model.forward( + ids, + kv_cache=kv_cache_decode, + attention_mask=decode_mask.unsqueeze(1).unsqueeze(1), # (B, 1, 1, T) + ) + else: + logits = self.model.forward(ids, kv_cache=kv_cache_decode) + logits = logits[:, -1, :] def generate_batch(self, tokens, num_samples=1, **kwargs): """ - Non-streaming batch generation that just returns the final token sequences. - Returns a list of token sequences (list of lists of ints). + Non-streaming batch generation that returns the final token sequences. Terminal tokens (assistant_end, bos) are not included in the results. + + Returns: + (results, masks): For batched input, both are list[list[list[int]]] of shape + (num_prompts, num_samples, seq_len). For single prompt, both are + list[list[int]] of shape (num_samples, seq_len). Masks: 1=sampled, 0=forced. """ assistant_end = self.tokenizer.encode_special("<|assistant_end|>") bos = self.tokenizer.get_bos_token_id() - results = [tokens.copy() for _ in range(num_samples)] - masks = [[0] * len(tokens) for _ in range(num_samples)] - completed = [False] * num_samples + + # Normalize input to list of prompts + is_batched = len(tokens) > 0 and isinstance(tokens[0], list) + prompts = tokens if is_batched else [tokens] + + # Work with flat structure internally (prompt0_sample0, prompt0_sample1, ..., prompt1_sample0, ...) + results = [p.copy() for p in prompts for _ in range(num_samples)] + masks = [[0] * len(p) for p in prompts for _ in range(num_samples)] + completed = [False] * len(results) + for token_column, token_masks in self.generate(tokens, num_samples, **kwargs): + # Flatten nested output from generate() if batched + if is_batched: + token_column = [t for row in token_column for t in row] + token_masks = [m for row in token_masks for m in row] + for i, (token, mask) in enumerate(zip(token_column, token_masks)): if not completed[i]: if token == assistant_end or token == bos: @@ -314,6 +399,11 @@ class Engine: # Stop if all rows are completed if all(completed): break + + # Reshape to nested structure for batched output + if is_batched: + results = [results[i * num_samples:(i + 1) * num_samples] for i in range(len(prompts))] + masks = [masks[i * num_samples:(i + 1) * num_samples] for i in range(len(prompts))] return results, masks diff --git a/nanochat/gpt.py b/nanochat/gpt.py index e6027a9..0d0be2c 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -61,7 +61,7 @@ class CausalSelfAttention(nn.Module): self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) - def forward(self, x, cos_sin, kv_cache): + def forward(self, x, cos_sin, kv_cache, attention_mask=None): B, T, C = x.size() # Project the input to get queries, keys, and values @@ -83,7 +83,10 @@ class CausalSelfAttention(nn.Module): # Attention: queries attend to keys/values autoregressively. A few cases to handle: enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired - if kv_cache is None or Tq == Tk: + if attention_mask is not None: + # Custom attention mask provided (for batched generation with padding) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, enable_gqa=enable_gqa) + elif kv_cache is None or Tq == Tk: # During training (no KV cache), attend as usual with causal attention # And even if there is KV cache, we can still use this simple version when Tq == Tk y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa) @@ -126,8 +129,8 @@ class Block(nn.Module): self.attn = CausalSelfAttention(config, layer_idx) self.mlp = MLP(config) - def forward(self, x, cos_sin, kv_cache): - x = x + self.attn(norm(x), cos_sin, kv_cache) + def forward(self, x, cos_sin, kv_cache, attention_mask=None): + x = x + self.attn(norm(x), cos_sin, kv_cache, attention_mask) x = x + self.mlp(norm(x)) return x @@ -253,7 +256,7 @@ class GPT(nn.Module): group["initial_lr"] = group["lr"] return optimizers - def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): + def forward(self, idx, targets=None, kv_cache=None, attention_mask=None, loss_reduction='mean'): B, T = idx.size() # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2)) @@ -268,7 +271,7 @@ class GPT(nn.Module): x = self.transformer.wte(idx) x = norm(x) for block in self.transformer.h: - x = block(x, cos_sin, kv_cache) + x = block(x, cos_sin, kv_cache, attention_mask) x = norm(x) # Forward the lm_head (compute logits) diff --git a/tests/test_engine.py b/tests/test_engine.py index 683f89b..ec601b5 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -4,10 +4,27 @@ Test Engine class. Example run: python -m pytest tests/test_engine.py -v """ +import os +import json import torch -from nanochat.engine import KVCache, Engine +import pytest from dataclasses import dataclass +from huggingface_hub import snapshot_download +from nanochat.engine import KVCache, Engine +from nanochat.gpt import GPT, GPTConfig +from nanochat.tokenizer import RustBPETokenizer +from nanochat.checkpoint_manager import find_last_step + +# ----------------------------------------------------------------------------- +# Ensure deterministic behavior for reproducible tests +# See: https://docs.pytorch.org/docs/stable/notes/randomness.html + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Required for CUDA >= 10.2 determinism +torch.manual_seed(0) +torch.use_deterministic_algorithms(True) +torch.backends.cudnn.benchmark = False + # ----------------------------------------------------------------------------- # Mock classes for testing Engine without loading a real model @@ -36,7 +53,7 @@ class MockModel: def get_device(self): return self._device - def forward(self, ids, kv_cache=None): + def forward(self, ids, kv_cache=None, attention_mask=None): """Return uniform logits so sampling is spread across vocab.""" B, T = ids.shape # Simulate what a real transformer does: insert k,v into the cache for each layer @@ -85,6 +102,80 @@ class ByteTokenizer: byte_tokens = [t for t in tokens if t < 256] return bytes(byte_tokens).decode("utf-8", errors="replace") + +def get_model_and_tokenizer(use_pretrained=False): + """ + Get a model and tokenizer for testing. Requires CUDA. + + Args: + use_pretrained: If True, download and load the pretrained nanochat-d34 model. + If False, create a small randomly initialized model. + + Returns: + (model, tokenizer) tuple + """ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for these tests") + device = torch.device("cuda") + + if use_pretrained: + # Download the checkpoint + cache_dir = snapshot_download(repo_id="karpathy/nanochat-d34") + + # Find the last step + step = find_last_step(cache_dir) + + # Load model data + model_path = os.path.join(cache_dir, f"model_{step:06d}.pt") + model_data = torch.load(model_path, map_location=device) + + # Fix torch compile key prefix + model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()} + + # Convert all tensors to bfloat16 for consistent dtypes (checkpoint has mixed bfloat16/float32) + model_data = { + k: v.bfloat16() if v.is_floating_point() else v + for k, v in model_data.items() + } + + # Load metadata + meta_path = os.path.join(cache_dir, f"meta_{step:06d}.json") + with open(meta_path, "r", encoding="utf-8") as f: + meta_data = json.load(f) + + # Build model + model_config = GPTConfig(**meta_data["model_config"]) + with torch.device("meta"): + model = GPT(model_config) + model.to_empty(device=device) + model.init_weights() + model.load_state_dict(model_data, strict=True, assign=True) + model.eval() + + # Load tokenizer from the checkpoint directory + tokenizer = RustBPETokenizer.from_directory(cache_dir) + else: + # Small model for fast testing + config = GPTConfig( + sequence_len=256, + vocab_size=262, # 256 bytes + 6 special tokens + n_layer=2, + n_head=4, + n_kv_head=4, + n_embd=64, + ) + model = GPT(config) + model.init_weights() + model = model.to(device) + model.eval() + tokenizer = ByteTokenizer() + + return model, tokenizer + + +# ----------------------------------------------------------------------------- +# KVCache tests + def test_kv_cache_resize(): """ The KV cache was not resized correctly, more information here: @@ -185,3 +276,134 @@ def test_multi_sample_first_token_diversity(): f"With uniform logits, this is statistically impossible (~10^-36 probability) " f"unless tokens are being broadcast instead of independently sampled." ) + + +# ----------------------------------------------------------------------------- +# Batched generation tests + +@pytest.mark.parametrize("use_pretrained", [False, True]) +def test_batched_generation_consistency(use_pretrained): + """ + Test that batched generation produces the same results as individual generation. + + This test: + 1. Generates from each prompt individually (existing single-prompt behavior) + 2. Generates from all prompts together in a batch (new batched behavior) + 3. Asserts that the results match exactly + + Uses temperature=0.0 for deterministic outputs. + """ + try: + model, tokenizer = get_model_and_tokenizer(use_pretrained=use_pretrained) + except Exception as e: + if use_pretrained: + pytest.skip(f"Could not load pretrained model: {e}") + raise + + engine = Engine(model, tokenizer) + + # Define test prompts with different lengths + bos = tokenizer.get_bos_token_id() + prompts = [ + tokenizer.encode("hi", prepend=bos), + tokenizer.encode("the capital of France is", prepend=bos), + tokenizer.encode("hello, I'm a", prepend=bos), + ] + + num_samples = 2 + # Deterministic decoding + generation_kwargs = dict(max_tokens=10, temperature=0.0, seed=0) + + # 1) Generate individually for each prompt + individual_results = [] + individual_masks = [] + for prompt in prompts: + results, masks = engine.generate_batch(prompt, num_samples=num_samples, **generation_kwargs) + individual_results.append(results) # results is list[list[int]] of shape (num_samples, seq_len) + individual_masks.append(masks) # masks is list[list[int]] of shape (num_samples, seq_len) + + # 2) Generate batched (all prompts together) + batched_results, batched_masks = engine.generate_batch(prompts, num_samples=num_samples, **generation_kwargs) + + # 3) Assert results match + assert len(individual_results) == len(batched_results), \ + f"Prompt count mismatch: {len(individual_results)} vs {len(batched_results)}" + + for prompt_idx, (ind_samples, batch_samples, ind_masks, batch_masks) in enumerate( + zip(individual_results, batched_results, individual_masks, batched_masks)): + assert len(ind_samples) == len(batch_samples), f"Sample count mismatch for prompt {prompt_idx}" + for sample_idx, (ind_result, batch_result, ind_mask, batch_mask) in enumerate( + zip(ind_samples, batch_samples, ind_masks, batch_masks)): + assert ind_result == batch_result, ( + f"Mismatch for prompt {prompt_idx}, sample {sample_idx}:\n" + f" Individual: {ind_result}\n" + f" Batched: {batch_result}" + ) + assert ind_mask == batch_mask, ( + f"Mask mismatch for prompt {prompt_idx}, sample {sample_idx}:\n" + f" Individual: {ind_mask}\n" + f" Batched: {batch_mask}" + ) + + +def test_batched_generation_single_prompt(): + """ + Test that batched generation with a single prompt in the batch + produces the same result as non-batched single prompt generation. + """ + model, tokenizer = get_model_and_tokenizer(use_pretrained=False) + engine = Engine(model, tokenizer) + + bos = tokenizer.get_bos_token_id() + prompt = tokenizer.encode("the capital of France is", prepend=bos) + num_samples = 3 + generation_kwargs = dict(max_tokens=8, temperature=0.0, seed=0) + + # Generate non-batched: returns shape (num_samples, seq_len) + single_results, single_masks = engine.generate_batch(prompt, num_samples=num_samples, **generation_kwargs) + + # Generate batched with single prompt: returns shape (1, num_samples, seq_len) + batched_results, batched_masks = engine.generate_batch([prompt], num_samples=num_samples, **generation_kwargs) + + assert single_results == batched_results[0], ( + f"Single vs batched single-prompt mismatch:\n" + f" Single: {single_results}\n" + f" Batched: {batched_results[0]}" + ) + assert single_masks == batched_masks[0], ( + f"Single vs batched single-prompt mask mismatch:\n" + f" Single: {single_masks}\n" + f" Batched: {batched_masks[0]}" + ) + + +def test_batched_generation_stochastic(): + """ + Test that batched generation with temperature > 0 produces diverse outputs. + """ + model, tokenizer = get_model_and_tokenizer(use_pretrained=False) + engine = Engine(model, tokenizer) + + bos = tokenizer.get_bos_token_id() + prompts = [ + tokenizer.encode("hi", prepend=bos), + tokenizer.encode("the capital of France is", prepend=bos), + ] + + num_samples = 4 + generation_kwargs = dict(max_tokens=64, temperature=1.0, seed=0) + + # Generate batched: returns shape (num_prompts, num_samples, seq_len) + results, _ = engine.generate_batch(prompts, num_samples=num_samples, **generation_kwargs) + + # Check structure + assert len(results) == len(prompts) + + # Check that samples within each prompt are diverse (not all identical) + for prompt_idx, samples in enumerate(results): + assert len(samples) == num_samples + unique_samples = set(tuple(s) for s in samples) + assert len(unique_samples) > 1, ( + f"All {num_samples} samples for prompt {prompt_idx} are identical. " + f"With temperature=1.0, samples should differ." + )