From 557b2d5840799fc80c93a5c6a92f2ec4041cdccb Mon Sep 17 00:00:00 2001 From: Azekowka Date: Tue, 14 Oct 2025 17:33:47 +0500 Subject: [PATCH] feat(engine.py): Sample unique tokens per row in generation stream Before, when initiating a batch generation, the first sampled token was broadcasted to all rows. This change updates the logic to ensure that a unique token is sampled for each row, improving the diversity and independence of generated sequences within a batch. --- nanochat/engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nanochat/engine.py b/nanochat/engine.py index de1253a..5daa8eb 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -185,7 +185,7 @@ class Engine: seq_len=len(tokens), **kv_model_kwargs, ) - ids = torch.tensor([tokens], dtype=torch.long, device=device) + ids = torch.tensor([tokens.copy() for _ in range(num_samples)], dtype=torch.long, device=device) logits = self.model.forward(ids, kv_cache=kv_cache_prefill) logits = logits[:, -1, :] next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) @@ -218,7 +218,7 @@ class Engine: # Get sampled tokens - either from prefill or from forward pass if first_iteration: # Use the tokens we already sampled from prefill - sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows + # sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows # TODO: we should sample a token for each row instead of broadcasting first_iteration = False else: