fix(serve): detect tool markers in text stream not token ids

The SFT loader tokenizes assistant content with .encode() (ordinary), not .encode_special(), so the model was trained to emit <|python_start|> / <|python_end|> as the 7-token ordinary sequence [60, 124, 25145, 95, 17104, 124, 62] rather than as special token id 32764. My prior state-machine matched token_id == python_start_id, which never fired — so tool calls were never executed and the model just hallucinated fake tool results (Official leadership page etc). Fix: detect markers in the decoded text stream, parse the payload between <|python_start|> and <|python_end|>, execute the tool, inject the real <|output_start|>…<|output_end|> tokens into both the SSE stream and the model's input_ids. Next-token prediction is now grounded on real Tavily output.
This commit is contained in:
Manmohan Sharma 2026-04-22 14:39:36 -07:00
parent f642cb2eb6
commit 7a92f5b016
No known key found for this signature in database

View File

@ -254,11 +254,21 @@ class Inference:
tokens = tokens[-max_context:]
async def stream():
from collections import deque
input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device)
forced = deque()
in_tool = False
tool_payload_ids = []
# The model was SFT-trained with <|python_start|>/<|python_end|>/<|output_start|>/<|output_end|>
# encoded as ORDINARY token sequences (not single special-token ids), because the SFT
# loader tokenizes assistant content via .encode() rather than .encode_special().
# So we must detect the markers in the decoded TEXT stream, not in individual token ids.
# When we see <|python_end|> complete in the accumulated text, we execute the tool and
# inject the real <|output_start|>…<|output_end|> tokens into both the stream and the
# model's input_ids so subsequent generation conditions on the real result.
TOOL_START = "<|python_start|>"
TOOL_END = "<|python_end|>"
OUT_START = "<|output_start|>"
OUT_END = "<|output_end|>"
assistant_text = ""
# index of tool-call start in assistant_text, once we see it
tool_start_pos = -1
def _append_token(tid):
nonlocal input_ids
@ -270,56 +280,70 @@ class Inference:
with torch.no_grad():
num_generated = 0
while num_generated < max_tokens:
if forced:
token_id = forced.popleft()
else:
logits = self.model(input_ids)
next_logits = logits[:, -1, :]
if temperature > 0:
next_logits = next_logits / temperature
if top_k > 0:
v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
next_logits[next_logits < v[:, [-1]]] = float('-inf')
probs = torch.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
token_id = next_token.item()
logits = self.model(input_ids)
next_logits = logits[:, -1, :]
if temperature > 0:
next_logits = next_logits / temperature
if top_k > 0:
v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
next_logits[next_logits < v[:, [-1]]] = float('-inf')
probs = torch.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
token_id = next_token.item()
# Tool state machine: detect <|python_start|>...<|python_end|>,
# execute tool, inject <|output_start|>...<|output_end|> as forced tokens
if token_id == self.python_start_id:
in_tool = True
tool_payload_ids = []
elif token_id == self.python_end_id and in_tool:
in_tool = False
if tool_payload_ids:
# Stop on real terminators (assistant_end, bos)
if token_id in self._stop_token_ids:
break
try:
token_text = self.tokenizer.decode([token_id])
except Exception:
token_text = ""
# Append to the model's context first
_append_token(token_id)
num_generated += 1
# Stream token to client (raw, UI parses markers)
if token_text:
yield "data: " + json.dumps({"token": token_text, "gpu": 0}) + "\n\n"
assistant_text += token_text
# --- tool-call detection on the text stream ---
if tool_start_pos < 0:
idx = assistant_text.find(TOOL_START)
if idx >= 0:
tool_start_pos = idx
if tool_start_pos >= 0:
tail = assistant_text[tool_start_pos + len(TOOL_START):]
end_rel = tail.find(TOOL_END)
if end_rel >= 0:
payload_text = tail[:end_rel]
# reset detector for any subsequent call in the same turn
tool_start_pos = -1
try:
payload_text = self.tokenizer.decode(tool_payload_ids)
invocation = self._parse_tool_call(payload_text)
result = self.tool_registry.execute(invocation.tool_name, invocation.arguments)
result_text = result.to_payload()[:4096]
except Exception as exc:
result_text = json.dumps({"error": str(exc)[:500]})
if result_text:
forced.append(self.output_start_id)
forced.extend(self.tokenizer.encode(result_text))
forced.append(self.output_end_id)
tool_payload_ids = []
elif in_tool:
tool_payload_ids.append(token_id)
# Stop only on assistant_end or bos (NOT on tool markers)
if token_id in self._stop_token_ids:
break
# Decode + stream to client (includes tool markers; UI renders)
try:
token_text = self.tokenizer.decode([token_id])
yield "data: " + json.dumps({"token": token_text, "gpu": 0}) + "\n\n"
except Exception:
pass
_append_token(token_id)
num_generated += 1
wrapped = OUT_START + result_text + OUT_END
# Push wrapped result into the model's context so the next-token
# prediction is grounded on the real tool output, and also stream it
# to the client so the UI can render the tool-result card.
for rid in self.tokenizer.encode(wrapped):
try:
rt = self.tokenizer.decode([rid])
except Exception:
rt = ""
if rt:
yield "data: " + json.dumps({"token": rt, "gpu": 0}) + "\n\n"
assistant_text += rt
_append_token(rid)
num_generated += 1
if num_generated >= max_tokens:
break
yield "data: " + json.dumps({"done": True}) + "\n\n"