nanochat/services/chat-api/src/routes/messages.py
Manmohan 3ab89e7890
feat: deploy d24-sft-r6 with full reasoning mode + live tool use (Tavily)
Model R6 (97% pass rate on 33-probe eval, val_bpb 0.2635):
- modal/serve.py + modal/_tools.py: tool-aware streaming with
  TavilySearchBackend auto-detect, python_start/end state machine,
  output_start/end forcing; mount tavily secret
- modal/serve.py: MODEL_TAG=d24-sft-r6, model path points at new SFT r6
- services/chat-api/routes/messages.py: accept thinking_mode flag,
  inject samosaChaat system prompt (direct or <think> variant) into
  first user message before streaming to Modal
- services/frontend/components/chat/ChatInput.tsx: Brain toggle
  'Think' button next to send; when active, model uses think mode
- services/frontend/components/chat/ChatWindow.tsx: track
  thinkingMode state, pass through to API body as thinking_mode
- services/frontend/components/chat/MessageBubble.tsx: parse and
  render <think>...</think> as collapsible italic blocks;
  <|python_start|>...<|python_end|> as tool-call cards with icons
  per tool name; <|output_start|>...<|output_end|> as result cards
  with expandable JSON
- nanochat/tools.py: TavilySearchBackend class + auto-detect
- nanochat/ui.html: legacy UI reasoning toggle (kept for parity)

Tool execution verified live: query -> web_search via Tavily ->
Macron returned with grounded answer.
2026-04-22 13:43:43 -07:00

298 lines
11 KiB
Python

"""The main chat route: send a message and stream an assistant response."""
from __future__ import annotations
import uuid
from typing import Annotated, AsyncIterator
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sse_starlette.sse import EventSourceResponse
from ..config import Settings, get_settings
from ..database import get_session_factory
from ..logging_setup import get_logger
from ..middleware.auth_guard import AuthenticatedUser, require_user
from ..services import conversation_service
from ..services.inference_client import InferenceClient
from ..services.stream_proxy import StreamResult, proxy_inference_stream
logger = get_logger(__name__)
router = APIRouter(prefix="/api/conversations", tags=["messages"])
class SendMessageRequest(BaseModel):
content: str = Field(..., min_length=1, max_length=8000)
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
max_tokens: int | None = Field(default=None, ge=1, le=4096)
top_k: int | None = Field(default=None, ge=0, le=200)
thinking_mode: bool = Field(default=False)
class RegenerateRequest(BaseModel):
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
max_tokens: int | None = Field(default=None, ge=1, le=4096)
top_k: int | None = Field(default=None, ge=0, le=200)
thinking_mode: bool = Field(default=False)
# System prompts: tools are always implicitly available via the model's SFT training.
# The toggle only affects whether the model is nudged into <think>...</think> mode.
_SYS_DEFAULT = (
"You are samosaChaat, a helpful AI assistant created by Manmohan Sharma. "
"You have access to tools: use web_search for facts that may change over time or "
"require current information, and use calculator for arithmetic. Otherwise answer directly and concisely."
)
_SYS_THINK = (
"You are samosaChaat, a helpful AI assistant created by Manmohan Sharma. "
"You have access to tools: use web_search for facts that may change over time or "
"require current information, and use calculator for arithmetic. "
"Think step by step inside <think>...</think> tags, then give your final answer after the closing tag."
)
def _inject_system_prompt(history: list[dict[str, str]], thinking_mode: bool) -> list[dict[str, str]]:
"""Merge a system prompt into the first user message. Upstream Modal serve
ignores role='system', so we prepend the system prompt inline to the first
user turn — mirroring nanochat's tokenizer convention."""
if not history:
return history
sys_prompt = _SYS_THINK if thinking_mode else _SYS_DEFAULT
out = [dict(m) for m in history]
for m in out:
if m.get("role") == "user":
m["content"] = sys_prompt + "\n\n" + m.get("content", "")
break
return out
def _parse_uuid(raw: str) -> uuid.UUID:
try:
return uuid.UUID(raw)
except (ValueError, TypeError) as exc:
raise HTTPException(status_code=404, detail="conversation not found") from exc
async def _ensure_ownership(
session: AsyncSession,
*,
user_id: uuid.UUID,
conversation_id: uuid.UUID,
):
convo = await conversation_service.get_user_conversation(
session, user_id=user_id, conversation_id=conversation_id
)
if convo is None:
raise HTTPException(status_code=404, detail="conversation not found")
return convo
async def _stream_and_persist(
*,
request: Request,
user_id: uuid.UUID,
conversation_id: uuid.UUID,
history: list[dict[str, str]],
temperature: float | None,
max_tokens: int | None,
top_k: int | None,
model_tag: str,
first_message: bool,
first_message_preview: str | None,
settings: Settings,
) -> AsyncIterator[dict]:
"""Generator that streams inference SSE events to the client and, after the
stream closes, persists the full assistant message in a fresh DB session.
"""
http_client = getattr(request.app.state, "inference_http_client", None)
inference = InferenceClient(settings=settings, http_client=http_client)
accumulated: dict[str, StreamResult] = {}
def _capture(result: StreamResult) -> None:
accumulated["result"] = result
try:
async with inference.stream_generate(
messages=history,
temperature=temperature,
max_tokens=max_tokens,
top_k=top_k,
) as response:
async for event in proxy_inference_stream(response, on_complete=_capture):
yield event
except Exception as exc: # pragma: no cover - defensive path
logger.error(
"inference_stream_failed",
conversation_id=str(conversation_id),
error=str(exc),
)
yield {"data": '{"error":"inference_stream_failed"}'}
yield {"data": '{"done":true}'}
return
result = accumulated.get("result")
if result is None or not result.completed or not result.content:
logger.warning(
"assistant_message_not_persisted",
conversation_id=str(conversation_id),
completed=bool(result and result.completed),
content_len=len(result.content) if result else 0,
)
return
factory: async_sessionmaker[AsyncSession] = get_session_factory()
try:
async with factory() as persist_session:
await conversation_service.append_message(
persist_session,
conversation_id=conversation_id,
role="assistant",
content=result.content,
token_count=result.token_count,
model_tag=model_tag,
inference_time_ms=result.inference_time_ms,
)
if first_message and first_message_preview is not None:
await conversation_service.update_conversation_title(
persist_session,
user_id=user_id,
conversation_id=conversation_id,
title=first_message_preview,
)
logger.info(
"assistant_message_persisted",
conversation_id=str(conversation_id),
token_count=result.token_count,
inference_time_ms=result.inference_time_ms,
)
except Exception as exc: # pragma: no cover - defensive path
logger.error(
"assistant_message_persist_failed",
conversation_id=str(conversation_id),
error=str(exc),
)
@router.post("/{conversation_id}/messages")
async def send_message(
conversation_id: str,
body: SendMessageRequest,
request: Request,
user: Annotated[AuthenticatedUser, Depends(require_user)],
settings: Annotated[Settings, Depends(get_settings)] = None, # type: ignore[assignment]
):
# We own our DB sessions explicitly here so the background task that
# persists the streamed assistant response can open its own session once
# the request scope has already closed.
conv_uuid = _parse_uuid(conversation_id)
user_uuid = uuid.UUID(user.id)
session_factory = get_session_factory()
async with session_factory() as db_session:
convo = await _ensure_ownership(
db_session, user_id=user_uuid, conversation_id=conv_uuid
)
model_tag = convo.model_tag or "default"
existing = await conversation_service.get_conversation_messages(
db_session, conversation_id=conv_uuid, limit=1
)
first_message = len(existing) == 0
first_preview = body.content[:80] if first_message else None
await conversation_service.append_message(
db_session,
conversation_id=conv_uuid,
role="user",
content=body.content,
token_count=None,
model_tag=model_tag,
)
history = await conversation_service.build_history_for_inference(
db_session, conversation_id=conv_uuid
)
# Inject system prompt (direct or think mode) into the first user message,
# since upstream Modal serve ignores role='system'.
history_with_sys = _inject_system_prompt(history, body.thinking_mode)
logger.info(
"send_message",
conversation_id=str(conv_uuid),
history_len=len(history),
model_tag=model_tag,
thinking_mode=body.thinking_mode,
)
generator = _stream_and_persist(
request=request,
user_id=user_uuid,
conversation_id=conv_uuid,
history=history_with_sys,
temperature=body.temperature,
max_tokens=body.max_tokens,
top_k=body.top_k,
model_tag=model_tag,
first_message=first_message,
first_message_preview=first_preview,
settings=settings,
)
return EventSourceResponse(generator, media_type="text/event-stream")
@router.post("/{conversation_id}/regenerate")
async def regenerate(
conversation_id: str,
body: RegenerateRequest,
request: Request,
user: Annotated[AuthenticatedUser, Depends(require_user)],
settings: Annotated[Settings, Depends(get_settings)] = None, # type: ignore[assignment]
):
conv_uuid = _parse_uuid(conversation_id)
user_uuid = uuid.UUID(user.id)
session_factory = get_session_factory()
async with session_factory() as db_session:
convo = await _ensure_ownership(
db_session, user_id=user_uuid, conversation_id=conv_uuid
)
model_tag = convo.model_tag or "default"
await conversation_service.delete_last_assistant_message(
db_session, conversation_id=conv_uuid
)
history = await conversation_service.build_history_for_inference(
db_session, conversation_id=conv_uuid
)
if not history:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="conversation has no user messages to regenerate from",
)
history_with_sys = _inject_system_prompt(history, body.thinking_mode)
logger.info(
"regenerate_message",
conversation_id=str(conv_uuid),
history_len=len(history),
thinking_mode=body.thinking_mode,
)
generator = _stream_and_persist(
request=request,
user_id=user_uuid,
conversation_id=conv_uuid,
history=history_with_sys,
temperature=body.temperature,
max_tokens=body.max_tokens,
top_k=body.top_k,
model_tag=model_tag,
first_message=False,
first_message_preview=None,
settings=settings,
)
return EventSourceResponse(generator, media_type="text/event-stream")