diff --git a/nanochat/engine.py b/nanochat/engine.py index aa2e6a98..423d7920 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -195,6 +195,7 @@ class Engine: output_end = get_special("<|output_end|>") assistant_end = get_special("<|assistant_end|>") # if sampled, ends row bos = self.tokenizer.get_bos_token_id() # if sampled, ends row + pad = get_special("<|pad|>") # 1) Run a batch 1 prefill of the prompt tokens m = self.model.config @@ -243,6 +244,12 @@ class Engine: token_column = [] # contains the next token id along each row token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row for i, state in enumerate(row_states): + # Skip the row completed + if state.completed: + token_column.append(pad) + token_masks.append(0) + continue + # Select the next token in this row is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque? token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled