diff --git a/modal/serve.py b/modal/serve.py index 5eca71e2..9bafe13a 100644 --- a/modal/serve.py +++ b/modal/serve.py @@ -254,11 +254,21 @@ class Inference: tokens = tokens[-max_context:] async def stream(): - from collections import deque input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device) - forced = deque() - in_tool = False - tool_payload_ids = [] + # The model was SFT-trained with <|python_start|>/<|python_end|>/<|output_start|>/<|output_end|> + # encoded as ORDINARY token sequences (not single special-token ids), because the SFT + # loader tokenizes assistant content via .encode() rather than .encode_special(). + # So we must detect the markers in the decoded TEXT stream, not in individual token ids. + # When we see <|python_end|> complete in the accumulated text, we execute the tool and + # inject the real <|output_start|>…<|output_end|> tokens into both the stream and the + # model's input_ids so subsequent generation conditions on the real result. + TOOL_START = "<|python_start|>" + TOOL_END = "<|python_end|>" + OUT_START = "<|output_start|>" + OUT_END = "<|output_end|>" + assistant_text = "" + # index of tool-call start in assistant_text, once we see it + tool_start_pos = -1 def _append_token(tid): nonlocal input_ids @@ -270,56 +280,70 @@ class Inference: with torch.no_grad(): num_generated = 0 while num_generated < max_tokens: - if forced: - token_id = forced.popleft() - else: - logits = self.model(input_ids) - next_logits = logits[:, -1, :] - if temperature > 0: - next_logits = next_logits / temperature - if top_k > 0: - v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1))) - next_logits[next_logits < v[:, [-1]]] = float('-inf') - probs = torch.softmax(next_logits, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - token_id = next_token.item() + logits = self.model(input_ids) + next_logits = logits[:, -1, :] + if temperature > 0: + next_logits = next_logits / temperature + if top_k > 0: + v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1))) + next_logits[next_logits < v[:, [-1]]] = float('-inf') + probs = torch.softmax(next_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + token_id = next_token.item() - # Tool state machine: detect <|python_start|>...<|python_end|>, - # execute tool, inject <|output_start|>...<|output_end|> as forced tokens - if token_id == self.python_start_id: - in_tool = True - tool_payload_ids = [] - elif token_id == self.python_end_id and in_tool: - in_tool = False - if tool_payload_ids: + # Stop on real terminators (assistant_end, bos) + if token_id in self._stop_token_ids: + break + + try: + token_text = self.tokenizer.decode([token_id]) + except Exception: + token_text = "" + + # Append to the model's context first + _append_token(token_id) + num_generated += 1 + + # Stream token to client (raw, UI parses markers) + if token_text: + yield "data: " + json.dumps({"token": token_text, "gpu": 0}) + "\n\n" + assistant_text += token_text + + # --- tool-call detection on the text stream --- + if tool_start_pos < 0: + idx = assistant_text.find(TOOL_START) + if idx >= 0: + tool_start_pos = idx + if tool_start_pos >= 0: + tail = assistant_text[tool_start_pos + len(TOOL_START):] + end_rel = tail.find(TOOL_END) + if end_rel >= 0: + payload_text = tail[:end_rel] + # reset detector for any subsequent call in the same turn + tool_start_pos = -1 try: - payload_text = self.tokenizer.decode(tool_payload_ids) invocation = self._parse_tool_call(payload_text) result = self.tool_registry.execute(invocation.tool_name, invocation.arguments) result_text = result.to_payload()[:4096] except Exception as exc: result_text = json.dumps({"error": str(exc)[:500]}) - if result_text: - forced.append(self.output_start_id) - forced.extend(self.tokenizer.encode(result_text)) - forced.append(self.output_end_id) - tool_payload_ids = [] - elif in_tool: - tool_payload_ids.append(token_id) - # Stop only on assistant_end or bos (NOT on tool markers) - if token_id in self._stop_token_ids: - break - - # Decode + stream to client (includes tool markers; UI renders) - try: - token_text = self.tokenizer.decode([token_id]) - yield "data: " + json.dumps({"token": token_text, "gpu": 0}) + "\n\n" - except Exception: - pass - - _append_token(token_id) - num_generated += 1 + wrapped = OUT_START + result_text + OUT_END + # Push wrapped result into the model's context so the next-token + # prediction is grounded on the real tool output, and also stream it + # to the client so the UI can render the tool-result card. + for rid in self.tokenizer.encode(wrapped): + try: + rt = self.tokenizer.decode([rid]) + except Exception: + rt = "" + if rt: + yield "data: " + json.dumps({"token": rt, "gpu": 0}) + "\n\n" + assistant_text += rt + _append_token(rid) + num_generated += 1 + if num_generated >= max_tokens: + break yield "data: " + json.dumps({"done": True}) + "\n\n"