diff --git a/modal/serve.py b/modal/serve.py index f1a59a99..442da4a4 100644 --- a/modal/serve.py +++ b/modal/serve.py @@ -253,28 +253,21 @@ class Inference: if len(tokens) > max_context: tokens = tokens[-max_context:] - # Ordinary-text token-id sequences for the tool markers. # The SFT loader tokenizes assistant content with .encode() (not .encode_special()), - # so the model emits these as multi-token sequences, not single special-token ids. - # Match on the id sequence directly — more reliable than text (BPE partial UTF-8 - # can make single-token decodes return empty strings). - tool_start_ids = tuple(self.tokenizer.encode("<|python_start|>")) - tool_end_ids = tuple(self.tokenizer.encode("<|python_end|>")) + # so these markers are emitted as multi-token byte sequences, and BPE has + # multiple valid tokenizations of the same string — so matching on a single + # expected id sequence is unreliable. Instead we decode the tail of the + # generated token stream and search for the marker TEXT. + tool_start_str = "<|python_start|>" + tool_end_str = "<|python_end|>" out_start_str = "<|output_start|>" out_end_str = "<|output_end|>" - # Suppression: model training has many convs that emit a fake - # <|output_start|>…<|output_end|> after our injected one. We stop the - # turn if we detect another full <|output_start|> sequence emitted - # after our injection. - out_start_ids = tuple(self.tokenizer.encode(out_start_str)) - async def stream(): input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device) gen_ids: list[int] = [] # everything the MODEL sampled this turn - tool_start_idx = -1 # position in gen_ids where <|python_start|> begins tool_injected = False # once True, stop detecting further tool calls - injection_end_pos = -1 # index in gen_ids where our injected tokens end + pre_injection_len = 0 # len(gen_ids) right before we start injection def _append_token(tid): nonlocal input_ids @@ -283,19 +276,13 @@ class Inference: if input_ids.size(1) > self.config.sequence_len: input_ids = input_ids[:, -self.config.sequence_len:] - def _match_at(seq: list[int], pos: int, pat: tuple) -> bool: - if pos < 0 or pos + len(pat) > len(seq): - return False - return tuple(seq[pos:pos + len(pat)]) == pat - - def _find_subseq(seq: list[int], pat: tuple, start: int = 0) -> int: - L = len(pat) - if L == 0 or len(seq) < start + L: - return -1 - for i in range(start, len(seq) - L + 1): - if tuple(seq[i:i + L]) == pat: - return i - return -1 + def _decode_tail_text(last_n: int = 40) -> str: + if not gen_ids: + return "" + try: + return self.tokenizer.decode(gen_ids[-last_n:]) + except Exception: + return "" with torch.no_grad(): num_generated = 0 @@ -319,7 +306,7 @@ class Inference: gen_ids.append(token_id) num_generated += 1 - # Stream raw decoded text (may be empty for partial BPE bytes — that's OK) + # Stream raw decoded text (may be empty for partial BPE bytes) try: token_text = self.tokenizer.decode([token_id]) except Exception: @@ -327,46 +314,50 @@ class Inference: if token_text: yield "data: " + json.dumps({"token": token_text, "gpu": 0}) + "\n\n" - # --- tool-call detection (id-sequence match) --- - if not tool_injected and tool_start_idx < 0: - idx = _find_subseq(gen_ids, tool_start_ids, max(0, len(gen_ids) - len(tool_start_ids) - 2)) - if idx >= 0: - tool_start_idx = idx - if not tool_injected and tool_start_idx >= 0: - # look for <|python_end|> after the payload - end_idx = _find_subseq(gen_ids, tool_end_ids, tool_start_idx + len(tool_start_ids)) - if end_idx >= 0: - payload_ids = gen_ids[tool_start_idx + len(tool_start_ids):end_idx] - try: - payload_text = self.tokenizer.decode(payload_ids) - invocation = self._parse_tool_call(payload_text) - result = self.tool_registry.execute(invocation.tool_name, invocation.arguments) - result_text = result.to_payload()[:4096] - except Exception as exc: - result_text = json.dumps({"error": str(exc)[:500]}) + # --- tool-call detection on decoded-tail text --- + if not tool_injected: + # Decode the whole turn-so-far and look for markers + try: + full_text = self.tokenizer.decode(gen_ids) + except Exception: + full_text = "" + if full_text: + ps = full_text.rfind(tool_start_str) + if ps >= 0: + pe = full_text.find(tool_end_str, ps + len(tool_start_str)) + if pe >= 0: + payload_text = full_text[ps + len(tool_start_str):pe] + try: + invocation = self._parse_tool_call(payload_text) + result = self.tool_registry.execute(invocation.tool_name, invocation.arguments) + result_text = result.to_payload()[:4096] + except Exception as exc: + result_text = json.dumps({"error": str(exc)[:500]}) - wrapped = out_start_str + result_text + out_end_str - # Inject real result tokens into the model's context and the client stream. - for rid in self.tokenizer.encode(wrapped): - try: - rt = self.tokenizer.decode([rid]) - except Exception: - rt = "" - if rt: - yield "data: " + json.dumps({"token": rt, "gpu": 0}) + "\n\n" - _append_token(rid) - gen_ids.append(rid) - num_generated += 1 - if num_generated >= max_tokens: - break - tool_injected = True - injection_end_pos = len(gen_ids) - tool_start_idx = -1 + pre_injection_len = len(gen_ids) + wrapped = out_start_str + result_text + out_end_str + for rid in self.tokenizer.encode(wrapped): + try: + rt = self.tokenizer.decode([rid]) + except Exception: + rt = "" + if rt: + yield "data: " + json.dumps({"token": rt, "gpu": 0}) + "\n\n" + _append_token(rid) + gen_ids.append(rid) + num_generated += 1 + if num_generated >= max_tokens: + break + tool_injected = True - # After injection, detect if the model emits ANOTHER full - # <|output_start|> sequence (training-data loop artifact) and stop the turn. - if tool_injected and injection_end_pos > 0: - if _find_subseq(gen_ids, out_start_ids, injection_end_pos) >= 0: + # After injection: if the model starts emitting another <|output_start|>, + # break the turn — the grounded result already streamed. + elif pre_injection_len > 0 and len(gen_ids) > pre_injection_len + 20: + try: + post_text = self.tokenizer.decode(gen_ids[pre_injection_len + 10:]) + except Exception: + post_text = "" + if out_start_str in post_text: break yield "data: " + json.dumps({"done": True}) + "\n\n"