mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-12 02:40:17 +00:00
Merge pull request #51 from manmohan659/fix/forced-web-search-and-ui-cleanup
fix: forced web_search classifier + UI orphan-marker cleanup
This commit is contained in:
commit
65e681add5
161
modal/_query_classifier.py
Normal file
161
modal/_query_classifier.py
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
"""Heuristic query classifier for forced tool use.
|
||||
|
||||
This module answers a single question: given a user message, does it likely
|
||||
need a fresh web search to be answered correctly?
|
||||
|
||||
We use this to *force* the model to call `web_search` even when its natural
|
||||
generation path would skip the tool (e.g. "whos the present president" where
|
||||
"present" doesn't appear in the tool-use SFT training distribution).
|
||||
|
||||
Two outputs:
|
||||
needs_web_search(text) -> (bool, rewritten_query)
|
||||
needs_calculator(text) -> (bool, expression)
|
||||
|
||||
Keep the rules simple and fast. False positives are cheap (we pay a Tavily
|
||||
call); false negatives are the real cost (user gets stale training-data).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Web-search triggers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Time words that almost always indicate the user wants fresh info
|
||||
_TIME_WORDS = r"""
|
||||
(?:
|
||||
current | currently | now | today | tonight | tomorrow | yesterday |
|
||||
present | this\ (?:week|month|year|morning|afternoon|evening) |
|
||||
latest | recent | recently | upcoming | right\ now | as\ of |
|
||||
20\d{2} | this\ very\ (?:moment|second|minute)
|
||||
)
|
||||
"""
|
||||
|
||||
# Role / position words that change over time
|
||||
_POSITION_WORDS = r"""
|
||||
(?:
|
||||
president | vice[-\s]?president | prime\ minister | chancellor | governor |
|
||||
senator | congressman | congresswoman | representative | mayor |
|
||||
ceo | cto | cfo | coo | chairman | chief\s+executive |
|
||||
chief\s+justice | speaker | foreign\s+minister | attorney\s+general |
|
||||
pope | monarch | king | queen | emperor | dictator | leader
|
||||
)
|
||||
"""
|
||||
|
||||
# Topic categories that are inherently time-sensitive
|
||||
_CATEGORY_PATTERNS = [
|
||||
# weather / climate
|
||||
r"\b(?:weather|temperature|forecast|humidity|rainfall|snowfall|wind\s+speed)\b",
|
||||
# finance / markets
|
||||
r"\b(?:stock\s+price|share\s+price|market\s+cap|exchange\s+rate|forex|crypto(?:currency)?\s+price)\b",
|
||||
r"\b(?:nasdaq|s&p\s*500|dow\s+jones|nifty|sensex|ftse|nikkei)\b",
|
||||
r"\bprice\s+of\s+(?:gold|silver|oil|bitcoin|ethereum)\b",
|
||||
# news / events
|
||||
r"\b(?:breaking\s+news|headlines?|latest\s+news|news\s+(?:today|now))\b",
|
||||
r"\bwhat(?:'|\s+i)s\s+happening\s+(?:in|with|at)\b",
|
||||
# sports scores
|
||||
r"\b(?:score|result|winner|loser|champion|finalist)\b.*(?:match|game|tournament|series|cup|open|championship)",
|
||||
r"\b(?:ipl|world\s+cup|super\s+bowl|nba|nfl|nhl|olympics?|wimbledon|us\s+open|australian\s+open|french\s+open)\b",
|
||||
# time of day / where
|
||||
r"\bwhat\s+time\s+is\s+it\b",
|
||||
r"\btime\s+in\s+[A-Z][a-z]+\b",
|
||||
# people (often changes) — VIPs, CEOs, recent releases
|
||||
r"\bwho\s+(?:is|'s|runs|leads|owns|founded|heads)\b",
|
||||
# "is X still alive" / "did X die"
|
||||
r"\bis\s+\w+\s+(?:still\s+)?(?:alive|dead)\b",
|
||||
# recent product / release
|
||||
r"\b(?:latest|newest|recent|upcoming)\s+(?:version|release|model|iphone|android|macbook|tesla|game|album|movie|film|book)\b",
|
||||
# recent statistics that change
|
||||
r"\b(?:population|number\s+of)\s+\w+\s+(?:of|in)\b",
|
||||
# explicit search instructions
|
||||
r"\b(?:search|look\s+up|google|find\s+out|check)\s+(?:online|on\s+the\s+web|on\s+google)\b",
|
||||
r"\buse\s+(?:the\s+)?web[_\s]?search\b",
|
||||
]
|
||||
|
||||
# Standalone keyword patterns that, combined with any position/role, trigger search
|
||||
_KEYWORD_GATE = re.compile(
|
||||
rf"""
|
||||
\b{_TIME_WORDS}\b # at least one time word
|
||||
| \b(?:who|what|where|when)\b.*\b{_POSITION_WORDS}\b # who/what + position
|
||||
| \b{_POSITION_WORDS}\b.*\b(?:of|in|for|at)\s+[A-Z] # position + proper noun (of America, in India)
|
||||
""",
|
||||
re.IGNORECASE | re.VERBOSE,
|
||||
)
|
||||
|
||||
_CATEGORY_REGEXES = [re.compile(p, re.IGNORECASE) for p in _CATEGORY_PATTERNS]
|
||||
|
||||
|
||||
def needs_web_search(text: str) -> Tuple[bool, str]:
|
||||
"""Classify whether a user query likely needs a live web search.
|
||||
|
||||
Returns (needs, rewritten_query). The rewritten_query strips filler and
|
||||
reformulates for better Tavily results (e.g. "whos the present president" ->
|
||||
"who is the current president of the United States 2026").
|
||||
"""
|
||||
if not text or not isinstance(text, str):
|
||||
return False, ""
|
||||
|
||||
stripped = text.strip()
|
||||
if len(stripped) < 3:
|
||||
return False, ""
|
||||
|
||||
# Any category pattern hit
|
||||
for rx in _CATEGORY_REGEXES:
|
||||
if rx.search(stripped):
|
||||
return True, _rewrite_query(stripped)
|
||||
|
||||
# Keyword gate (time words or position + specifier)
|
||||
if _KEYWORD_GATE.search(stripped):
|
||||
return True, _rewrite_query(stripped)
|
||||
|
||||
return False, ""
|
||||
|
||||
|
||||
def _rewrite_query(text: str) -> str:
|
||||
"""Clean up the query for Tavily — expand contractions, normalize 'present'->'current',
|
||||
strip filler, add a year anchor."""
|
||||
q = text.strip().rstrip("?.!")
|
||||
# contractions
|
||||
q = re.sub(r"\bwho(?:'| i)s\b", "who is", q, flags=re.IGNORECASE)
|
||||
q = re.sub(r"\bwhat(?:'| i)s\b", "what is", q, flags=re.IGNORECASE)
|
||||
q = re.sub(r"\bwhere(?:'| i)s\b", "where is", q, flags=re.IGNORECASE)
|
||||
q = re.sub(r"\bwhen(?:'| i)s\b", "when is", q, flags=re.IGNORECASE)
|
||||
q = re.sub(r"\bits\b", "it is", q, flags=re.IGNORECASE)
|
||||
# strip quantifiers the model tends to hallucinate
|
||||
q = re.sub(r"\b(please|kindly|could\s+you|can\s+you)\b\s*", "", q, flags=re.IGNORECASE)
|
||||
# "present X" -> "current X" (better Tavily results)
|
||||
q = re.sub(r"\bpresent\b", "current", q, flags=re.IGNORECASE)
|
||||
# collapse whitespace
|
||||
q = re.sub(r"\s+", " ", q).strip()
|
||||
# anchor to the current year if no year is already present
|
||||
if not re.search(r"\b20\d{2}\b", q):
|
||||
q = q + " 2026"
|
||||
return q
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Calculator triggers (cheap, local)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CALC_RX = re.compile(
|
||||
r"""
|
||||
\b(?:calculate|compute|what\s+is|what's)\b.*?
|
||||
(?:
|
||||
\d[\d,\.\s]*\s*[+\-\*/x×÷]\s*\d # basic arithmetic
|
||||
| \d+\s*%\s+(?:of|tip|tax|discount)\s+\d # percentage
|
||||
| \b(?:emi|cagr|compound\s+interest|tip|discount|percent(?:age)?)\b.*\d
|
||||
)
|
||||
""",
|
||||
re.IGNORECASE | re.VERBOSE,
|
||||
)
|
||||
|
||||
|
||||
def needs_calculator(text: str) -> Tuple[bool, str]:
|
||||
if not text:
|
||||
return False, ""
|
||||
m = _CALC_RX.search(text)
|
||||
if not m:
|
||||
return False, ""
|
||||
return True, text.strip()
|
||||
|
|
@ -51,6 +51,7 @@ inference_image = (
|
|||
.add_local_file("modal/_model.py", "/root/_model.py")
|
||||
.add_local_file("modal/_tokenizer.py", "/root/_tokenizer.py")
|
||||
.add_local_file("modal/_tools.py", "/root/_tools.py")
|
||||
.add_local_file("modal/_query_classifier.py", "/root/_query_classifier.py")
|
||||
)
|
||||
|
||||
# Persistent volume for model weights
|
||||
|
|
@ -198,8 +199,10 @@ class Inference:
|
|||
import sys as _sys
|
||||
if '/root' not in _sys.path: _sys.path.insert(0, '/root')
|
||||
from _tools import build_default_tool_registry, parse_tool_call_payload
|
||||
from _query_classifier import needs_web_search
|
||||
self.tool_registry = build_default_tool_registry()
|
||||
self._parse_tool_call = parse_tool_call_payload
|
||||
self._needs_web_search = needs_web_search
|
||||
# Marker tokens for tool state machine
|
||||
self.python_start_id = self.tokenizer.encode_special("<|python_start|>")[0]
|
||||
self.python_end_id = self.tokenizer.encode_special("<|python_end|>")[0]
|
||||
|
|
@ -248,6 +251,42 @@ class Inference:
|
|||
# Prompt the model to generate an assistant response
|
||||
tokens.extend(assistant_start)
|
||||
|
||||
# --- Forced tool use ---
|
||||
# The model's SFT training doesn't always trigger web_search even when
|
||||
# a question clearly needs current info (e.g. "present president" vs
|
||||
# "current president"). We classify the last user message and, if it
|
||||
# matches tool-worthy patterns, pre-seed the assistant turn with a real
|
||||
# tool call + Tavily result. The model then just writes the final
|
||||
# grounded answer instead of hallucinating from stale memory.
|
||||
forced_prefix_text = ""
|
||||
last_user = ""
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
last_user = msg.get("content", "")
|
||||
break
|
||||
try:
|
||||
needs_search, rewritten = self._needs_web_search(last_user)
|
||||
except Exception:
|
||||
needs_search, rewritten = False, ""
|
||||
if needs_search and rewritten:
|
||||
preface = "I'll look that up for you. "
|
||||
tool_call_json = json.dumps(
|
||||
{"arguments": {"query": rewritten, "top_k": 1}, "tool": "web_search"},
|
||||
separators=(",", ":"),
|
||||
)
|
||||
try:
|
||||
invocation = self._parse_tool_call(tool_call_json)
|
||||
tool_result = self.tool_registry.execute(invocation.tool_name, invocation.arguments)
|
||||
result_text = tool_result.to_payload()[:4096]
|
||||
except Exception as exc:
|
||||
result_text = json.dumps({"error": str(exc)[:500]})
|
||||
forced_prefix_text = (
|
||||
preface
|
||||
+ "<|python_start|>" + tool_call_json + "<|python_end|>"
|
||||
+ "<|output_start|>" + result_text + "<|output_end|>\n"
|
||||
)
|
||||
tokens.extend(self.tokenizer.encode(forced_prefix_text))
|
||||
|
||||
# Truncate to fit context
|
||||
max_context = self.config.sequence_len - max_tokens
|
||||
if len(tokens) > max_context:
|
||||
|
|
@ -266,9 +305,14 @@ class Inference:
|
|||
async def stream():
|
||||
input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device)
|
||||
gen_ids: list[int] = [] # everything the MODEL sampled this turn
|
||||
tool_injected = False # once True, stop detecting further tool calls
|
||||
tool_injected = bool(forced_prefix_text) # forced prefix counts as an injection
|
||||
pre_injection_len = 0 # len(gen_ids) right before we start injection
|
||||
|
||||
# If we pre-seeded a forced tool call + result, stream it to the client
|
||||
# now so the UI can render the tool-call / tool-result cards.
|
||||
if forced_prefix_text:
|
||||
yield "data: " + json.dumps({"token": forced_prefix_text, "gpu": 0}) + "\n\n"
|
||||
|
||||
def _append_token(tid):
|
||||
nonlocal input_ids
|
||||
nt = torch.tensor([[tid]], dtype=torch.long, device=self.device)
|
||||
|
|
@ -350,14 +394,22 @@ class Inference:
|
|||
break
|
||||
tool_injected = True
|
||||
|
||||
# After injection: if the model starts emitting another <|output_start|>,
|
||||
# break the turn — the grounded result already streamed.
|
||||
elif pre_injection_len > 0 and len(gen_ids) > pre_injection_len + 20:
|
||||
# After injection (forced OR runtime): the model often loops and
|
||||
# emits another fake <|output_start|>…<|output_end|> / <|python_start|>…
|
||||
# block. Break the turn as soon as ANY tool-marker appears in what the
|
||||
# MODEL itself generated. We check the decoded text of gen_ids[pre_injection_len:].
|
||||
elif tool_injected and len(gen_ids) > pre_injection_len + 6:
|
||||
try:
|
||||
post_text = self.tokenizer.decode(gen_ids[pre_injection_len + 10:])
|
||||
post_text = self.tokenizer.decode(gen_ids[pre_injection_len:])
|
||||
except Exception:
|
||||
post_text = ""
|
||||
if out_start_str in post_text:
|
||||
for bad in (out_start_str, out_end_str, tool_start_str, tool_end_str):
|
||||
if bad in post_text:
|
||||
break_now = True
|
||||
break
|
||||
else:
|
||||
break_now = False
|
||||
if break_now:
|
||||
break
|
||||
|
||||
yield "data: " + json.dumps({"done": True}) + "\n\n"
|
||||
|
|
|
|||
|
|
@ -18,16 +18,19 @@ type Segment =
|
|||
| { kind: 'tool_result'; content: string; closed: boolean };
|
||||
|
||||
function parseSegments(raw: string): Segment[] {
|
||||
// First pass: strip orphan tool markers (end-tag without open-tag, or any
|
||||
// stray marker outside a pair) that the model sometimes emits as loop
|
||||
// artifacts — otherwise they leak into the message body as raw text.
|
||||
raw = stripOrphanMarkers(raw);
|
||||
|
||||
const segs: Segment[] = [];
|
||||
let i = 0;
|
||||
// marker -> [openTag, closeTag, kind]
|
||||
const markers: Array<[string, string, Segment['kind']]> = [
|
||||
['<think>', '</think>', 'think'],
|
||||
['<|python_start|>', '<|python_end|>', 'tool_call'],
|
||||
['<|output_start|>', '<|output_end|>', 'tool_result'],
|
||||
];
|
||||
while (i < raw.length) {
|
||||
// find the nearest opening marker
|
||||
let bestOpen = -1;
|
||||
let bestMarker: [string, string, Segment['kind']] | null = null;
|
||||
for (const m of markers) {
|
||||
|
|
@ -50,7 +53,73 @@ function parseSegments(raw: string): Segment[] {
|
|||
i = closeIdx + closeTag.length;
|
||||
}
|
||||
}
|
||||
return segs;
|
||||
|
||||
return dedupeAndClean(segs);
|
||||
}
|
||||
|
||||
function stripOrphanMarkers(s: string): string {
|
||||
// Walk the string left-to-right. For each opening marker we encounter, keep
|
||||
// it only if its matching close exists somewhere after it. For each close
|
||||
// marker encountered without a preceding open, drop it.
|
||||
const pairs: Array<[string, string]> = [
|
||||
['<think>', '</think>'],
|
||||
['<|python_start|>', '<|python_end|>'],
|
||||
['<|output_start|>', '<|output_end|>'],
|
||||
];
|
||||
for (const [open, close] of pairs) {
|
||||
// Remove any close-tag that has no preceding open-tag
|
||||
const openPositions: number[] = [];
|
||||
let idx = 0;
|
||||
while (true) {
|
||||
const p = s.indexOf(open, idx);
|
||||
if (p === -1) break;
|
||||
openPositions.push(p);
|
||||
idx = p + open.length;
|
||||
}
|
||||
const closePositions: number[] = [];
|
||||
idx = 0;
|
||||
while (true) {
|
||||
const p = s.indexOf(close, idx);
|
||||
if (p === -1) break;
|
||||
closePositions.push(p);
|
||||
idx = p + close.length;
|
||||
}
|
||||
// drop close tags that appear before any open tag
|
||||
const firstOpen = openPositions[0] ?? Infinity;
|
||||
const orphanCloses = closePositions.filter((c) => c < firstOpen);
|
||||
if (orphanCloses.length) {
|
||||
// remove each orphan close (work in reverse so indices stay valid)
|
||||
for (const c of orphanCloses.reverse()) {
|
||||
s = s.slice(0, c) + s.slice(c + close.length);
|
||||
}
|
||||
}
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
function dedupeAndClean(segs: Segment[]): Segment[] {
|
||||
const out: Segment[] = [];
|
||||
let lastResultKey: string | null = null;
|
||||
for (const seg of segs) {
|
||||
// collapse consecutive duplicate tool_result segments (model re-emits the
|
||||
// same block as a training artifact)
|
||||
if (seg.kind === 'tool_result') {
|
||||
const key = seg.content.replace(/\s+/g, ' ').trim();
|
||||
if (key === lastResultKey) continue;
|
||||
lastResultKey = key;
|
||||
} else {
|
||||
lastResultKey = null;
|
||||
}
|
||||
// drop plain-text segments that are just leftover tool-marker fragments
|
||||
if (seg.kind === 'text') {
|
||||
const t = seg.content.replace(/<\|?(?:python|output)_(?:start|end)\|?>/g, '').trim();
|
||||
if (!t) continue;
|
||||
out.push({ kind: 'text', content: seg.content.replace(/<\|?(?:python|output)_(?:start|end)\|?>/g, '') });
|
||||
continue;
|
||||
}
|
||||
out.push(seg);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function ThinkBlock({ content, closed }: { content: string; closed: boolean }) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user