diff --git a/modal/_query_classifier.py b/modal/_query_classifier.py
new file mode 100644
index 00000000..16caa193
--- /dev/null
+++ b/modal/_query_classifier.py
@@ -0,0 +1,161 @@
+"""Heuristic query classifier for forced tool use.
+
+This module answers a single question: given a user message, does it likely
+need a fresh web search to be answered correctly?
+
+We use this to *force* the model to call `web_search` even when its natural
+generation path would skip the tool (e.g. "whos the present president" where
+"present" doesn't appear in the tool-use SFT training distribution).
+
+Two outputs:
+ needs_web_search(text) -> (bool, rewritten_query)
+ needs_calculator(text) -> (bool, expression)
+
+Keep the rules simple and fast. False positives are cheap (we pay a Tavily
+call); false negatives are the real cost (user gets stale training-data).
+"""
+from __future__ import annotations
+
+import re
+from typing import Tuple
+
+# ---------------------------------------------------------------------------
+# Web-search triggers
+# ---------------------------------------------------------------------------
+
+# Time words that almost always indicate the user wants fresh info
+_TIME_WORDS = r"""
+(?:
+ current | currently | now | today | tonight | tomorrow | yesterday |
+ present | this\ (?:week|month|year|morning|afternoon|evening) |
+ latest | recent | recently | upcoming | right\ now | as\ of |
+ 20\d{2} | this\ very\ (?:moment|second|minute)
+)
+"""
+
+# Role / position words that change over time
+_POSITION_WORDS = r"""
+(?:
+ president | vice[-\s]?president | prime\ minister | chancellor | governor |
+ senator | congressman | congresswoman | representative | mayor |
+ ceo | cto | cfo | coo | chairman | chief\s+executive |
+ chief\s+justice | speaker | foreign\s+minister | attorney\s+general |
+ pope | monarch | king | queen | emperor | dictator | leader
+)
+"""
+
+# Topic categories that are inherently time-sensitive
+_CATEGORY_PATTERNS = [
+ # weather / climate
+ r"\b(?:weather|temperature|forecast|humidity|rainfall|snowfall|wind\s+speed)\b",
+ # finance / markets
+ r"\b(?:stock\s+price|share\s+price|market\s+cap|exchange\s+rate|forex|crypto(?:currency)?\s+price)\b",
+ r"\b(?:nasdaq|s&p\s*500|dow\s+jones|nifty|sensex|ftse|nikkei)\b",
+ r"\bprice\s+of\s+(?:gold|silver|oil|bitcoin|ethereum)\b",
+ # news / events
+ r"\b(?:breaking\s+news|headlines?|latest\s+news|news\s+(?:today|now))\b",
+ r"\bwhat(?:'|\s+i)s\s+happening\s+(?:in|with|at)\b",
+ # sports scores
+ r"\b(?:score|result|winner|loser|champion|finalist)\b.*(?:match|game|tournament|series|cup|open|championship)",
+ r"\b(?:ipl|world\s+cup|super\s+bowl|nba|nfl|nhl|olympics?|wimbledon|us\s+open|australian\s+open|french\s+open)\b",
+ # time of day / where
+ r"\bwhat\s+time\s+is\s+it\b",
+ r"\btime\s+in\s+[A-Z][a-z]+\b",
+ # people (often changes) — VIPs, CEOs, recent releases
+ r"\bwho\s+(?:is|'s|runs|leads|owns|founded|heads)\b",
+ # "is X still alive" / "did X die"
+ r"\bis\s+\w+\s+(?:still\s+)?(?:alive|dead)\b",
+ # recent product / release
+ r"\b(?:latest|newest|recent|upcoming)\s+(?:version|release|model|iphone|android|macbook|tesla|game|album|movie|film|book)\b",
+ # recent statistics that change
+ r"\b(?:population|number\s+of)\s+\w+\s+(?:of|in)\b",
+ # explicit search instructions
+ r"\b(?:search|look\s+up|google|find\s+out|check)\s+(?:online|on\s+the\s+web|on\s+google)\b",
+ r"\buse\s+(?:the\s+)?web[_\s]?search\b",
+]
+
+# Standalone keyword patterns that, combined with any position/role, trigger search
+_KEYWORD_GATE = re.compile(
+ rf"""
+ \b{_TIME_WORDS}\b # at least one time word
+ | \b(?:who|what|where|when)\b.*\b{_POSITION_WORDS}\b # who/what + position
+ | \b{_POSITION_WORDS}\b.*\b(?:of|in|for|at)\s+[A-Z] # position + proper noun (of America, in India)
+ """,
+ re.IGNORECASE | re.VERBOSE,
+)
+
+_CATEGORY_REGEXES = [re.compile(p, re.IGNORECASE) for p in _CATEGORY_PATTERNS]
+
+
+def needs_web_search(text: str) -> Tuple[bool, str]:
+ """Classify whether a user query likely needs a live web search.
+
+ Returns (needs, rewritten_query). The rewritten_query strips filler and
+ reformulates for better Tavily results (e.g. "whos the present president" ->
+ "who is the current president of the United States 2026").
+ """
+ if not text or not isinstance(text, str):
+ return False, ""
+
+ stripped = text.strip()
+ if len(stripped) < 3:
+ return False, ""
+
+ # Any category pattern hit
+ for rx in _CATEGORY_REGEXES:
+ if rx.search(stripped):
+ return True, _rewrite_query(stripped)
+
+ # Keyword gate (time words or position + specifier)
+ if _KEYWORD_GATE.search(stripped):
+ return True, _rewrite_query(stripped)
+
+ return False, ""
+
+
+def _rewrite_query(text: str) -> str:
+ """Clean up the query for Tavily — expand contractions, normalize 'present'->'current',
+ strip filler, add a year anchor."""
+ q = text.strip().rstrip("?.!")
+ # contractions
+ q = re.sub(r"\bwho(?:'| i)s\b", "who is", q, flags=re.IGNORECASE)
+ q = re.sub(r"\bwhat(?:'| i)s\b", "what is", q, flags=re.IGNORECASE)
+ q = re.sub(r"\bwhere(?:'| i)s\b", "where is", q, flags=re.IGNORECASE)
+ q = re.sub(r"\bwhen(?:'| i)s\b", "when is", q, flags=re.IGNORECASE)
+ q = re.sub(r"\bits\b", "it is", q, flags=re.IGNORECASE)
+ # strip quantifiers the model tends to hallucinate
+ q = re.sub(r"\b(please|kindly|could\s+you|can\s+you)\b\s*", "", q, flags=re.IGNORECASE)
+ # "present X" -> "current X" (better Tavily results)
+ q = re.sub(r"\bpresent\b", "current", q, flags=re.IGNORECASE)
+ # collapse whitespace
+ q = re.sub(r"\s+", " ", q).strip()
+ # anchor to the current year if no year is already present
+ if not re.search(r"\b20\d{2}\b", q):
+ q = q + " 2026"
+ return q
+
+
+# ---------------------------------------------------------------------------
+# Calculator triggers (cheap, local)
+# ---------------------------------------------------------------------------
+
+_CALC_RX = re.compile(
+ r"""
+ \b(?:calculate|compute|what\s+is|what's)\b.*?
+ (?:
+ \d[\d,\.\s]*\s*[+\-\*/x×÷]\s*\d # basic arithmetic
+ | \d+\s*%\s+(?:of|tip|tax|discount)\s+\d # percentage
+ | \b(?:emi|cagr|compound\s+interest|tip|discount|percent(?:age)?)\b.*\d
+ )
+ """,
+ re.IGNORECASE | re.VERBOSE,
+)
+
+
+def needs_calculator(text: str) -> Tuple[bool, str]:
+ if not text:
+ return False, ""
+ m = _CALC_RX.search(text)
+ if not m:
+ return False, ""
+ return True, text.strip()
diff --git a/modal/serve.py b/modal/serve.py
index 442da4a4..536b6a46 100644
--- a/modal/serve.py
+++ b/modal/serve.py
@@ -51,6 +51,7 @@ inference_image = (
.add_local_file("modal/_model.py", "/root/_model.py")
.add_local_file("modal/_tokenizer.py", "/root/_tokenizer.py")
.add_local_file("modal/_tools.py", "/root/_tools.py")
+ .add_local_file("modal/_query_classifier.py", "/root/_query_classifier.py")
)
# Persistent volume for model weights
@@ -198,8 +199,10 @@ class Inference:
import sys as _sys
if '/root' not in _sys.path: _sys.path.insert(0, '/root')
from _tools import build_default_tool_registry, parse_tool_call_payload
+ from _query_classifier import needs_web_search
self.tool_registry = build_default_tool_registry()
self._parse_tool_call = parse_tool_call_payload
+ self._needs_web_search = needs_web_search
# Marker tokens for tool state machine
self.python_start_id = self.tokenizer.encode_special("<|python_start|>")[0]
self.python_end_id = self.tokenizer.encode_special("<|python_end|>")[0]
@@ -248,6 +251,42 @@ class Inference:
# Prompt the model to generate an assistant response
tokens.extend(assistant_start)
+ # --- Forced tool use ---
+ # The model's SFT training doesn't always trigger web_search even when
+ # a question clearly needs current info (e.g. "present president" vs
+ # "current president"). We classify the last user message and, if it
+ # matches tool-worthy patterns, pre-seed the assistant turn with a real
+ # tool call + Tavily result. The model then just writes the final
+ # grounded answer instead of hallucinating from stale memory.
+ forced_prefix_text = ""
+ last_user = ""
+ for msg in reversed(messages):
+ if msg.get("role") == "user":
+ last_user = msg.get("content", "")
+ break
+ try:
+ needs_search, rewritten = self._needs_web_search(last_user)
+ except Exception:
+ needs_search, rewritten = False, ""
+ if needs_search and rewritten:
+ preface = "I'll look that up for you. "
+ tool_call_json = json.dumps(
+ {"arguments": {"query": rewritten, "top_k": 1}, "tool": "web_search"},
+ separators=(",", ":"),
+ )
+ try:
+ invocation = self._parse_tool_call(tool_call_json)
+ tool_result = self.tool_registry.execute(invocation.tool_name, invocation.arguments)
+ result_text = tool_result.to_payload()[:4096]
+ except Exception as exc:
+ result_text = json.dumps({"error": str(exc)[:500]})
+ forced_prefix_text = (
+ preface
+ + "<|python_start|>" + tool_call_json + "<|python_end|>"
+ + "<|output_start|>" + result_text + "<|output_end|>\n"
+ )
+ tokens.extend(self.tokenizer.encode(forced_prefix_text))
+
# Truncate to fit context
max_context = self.config.sequence_len - max_tokens
if len(tokens) > max_context:
@@ -266,9 +305,14 @@ class Inference:
async def stream():
input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device)
gen_ids: list[int] = [] # everything the MODEL sampled this turn
- tool_injected = False # once True, stop detecting further tool calls
+ tool_injected = bool(forced_prefix_text) # forced prefix counts as an injection
pre_injection_len = 0 # len(gen_ids) right before we start injection
+ # If we pre-seeded a forced tool call + result, stream it to the client
+ # now so the UI can render the tool-call / tool-result cards.
+ if forced_prefix_text:
+ yield "data: " + json.dumps({"token": forced_prefix_text, "gpu": 0}) + "\n\n"
+
def _append_token(tid):
nonlocal input_ids
nt = torch.tensor([[tid]], dtype=torch.long, device=self.device)
@@ -350,14 +394,22 @@ class Inference:
break
tool_injected = True
- # After injection: if the model starts emitting another <|output_start|>,
- # break the turn — the grounded result already streamed.
- elif pre_injection_len > 0 and len(gen_ids) > pre_injection_len + 20:
+ # After injection (forced OR runtime): the model often loops and
+ # emits another fake <|output_start|>…<|output_end|> / <|python_start|>…
+ # block. Break the turn as soon as ANY tool-marker appears in what the
+ # MODEL itself generated. We check the decoded text of gen_ids[pre_injection_len:].
+ elif tool_injected and len(gen_ids) > pre_injection_len + 6:
try:
- post_text = self.tokenizer.decode(gen_ids[pre_injection_len + 10:])
+ post_text = self.tokenizer.decode(gen_ids[pre_injection_len:])
except Exception:
post_text = ""
- if out_start_str in post_text:
+ for bad in (out_start_str, out_end_str, tool_start_str, tool_end_str):
+ if bad in post_text:
+ break_now = True
+ break
+ else:
+ break_now = False
+ if break_now:
break
yield "data: " + json.dumps({"done": True}) + "\n\n"
diff --git a/services/frontend/components/chat/MessageBubble.tsx b/services/frontend/components/chat/MessageBubble.tsx
index 0d7c1314..7d7f9a59 100644
--- a/services/frontend/components/chat/MessageBubble.tsx
+++ b/services/frontend/components/chat/MessageBubble.tsx
@@ -18,16 +18,19 @@ type Segment =
| { kind: 'tool_result'; content: string; closed: boolean };
function parseSegments(raw: string): Segment[] {
+ // First pass: strip orphan tool markers (end-tag without open-tag, or any
+ // stray marker outside a pair) that the model sometimes emits as loop
+ // artifacts — otherwise they leak into the message body as raw text.
+ raw = stripOrphanMarkers(raw);
+
const segs: Segment[] = [];
let i = 0;
- // marker -> [openTag, closeTag, kind]
const markers: Array<[string, string, Segment['kind']]> = [
['', '', 'think'],
['<|python_start|>', '<|python_end|>', 'tool_call'],
['<|output_start|>', '<|output_end|>', 'tool_result'],
];
while (i < raw.length) {
- // find the nearest opening marker
let bestOpen = -1;
let bestMarker: [string, string, Segment['kind']] | null = null;
for (const m of markers) {
@@ -50,7 +53,73 @@ function parseSegments(raw: string): Segment[] {
i = closeIdx + closeTag.length;
}
}
- return segs;
+
+ return dedupeAndClean(segs);
+}
+
+function stripOrphanMarkers(s: string): string {
+ // Walk the string left-to-right. For each opening marker we encounter, keep
+ // it only if its matching close exists somewhere after it. For each close
+ // marker encountered without a preceding open, drop it.
+ const pairs: Array<[string, string]> = [
+ ['', ''],
+ ['<|python_start|>', '<|python_end|>'],
+ ['<|output_start|>', '<|output_end|>'],
+ ];
+ for (const [open, close] of pairs) {
+ // Remove any close-tag that has no preceding open-tag
+ const openPositions: number[] = [];
+ let idx = 0;
+ while (true) {
+ const p = s.indexOf(open, idx);
+ if (p === -1) break;
+ openPositions.push(p);
+ idx = p + open.length;
+ }
+ const closePositions: number[] = [];
+ idx = 0;
+ while (true) {
+ const p = s.indexOf(close, idx);
+ if (p === -1) break;
+ closePositions.push(p);
+ idx = p + close.length;
+ }
+ // drop close tags that appear before any open tag
+ const firstOpen = openPositions[0] ?? Infinity;
+ const orphanCloses = closePositions.filter((c) => c < firstOpen);
+ if (orphanCloses.length) {
+ // remove each orphan close (work in reverse so indices stay valid)
+ for (const c of orphanCloses.reverse()) {
+ s = s.slice(0, c) + s.slice(c + close.length);
+ }
+ }
+ }
+ return s;
+}
+
+function dedupeAndClean(segs: Segment[]): Segment[] {
+ const out: Segment[] = [];
+ let lastResultKey: string | null = null;
+ for (const seg of segs) {
+ // collapse consecutive duplicate tool_result segments (model re-emits the
+ // same block as a training artifact)
+ if (seg.kind === 'tool_result') {
+ const key = seg.content.replace(/\s+/g, ' ').trim();
+ if (key === lastResultKey) continue;
+ lastResultKey = key;
+ } else {
+ lastResultKey = null;
+ }
+ // drop plain-text segments that are just leftover tool-marker fragments
+ if (seg.kind === 'text') {
+ const t = seg.content.replace(/<\|?(?:python|output)_(?:start|end)\|?>/g, '').trim();
+ if (!t) continue;
+ out.push({ kind: 'text', content: seg.content.replace(/<\|?(?:python|output)_(?:start|end)\|?>/g, '') });
+ continue;
+ }
+ out.push(seg);
+ }
+ return out;
}
function ThinkBlock({ content, closed }: { content: string; closed: boolean }) {