mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-20 14:47:34 +00:00
Skip the row completed in inference sample loop
For a single prompt, the inference engine generates num_samples of result independently. Once a row is completed, stop the row from generating tokens to reduce computation. Signed-off-by: Luojects <13113951796@163.com>
This commit is contained in:
parent
0aaca56805
commit
91cb0761ab
|
|
@ -195,6 +195,7 @@ class Engine:
|
|||
output_end = get_special("<|output_end|>")
|
||||
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
|
||||
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
|
||||
pad = get_special("<|pad|>")
|
||||
|
||||
# 1) Run a batch 1 prefill of the prompt tokens
|
||||
m = self.model.config
|
||||
|
|
@ -243,6 +244,12 @@ class Engine:
|
|||
token_column = [] # contains the next token id along each row
|
||||
token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
|
||||
for i, state in enumerate(row_states):
|
||||
# Skip the row completed
|
||||
if state.completed:
|
||||
token_column.append(pad)
|
||||
token_masks.append(0)
|
||||
continue
|
||||
|
||||
# Select the next token in this row
|
||||
is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
|
||||
token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user