Skip the row completed in inference sample loop

For a single prompt, the inference engine generates num_samples of result independently.
Once a row is completed, stop the row from generating tokens to reduce computation.

Signed-off-by: Luojects <13113951796@163.com>
This commit is contained in:
Luojects 2026-04-29 20:25:23 +08:00
parent 0aaca56805
commit 91cb0761ab

View File

@ -195,6 +195,7 @@ class Engine:
output_end = get_special("<|output_end|>")
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
pad = get_special("<|pad|>")
# 1) Run a batch 1 prefill of the prompt tokens
m = self.model.config
@ -243,6 +244,12 @@ class Engine:
token_column = [] # contains the next token id along each row
token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
for i, state in enumerate(row_states):
# Skip the row completed
if state.completed:
token_column.append(pad)
token_masks.append(0)
continue
# Select the next token in this row
is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled