diff --git a/modal/serve.py b/modal/serve.py index 01554a65..00acc47f 100644 --- a/modal/serve.py +++ b/modal/serve.py @@ -229,6 +229,7 @@ class Inference: temperature = min(max(request.get("temperature", 0.8), 0.0), 2.0) max_tokens = min(max(request.get("max_tokens", 512), 1), 2048) top_k = min(max(request.get("top_k", 50), 0), 200) + force_web_search = bool(request.get("force_web_search", False)) # Build token sequence from messages tokens = [] @@ -274,6 +275,12 @@ class Inference: needs_search, rewritten = self._needs_web_search(query_for_classify) except Exception: needs_search, rewritten = False, "" + # Explicit user toggle wins — always force when force_web_search is True + if force_web_search and query_for_classify: + needs_search = True + if not rewritten: + # if classifier didn't rewrite, do a minimal cleanup + rewritten = query_for_classify.strip().rstrip("?.!") + " 2026" if needs_search and rewritten: preface = "I'll look that up for you. " tool_call_json = json.dumps( diff --git a/services/chat-api/src/routes/messages.py b/services/chat-api/src/routes/messages.py index 880ad63b..641f2206 100644 --- a/services/chat-api/src/routes/messages.py +++ b/services/chat-api/src/routes/messages.py @@ -28,6 +28,7 @@ class SendMessageRequest(BaseModel): max_tokens: int | None = Field(default=None, ge=1, le=4096) top_k: int | None = Field(default=None, ge=0, le=200) thinking_mode: bool = Field(default=False) + force_web_search: bool = Field(default=False) class RegenerateRequest(BaseModel): @@ -35,6 +36,7 @@ class RegenerateRequest(BaseModel): max_tokens: int | None = Field(default=None, ge=1, le=4096) top_k: int | None = Field(default=None, ge=0, le=200) thinking_mode: bool = Field(default=False) + force_web_search: bool = Field(default=False) # System prompts: tools are always implicitly available via the model's SFT training. @@ -101,6 +103,7 @@ async def _stream_and_persist( first_message: bool, first_message_preview: str | None, settings: Settings, + force_web_search: bool = False, ) -> AsyncIterator[dict]: """Generator that streams inference SSE events to the client and, after the stream closes, persists the full assistant message in a fresh DB session. @@ -119,6 +122,7 @@ async def _stream_and_persist( temperature=temperature, max_tokens=max_tokens, top_k=top_k, + force_web_search=force_web_search, ) as response: async for event in proxy_inference_stream(response, on_complete=_capture): yield event @@ -238,6 +242,7 @@ async def send_message( first_message=first_message, first_message_preview=first_preview, settings=settings, + force_web_search=body.force_web_search, ) return EventSourceResponse(generator, media_type="text/event-stream") @@ -293,5 +298,6 @@ async def regenerate( first_message=False, first_message_preview=None, settings=settings, + force_web_search=body.force_web_search, ) return EventSourceResponse(generator, media_type="text/event-stream") diff --git a/services/chat-api/src/services/inference_client.py b/services/chat-api/src/services/inference_client.py index 911afa2a..bdaa129c 100644 --- a/services/chat-api/src/services/inference_client.py +++ b/services/chat-api/src/services/inference_client.py @@ -63,6 +63,7 @@ class InferenceClient: temperature: float | None = None, max_tokens: int | None = None, top_k: int | None = None, + force_web_search: bool = False, ) -> AsyncIterator[httpx.Response]: temperature = ( temperature @@ -81,6 +82,7 @@ class InferenceClient: "temperature": temperature, "max_tokens": max_tokens, "top_k": top_k, + "force_web_search": force_web_search, } client = self._get_client() diff --git a/services/frontend/components/chat/ChatInput.tsx b/services/frontend/components/chat/ChatInput.tsx index c389e4bd..90639708 100644 --- a/services/frontend/components/chat/ChatInput.tsx +++ b/services/frontend/components/chat/ChatInput.tsx @@ -1,7 +1,7 @@ 'use client'; import { useEffect, useRef } from 'react'; -import { ArrowUp, Brain, Square } from 'lucide-react'; +import { ArrowUp, Brain, Globe, Square } from 'lucide-react'; import clsx from 'clsx'; interface Props { @@ -13,9 +13,11 @@ interface Props { disabled?: boolean; thinkingMode?: boolean; onToggleThinking?: () => void; + webSearchMode?: boolean; + onToggleWebSearch?: () => void; } -export default function ChatInput({ value, onChange, onSubmit, onStop, isStreaming, disabled, thinkingMode, onToggleThinking }: Props) { +export default function ChatInput({ value, onChange, onSubmit, onStop, isStreaming, disabled, thinkingMode, onToggleThinking, webSearchMode, onToggleWebSearch }: Props) { const ref = useRef(null); useEffect(() => { @@ -81,6 +83,27 @@ export default function ChatInput({ value, onChange, onSubmit, onStop, isStreami )} + {/* Force web-search toggle */} + {onToggleWebSearch && ( +
+ +
+ )} + {/* Send / stop button — vertically centered with the textarea baseline */}
{isStreaming && onStop ? ( diff --git a/services/frontend/components/chat/ChatWindow.tsx b/services/frontend/components/chat/ChatWindow.tsx index 757ae260..7003cf1a 100644 --- a/services/frontend/components/chat/ChatWindow.tsx +++ b/services/frontend/components/chat/ChatWindow.tsx @@ -32,6 +32,7 @@ export default function ChatWindow() { const [draft, setDraft] = useState(''); const [streamingMsgId, setStreamingMsgId] = useState(null); const [thinkingMode, setThinkingMode] = useState(false); + const [webSearchMode, setWebSearchMode] = useState(false); const streamingBufferRef = useRef(''); const scrollRef = useRef(null); @@ -61,7 +62,7 @@ export default function ChatWindow() { setIsStreaming(false); }, []); - const streamFromApi = useCallback(async (convId: string, assistantMsgId: string, content: string, temp?: number, topk?: number, thinking?: boolean) => { + const streamFromApi = useCallback(async (convId: string, assistantMsgId: string, content: string, temp?: number, topk?: number, thinking?: boolean, forceSearch?: boolean) => { stop(); const ac = new AbortController(); abortRef.current = ac; @@ -77,7 +78,14 @@ export default function ChatWindow() { const res = await fetch(`/api/conversations/${convId}/messages`, { method: 'POST', headers, - body: JSON.stringify({ content, temperature: temp, max_tokens: 512, top_k: topk, thinking_mode: !!thinking }), + body: JSON.stringify({ + content, + temperature: temp, + max_tokens: 512, + top_k: topk, + thinking_mode: !!thinking, + force_web_search: !!forceSearch, + }), signal: ac.signal, }); @@ -173,7 +181,7 @@ export default function ChatWindow() { setStreamingMsgId(assistantId); streamingBufferRef.current = ''; - await streamFromApi(convId, assistantId, text, temperature, topK, thinkingMode); + await streamFromApi(convId, assistantId, text, temperature, topK, thinkingMode, webSearchMode); }, [ draft, @@ -182,6 +190,7 @@ export default function ChatWindow() { temperature, topK, thinkingMode, + webSearchMode, appendMessage, streamFromApi, setTemperature, @@ -241,6 +250,8 @@ export default function ChatWindow() { isStreaming={isStreaming} thinkingMode={thinkingMode} onToggleThinking={() => setThinkingMode((v) => !v)} + webSearchMode={webSearchMode} + onToggleWebSearch={() => setWebSearchMode((v) => !v)} /> );