mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-10 18:00:17 +00:00
Merge 91cb0761ab into dc54a1a307
This commit is contained in:
commit
5b19737480
|
|
@ -195,6 +195,7 @@ class Engine:
|
|||
output_end = get_special("<|output_end|>")
|
||||
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
|
||||
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
|
||||
pad = get_special("<|pad|>")
|
||||
|
||||
# 1) Run a batch 1 prefill of the prompt tokens
|
||||
m = self.model.config
|
||||
|
|
@ -243,6 +244,12 @@ class Engine:
|
|||
token_column = [] # contains the next token id along each row
|
||||
token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
|
||||
for i, state in enumerate(row_states):
|
||||
# Skip the row completed
|
||||
if state.completed:
|
||||
token_column.append(pad)
|
||||
token_masks.append(0)
|
||||
continue
|
||||
|
||||
# Select the next token in this row
|
||||
is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
|
||||
token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user