From 7bd999ba0252e318ce1a44b249626f4c6b222346 Mon Sep 17 00:00:00 2001 From: Azekowka Date: Wed, 29 Oct 2025 22:02:28 +0500 Subject: [PATCH] feat(engine.py): Sample unique initial tokens for each sequence in a batch Before, when initiating a batch generation, the first sampled token was broadcasted to all sequences. This change now expands the output logits after a single efficient prefill, allowing for the sampling of a unique starting token for each sequence in the batch. --- nanochat/engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nanochat/engine.py b/nanochat/engine.py index 44ed16b..06e8671 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -216,6 +216,8 @@ class Engine: ids = torch.tensor([tokens], dtype=torch.long, device=device) logits = self.model.forward(ids, kv_cache=kv_cache_prefill) logits = logits[:, -1, :] + # expand the logits to be of size (num_samples, vocab_size) + logits = logits.expand(num_samples, -1) next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) sampled_tokens = next_ids[:, 0].tolist() @@ -246,8 +248,6 @@ class Engine: # Get sampled tokens - either from prefill or from forward pass if first_iteration: # Use the tokens we already sampled from prefill - sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows - # TODO: we should sample a token for each row instead of broadcasting first_iteration = False else: # Forward the model and get the next token for each row