diff --git a/modal/serve.py b/modal/serve.py index 92083424..f1a59a99 100644 --- a/modal/serve.py +++ b/modal/serve.py @@ -263,10 +263,18 @@ class Inference: out_start_str = "<|output_start|>" out_end_str = "<|output_end|>" + # Suppression: model training has many convs that emit a fake + # <|output_start|>…<|output_end|> after our injected one. We stop the + # turn if we detect another full <|output_start|> sequence emitted + # after our injection. + out_start_ids = tuple(self.tokenizer.encode(out_start_str)) + async def stream(): input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device) gen_ids: list[int] = [] # everything the MODEL sampled this turn tool_start_idx = -1 # position in gen_ids where <|python_start|> begins + tool_injected = False # once True, stop detecting further tool calls + injection_end_pos = -1 # index in gen_ids where our injected tokens end def _append_token(tid): nonlocal input_ids @@ -320,11 +328,11 @@ class Inference: yield "data: " + json.dumps({"token": token_text, "gpu": 0}) + "\n\n" # --- tool-call detection (id-sequence match) --- - if tool_start_idx < 0: + if not tool_injected and tool_start_idx < 0: idx = _find_subseq(gen_ids, tool_start_ids, max(0, len(gen_ids) - len(tool_start_ids) - 2)) if idx >= 0: tool_start_idx = idx - if tool_start_idx >= 0: + if not tool_injected and tool_start_idx >= 0: # look for <|python_end|> after the payload end_idx = _find_subseq(gen_ids, tool_end_ids, tool_start_idx + len(tool_start_ids)) if end_idx >= 0: @@ -351,9 +359,16 @@ class Inference: num_generated += 1 if num_generated >= max_tokens: break - # Reset so a second tool call in the same turn still works + tool_injected = True + injection_end_pos = len(gen_ids) tool_start_idx = -1 + # After injection, detect if the model emits ANOTHER full + # <|output_start|> sequence (training-data loop artifact) and stop the turn. + if tool_injected and injection_end_pos > 0: + if _find_subseq(gen_ids, out_start_ids, injection_end_pos) >= 0: + break + yield "data: " + json.dumps({"done": True}) + "\n\n" return StreamingResponse(