This commit is contained in:
Luoject 2026-05-05 14:54:07 +08:00 committed by GitHub
commit 5b19737480
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -195,6 +195,7 @@ class Engine:
output_end = get_special("<|output_end|>")
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
pad = get_special("<|pad|>")
# 1) Run a batch 1 prefill of the prompt tokens
m = self.model.config
@ -243,6 +244,12 @@ class Engine:
token_column = [] # contains the next token id along each row
token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
for i, state in enumerate(row_states):
# Skip the row completed
if state.completed:
token_column.append(pad)
token_masks.append(0)
continue
# Select the next token in this row
is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled