From 4628d53d67455edbfd5ff62d748f5e71b4aa99e0 Mon Sep 17 00:00:00 2001 From: Manmohan Sharma Date: Wed, 22 Apr 2026 15:01:07 -0700 Subject: [PATCH] fix(tools): force web_search on tool-worthy queries + strip orphan markers in UI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds modal/_query_classifier.py with regex patterns covering time-sensitive queries (current/present/latest/today/weather/CEO/president/stock/news/sports/etc). Modal serve.py classifies each user message and, when it matches, pre-seeds the assistant turn with a real Tavily-backed tool call + result — so 'whos the present president' now triggers web_search the same as 'current president'. Also tightens the post-injection break to fire on any leaked tool marker. UI: MessageBubble.tsx now strips orphan close-tags (<|output_end|> without an open), dedupes consecutive identical tool-result blocks, and removes fragment markers from text segments so they don't leak into the message body. --- modal/_query_classifier.py | 161 ++++++++++++++++++ modal/serve.py | 64 ++++++- .../components/chat/MessageBubble.tsx | 75 +++++++- 3 files changed, 291 insertions(+), 9 deletions(-) create mode 100644 modal/_query_classifier.py diff --git a/modal/_query_classifier.py b/modal/_query_classifier.py new file mode 100644 index 00000000..16caa193 --- /dev/null +++ b/modal/_query_classifier.py @@ -0,0 +1,161 @@ +"""Heuristic query classifier for forced tool use. + +This module answers a single question: given a user message, does it likely +need a fresh web search to be answered correctly? + +We use this to *force* the model to call `web_search` even when its natural +generation path would skip the tool (e.g. "whos the present president" where +"present" doesn't appear in the tool-use SFT training distribution). + +Two outputs: + needs_web_search(text) -> (bool, rewritten_query) + needs_calculator(text) -> (bool, expression) + +Keep the rules simple and fast. False positives are cheap (we pay a Tavily +call); false negatives are the real cost (user gets stale training-data). +""" +from __future__ import annotations + +import re +from typing import Tuple + +# --------------------------------------------------------------------------- +# Web-search triggers +# --------------------------------------------------------------------------- + +# Time words that almost always indicate the user wants fresh info +_TIME_WORDS = r""" +(?: + current | currently | now | today | tonight | tomorrow | yesterday | + present | this\ (?:week|month|year|morning|afternoon|evening) | + latest | recent | recently | upcoming | right\ now | as\ of | + 20\d{2} | this\ very\ (?:moment|second|minute) +) +""" + +# Role / position words that change over time +_POSITION_WORDS = r""" +(?: + president | vice[-\s]?president | prime\ minister | chancellor | governor | + senator | congressman | congresswoman | representative | mayor | + ceo | cto | cfo | coo | chairman | chief\s+executive | + chief\s+justice | speaker | foreign\s+minister | attorney\s+general | + pope | monarch | king | queen | emperor | dictator | leader +) +""" + +# Topic categories that are inherently time-sensitive +_CATEGORY_PATTERNS = [ + # weather / climate + r"\b(?:weather|temperature|forecast|humidity|rainfall|snowfall|wind\s+speed)\b", + # finance / markets + r"\b(?:stock\s+price|share\s+price|market\s+cap|exchange\s+rate|forex|crypto(?:currency)?\s+price)\b", + r"\b(?:nasdaq|s&p\s*500|dow\s+jones|nifty|sensex|ftse|nikkei)\b", + r"\bprice\s+of\s+(?:gold|silver|oil|bitcoin|ethereum)\b", + # news / events + r"\b(?:breaking\s+news|headlines?|latest\s+news|news\s+(?:today|now))\b", + r"\bwhat(?:'|\s+i)s\s+happening\s+(?:in|with|at)\b", + # sports scores + r"\b(?:score|result|winner|loser|champion|finalist)\b.*(?:match|game|tournament|series|cup|open|championship)", + r"\b(?:ipl|world\s+cup|super\s+bowl|nba|nfl|nhl|olympics?|wimbledon|us\s+open|australian\s+open|french\s+open)\b", + # time of day / where + r"\bwhat\s+time\s+is\s+it\b", + r"\btime\s+in\s+[A-Z][a-z]+\b", + # people (often changes) — VIPs, CEOs, recent releases + r"\bwho\s+(?:is|'s|runs|leads|owns|founded|heads)\b", + # "is X still alive" / "did X die" + r"\bis\s+\w+\s+(?:still\s+)?(?:alive|dead)\b", + # recent product / release + r"\b(?:latest|newest|recent|upcoming)\s+(?:version|release|model|iphone|android|macbook|tesla|game|album|movie|film|book)\b", + # recent statistics that change + r"\b(?:population|number\s+of)\s+\w+\s+(?:of|in)\b", + # explicit search instructions + r"\b(?:search|look\s+up|google|find\s+out|check)\s+(?:online|on\s+the\s+web|on\s+google)\b", + r"\buse\s+(?:the\s+)?web[_\s]?search\b", +] + +# Standalone keyword patterns that, combined with any position/role, trigger search +_KEYWORD_GATE = re.compile( + rf""" + \b{_TIME_WORDS}\b # at least one time word + | \b(?:who|what|where|when)\b.*\b{_POSITION_WORDS}\b # who/what + position + | \b{_POSITION_WORDS}\b.*\b(?:of|in|for|at)\s+[A-Z] # position + proper noun (of America, in India) + """, + re.IGNORECASE | re.VERBOSE, +) + +_CATEGORY_REGEXES = [re.compile(p, re.IGNORECASE) for p in _CATEGORY_PATTERNS] + + +def needs_web_search(text: str) -> Tuple[bool, str]: + """Classify whether a user query likely needs a live web search. + + Returns (needs, rewritten_query). The rewritten_query strips filler and + reformulates for better Tavily results (e.g. "whos the present president" -> + "who is the current president of the United States 2026"). + """ + if not text or not isinstance(text, str): + return False, "" + + stripped = text.strip() + if len(stripped) < 3: + return False, "" + + # Any category pattern hit + for rx in _CATEGORY_REGEXES: + if rx.search(stripped): + return True, _rewrite_query(stripped) + + # Keyword gate (time words or position + specifier) + if _KEYWORD_GATE.search(stripped): + return True, _rewrite_query(stripped) + + return False, "" + + +def _rewrite_query(text: str) -> str: + """Clean up the query for Tavily — expand contractions, normalize 'present'->'current', + strip filler, add a year anchor.""" + q = text.strip().rstrip("?.!") + # contractions + q = re.sub(r"\bwho(?:'| i)s\b", "who is", q, flags=re.IGNORECASE) + q = re.sub(r"\bwhat(?:'| i)s\b", "what is", q, flags=re.IGNORECASE) + q = re.sub(r"\bwhere(?:'| i)s\b", "where is", q, flags=re.IGNORECASE) + q = re.sub(r"\bwhen(?:'| i)s\b", "when is", q, flags=re.IGNORECASE) + q = re.sub(r"\bits\b", "it is", q, flags=re.IGNORECASE) + # strip quantifiers the model tends to hallucinate + q = re.sub(r"\b(please|kindly|could\s+you|can\s+you)\b\s*", "", q, flags=re.IGNORECASE) + # "present X" -> "current X" (better Tavily results) + q = re.sub(r"\bpresent\b", "current", q, flags=re.IGNORECASE) + # collapse whitespace + q = re.sub(r"\s+", " ", q).strip() + # anchor to the current year if no year is already present + if not re.search(r"\b20\d{2}\b", q): + q = q + " 2026" + return q + + +# --------------------------------------------------------------------------- +# Calculator triggers (cheap, local) +# --------------------------------------------------------------------------- + +_CALC_RX = re.compile( + r""" + \b(?:calculate|compute|what\s+is|what's)\b.*? + (?: + \d[\d,\.\s]*\s*[+\-\*/x×÷]\s*\d # basic arithmetic + | \d+\s*%\s+(?:of|tip|tax|discount)\s+\d # percentage + | \b(?:emi|cagr|compound\s+interest|tip|discount|percent(?:age)?)\b.*\d + ) + """, + re.IGNORECASE | re.VERBOSE, +) + + +def needs_calculator(text: str) -> Tuple[bool, str]: + if not text: + return False, "" + m = _CALC_RX.search(text) + if not m: + return False, "" + return True, text.strip() diff --git a/modal/serve.py b/modal/serve.py index 442da4a4..536b6a46 100644 --- a/modal/serve.py +++ b/modal/serve.py @@ -51,6 +51,7 @@ inference_image = ( .add_local_file("modal/_model.py", "/root/_model.py") .add_local_file("modal/_tokenizer.py", "/root/_tokenizer.py") .add_local_file("modal/_tools.py", "/root/_tools.py") + .add_local_file("modal/_query_classifier.py", "/root/_query_classifier.py") ) # Persistent volume for model weights @@ -198,8 +199,10 @@ class Inference: import sys as _sys if '/root' not in _sys.path: _sys.path.insert(0, '/root') from _tools import build_default_tool_registry, parse_tool_call_payload + from _query_classifier import needs_web_search self.tool_registry = build_default_tool_registry() self._parse_tool_call = parse_tool_call_payload + self._needs_web_search = needs_web_search # Marker tokens for tool state machine self.python_start_id = self.tokenizer.encode_special("<|python_start|>")[0] self.python_end_id = self.tokenizer.encode_special("<|python_end|>")[0] @@ -248,6 +251,42 @@ class Inference: # Prompt the model to generate an assistant response tokens.extend(assistant_start) + # --- Forced tool use --- + # The model's SFT training doesn't always trigger web_search even when + # a question clearly needs current info (e.g. "present president" vs + # "current president"). We classify the last user message and, if it + # matches tool-worthy patterns, pre-seed the assistant turn with a real + # tool call + Tavily result. The model then just writes the final + # grounded answer instead of hallucinating from stale memory. + forced_prefix_text = "" + last_user = "" + for msg in reversed(messages): + if msg.get("role") == "user": + last_user = msg.get("content", "") + break + try: + needs_search, rewritten = self._needs_web_search(last_user) + except Exception: + needs_search, rewritten = False, "" + if needs_search and rewritten: + preface = "I'll look that up for you. " + tool_call_json = json.dumps( + {"arguments": {"query": rewritten, "top_k": 1}, "tool": "web_search"}, + separators=(",", ":"), + ) + try: + invocation = self._parse_tool_call(tool_call_json) + tool_result = self.tool_registry.execute(invocation.tool_name, invocation.arguments) + result_text = tool_result.to_payload()[:4096] + except Exception as exc: + result_text = json.dumps({"error": str(exc)[:500]}) + forced_prefix_text = ( + preface + + "<|python_start|>" + tool_call_json + "<|python_end|>" + + "<|output_start|>" + result_text + "<|output_end|>\n" + ) + tokens.extend(self.tokenizer.encode(forced_prefix_text)) + # Truncate to fit context max_context = self.config.sequence_len - max_tokens if len(tokens) > max_context: @@ -266,9 +305,14 @@ class Inference: async def stream(): input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device) gen_ids: list[int] = [] # everything the MODEL sampled this turn - tool_injected = False # once True, stop detecting further tool calls + tool_injected = bool(forced_prefix_text) # forced prefix counts as an injection pre_injection_len = 0 # len(gen_ids) right before we start injection + # If we pre-seeded a forced tool call + result, stream it to the client + # now so the UI can render the tool-call / tool-result cards. + if forced_prefix_text: + yield "data: " + json.dumps({"token": forced_prefix_text, "gpu": 0}) + "\n\n" + def _append_token(tid): nonlocal input_ids nt = torch.tensor([[tid]], dtype=torch.long, device=self.device) @@ -350,14 +394,22 @@ class Inference: break tool_injected = True - # After injection: if the model starts emitting another <|output_start|>, - # break the turn — the grounded result already streamed. - elif pre_injection_len > 0 and len(gen_ids) > pre_injection_len + 20: + # After injection (forced OR runtime): the model often loops and + # emits another fake <|output_start|>…<|output_end|> / <|python_start|>… + # block. Break the turn as soon as ANY tool-marker appears in what the + # MODEL itself generated. We check the decoded text of gen_ids[pre_injection_len:]. + elif tool_injected and len(gen_ids) > pre_injection_len + 6: try: - post_text = self.tokenizer.decode(gen_ids[pre_injection_len + 10:]) + post_text = self.tokenizer.decode(gen_ids[pre_injection_len:]) except Exception: post_text = "" - if out_start_str in post_text: + for bad in (out_start_str, out_end_str, tool_start_str, tool_end_str): + if bad in post_text: + break_now = True + break + else: + break_now = False + if break_now: break yield "data: " + json.dumps({"done": True}) + "\n\n" diff --git a/services/frontend/components/chat/MessageBubble.tsx b/services/frontend/components/chat/MessageBubble.tsx index 0d7c1314..7d7f9a59 100644 --- a/services/frontend/components/chat/MessageBubble.tsx +++ b/services/frontend/components/chat/MessageBubble.tsx @@ -18,16 +18,19 @@ type Segment = | { kind: 'tool_result'; content: string; closed: boolean }; function parseSegments(raw: string): Segment[] { + // First pass: strip orphan tool markers (end-tag without open-tag, or any + // stray marker outside a pair) that the model sometimes emits as loop + // artifacts — otherwise they leak into the message body as raw text. + raw = stripOrphanMarkers(raw); + const segs: Segment[] = []; let i = 0; - // marker -> [openTag, closeTag, kind] const markers: Array<[string, string, Segment['kind']]> = [ ['', '', 'think'], ['<|python_start|>', '<|python_end|>', 'tool_call'], ['<|output_start|>', '<|output_end|>', 'tool_result'], ]; while (i < raw.length) { - // find the nearest opening marker let bestOpen = -1; let bestMarker: [string, string, Segment['kind']] | null = null; for (const m of markers) { @@ -50,7 +53,73 @@ function parseSegments(raw: string): Segment[] { i = closeIdx + closeTag.length; } } - return segs; + + return dedupeAndClean(segs); +} + +function stripOrphanMarkers(s: string): string { + // Walk the string left-to-right. For each opening marker we encounter, keep + // it only if its matching close exists somewhere after it. For each close + // marker encountered without a preceding open, drop it. + const pairs: Array<[string, string]> = [ + ['', ''], + ['<|python_start|>', '<|python_end|>'], + ['<|output_start|>', '<|output_end|>'], + ]; + for (const [open, close] of pairs) { + // Remove any close-tag that has no preceding open-tag + const openPositions: number[] = []; + let idx = 0; + while (true) { + const p = s.indexOf(open, idx); + if (p === -1) break; + openPositions.push(p); + idx = p + open.length; + } + const closePositions: number[] = []; + idx = 0; + while (true) { + const p = s.indexOf(close, idx); + if (p === -1) break; + closePositions.push(p); + idx = p + close.length; + } + // drop close tags that appear before any open tag + const firstOpen = openPositions[0] ?? Infinity; + const orphanCloses = closePositions.filter((c) => c < firstOpen); + if (orphanCloses.length) { + // remove each orphan close (work in reverse so indices stay valid) + for (const c of orphanCloses.reverse()) { + s = s.slice(0, c) + s.slice(c + close.length); + } + } + } + return s; +} + +function dedupeAndClean(segs: Segment[]): Segment[] { + const out: Segment[] = []; + let lastResultKey: string | null = null; + for (const seg of segs) { + // collapse consecutive duplicate tool_result segments (model re-emits the + // same block as a training artifact) + if (seg.kind === 'tool_result') { + const key = seg.content.replace(/\s+/g, ' ').trim(); + if (key === lastResultKey) continue; + lastResultKey = key; + } else { + lastResultKey = null; + } + // drop plain-text segments that are just leftover tool-marker fragments + if (seg.kind === 'text') { + const t = seg.content.replace(/<\|?(?:python|output)_(?:start|end)\|?>/g, '').trim(); + if (!t) continue; + out.push({ kind: 'text', content: seg.content.replace(/<\|?(?:python|output)_(?:start|end)\|?>/g, '') }); + continue; + } + out.push(seg); + } + return out; } function ThinkBlock({ content, closed }: { content: string; closed: boolean }) {