Merge pull request #14 from Dianababaei/refactor/engine-remove-token-broadcasting-first-iteration

refactor(engine): Remove 2 unnecessary lines from Engine class implementation
This commit is contained in:
Dianababaei 2025-11-05 20:01:54 +03:30 committed by GitHub
commit 737165ce44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -188,9 +188,7 @@ class Engine:
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
logits = logits[:, -1, :]
# Sample num_samples independent tokens from the same distribution
logits_repeated = logits.repeat(num_samples, 1)
next_ids = sample_next_token(logits_repeated, rng, temperature, top_k) # (B, 1)
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
# 2) Replicate the KV cache for each sample/row
@ -219,9 +217,7 @@ class Engine:
# Get sampled tokens - either from prefill or from forward pass
if first_iteration:
# Use the tokens we already sampled from prefill
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
# TODO: we should sample a token for each row instead of broadcasting
# sampled_tokens already contains num_samples independently sampled tokens
first_iteration = False
else:
# Forward the model and get the next token for each row