Compare commits

...

2 Commits

Author SHA1 Message Date
Abdulaziz Gabitov
920fa8663d
Merge 7bd999ba02 into bc1fca39f3 2025-11-15 22:38:03 +03:00
Azekowka
7bd999ba02 feat(engine.py): Sample unique initial tokens for each sequence in a batch
Before, when initiating a batch generation, the first sampled token was broadcasted to all sequences. This change now expands the output logits after a single efficient prefill, allowing for the sampling of a unique starting token for each sequence in the batch.
2025-10-29 22:02:28 +05:00

View File

@ -218,6 +218,8 @@ class Engine:
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
logits = logits[:, -1, :]
# expand the logits to be of size (num_samples, vocab_size)
logits = logits.expand(num_samples, -1)
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
@ -248,8 +250,6 @@ class Engine:
# Get sampled tokens - either from prefill or from forward pass
if first_iteration:
# Use the tokens we already sampled from prefill
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
# TODO: we should sample a token for each row instead of broadcasting
first_iteration = False
else:
# Forward the model and get the next token for each row