Merge pull request #51 from manmohan659/fix/forced-web-search-and-ui-cleanup

fix: forced web_search classifier + UI orphan-marker cleanup
This commit is contained in:
Manmohan 2026-04-22 18:01:13 -04:00 committed by GitHub
commit 65e681add5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 291 additions and 9 deletions

161
modal/_query_classifier.py Normal file
View File

@ -0,0 +1,161 @@
"""Heuristic query classifier for forced tool use.
This module answers a single question: given a user message, does it likely
need a fresh web search to be answered correctly?
We use this to *force* the model to call `web_search` even when its natural
generation path would skip the tool (e.g. "whos the present president" where
"present" doesn't appear in the tool-use SFT training distribution).
Two outputs:
needs_web_search(text) -> (bool, rewritten_query)
needs_calculator(text) -> (bool, expression)
Keep the rules simple and fast. False positives are cheap (we pay a Tavily
call); false negatives are the real cost (user gets stale training-data).
"""
from __future__ import annotations
import re
from typing import Tuple
# ---------------------------------------------------------------------------
# Web-search triggers
# ---------------------------------------------------------------------------
# Time words that almost always indicate the user wants fresh info
_TIME_WORDS = r"""
(?:
current | currently | now | today | tonight | tomorrow | yesterday |
present | this\ (?:week|month|year|morning|afternoon|evening) |
latest | recent | recently | upcoming | right\ now | as\ of |
20\d{2} | this\ very\ (?:moment|second|minute)
)
"""
# Role / position words that change over time
_POSITION_WORDS = r"""
(?:
president | vice[-\s]?president | prime\ minister | chancellor | governor |
senator | congressman | congresswoman | representative | mayor |
ceo | cto | cfo | coo | chairman | chief\s+executive |
chief\s+justice | speaker | foreign\s+minister | attorney\s+general |
pope | monarch | king | queen | emperor | dictator | leader
)
"""
# Topic categories that are inherently time-sensitive
_CATEGORY_PATTERNS = [
# weather / climate
r"\b(?:weather|temperature|forecast|humidity|rainfall|snowfall|wind\s+speed)\b",
# finance / markets
r"\b(?:stock\s+price|share\s+price|market\s+cap|exchange\s+rate|forex|crypto(?:currency)?\s+price)\b",
r"\b(?:nasdaq|s&p\s*500|dow\s+jones|nifty|sensex|ftse|nikkei)\b",
r"\bprice\s+of\s+(?:gold|silver|oil|bitcoin|ethereum)\b",
# news / events
r"\b(?:breaking\s+news|headlines?|latest\s+news|news\s+(?:today|now))\b",
r"\bwhat(?:'|\s+i)s\s+happening\s+(?:in|with|at)\b",
# sports scores
r"\b(?:score|result|winner|loser|champion|finalist)\b.*(?:match|game|tournament|series|cup|open|championship)",
r"\b(?:ipl|world\s+cup|super\s+bowl|nba|nfl|nhl|olympics?|wimbledon|us\s+open|australian\s+open|french\s+open)\b",
# time of day / where
r"\bwhat\s+time\s+is\s+it\b",
r"\btime\s+in\s+[A-Z][a-z]+\b",
# people (often changes) — VIPs, CEOs, recent releases
r"\bwho\s+(?:is|'s|runs|leads|owns|founded|heads)\b",
# "is X still alive" / "did X die"
r"\bis\s+\w+\s+(?:still\s+)?(?:alive|dead)\b",
# recent product / release
r"\b(?:latest|newest|recent|upcoming)\s+(?:version|release|model|iphone|android|macbook|tesla|game|album|movie|film|book)\b",
# recent statistics that change
r"\b(?:population|number\s+of)\s+\w+\s+(?:of|in)\b",
# explicit search instructions
r"\b(?:search|look\s+up|google|find\s+out|check)\s+(?:online|on\s+the\s+web|on\s+google)\b",
r"\buse\s+(?:the\s+)?web[_\s]?search\b",
]
# Standalone keyword patterns that, combined with any position/role, trigger search
_KEYWORD_GATE = re.compile(
rf"""
\b{_TIME_WORDS}\b # at least one time word
| \b(?:who|what|where|when)\b.*\b{_POSITION_WORDS}\b # who/what + position
| \b{_POSITION_WORDS}\b.*\b(?:of|in|for|at)\s+[A-Z] # position + proper noun (of America, in India)
""",
re.IGNORECASE | re.VERBOSE,
)
_CATEGORY_REGEXES = [re.compile(p, re.IGNORECASE) for p in _CATEGORY_PATTERNS]
def needs_web_search(text: str) -> Tuple[bool, str]:
"""Classify whether a user query likely needs a live web search.
Returns (needs, rewritten_query). The rewritten_query strips filler and
reformulates for better Tavily results (e.g. "whos the present president" ->
"who is the current president of the United States 2026").
"""
if not text or not isinstance(text, str):
return False, ""
stripped = text.strip()
if len(stripped) < 3:
return False, ""
# Any category pattern hit
for rx in _CATEGORY_REGEXES:
if rx.search(stripped):
return True, _rewrite_query(stripped)
# Keyword gate (time words or position + specifier)
if _KEYWORD_GATE.search(stripped):
return True, _rewrite_query(stripped)
return False, ""
def _rewrite_query(text: str) -> str:
"""Clean up the query for Tavily — expand contractions, normalize 'present'->'current',
strip filler, add a year anchor."""
q = text.strip().rstrip("?.!")
# contractions
q = re.sub(r"\bwho(?:'| i)s\b", "who is", q, flags=re.IGNORECASE)
q = re.sub(r"\bwhat(?:'| i)s\b", "what is", q, flags=re.IGNORECASE)
q = re.sub(r"\bwhere(?:'| i)s\b", "where is", q, flags=re.IGNORECASE)
q = re.sub(r"\bwhen(?:'| i)s\b", "when is", q, flags=re.IGNORECASE)
q = re.sub(r"\bits\b", "it is", q, flags=re.IGNORECASE)
# strip quantifiers the model tends to hallucinate
q = re.sub(r"\b(please|kindly|could\s+you|can\s+you)\b\s*", "", q, flags=re.IGNORECASE)
# "present X" -> "current X" (better Tavily results)
q = re.sub(r"\bpresent\b", "current", q, flags=re.IGNORECASE)
# collapse whitespace
q = re.sub(r"\s+", " ", q).strip()
# anchor to the current year if no year is already present
if not re.search(r"\b20\d{2}\b", q):
q = q + " 2026"
return q
# ---------------------------------------------------------------------------
# Calculator triggers (cheap, local)
# ---------------------------------------------------------------------------
_CALC_RX = re.compile(
r"""
\b(?:calculate|compute|what\s+is|what's)\b.*?
(?:
\d[\d,\.\s]*\s*[+\-\*/x×÷]\s*\d # basic arithmetic
| \d+\s*%\s+(?:of|tip|tax|discount)\s+\d # percentage
| \b(?:emi|cagr|compound\s+interest|tip|discount|percent(?:age)?)\b.*\d
)
""",
re.IGNORECASE | re.VERBOSE,
)
def needs_calculator(text: str) -> Tuple[bool, str]:
if not text:
return False, ""
m = _CALC_RX.search(text)
if not m:
return False, ""
return True, text.strip()

View File

@ -51,6 +51,7 @@ inference_image = (
.add_local_file("modal/_model.py", "/root/_model.py")
.add_local_file("modal/_tokenizer.py", "/root/_tokenizer.py")
.add_local_file("modal/_tools.py", "/root/_tools.py")
.add_local_file("modal/_query_classifier.py", "/root/_query_classifier.py")
)
# Persistent volume for model weights
@ -198,8 +199,10 @@ class Inference:
import sys as _sys
if '/root' not in _sys.path: _sys.path.insert(0, '/root')
from _tools import build_default_tool_registry, parse_tool_call_payload
from _query_classifier import needs_web_search
self.tool_registry = build_default_tool_registry()
self._parse_tool_call = parse_tool_call_payload
self._needs_web_search = needs_web_search
# Marker tokens for tool state machine
self.python_start_id = self.tokenizer.encode_special("<|python_start|>")[0]
self.python_end_id = self.tokenizer.encode_special("<|python_end|>")[0]
@ -248,6 +251,42 @@ class Inference:
# Prompt the model to generate an assistant response
tokens.extend(assistant_start)
# --- Forced tool use ---
# The model's SFT training doesn't always trigger web_search even when
# a question clearly needs current info (e.g. "present president" vs
# "current president"). We classify the last user message and, if it
# matches tool-worthy patterns, pre-seed the assistant turn with a real
# tool call + Tavily result. The model then just writes the final
# grounded answer instead of hallucinating from stale memory.
forced_prefix_text = ""
last_user = ""
for msg in reversed(messages):
if msg.get("role") == "user":
last_user = msg.get("content", "")
break
try:
needs_search, rewritten = self._needs_web_search(last_user)
except Exception:
needs_search, rewritten = False, ""
if needs_search and rewritten:
preface = "I'll look that up for you. "
tool_call_json = json.dumps(
{"arguments": {"query": rewritten, "top_k": 1}, "tool": "web_search"},
separators=(",", ":"),
)
try:
invocation = self._parse_tool_call(tool_call_json)
tool_result = self.tool_registry.execute(invocation.tool_name, invocation.arguments)
result_text = tool_result.to_payload()[:4096]
except Exception as exc:
result_text = json.dumps({"error": str(exc)[:500]})
forced_prefix_text = (
preface
+ "<|python_start|>" + tool_call_json + "<|python_end|>"
+ "<|output_start|>" + result_text + "<|output_end|>\n"
)
tokens.extend(self.tokenizer.encode(forced_prefix_text))
# Truncate to fit context
max_context = self.config.sequence_len - max_tokens
if len(tokens) > max_context:
@ -266,9 +305,14 @@ class Inference:
async def stream():
input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device)
gen_ids: list[int] = [] # everything the MODEL sampled this turn
tool_injected = False # once True, stop detecting further tool calls
tool_injected = bool(forced_prefix_text) # forced prefix counts as an injection
pre_injection_len = 0 # len(gen_ids) right before we start injection
# If we pre-seeded a forced tool call + result, stream it to the client
# now so the UI can render the tool-call / tool-result cards.
if forced_prefix_text:
yield "data: " + json.dumps({"token": forced_prefix_text, "gpu": 0}) + "\n\n"
def _append_token(tid):
nonlocal input_ids
nt = torch.tensor([[tid]], dtype=torch.long, device=self.device)
@ -350,14 +394,22 @@ class Inference:
break
tool_injected = True
# After injection: if the model starts emitting another <|output_start|>,
# break the turn — the grounded result already streamed.
elif pre_injection_len > 0 and len(gen_ids) > pre_injection_len + 20:
# After injection (forced OR runtime): the model often loops and
# emits another fake <|output_start|>…<|output_end|> / <|python_start|>…
# block. Break the turn as soon as ANY tool-marker appears in what the
# MODEL itself generated. We check the decoded text of gen_ids[pre_injection_len:].
elif tool_injected and len(gen_ids) > pre_injection_len + 6:
try:
post_text = self.tokenizer.decode(gen_ids[pre_injection_len + 10:])
post_text = self.tokenizer.decode(gen_ids[pre_injection_len:])
except Exception:
post_text = ""
if out_start_str in post_text:
for bad in (out_start_str, out_end_str, tool_start_str, tool_end_str):
if bad in post_text:
break_now = True
break
else:
break_now = False
if break_now:
break
yield "data: " + json.dumps({"done": True}) + "\n\n"

View File

@ -18,16 +18,19 @@ type Segment =
| { kind: 'tool_result'; content: string; closed: boolean };
function parseSegments(raw: string): Segment[] {
// First pass: strip orphan tool markers (end-tag without open-tag, or any
// stray marker outside a pair) that the model sometimes emits as loop
// artifacts — otherwise they leak into the message body as raw text.
raw = stripOrphanMarkers(raw);
const segs: Segment[] = [];
let i = 0;
// marker -> [openTag, closeTag, kind]
const markers: Array<[string, string, Segment['kind']]> = [
['<think>', '</think>', 'think'],
['<|python_start|>', '<|python_end|>', 'tool_call'],
['<|output_start|>', '<|output_end|>', 'tool_result'],
];
while (i < raw.length) {
// find the nearest opening marker
let bestOpen = -1;
let bestMarker: [string, string, Segment['kind']] | null = null;
for (const m of markers) {
@ -50,7 +53,73 @@ function parseSegments(raw: string): Segment[] {
i = closeIdx + closeTag.length;
}
}
return segs;
return dedupeAndClean(segs);
}
function stripOrphanMarkers(s: string): string {
// Walk the string left-to-right. For each opening marker we encounter, keep
// it only if its matching close exists somewhere after it. For each close
// marker encountered without a preceding open, drop it.
const pairs: Array<[string, string]> = [
['<think>', '</think>'],
['<|python_start|>', '<|python_end|>'],
['<|output_start|>', '<|output_end|>'],
];
for (const [open, close] of pairs) {
// Remove any close-tag that has no preceding open-tag
const openPositions: number[] = [];
let idx = 0;
while (true) {
const p = s.indexOf(open, idx);
if (p === -1) break;
openPositions.push(p);
idx = p + open.length;
}
const closePositions: number[] = [];
idx = 0;
while (true) {
const p = s.indexOf(close, idx);
if (p === -1) break;
closePositions.push(p);
idx = p + close.length;
}
// drop close tags that appear before any open tag
const firstOpen = openPositions[0] ?? Infinity;
const orphanCloses = closePositions.filter((c) => c < firstOpen);
if (orphanCloses.length) {
// remove each orphan close (work in reverse so indices stay valid)
for (const c of orphanCloses.reverse()) {
s = s.slice(0, c) + s.slice(c + close.length);
}
}
}
return s;
}
function dedupeAndClean(segs: Segment[]): Segment[] {
const out: Segment[] = [];
let lastResultKey: string | null = null;
for (const seg of segs) {
// collapse consecutive duplicate tool_result segments (model re-emits the
// same block as a training artifact)
if (seg.kind === 'tool_result') {
const key = seg.content.replace(/\s+/g, ' ').trim();
if (key === lastResultKey) continue;
lastResultKey = key;
} else {
lastResultKey = null;
}
// drop plain-text segments that are just leftover tool-marker fragments
if (seg.kind === 'text') {
const t = seg.content.replace(/<\|?(?:python|output)_(?:start|end)\|?>/g, '').trim();
if (!t) continue;
out.push({ kind: 'text', content: seg.content.replace(/<\|?(?:python|output)_(?:start|end)\|?>/g, '') });
continue;
}
out.push(seg);
}
return out;
}
function ThinkBlock({ content, closed }: { content: string; closed: boolean }) {