From 91cb0761ab921b56ee5b088a944523e0722cb1dc Mon Sep 17 00:00:00 2001 From: Luojects <13113951796@163.com> Date: Wed, 29 Apr 2026 20:25:23 +0800 Subject: [PATCH] Skip the row completed in inference sample loop For a single prompt, the inference engine generates num_samples of result independently. Once a row is completed, stop the row from generating tokens to reduce computation. Signed-off-by: Luojects <13113951796@163.com> --- nanochat/engine.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/nanochat/engine.py b/nanochat/engine.py index aa2e6a98..423d7920 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -195,6 +195,7 @@ class Engine: output_end = get_special("<|output_end|>") assistant_end = get_special("<|assistant_end|>") # if sampled, ends row bos = self.tokenizer.get_bos_token_id() # if sampled, ends row + pad = get_special("<|pad|>") # 1) Run a batch 1 prefill of the prompt tokens m = self.model.config @@ -243,6 +244,12 @@ class Engine: token_column = [] # contains the next token id along each row token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row for i, state in enumerate(row_states): + # Skip the row completed + if state.completed: + token_column.append(pad) + token_masks.append(0) + continue + # Select the next token in this row is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque? token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled