From ba727cb4d59b21e7770f91d98cf0bf07618577a2 Mon Sep 17 00:00:00 2001 From: Manmohan Sharma Date: Wed, 22 Apr 2026 14:42:07 -0700 Subject: [PATCH] fix(serve): match tool markers on token-id sequences not decoded text Previous text-stream approach lost markers because BPE partial-byte tokens decode to empty strings, so assistant_text never accumulated the full marker. Switch to matching the ordinary-token id sequence directly (python_start = [60,124,25145,95,17104,124,62]). --- modal/serve.py | 85 ++++++++++++++++++++++++++++---------------------- 1 file changed, 47 insertions(+), 38 deletions(-) diff --git a/modal/serve.py b/modal/serve.py index 9bafe13a..92083424 100644 --- a/modal/serve.py +++ b/modal/serve.py @@ -253,22 +253,20 @@ class Inference: if len(tokens) > max_context: tokens = tokens[-max_context:] + # Ordinary-text token-id sequences for the tool markers. + # The SFT loader tokenizes assistant content with .encode() (not .encode_special()), + # so the model emits these as multi-token sequences, not single special-token ids. + # Match on the id sequence directly — more reliable than text (BPE partial UTF-8 + # can make single-token decodes return empty strings). + tool_start_ids = tuple(self.tokenizer.encode("<|python_start|>")) + tool_end_ids = tuple(self.tokenizer.encode("<|python_end|>")) + out_start_str = "<|output_start|>" + out_end_str = "<|output_end|>" + async def stream(): input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device) - # The model was SFT-trained with <|python_start|>/<|python_end|>/<|output_start|>/<|output_end|> - # encoded as ORDINARY token sequences (not single special-token ids), because the SFT - # loader tokenizes assistant content via .encode() rather than .encode_special(). - # So we must detect the markers in the decoded TEXT stream, not in individual token ids. - # When we see <|python_end|> complete in the accumulated text, we execute the tool and - # inject the real <|output_start|>…<|output_end|> tokens into both the stream and the - # model's input_ids so subsequent generation conditions on the real result. - TOOL_START = "<|python_start|>" - TOOL_END = "<|python_end|>" - OUT_START = "<|output_start|>" - OUT_END = "<|output_end|>" - assistant_text = "" - # index of tool-call start in assistant_text, once we see it - tool_start_pos = -1 + gen_ids: list[int] = [] # everything the MODEL sampled this turn + tool_start_idx = -1 # position in gen_ids where <|python_start|> begins def _append_token(tid): nonlocal input_ids @@ -277,6 +275,20 @@ class Inference: if input_ids.size(1) > self.config.sequence_len: input_ids = input_ids[:, -self.config.sequence_len:] + def _match_at(seq: list[int], pos: int, pat: tuple) -> bool: + if pos < 0 or pos + len(pat) > len(seq): + return False + return tuple(seq[pos:pos + len(pat)]) == pat + + def _find_subseq(seq: list[int], pat: tuple, start: int = 0) -> int: + L = len(pat) + if L == 0 or len(seq) < start + L: + return -1 + for i in range(start, len(seq) - L + 1): + if tuple(seq[i:i + L]) == pat: + return i + return -1 + with torch.no_grad(): num_generated = 0 while num_generated < max_tokens: @@ -291,47 +303,42 @@ class Inference: next_token = torch.multinomial(probs, num_samples=1) token_id = next_token.item() - # Stop on real terminators (assistant_end, bos) if token_id in self._stop_token_ids: break + # Commit to context + sequence + _append_token(token_id) + gen_ids.append(token_id) + num_generated += 1 + + # Stream raw decoded text (may be empty for partial BPE bytes — that's OK) try: token_text = self.tokenizer.decode([token_id]) except Exception: token_text = "" - - # Append to the model's context first - _append_token(token_id) - num_generated += 1 - - # Stream token to client (raw, UI parses markers) if token_text: yield "data: " + json.dumps({"token": token_text, "gpu": 0}) + "\n\n" - assistant_text += token_text - # --- tool-call detection on the text stream --- - if tool_start_pos < 0: - idx = assistant_text.find(TOOL_START) + # --- tool-call detection (id-sequence match) --- + if tool_start_idx < 0: + idx = _find_subseq(gen_ids, tool_start_ids, max(0, len(gen_ids) - len(tool_start_ids) - 2)) if idx >= 0: - tool_start_pos = idx - if tool_start_pos >= 0: - tail = assistant_text[tool_start_pos + len(TOOL_START):] - end_rel = tail.find(TOOL_END) - if end_rel >= 0: - payload_text = tail[:end_rel] - # reset detector for any subsequent call in the same turn - tool_start_pos = -1 + tool_start_idx = idx + if tool_start_idx >= 0: + # look for <|python_end|> after the payload + end_idx = _find_subseq(gen_ids, tool_end_ids, tool_start_idx + len(tool_start_ids)) + if end_idx >= 0: + payload_ids = gen_ids[tool_start_idx + len(tool_start_ids):end_idx] try: + payload_text = self.tokenizer.decode(payload_ids) invocation = self._parse_tool_call(payload_text) result = self.tool_registry.execute(invocation.tool_name, invocation.arguments) result_text = result.to_payload()[:4096] except Exception as exc: result_text = json.dumps({"error": str(exc)[:500]}) - wrapped = OUT_START + result_text + OUT_END - # Push wrapped result into the model's context so the next-token - # prediction is grounded on the real tool output, and also stream it - # to the client so the UI can render the tool-result card. + wrapped = out_start_str + result_text + out_end_str + # Inject real result tokens into the model's context and the client stream. for rid in self.tokenizer.encode(wrapped): try: rt = self.tokenizer.decode([rid]) @@ -339,11 +346,13 @@ class Inference: rt = "" if rt: yield "data: " + json.dumps({"token": rt, "gpu": 0}) + "\n\n" - assistant_text += rt _append_token(rid) + gen_ids.append(rid) num_generated += 1 if num_generated >= max_tokens: break + # Reset so a second tool call in the same turn still works + tool_start_idx = -1 yield "data: " + json.dumps({"done": True}) + "\n\n"