mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-12 09:58:54 +00:00
Merge pull request #49 from manmohan659/fix/tool-loop-suppression
fix(serve): suppress post-injection tool-call loop
This commit is contained in:
commit
eabfbd6d49
|
|
@ -263,10 +263,18 @@ class Inference:
|
||||||
out_start_str = "<|output_start|>"
|
out_start_str = "<|output_start|>"
|
||||||
out_end_str = "<|output_end|>"
|
out_end_str = "<|output_end|>"
|
||||||
|
|
||||||
|
# Suppression: model training has many convs that emit a fake
|
||||||
|
# <|output_start|>…<|output_end|> after our injected one. We stop the
|
||||||
|
# turn if we detect another full <|output_start|> sequence emitted
|
||||||
|
# after our injection.
|
||||||
|
out_start_ids = tuple(self.tokenizer.encode(out_start_str))
|
||||||
|
|
||||||
async def stream():
|
async def stream():
|
||||||
input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device)
|
input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device)
|
||||||
gen_ids: list[int] = [] # everything the MODEL sampled this turn
|
gen_ids: list[int] = [] # everything the MODEL sampled this turn
|
||||||
tool_start_idx = -1 # position in gen_ids where <|python_start|> begins
|
tool_start_idx = -1 # position in gen_ids where <|python_start|> begins
|
||||||
|
tool_injected = False # once True, stop detecting further tool calls
|
||||||
|
injection_end_pos = -1 # index in gen_ids where our injected tokens end
|
||||||
|
|
||||||
def _append_token(tid):
|
def _append_token(tid):
|
||||||
nonlocal input_ids
|
nonlocal input_ids
|
||||||
|
|
@ -320,11 +328,11 @@ class Inference:
|
||||||
yield "data: " + json.dumps({"token": token_text, "gpu": 0}) + "\n\n"
|
yield "data: " + json.dumps({"token": token_text, "gpu": 0}) + "\n\n"
|
||||||
|
|
||||||
# --- tool-call detection (id-sequence match) ---
|
# --- tool-call detection (id-sequence match) ---
|
||||||
if tool_start_idx < 0:
|
if not tool_injected and tool_start_idx < 0:
|
||||||
idx = _find_subseq(gen_ids, tool_start_ids, max(0, len(gen_ids) - len(tool_start_ids) - 2))
|
idx = _find_subseq(gen_ids, tool_start_ids, max(0, len(gen_ids) - len(tool_start_ids) - 2))
|
||||||
if idx >= 0:
|
if idx >= 0:
|
||||||
tool_start_idx = idx
|
tool_start_idx = idx
|
||||||
if tool_start_idx >= 0:
|
if not tool_injected and tool_start_idx >= 0:
|
||||||
# look for <|python_end|> after the payload
|
# look for <|python_end|> after the payload
|
||||||
end_idx = _find_subseq(gen_ids, tool_end_ids, tool_start_idx + len(tool_start_ids))
|
end_idx = _find_subseq(gen_ids, tool_end_ids, tool_start_idx + len(tool_start_ids))
|
||||||
if end_idx >= 0:
|
if end_idx >= 0:
|
||||||
|
|
@ -351,9 +359,16 @@ class Inference:
|
||||||
num_generated += 1
|
num_generated += 1
|
||||||
if num_generated >= max_tokens:
|
if num_generated >= max_tokens:
|
||||||
break
|
break
|
||||||
# Reset so a second tool call in the same turn still works
|
tool_injected = True
|
||||||
|
injection_end_pos = len(gen_ids)
|
||||||
tool_start_idx = -1
|
tool_start_idx = -1
|
||||||
|
|
||||||
|
# After injection, detect if the model emits ANOTHER full
|
||||||
|
# <|output_start|> sequence (training-data loop artifact) and stop the turn.
|
||||||
|
if tool_injected and injection_end_pos > 0:
|
||||||
|
if _find_subseq(gen_ids, out_start_ids, injection_end_pos) >= 0:
|
||||||
|
break
|
||||||
|
|
||||||
yield "data: " + json.dumps({"done": True}) + "\n\n"
|
yield "data: " + json.dumps({"done": True}) + "\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user