fix(serve): don't scan our own injected tokens for the loop-break check

Bug: after runtime tool injection, the post-injection break scanned gen_ids[pre_injection_len:] which included our own injected <|output_start|>…<|output_end|> — so the loop-break fired IMMEDIATELY and stopped the turn before the model could write its final answer. Visible on multi-turn queries like a follow-up 'tell me more about him' where the model naturally issued a tool call, got real Tavily output, and then got cut off. Fix: track post_injection_start (the index AFTER injected tokens) and only scan from there for stray markers.
This commit is contained in:
Manmohan Sharma 2026-04-22 15:15:34 -07:00
parent 07b7629ba7
commit 57be688fdc
No known key found for this signature in database

View File

@ -313,6 +313,7 @@ class Inference:
gen_ids: list[int] = [] # everything the MODEL sampled this turn
tool_injected = bool(forced_prefix_text) # forced prefix counts as an injection
pre_injection_len = 0 # len(gen_ids) right before we start injection
post_injection_start = 0 # index in gen_ids AFTER injection finished
# If we pre-seeded a forced tool call + result, stream it to the client
# now so the UI can render the tool-call / tool-result cards.
@ -399,14 +400,14 @@ class Inference:
if num_generated >= max_tokens:
break
tool_injected = True
post_injection_start = len(gen_ids) # ← scan only what the model generates AFTER our injection
# After injection (forced OR runtime): the model often loops and
# emits another fake <|output_start|>…<|output_end|> / <|python_start|>…
# block. Break the turn as soon as ANY tool-marker appears in what the
# MODEL itself generated. We check the decoded text of gen_ids[pre_injection_len:].
elif tool_injected and len(gen_ids) > pre_injection_len + 6:
# block. Scan only the model's POST-injection tokens — not our own.
elif tool_injected and len(gen_ids) > post_injection_start + 6:
try:
post_text = self.tokenizer.decode(gen_ids[pre_injection_len:])
post_text = self.tokenizer.decode(gen_ids[post_injection_start:])
except Exception:
post_text = ""
for bad in (out_start_str, out_end_str, tool_start_str, tool_end_str):