diff --git a/modal/serve.py b/modal/serve.py index 2df9763f..01554a65 100644 --- a/modal/serve.py +++ b/modal/serve.py @@ -313,6 +313,7 @@ class Inference: gen_ids: list[int] = [] # everything the MODEL sampled this turn tool_injected = bool(forced_prefix_text) # forced prefix counts as an injection pre_injection_len = 0 # len(gen_ids) right before we start injection + post_injection_start = 0 # index in gen_ids AFTER injection finished # If we pre-seeded a forced tool call + result, stream it to the client # now so the UI can render the tool-call / tool-result cards. @@ -399,14 +400,14 @@ class Inference: if num_generated >= max_tokens: break tool_injected = True + post_injection_start = len(gen_ids) # ← scan only what the model generates AFTER our injection # After injection (forced OR runtime): the model often loops and # emits another fake <|output_start|>…<|output_end|> / <|python_start|>… - # block. Break the turn as soon as ANY tool-marker appears in what the - # MODEL itself generated. We check the decoded text of gen_ids[pre_injection_len:]. - elif tool_injected and len(gen_ids) > pre_injection_len + 6: + # block. Scan only the model's POST-injection tokens — not our own. + elif tool_injected and len(gen_ids) > post_injection_start + 6: try: - post_text = self.tokenizer.decode(gen_ids[pre_injection_len:]) + post_text = self.tokenizer.decode(gen_ids[post_injection_start:]) except Exception: post_text = "" for bad in (out_start_str, out_end_str, tool_start_str, tool_end_str):