mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-07 16:30:11 +00:00
fix(serve): detect tool markers in text stream not token ids
The SFT loader tokenizes assistant content with .encode() (ordinary), not .encode_special(), so the model was trained to emit <|python_start|> / <|python_end|> as the 7-token ordinary sequence [60, 124, 25145, 95, 17104, 124, 62] rather than as special token id 32764. My prior state-machine matched token_id == python_start_id, which never fired — so tool calls were never executed and the model just hallucinated fake tool results (Official leadership page etc). Fix: detect markers in the decoded text stream, parse the payload between <|python_start|> and <|python_end|>, execute the tool, inject the real <|output_start|>…<|output_end|> tokens into both the SSE stream and the model's input_ids. Next-token prediction is now grounded on real Tavily output.
This commit is contained in:
parent
f642cb2eb6
commit
7a92f5b016
116
modal/serve.py
116
modal/serve.py
|
|
@ -254,11 +254,21 @@ class Inference:
|
|||
tokens = tokens[-max_context:]
|
||||
|
||||
async def stream():
|
||||
from collections import deque
|
||||
input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device)
|
||||
forced = deque()
|
||||
in_tool = False
|
||||
tool_payload_ids = []
|
||||
# The model was SFT-trained with <|python_start|>/<|python_end|>/<|output_start|>/<|output_end|>
|
||||
# encoded as ORDINARY token sequences (not single special-token ids), because the SFT
|
||||
# loader tokenizes assistant content via .encode() rather than .encode_special().
|
||||
# So we must detect the markers in the decoded TEXT stream, not in individual token ids.
|
||||
# When we see <|python_end|> complete in the accumulated text, we execute the tool and
|
||||
# inject the real <|output_start|>…<|output_end|> tokens into both the stream and the
|
||||
# model's input_ids so subsequent generation conditions on the real result.
|
||||
TOOL_START = "<|python_start|>"
|
||||
TOOL_END = "<|python_end|>"
|
||||
OUT_START = "<|output_start|>"
|
||||
OUT_END = "<|output_end|>"
|
||||
assistant_text = ""
|
||||
# index of tool-call start in assistant_text, once we see it
|
||||
tool_start_pos = -1
|
||||
|
||||
def _append_token(tid):
|
||||
nonlocal input_ids
|
||||
|
|
@ -270,56 +280,70 @@ class Inference:
|
|||
with torch.no_grad():
|
||||
num_generated = 0
|
||||
while num_generated < max_tokens:
|
||||
if forced:
|
||||
token_id = forced.popleft()
|
||||
else:
|
||||
logits = self.model(input_ids)
|
||||
next_logits = logits[:, -1, :]
|
||||
if temperature > 0:
|
||||
next_logits = next_logits / temperature
|
||||
if top_k > 0:
|
||||
v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
|
||||
next_logits[next_logits < v[:, [-1]]] = float('-inf')
|
||||
probs = torch.softmax(next_logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
token_id = next_token.item()
|
||||
logits = self.model(input_ids)
|
||||
next_logits = logits[:, -1, :]
|
||||
if temperature > 0:
|
||||
next_logits = next_logits / temperature
|
||||
if top_k > 0:
|
||||
v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
|
||||
next_logits[next_logits < v[:, [-1]]] = float('-inf')
|
||||
probs = torch.softmax(next_logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
token_id = next_token.item()
|
||||
|
||||
# Tool state machine: detect <|python_start|>...<|python_end|>,
|
||||
# execute tool, inject <|output_start|>...<|output_end|> as forced tokens
|
||||
if token_id == self.python_start_id:
|
||||
in_tool = True
|
||||
tool_payload_ids = []
|
||||
elif token_id == self.python_end_id and in_tool:
|
||||
in_tool = False
|
||||
if tool_payload_ids:
|
||||
# Stop on real terminators (assistant_end, bos)
|
||||
if token_id in self._stop_token_ids:
|
||||
break
|
||||
|
||||
try:
|
||||
token_text = self.tokenizer.decode([token_id])
|
||||
except Exception:
|
||||
token_text = ""
|
||||
|
||||
# Append to the model's context first
|
||||
_append_token(token_id)
|
||||
num_generated += 1
|
||||
|
||||
# Stream token to client (raw, UI parses markers)
|
||||
if token_text:
|
||||
yield "data: " + json.dumps({"token": token_text, "gpu": 0}) + "\n\n"
|
||||
assistant_text += token_text
|
||||
|
||||
# --- tool-call detection on the text stream ---
|
||||
if tool_start_pos < 0:
|
||||
idx = assistant_text.find(TOOL_START)
|
||||
if idx >= 0:
|
||||
tool_start_pos = idx
|
||||
if tool_start_pos >= 0:
|
||||
tail = assistant_text[tool_start_pos + len(TOOL_START):]
|
||||
end_rel = tail.find(TOOL_END)
|
||||
if end_rel >= 0:
|
||||
payload_text = tail[:end_rel]
|
||||
# reset detector for any subsequent call in the same turn
|
||||
tool_start_pos = -1
|
||||
try:
|
||||
payload_text = self.tokenizer.decode(tool_payload_ids)
|
||||
invocation = self._parse_tool_call(payload_text)
|
||||
result = self.tool_registry.execute(invocation.tool_name, invocation.arguments)
|
||||
result_text = result.to_payload()[:4096]
|
||||
except Exception as exc:
|
||||
result_text = json.dumps({"error": str(exc)[:500]})
|
||||
if result_text:
|
||||
forced.append(self.output_start_id)
|
||||
forced.extend(self.tokenizer.encode(result_text))
|
||||
forced.append(self.output_end_id)
|
||||
tool_payload_ids = []
|
||||
elif in_tool:
|
||||
tool_payload_ids.append(token_id)
|
||||
|
||||
# Stop only on assistant_end or bos (NOT on tool markers)
|
||||
if token_id in self._stop_token_ids:
|
||||
break
|
||||
|
||||
# Decode + stream to client (includes tool markers; UI renders)
|
||||
try:
|
||||
token_text = self.tokenizer.decode([token_id])
|
||||
yield "data: " + json.dumps({"token": token_text, "gpu": 0}) + "\n\n"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_append_token(token_id)
|
||||
num_generated += 1
|
||||
wrapped = OUT_START + result_text + OUT_END
|
||||
# Push wrapped result into the model's context so the next-token
|
||||
# prediction is grounded on the real tool output, and also stream it
|
||||
# to the client so the UI can render the tool-result card.
|
||||
for rid in self.tokenizer.encode(wrapped):
|
||||
try:
|
||||
rt = self.tokenizer.decode([rid])
|
||||
except Exception:
|
||||
rt = ""
|
||||
if rt:
|
||||
yield "data: " + json.dumps({"token": rt, "gpu": 0}) + "\n\n"
|
||||
assistant_text += rt
|
||||
_append_token(rid)
|
||||
num_generated += 1
|
||||
if num_generated >= max_tokens:
|
||||
break
|
||||
|
||||
yield "data: " + json.dumps({"done": True}) + "\n\n"
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user