Merge pull request #48 from manmohan659/fix/tool-id-sequence-match

fix(serve): tool-marker detection via token-id sequence
This commit is contained in:
Manmohan 2026-04-22 17:42:23 -04:00 committed by GitHub
commit fd43d6399b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -253,22 +253,20 @@ class Inference:
if len(tokens) > max_context:
tokens = tokens[-max_context:]
# Ordinary-text token-id sequences for the tool markers.
# The SFT loader tokenizes assistant content with .encode() (not .encode_special()),
# so the model emits these as multi-token sequences, not single special-token ids.
# Match on the id sequence directly — more reliable than text (BPE partial UTF-8
# can make single-token decodes return empty strings).
tool_start_ids = tuple(self.tokenizer.encode("<|python_start|>"))
tool_end_ids = tuple(self.tokenizer.encode("<|python_end|>"))
out_start_str = "<|output_start|>"
out_end_str = "<|output_end|>"
async def stream():
input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device)
# The model was SFT-trained with <|python_start|>/<|python_end|>/<|output_start|>/<|output_end|>
# encoded as ORDINARY token sequences (not single special-token ids), because the SFT
# loader tokenizes assistant content via .encode() rather than .encode_special().
# So we must detect the markers in the decoded TEXT stream, not in individual token ids.
# When we see <|python_end|> complete in the accumulated text, we execute the tool and
# inject the real <|output_start|>…<|output_end|> tokens into both the stream and the
# model's input_ids so subsequent generation conditions on the real result.
TOOL_START = "<|python_start|>"
TOOL_END = "<|python_end|>"
OUT_START = "<|output_start|>"
OUT_END = "<|output_end|>"
assistant_text = ""
# index of tool-call start in assistant_text, once we see it
tool_start_pos = -1
gen_ids: list[int] = [] # everything the MODEL sampled this turn
tool_start_idx = -1 # position in gen_ids where <|python_start|> begins
def _append_token(tid):
nonlocal input_ids
@ -277,6 +275,20 @@ class Inference:
if input_ids.size(1) > self.config.sequence_len:
input_ids = input_ids[:, -self.config.sequence_len:]
def _match_at(seq: list[int], pos: int, pat: tuple) -> bool:
if pos < 0 or pos + len(pat) > len(seq):
return False
return tuple(seq[pos:pos + len(pat)]) == pat
def _find_subseq(seq: list[int], pat: tuple, start: int = 0) -> int:
L = len(pat)
if L == 0 or len(seq) < start + L:
return -1
for i in range(start, len(seq) - L + 1):
if tuple(seq[i:i + L]) == pat:
return i
return -1
with torch.no_grad():
num_generated = 0
while num_generated < max_tokens:
@ -291,47 +303,42 @@ class Inference:
next_token = torch.multinomial(probs, num_samples=1)
token_id = next_token.item()
# Stop on real terminators (assistant_end, bos)
if token_id in self._stop_token_ids:
break
# Commit to context + sequence
_append_token(token_id)
gen_ids.append(token_id)
num_generated += 1
# Stream raw decoded text (may be empty for partial BPE bytes — that's OK)
try:
token_text = self.tokenizer.decode([token_id])
except Exception:
token_text = ""
# Append to the model's context first
_append_token(token_id)
num_generated += 1
# Stream token to client (raw, UI parses markers)
if token_text:
yield "data: " + json.dumps({"token": token_text, "gpu": 0}) + "\n\n"
assistant_text += token_text
# --- tool-call detection on the text stream ---
if tool_start_pos < 0:
idx = assistant_text.find(TOOL_START)
# --- tool-call detection (id-sequence match) ---
if tool_start_idx < 0:
idx = _find_subseq(gen_ids, tool_start_ids, max(0, len(gen_ids) - len(tool_start_ids) - 2))
if idx >= 0:
tool_start_pos = idx
if tool_start_pos >= 0:
tail = assistant_text[tool_start_pos + len(TOOL_START):]
end_rel = tail.find(TOOL_END)
if end_rel >= 0:
payload_text = tail[:end_rel]
# reset detector for any subsequent call in the same turn
tool_start_pos = -1
tool_start_idx = idx
if tool_start_idx >= 0:
# look for <|python_end|> after the payload
end_idx = _find_subseq(gen_ids, tool_end_ids, tool_start_idx + len(tool_start_ids))
if end_idx >= 0:
payload_ids = gen_ids[tool_start_idx + len(tool_start_ids):end_idx]
try:
payload_text = self.tokenizer.decode(payload_ids)
invocation = self._parse_tool_call(payload_text)
result = self.tool_registry.execute(invocation.tool_name, invocation.arguments)
result_text = result.to_payload()[:4096]
except Exception as exc:
result_text = json.dumps({"error": str(exc)[:500]})
wrapped = OUT_START + result_text + OUT_END
# Push wrapped result into the model's context so the next-token
# prediction is grounded on the real tool output, and also stream it
# to the client so the UI can render the tool-result card.
wrapped = out_start_str + result_text + out_end_str
# Inject real result tokens into the model's context and the client stream.
for rid in self.tokenizer.encode(wrapped):
try:
rt = self.tokenizer.decode([rid])
@ -339,11 +346,13 @@ class Inference:
rt = ""
if rt:
yield "data: " + json.dumps({"token": rt, "gpu": 0}) + "\n\n"
assistant_text += rt
_append_token(rid)
gen_ids.append(rid)
num_generated += 1
if num_generated >= max_tokens:
break
# Reset so a second tool call in the same turn still works
tool_start_idx = -1
yield "data: " + json.dumps({"done": True}) + "\n\n"