diff --git a/nanochat/engine.py b/nanochat/engine.py index f585b87..522d78f 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -188,9 +188,7 @@ class Engine: ids = torch.tensor([tokens], dtype=torch.long, device=device) logits = self.model.forward(ids, kv_cache=kv_cache_prefill) logits = logits[:, -1, :] - # Sample num_samples independent tokens from the same distribution - logits_repeated = logits.repeat(num_samples, 1) - next_ids = sample_next_token(logits_repeated, rng, temperature, top_k) # (B, 1) + next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) sampled_tokens = next_ids[:, 0].tolist() # 2) Replicate the KV cache for each sample/row @@ -219,9 +217,7 @@ class Engine: # Get sampled tokens - either from prefill or from forward pass if first_iteration: - # Use the tokens we already sampled from prefill - sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows - # TODO: we should sample a token for each row instead of broadcasting + # sampled_tokens already contains num_samples independently sampled tokens first_iteration = False else: # Forward the model and get the next token for each row