mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-07 08:19:52 +00:00
Merge pull request #50 from manmohan659/fix/tool-decode-text-match
fix(serve): decode-tail text match for tool markers
This commit is contained in:
commit
f41da418ab
123
modal/serve.py
123
modal/serve.py
|
|
@ -253,28 +253,21 @@ class Inference:
|
|||
if len(tokens) > max_context:
|
||||
tokens = tokens[-max_context:]
|
||||
|
||||
# Ordinary-text token-id sequences for the tool markers.
|
||||
# The SFT loader tokenizes assistant content with .encode() (not .encode_special()),
|
||||
# so the model emits these as multi-token sequences, not single special-token ids.
|
||||
# Match on the id sequence directly — more reliable than text (BPE partial UTF-8
|
||||
# can make single-token decodes return empty strings).
|
||||
tool_start_ids = tuple(self.tokenizer.encode("<|python_start|>"))
|
||||
tool_end_ids = tuple(self.tokenizer.encode("<|python_end|>"))
|
||||
# so these markers are emitted as multi-token byte sequences, and BPE has
|
||||
# multiple valid tokenizations of the same string — so matching on a single
|
||||
# expected id sequence is unreliable. Instead we decode the tail of the
|
||||
# generated token stream and search for the marker TEXT.
|
||||
tool_start_str = "<|python_start|>"
|
||||
tool_end_str = "<|python_end|>"
|
||||
out_start_str = "<|output_start|>"
|
||||
out_end_str = "<|output_end|>"
|
||||
|
||||
# Suppression: model training has many convs that emit a fake
|
||||
# <|output_start|>…<|output_end|> after our injected one. We stop the
|
||||
# turn if we detect another full <|output_start|> sequence emitted
|
||||
# after our injection.
|
||||
out_start_ids = tuple(self.tokenizer.encode(out_start_str))
|
||||
|
||||
async def stream():
|
||||
input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device)
|
||||
gen_ids: list[int] = [] # everything the MODEL sampled this turn
|
||||
tool_start_idx = -1 # position in gen_ids where <|python_start|> begins
|
||||
tool_injected = False # once True, stop detecting further tool calls
|
||||
injection_end_pos = -1 # index in gen_ids where our injected tokens end
|
||||
pre_injection_len = 0 # len(gen_ids) right before we start injection
|
||||
|
||||
def _append_token(tid):
|
||||
nonlocal input_ids
|
||||
|
|
@ -283,19 +276,13 @@ class Inference:
|
|||
if input_ids.size(1) > self.config.sequence_len:
|
||||
input_ids = input_ids[:, -self.config.sequence_len:]
|
||||
|
||||
def _match_at(seq: list[int], pos: int, pat: tuple) -> bool:
|
||||
if pos < 0 or pos + len(pat) > len(seq):
|
||||
return False
|
||||
return tuple(seq[pos:pos + len(pat)]) == pat
|
||||
|
||||
def _find_subseq(seq: list[int], pat: tuple, start: int = 0) -> int:
|
||||
L = len(pat)
|
||||
if L == 0 or len(seq) < start + L:
|
||||
return -1
|
||||
for i in range(start, len(seq) - L + 1):
|
||||
if tuple(seq[i:i + L]) == pat:
|
||||
return i
|
||||
return -1
|
||||
def _decode_tail_text(last_n: int = 40) -> str:
|
||||
if not gen_ids:
|
||||
return ""
|
||||
try:
|
||||
return self.tokenizer.decode(gen_ids[-last_n:])
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
with torch.no_grad():
|
||||
num_generated = 0
|
||||
|
|
@ -319,7 +306,7 @@ class Inference:
|
|||
gen_ids.append(token_id)
|
||||
num_generated += 1
|
||||
|
||||
# Stream raw decoded text (may be empty for partial BPE bytes — that's OK)
|
||||
# Stream raw decoded text (may be empty for partial BPE bytes)
|
||||
try:
|
||||
token_text = self.tokenizer.decode([token_id])
|
||||
except Exception:
|
||||
|
|
@ -327,46 +314,50 @@ class Inference:
|
|||
if token_text:
|
||||
yield "data: " + json.dumps({"token": token_text, "gpu": 0}) + "\n\n"
|
||||
|
||||
# --- tool-call detection (id-sequence match) ---
|
||||
if not tool_injected and tool_start_idx < 0:
|
||||
idx = _find_subseq(gen_ids, tool_start_ids, max(0, len(gen_ids) - len(tool_start_ids) - 2))
|
||||
if idx >= 0:
|
||||
tool_start_idx = idx
|
||||
if not tool_injected and tool_start_idx >= 0:
|
||||
# look for <|python_end|> after the payload
|
||||
end_idx = _find_subseq(gen_ids, tool_end_ids, tool_start_idx + len(tool_start_ids))
|
||||
if end_idx >= 0:
|
||||
payload_ids = gen_ids[tool_start_idx + len(tool_start_ids):end_idx]
|
||||
try:
|
||||
payload_text = self.tokenizer.decode(payload_ids)
|
||||
invocation = self._parse_tool_call(payload_text)
|
||||
result = self.tool_registry.execute(invocation.tool_name, invocation.arguments)
|
||||
result_text = result.to_payload()[:4096]
|
||||
except Exception as exc:
|
||||
result_text = json.dumps({"error": str(exc)[:500]})
|
||||
# --- tool-call detection on decoded-tail text ---
|
||||
if not tool_injected:
|
||||
# Decode the whole turn-so-far and look for markers
|
||||
try:
|
||||
full_text = self.tokenizer.decode(gen_ids)
|
||||
except Exception:
|
||||
full_text = ""
|
||||
if full_text:
|
||||
ps = full_text.rfind(tool_start_str)
|
||||
if ps >= 0:
|
||||
pe = full_text.find(tool_end_str, ps + len(tool_start_str))
|
||||
if pe >= 0:
|
||||
payload_text = full_text[ps + len(tool_start_str):pe]
|
||||
try:
|
||||
invocation = self._parse_tool_call(payload_text)
|
||||
result = self.tool_registry.execute(invocation.tool_name, invocation.arguments)
|
||||
result_text = result.to_payload()[:4096]
|
||||
except Exception as exc:
|
||||
result_text = json.dumps({"error": str(exc)[:500]})
|
||||
|
||||
wrapped = out_start_str + result_text + out_end_str
|
||||
# Inject real result tokens into the model's context and the client stream.
|
||||
for rid in self.tokenizer.encode(wrapped):
|
||||
try:
|
||||
rt = self.tokenizer.decode([rid])
|
||||
except Exception:
|
||||
rt = ""
|
||||
if rt:
|
||||
yield "data: " + json.dumps({"token": rt, "gpu": 0}) + "\n\n"
|
||||
_append_token(rid)
|
||||
gen_ids.append(rid)
|
||||
num_generated += 1
|
||||
if num_generated >= max_tokens:
|
||||
break
|
||||
tool_injected = True
|
||||
injection_end_pos = len(gen_ids)
|
||||
tool_start_idx = -1
|
||||
pre_injection_len = len(gen_ids)
|
||||
wrapped = out_start_str + result_text + out_end_str
|
||||
for rid in self.tokenizer.encode(wrapped):
|
||||
try:
|
||||
rt = self.tokenizer.decode([rid])
|
||||
except Exception:
|
||||
rt = ""
|
||||
if rt:
|
||||
yield "data: " + json.dumps({"token": rt, "gpu": 0}) + "\n\n"
|
||||
_append_token(rid)
|
||||
gen_ids.append(rid)
|
||||
num_generated += 1
|
||||
if num_generated >= max_tokens:
|
||||
break
|
||||
tool_injected = True
|
||||
|
||||
# After injection, detect if the model emits ANOTHER full
|
||||
# <|output_start|> sequence (training-data loop artifact) and stop the turn.
|
||||
if tool_injected and injection_end_pos > 0:
|
||||
if _find_subseq(gen_ids, out_start_ids, injection_end_pos) >= 0:
|
||||
# After injection: if the model starts emitting another <|output_start|>,
|
||||
# break the turn — the grounded result already streamed.
|
||||
elif pre_injection_len > 0 and len(gen_ids) > pre_injection_len + 20:
|
||||
try:
|
||||
post_text = self.tokenizer.decode(gen_ids[pre_injection_len + 10:])
|
||||
except Exception:
|
||||
post_text = ""
|
||||
if out_start_str in post_text:
|
||||
break
|
||||
|
||||
yield "data: " + json.dumps({"done": True}) + "\n\n"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user