Merge pull request #54 from manmohan659/feat/force-web-search-toggle

feat(ui): Search toggle — force web_search on every message
This commit is contained in:
Manmohan 2026-04-22 18:20:51 -04:00 committed by GitHub
commit 31823b632a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 54 additions and 5 deletions

View File

@ -229,6 +229,7 @@ class Inference:
temperature = min(max(request.get("temperature", 0.8), 0.0), 2.0)
max_tokens = min(max(request.get("max_tokens", 512), 1), 2048)
top_k = min(max(request.get("top_k", 50), 0), 200)
force_web_search = bool(request.get("force_web_search", False))
# Build token sequence from messages
tokens = []
@ -274,6 +275,12 @@ class Inference:
needs_search, rewritten = self._needs_web_search(query_for_classify)
except Exception:
needs_search, rewritten = False, ""
# Explicit user toggle wins — always force when force_web_search is True
if force_web_search and query_for_classify:
needs_search = True
if not rewritten:
# if classifier didn't rewrite, do a minimal cleanup
rewritten = query_for_classify.strip().rstrip("?.!") + " 2026"
if needs_search and rewritten:
preface = "I'll look that up for you. "
tool_call_json = json.dumps(

View File

@ -28,6 +28,7 @@ class SendMessageRequest(BaseModel):
max_tokens: int | None = Field(default=None, ge=1, le=4096)
top_k: int | None = Field(default=None, ge=0, le=200)
thinking_mode: bool = Field(default=False)
force_web_search: bool = Field(default=False)
class RegenerateRequest(BaseModel):
@ -35,6 +36,7 @@ class RegenerateRequest(BaseModel):
max_tokens: int | None = Field(default=None, ge=1, le=4096)
top_k: int | None = Field(default=None, ge=0, le=200)
thinking_mode: bool = Field(default=False)
force_web_search: bool = Field(default=False)
# System prompts: tools are always implicitly available via the model's SFT training.
@ -101,6 +103,7 @@ async def _stream_and_persist(
first_message: bool,
first_message_preview: str | None,
settings: Settings,
force_web_search: bool = False,
) -> AsyncIterator[dict]:
"""Generator that streams inference SSE events to the client and, after the
stream closes, persists the full assistant message in a fresh DB session.
@ -119,6 +122,7 @@ async def _stream_and_persist(
temperature=temperature,
max_tokens=max_tokens,
top_k=top_k,
force_web_search=force_web_search,
) as response:
async for event in proxy_inference_stream(response, on_complete=_capture):
yield event
@ -238,6 +242,7 @@ async def send_message(
first_message=first_message,
first_message_preview=first_preview,
settings=settings,
force_web_search=body.force_web_search,
)
return EventSourceResponse(generator, media_type="text/event-stream")
@ -293,5 +298,6 @@ async def regenerate(
first_message=False,
first_message_preview=None,
settings=settings,
force_web_search=body.force_web_search,
)
return EventSourceResponse(generator, media_type="text/event-stream")

View File

@ -63,6 +63,7 @@ class InferenceClient:
temperature: float | None = None,
max_tokens: int | None = None,
top_k: int | None = None,
force_web_search: bool = False,
) -> AsyncIterator[httpx.Response]:
temperature = (
temperature
@ -81,6 +82,7 @@ class InferenceClient:
"temperature": temperature,
"max_tokens": max_tokens,
"top_k": top_k,
"force_web_search": force_web_search,
}
client = self._get_client()

View File

@ -1,7 +1,7 @@
'use client';
import { useEffect, useRef } from 'react';
import { ArrowUp, Brain, Square } from 'lucide-react';
import { ArrowUp, Brain, Globe, Square } from 'lucide-react';
import clsx from 'clsx';
interface Props {
@ -13,9 +13,11 @@ interface Props {
disabled?: boolean;
thinkingMode?: boolean;
onToggleThinking?: () => void;
webSearchMode?: boolean;
onToggleWebSearch?: () => void;
}
export default function ChatInput({ value, onChange, onSubmit, onStop, isStreaming, disabled, thinkingMode, onToggleThinking }: Props) {
export default function ChatInput({ value, onChange, onSubmit, onStop, isStreaming, disabled, thinkingMode, onToggleThinking, webSearchMode, onToggleWebSearch }: Props) {
const ref = useRef<HTMLTextAreaElement>(null);
useEffect(() => {
@ -81,6 +83,27 @@ export default function ChatInput({ value, onChange, onSubmit, onStop, isStreami
</div>
)}
{/* Force web-search toggle */}
{onToggleWebSearch && (
<div className="self-end p-2">
<button
type="button"
onClick={onToggleWebSearch}
aria-pressed={!!webSearchMode}
title={webSearchMode ? 'Web search ON — every message will be searched online' : 'Force a web search for your next message'}
className={clsx(
'h-10 px-3 rounded-full flex items-center gap-1.5 text-xs font-medium transition-all border',
webSearchMode
? 'bg-emerald-500/15 dark:bg-emerald-500/20 border-emerald-500/50 text-emerald-600 dark:text-emerald-400 shadow-[0_4px_14px_rgba(16,185,129,0.15)]'
: 'bg-transparent border-cream-border dark:border-ink-border text-gray-500 dark:text-ink-text-soft hover:bg-gray-50 dark:hover:bg-ink-elev',
)}
>
<Globe size={14} />
<span>Search</span>
</button>
</div>
)}
{/* Send / stop button — vertically centered with the textarea baseline */}
<div className="self-end p-2">
{isStreaming && onStop ? (

View File

@ -32,6 +32,7 @@ export default function ChatWindow() {
const [draft, setDraft] = useState('');
const [streamingMsgId, setStreamingMsgId] = useState<string | null>(null);
const [thinkingMode, setThinkingMode] = useState(false);
const [webSearchMode, setWebSearchMode] = useState(false);
const streamingBufferRef = useRef('');
const scrollRef = useRef<HTMLDivElement>(null);
@ -61,7 +62,7 @@ export default function ChatWindow() {
setIsStreaming(false);
}, []);
const streamFromApi = useCallback(async (convId: string, assistantMsgId: string, content: string, temp?: number, topk?: number, thinking?: boolean) => {
const streamFromApi = useCallback(async (convId: string, assistantMsgId: string, content: string, temp?: number, topk?: number, thinking?: boolean, forceSearch?: boolean) => {
stop();
const ac = new AbortController();
abortRef.current = ac;
@ -77,7 +78,14 @@ export default function ChatWindow() {
const res = await fetch(`/api/conversations/${convId}/messages`, {
method: 'POST',
headers,
body: JSON.stringify({ content, temperature: temp, max_tokens: 512, top_k: topk, thinking_mode: !!thinking }),
body: JSON.stringify({
content,
temperature: temp,
max_tokens: 512,
top_k: topk,
thinking_mode: !!thinking,
force_web_search: !!forceSearch,
}),
signal: ac.signal,
});
@ -173,7 +181,7 @@ export default function ChatWindow() {
setStreamingMsgId(assistantId);
streamingBufferRef.current = '';
await streamFromApi(convId, assistantId, text, temperature, topK, thinkingMode);
await streamFromApi(convId, assistantId, text, temperature, topK, thinkingMode, webSearchMode);
},
[
draft,
@ -182,6 +190,7 @@ export default function ChatWindow() {
temperature,
topK,
thinkingMode,
webSearchMode,
appendMessage,
streamFromApi,
setTemperature,
@ -241,6 +250,8 @@ export default function ChatWindow() {
isStreaming={isStreaming}
thinkingMode={thinkingMode}
onToggleThinking={() => setThinkingMode((v) => !v)}
webSearchMode={webSearchMode}
onToggleWebSearch={() => setWebSearchMode((v) => !v)}
/>
</section>
);