This commit is contained in:
Abdulaziz Gabitov 2025-11-15 22:38:03 +03:00 committed by GitHub
commit 920fa8663d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -218,6 +218,8 @@ class Engine:
ids = torch.tensor([tokens], dtype=torch.long, device=device) ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill) logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
logits = logits[:, -1, :] logits = logits[:, -1, :]
# expand the logits to be of size (num_samples, vocab_size)
logits = logits.expand(num_samples, -1)
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist() sampled_tokens = next_ids[:, 0].tolist()
@ -248,8 +250,6 @@ class Engine:
# Get sampled tokens - either from prefill or from forward pass # Get sampled tokens - either from prefill or from forward pass
if first_iteration: if first_iteration:
# Use the tokens we already sampled from prefill # Use the tokens we already sampled from prefill
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
# TODO: we should sample a token for each row instead of broadcasting
first_iteration = False first_iteration = False
else: else:
# Forward the model and get the next token for each row # Forward the model and get the next token for each row