feat: deploy d24-sft-r6 with full reasoning mode + live tool use (Tavily)

Model R6 (97% pass rate on 33-probe eval, val_bpb 0.2635):
- modal/serve.py + modal/_tools.py: tool-aware streaming with
  TavilySearchBackend auto-detect, python_start/end state machine,
  output_start/end forcing; mount tavily secret
- modal/serve.py: MODEL_TAG=d24-sft-r6, model path points at new SFT r6
- services/chat-api/routes/messages.py: accept thinking_mode flag,
  inject samosaChaat system prompt (direct or <think> variant) into
  first user message before streaming to Modal
- services/frontend/components/chat/ChatInput.tsx: Brain toggle
  'Think' button next to send; when active, model uses think mode
- services/frontend/components/chat/ChatWindow.tsx: track
  thinkingMode state, pass through to API body as thinking_mode
- services/frontend/components/chat/MessageBubble.tsx: parse and
  render <think>...</think> as collapsible italic blocks;
  <|python_start|>...<|python_end|> as tool-call cards with icons
  per tool name; <|output_start|>...<|output_end|> as result cards
  with expandable JSON
- nanochat/tools.py: TavilySearchBackend class + auto-detect
- nanochat/ui.html: legacy UI reasoning toggle (kept for parity)

Tool execution verified live: query -> web_search via Tavily ->
Macron returned with grounded answer.
This commit is contained in:
Manmohan 2026-04-22 20:38:21 +00:00 committed by Manmohan Sharma
parent 67f568a4f2
commit 3ab89e7890
No known key found for this signature in database
8 changed files with 972 additions and 56 deletions

554
modal/_tools.py Normal file
View File

@ -0,0 +1,554 @@
"""
Shared tool definitions for nanochat.
The current tokenizer only has python/output special tokens. To preserve
checkpoint compatibility, we reuse those tokens as a generic tool-call and
tool-result channel. Legacy "python" calculator payloads still work.
"""
from __future__ import annotations
import ast
import json
import math
import os
import time
from dataclasses import dataclass, field
from typing import Any, Protocol
import requests
TOOL_CALL_START = "<|python_start|>"
TOOL_CALL_END = "<|python_end|>"
TOOL_RESULT_START = "<|output_start|>"
TOOL_RESULT_END = "<|output_end|>"
MAX_TOOL_PAYLOAD_CHARS = 4096
DEFAULT_TOOL_SCHEMA = [
{
"name": "calculator",
"description": "Deterministic scientific calculator for exact arithmetic and common finance formulas.",
"arguments": {
"expression": "String expression using numbers, operators, and supported functions.",
},
},
{
"name": "web_search",
"description": "Search and fetch web content. Requires a search backend and optionally a page fetch client.",
"arguments": {
"query": "Search query string.",
"top_k": "Maximum number of results to return.",
"urls": "Optional explicit URLs to fetch instead of searching.",
},
},
]
def _compact_json(data: Any) -> str:
return json.dumps(data, ensure_ascii=True, separators=(",", ":"), sort_keys=True)
@dataclass
class ToolInvocation:
tool_name: str
arguments: dict[str, Any]
raw_text: str = ""
@dataclass
class ToolResult:
tool_name: str
success: bool
output: Any = None
error: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
def to_payload(self) -> str:
return _compact_json(
{
"tool": self.tool_name,
"success": self.success,
"output": self.output,
"error": self.error,
"metadata": self.metadata,
}
)
class BaseTool:
name: str
def run(self, arguments: dict[str, Any]) -> ToolResult:
raise NotImplementedError
class ToolRegistry:
def __init__(self, tools: list[BaseTool] | tuple[BaseTool, ...]):
self._tools = {tool.name: tool for tool in tools}
def execute(self, tool_name: str, arguments: dict[str, Any]) -> ToolResult:
tool = self._tools.get(tool_name)
if tool is None:
return ToolResult(tool_name=tool_name, success=False, error=f"Unknown tool: {tool_name}")
try:
return tool.run(arguments)
except Exception as exc: # defensive: tool failures should become model-visible outputs
return ToolResult(tool_name=tool_name, success=False, error=str(exc))
def schema(self) -> list[dict[str, Any]]:
return [item for item in DEFAULT_TOOL_SCHEMA if item["name"] in self._tools]
def serialize_tool_call(tool_name: str, arguments: dict[str, Any] | None = None) -> str:
payload = {
"tool": tool_name,
"arguments": arguments or {},
}
text = _compact_json(payload)
return text[:MAX_TOOL_PAYLOAD_CHARS]
def serialize_tool_result(
tool_name: str,
output: Any = None,
*,
success: bool = True,
error: str | None = None,
metadata: dict[str, Any] | None = None,
) -> str:
return ToolResult(
tool_name=tool_name,
success=success,
output=output,
error=error,
metadata=metadata or {},
).to_payload()[:MAX_TOOL_PAYLOAD_CHARS]
def parse_tool_call_payload(text: str) -> ToolInvocation:
stripped = text.strip()
if not stripped:
return ToolInvocation(tool_name="calculator", arguments={"expression": ""}, raw_text=text)
try:
payload = json.loads(stripped)
except json.JSONDecodeError:
return ToolInvocation(tool_name="calculator", arguments={"expression": stripped}, raw_text=text)
if isinstance(payload, dict):
tool_name = payload.get("tool") or payload.get("tool_name") or payload.get("name")
arguments = payload.get("arguments") or payload.get("args") or {}
if isinstance(tool_name, str) and isinstance(arguments, dict):
return ToolInvocation(tool_name=tool_name, arguments=arguments, raw_text=text)
return ToolInvocation(tool_name="calculator", arguments={"expression": stripped}, raw_text=text)
def parse_tool_result_payload(text: str) -> ToolResult | None:
stripped = text.strip()
try:
payload = json.loads(stripped)
except json.JSONDecodeError:
return None
if not isinstance(payload, dict):
return None
tool_name = payload.get("tool")
if not isinstance(tool_name, str):
return None
return ToolResult(
tool_name=tool_name,
success=bool(payload.get("success", True)),
output=payload.get("output"),
error=payload.get("error"),
metadata=payload.get("metadata") or {},
)
def _percent(value: float, rate: float) -> float:
return value * rate / 100.0
def _percent_change(old: float, new: float) -> float:
if old == 0:
raise ValueError("percent_change old value cannot be zero")
return ((new - old) / old) * 100.0
def _cagr(start: float, end: float, years: float) -> float:
if start <= 0 or end <= 0 or years <= 0:
raise ValueError("cagr inputs must be positive")
return ((end / start) ** (1.0 / years) - 1.0) * 100.0
def _simple_interest(principal: float, annual_rate: float, years: float) -> float:
return principal * annual_rate / 100.0 * years
def _compound_interest(principal: float, annual_rate: float, periods_per_year: float, years: float) -> float:
if periods_per_year <= 0:
raise ValueError("periods_per_year must be positive")
return principal * (1.0 + annual_rate / 100.0 / periods_per_year) ** (periods_per_year * years)
def _emi(principal: float, annual_rate: float, months: float) -> float:
if months <= 0:
raise ValueError("months must be positive")
monthly_rate = annual_rate / 100.0 / 12.0
if monthly_rate == 0:
return principal / months
growth = (1.0 + monthly_rate) ** months
return principal * monthly_rate * growth / (growth - 1.0)
ALLOWED_BINOPS = {
ast.Add: lambda a, b: a + b,
ast.Sub: lambda a, b: a - b,
ast.Mult: lambda a, b: a * b,
ast.Div: lambda a, b: a / b,
ast.Pow: lambda a, b: a ** b,
ast.Mod: lambda a, b: a % b,
}
ALLOWED_UNARYOPS = {
ast.UAdd: lambda a: a,
ast.USub: lambda a: -a,
}
ALLOWED_NAMES = {
"pi": math.pi,
"e": math.e,
"tau": math.tau,
}
ALLOWED_FUNCTIONS = {
"abs": abs,
"round": round,
"floor": math.floor,
"ceil": math.ceil,
"sqrt": math.sqrt,
"log": math.log,
"log10": math.log10,
"exp": math.exp,
"sin": math.sin,
"cos": math.cos,
"tan": math.tan,
"asin": math.asin,
"acos": math.acos,
"atan": math.atan,
"degrees": math.degrees,
"radians": math.radians,
"percent": _percent,
"percent_change": _percent_change,
"cagr": _cagr,
"simple_interest": _simple_interest,
"compound_interest": _compound_interest,
"emi": _emi,
}
class _SafeMathEvaluator:
def __init__(self, expression: str):
self.expression = expression
self.node_count = 0
def eval(self) -> float:
if len(self.expression) > 512:
raise ValueError("expression too long")
parsed = ast.parse(self.expression, mode="eval")
return self._visit(parsed.body)
def _visit(self, node: ast.AST) -> Any:
self.node_count += 1
if self.node_count > 128:
raise ValueError("expression too complex")
if isinstance(node, ast.Constant):
if isinstance(node.value, (int, float)):
return node.value
raise ValueError(f"unsupported constant: {node.value!r}")
if isinstance(node, ast.Num): # pragma: no cover - py<3.8 compatibility
return node.n
if isinstance(node, ast.BinOp):
op = ALLOWED_BINOPS.get(type(node.op))
if op is None:
raise ValueError(f"unsupported operator: {type(node.op).__name__}")
return op(self._visit(node.left), self._visit(node.right))
if isinstance(node, ast.UnaryOp):
op = ALLOWED_UNARYOPS.get(type(node.op))
if op is None:
raise ValueError(f"unsupported unary operator: {type(node.op).__name__}")
return op(self._visit(node.operand))
if isinstance(node, ast.Name):
if node.id not in ALLOWED_NAMES:
raise ValueError(f"unknown symbol: {node.id}")
return ALLOWED_NAMES[node.id]
if isinstance(node, ast.Call):
if not isinstance(node.func, ast.Name):
raise ValueError("only direct function calls are allowed")
fn = ALLOWED_FUNCTIONS.get(node.func.id)
if fn is None:
raise ValueError(f"unsupported function: {node.func.id}")
if node.keywords:
raise ValueError("keyword arguments are not supported")
args = [self._visit(arg) for arg in node.args]
return fn(*args)
raise ValueError(f"unsupported expression node: {type(node).__name__}")
def _normalize_numeric_output(value: Any) -> Any:
if isinstance(value, float):
if not math.isfinite(value):
raise ValueError("result is not finite")
return float(f"{value:.12g}")
return value
class CalculatorTool(BaseTool):
name = "calculator"
def run(self, arguments: dict[str, Any]) -> ToolResult:
expression = str(arguments.get("expression", "")).strip()
if not expression:
return ToolResult(tool_name=self.name, success=False, error="Missing expression")
value = _SafeMathEvaluator(expression).eval()
return ToolResult(
tool_name=self.name,
success=True,
output={"expression": expression, "value": _normalize_numeric_output(value)},
)
@dataclass
class SearchHit:
url: str
title: str = ""
snippet: str = ""
class SearchBackend(Protocol):
def search(self, query: str, top_k: int) -> list[SearchHit]:
raise NotImplementedError
class MockSearchBackend:
def __init__(self, canned_results: dict[str, list[dict[str, str]]] | None = None):
self.canned_results = canned_results or {
"browser rendering markdown": [
{
"url": "https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/",
"title": "Cloudflare markdown endpoint",
"snippet": "Extract markdown from a webpage using Cloudflare Browser Rendering.",
}
],
"nanochat gpt2 speedrun": [
{
"url": "https://github.com/karpathy/nanochat",
"title": "karpathy/nanochat",
"snippet": "Minimal LLM training harness with pretraining, SFT, RL, and chat UI.",
}
],
}
def search(self, query: str, top_k: int) -> list[SearchHit]:
normalized = query.strip().lower()
rows = self.canned_results.get(normalized, [])
return [SearchHit(**row) for row in rows[:top_k]]
class TavilySearchBackend:
"""LLM-optimized web search via Tavily. Falls back silently on errors."""
def __init__(self, api_key: str | None = None, timeout: float = 15.0):
self.api_key = api_key or os.environ.get('TAVILY_API_KEY')
if not self.api_key:
raise ValueError('TavilySearchBackend requires TAVILY_API_KEY')
self.timeout = timeout
def search(self, query: str, top_k: int) -> list[SearchHit]:
import requests
try:
r = requests.post(
'https://api.tavily.com/search',
json={
'api_key': self.api_key,
'query': query,
'max_results': max(1, min(int(top_k), 8)),
'include_answer': False,
'include_raw_content': False,
'search_depth': 'basic',
},
timeout=self.timeout,
)
r.raise_for_status()
data = r.json()
except Exception:
return []
return [
SearchHit(
url=h.get('url', ''),
title=h.get('title', ''),
snippet=h.get('content', ''),
)
for h in data.get('results', [])[:top_k]
]
class CloudflareBrowserRenderingClient:
def __init__(
self,
*,
api_token: str | None = None,
account_id: str | None = None,
base_url: str = "https://api.cloudflare.com/client/v4",
timeout: float = 30.0,
max_retries: int = 3,
):
self.api_token = api_token or os.environ.get("CLOUDFLARE_API_TOKEN")
self.account_id = account_id or os.environ.get("CLOUDFLARE_ACCOUNT_ID")
if not self.api_token or not self.account_id:
raise ValueError("Cloudflare Browser Rendering requires CLOUDFLARE_API_TOKEN and CLOUDFLARE_ACCOUNT_ID")
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.max_retries = max_retries
self.session = requests.Session()
self.session.headers.update(
{
"Authorization": f"Bearer {self.api_token}",
"Content-Type": "application/json",
}
)
def _post(self, endpoint: str, body: dict[str, Any]) -> Any:
url = f"{self.base_url}/accounts/{self.account_id}/browser-rendering/{endpoint}"
last_error = None
for attempt in range(1, self.max_retries + 1):
response = self.session.post(url, json=body, timeout=self.timeout)
if response.status_code == 429:
retry_after = response.headers.get("Retry-After")
sleep_seconds = float(retry_after) if retry_after else float(attempt)
last_error = RuntimeError(f"Cloudflare Browser Rendering rate limited on {endpoint}")
time.sleep(min(sleep_seconds, 5.0))
continue
response.raise_for_status()
payload = response.json()
if not payload.get("success", False):
errors = payload.get("errors", [])
last_error = RuntimeError(f"Cloudflare Browser Rendering request failed: {errors}")
break
return payload.get("result")
if last_error is not None:
raise last_error
raise RuntimeError(f"Cloudflare Browser Rendering request failed for {endpoint}")
def markdown(self, url: str, **options: Any) -> str:
body = {"url": url}
body.update(options)
return self._post("markdown", body)
def links(self, url: str, **options: Any) -> list[str]:
body = {"url": url}
body.update(options)
return self._post("links", body)
def json_extract(self, url: str, *, prompt: str | None = None, schema: dict[str, Any] | None = None, **options: Any) -> dict[str, Any]:
body: dict[str, Any] = {"url": url}
if prompt is not None:
body["prompt"] = prompt
if schema is not None:
body["schema"] = schema
body.update(options)
return self._post("json", body)
class WebSearchTool(BaseTool):
name = "web_search"
def __init__(
self,
*,
search_backend: SearchBackend | None = None,
fetch_client: CloudflareBrowserRenderingClient | None = None,
max_results: int = 3,
):
self.search_backend = search_backend
self.fetch_client = fetch_client
self.max_results = max_results
def run(self, arguments: dict[str, Any]) -> ToolResult:
query = str(arguments.get("query", "")).strip()
requested_urls = arguments.get("urls") or []
if isinstance(requested_urls, str):
requested_urls = [requested_urls]
top_k = int(arguments.get("top_k", self.max_results) or self.max_results)
top_k = max(1, min(top_k, 8))
if not query and not requested_urls:
return ToolResult(tool_name=self.name, success=False, error="Missing query or urls")
hits: list[SearchHit]
if requested_urls:
hits = [SearchHit(url=str(url)) for url in requested_urls[:top_k]]
else:
if self.search_backend is None:
return ToolResult(
tool_name=self.name,
success=False,
error="No search backend configured. Cloudflare Browser Rendering can fetch pages but does not provide public web search by itself.",
)
hits = self.search_backend.search(query, top_k)
results = []
for hit in hits[:top_k]:
entry: dict[str, Any] = {"url": hit.url}
if hit.title:
entry["title"] = hit.title
if hit.snippet:
entry["snippet"] = hit.snippet
if self.fetch_client is not None:
try:
markdown = self.fetch_client.markdown(hit.url)
links = self.fetch_client.links(hit.url)
entry["markdown"] = markdown[:4000]
entry["links"] = links[:10]
except Exception as exc:
entry["fetch_error"] = str(exc)
results.append(entry)
return ToolResult(
tool_name=self.name,
success=True,
output={"query": query, "results": results},
metadata={
"search_backend": type(self.search_backend).__name__ if self.search_backend is not None else None,
"fetch_backend": type(self.fetch_client).__name__ if self.fetch_client is not None else None,
"num_results": len(results),
},
)
def build_default_tool_registry(
*,
cloudflare_token: str | None = None,
cloudflare_account_id: str | None = None,
search_backend: SearchBackend | None = None,
) -> ToolRegistry:
fetch_client = None
if cloudflare_token or os.environ.get("CLOUDFLARE_API_TOKEN"):
try:
fetch_client = CloudflareBrowserRenderingClient(
api_token=cloudflare_token,
account_id=cloudflare_account_id,
)
except Exception:
fetch_client = None
if search_backend is None:
if os.environ.get('TAVILY_API_KEY'):
try:
search_backend = TavilySearchBackend()
except Exception:
search_backend = MockSearchBackend()
else:
search_backend = MockSearchBackend()
registry = ToolRegistry(
[
CalculatorTool(),
WebSearchTool(
search_backend=search_backend,
fetch_client=fetch_client,
),
]
)
return registry

View File

@ -20,14 +20,15 @@ import modal
# Configuration
# ---------------------------------------------------------------------------
MODEL_REPO = "ManmohanSharma/nanochat-d24"
MODEL_PT = "chatsft_checkpoints/d24/model_000484.pt"
META_JSON = "chatsft_checkpoints/d24/meta_000484.json"
MODEL_PT = "chatsft_checkpoints/d24-sft-r6/model_000754.pt"
META_JSON = "chatsft_checkpoints/d24-sft-r6/meta_000754.json"
TOKENIZER_PKL = "tokenizer/tokenizer.pkl"
TOKEN_BYTES = "tokenizer/token_bytes.pt"
MODEL_TAG = "d24-sft"
MODEL_TAG = "d24-sft-r6"
GPU_TYPE = "L4" # 24 GB VRAM — fits 4 GB bf16 model loaded as fp32
VOLUME_NAME = "samosachaat-weights"
HF_SECRET_NAME = "huggingface" # Modal secret containing HF_TOKEN
TAVILY_SECRET_NAME = "tavily" # Modal secret containing TAVILY_API_KEY
# ---------------------------------------------------------------------------
# Modal app + image
@ -42,12 +43,14 @@ inference_image = (
"tiktoken>=0.11.0",
"tokenizers>=0.22.0",
"huggingface_hub>=0.25.0",
"requests>=2.31.0",
"fastapi>=0.115.0",
"uvicorn>=0.30.0",
extra_index_url="https://download.pytorch.org/whl/cu124",
)
.add_local_file("modal/_model.py", "/root/_model.py")
.add_local_file("modal/_tokenizer.py", "/root/_tokenizer.py")
.add_local_file("modal/_tools.py", "/root/_tools.py")
)
# Persistent volume for model weights
@ -104,6 +107,7 @@ def download_weights():
scaledown_window=300, # keep warm for 5 min after last request
# concurrency handled by @modal.concurrent below
timeout=120,
secrets=[modal.Secret.from_name(TAVILY_SECRET_NAME)],
)
class Inference:
model: object
@ -190,8 +194,22 @@ class Inference:
self.assistant_end_id = self.tokenizer.encode_special("<|assistant_end|>")[0]
print(f" Special token IDs: {sorted(self.special_token_ids)}")
# Initialize tool registry (Tavily web_search + calculator)
import sys as _sys
if '/root' not in _sys.path: _sys.path.insert(0, '/root')
from _tools import build_default_tool_registry, parse_tool_call_payload
self.tool_registry = build_default_tool_registry()
self._parse_tool_call = parse_tool_call_payload
# Marker tokens for tool state machine
self.python_start_id = self.tokenizer.encode_special("<|python_start|>")[0]
self.python_end_id = self.tokenizer.encode_special("<|python_end|>")[0]
self.output_start_id = self.tokenizer.encode_special("<|output_start|>")[0]
self.output_end_id = self.tokenizer.encode_special("<|output_end|>")[0]
# Stop tokens (exclude tool markers so generation continues through tool calls)
self._stop_token_ids = {self.assistant_end_id, self.tokenizer.get_bos_token_id() if hasattr(self.tokenizer, "get_bos_token_id") else self.tokenizer.encode_special("<|bos|>")[0]}
dt = time.time() - t0
print(f"Model loaded in {dt:.1f}s on {device}")
print(f"Model loaded in {dt:.1f}s on {device} | tools: {[t for t in self.tool_registry._tools.keys()] if hasattr(self.tool_registry, '_tools') else 'registered'}")
@modal.fastapi_endpoint(method="POST", docs=True)
async def generate(self, request: dict):
@ -236,49 +254,74 @@ class Inference:
tokens = tokens[-max_context:]
async def stream():
from collections import deque
input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device)
forced = deque()
in_tool = False
tool_payload_ids = []
def _append_token(tid):
nonlocal input_ids
nt = torch.tensor([[tid]], dtype=torch.long, device=self.device)
input_ids = torch.cat([input_ids, nt], dim=1)
if input_ids.size(1) > self.config.sequence_len:
input_ids = input_ids[:, -self.config.sequence_len:]
with torch.no_grad():
generated = []
for _ in range(max_tokens):
# Forward pass
logits = self.model(input_ids)
next_logits = logits[:, -1, :]
num_generated = 0
while num_generated < max_tokens:
if forced:
token_id = forced.popleft()
else:
logits = self.model(input_ids)
next_logits = logits[:, -1, :]
if temperature > 0:
next_logits = next_logits / temperature
if top_k > 0:
v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
next_logits[next_logits < v[:, [-1]]] = float('-inf')
probs = torch.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
token_id = next_token.item()
# Temperature
if temperature > 0:
next_logits = next_logits / temperature
# Tool state machine: detect <|python_start|>...<|python_end|>,
# execute tool, inject <|output_start|>...<|output_end|> as forced tokens
if token_id == self.python_start_id:
in_tool = True
tool_payload_ids = []
elif token_id == self.python_end_id and in_tool:
in_tool = False
if tool_payload_ids:
try:
payload_text = self.tokenizer.decode(tool_payload_ids)
invocation = self._parse_tool_call(payload_text)
result = self.tool_registry.execute(invocation.tool_name, invocation.arguments)
result_text = result.to_payload()[:4096]
except Exception as exc:
result_text = json.dumps({"error": str(exc)[:500]})
if result_text:
forced.append(self.output_start_id)
forced.extend(self.tokenizer.encode(result_text))
forced.append(self.output_end_id)
tool_payload_ids = []
elif in_tool:
tool_payload_ids.append(token_id)
# Top-k filtering
if top_k > 0:
v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
next_logits[next_logits < v[:, [-1]]] = float('-inf')
# Sample
probs = torch.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
token_id = next_token.item()
# Stop on any special token (assistant_end, bos, etc.)
if token_id in self.special_token_ids:
# Stop only on assistant_end or bos (NOT on tool markers)
if token_id in self._stop_token_ids:
break
# Decode and yield (skip tokens that can't be decoded)
# Decode + stream to client (includes tool markers; UI renders)
try:
token_text = self.tokenizer.decode([token_id])
except (KeyError, Exception):
continue
yield f"data: {json.dumps({'token': token_text, 'gpu': 0})}\n\n"
yield "data: " + json.dumps({"token": token_text, "gpu": 0}) + "\n\n"
except Exception:
pass
# Append for next iteration
input_ids = torch.cat([input_ids, next_token], dim=1)
_append_token(token_id)
num_generated += 1
# Truncate if exceeding sequence length
if input_ids.size(1) > self.config.sequence_len:
input_ids = input_ids[:, -self.config.sequence_len:]
yield f"data: {json.dumps({'done': True})}\n\n"
yield "data: " + json.dumps({"done": True}) + "\n\n"
return StreamingResponse(
stream(),

View File

@ -348,6 +348,43 @@ class MockSearchBackend:
rows = self.canned_results.get(normalized, [])
return [SearchHit(**row) for row in rows[:top_k]]
class TavilySearchBackend:
"""LLM-optimized web search via Tavily. Falls back silently on errors."""
def __init__(self, api_key: str | None = None, timeout: float = 15.0):
self.api_key = api_key or os.environ.get('TAVILY_API_KEY')
if not self.api_key:
raise ValueError('TavilySearchBackend requires TAVILY_API_KEY')
self.timeout = timeout
def search(self, query: str, top_k: int) -> list[SearchHit]:
import requests
try:
r = requests.post(
'https://api.tavily.com/search',
json={
'api_key': self.api_key,
'query': query,
'max_results': max(1, min(int(top_k), 8)),
'include_answer': False,
'include_raw_content': False,
'search_depth': 'basic',
},
timeout=self.timeout,
)
r.raise_for_status()
data = r.json()
except Exception:
return []
return [
SearchHit(
url=h.get('url', ''),
title=h.get('title', ''),
snippet=h.get('content', ''),
)
for h in data.get('results', [])[:top_k]
]
class CloudflareBrowserRenderingClient:
def __init__(
@ -497,11 +534,19 @@ def build_default_tool_registry(
)
except Exception:
fetch_client = None
if search_backend is None:
if os.environ.get('TAVILY_API_KEY'):
try:
search_backend = TavilySearchBackend()
except Exception:
search_backend = MockSearchBackend()
else:
search_backend = MockSearchBackend()
registry = ToolRegistry(
[
CalculatorTool(),
WebSearchTool(
search_backend=search_backend or MockSearchBackend(),
search_backend=search_backend,
fetch_client=fetch_client,
),
]

View File

@ -459,6 +459,52 @@
.illust-right svg.kettle-svg { width: 70px; }
.explore-tag, .chai-label { font-size: 0.85rem; }
}
/* --- Reasoning mode --- */
.think-toggle {
background: transparent;
border: 1px solid rgba(255,255,255,0.15);
color: #b8a88a;
cursor: pointer;
padding: 6px 10px;
border-radius: 6px;
margin-right: 8px;
font-size: 0.85rem;
display: flex;
align-items: center;
gap: 4px;
transition: all 0.2s;
}
.think-toggle:hover { background: rgba(255,255,255,0.05); }
.think-toggle.active {
background: rgba(184,168,138,0.15);
border-color: #b8a88a;
color: #fff;
}
.think-block {
background: rgba(100,100,100,0.08);
border-left: 3px solid #777;
padding: 10px 14px;
margin-bottom: 10px;
font-style: italic;
color: #999;
font-size: 0.88em;
white-space: pre-wrap;
border-radius: 4px;
}
.think-block::before {
content: "\1F4AD" " thinking";
display: block;
font-weight: 600;
margin-bottom: 6px;
color: #888;
font-style: normal;
font-size: 0.85em;
letter-spacing: 0.05em;
text-transform: uppercase;
}
.answer-block { white-space: pre-wrap; }
</style>
</head>
<body>
@ -609,6 +655,10 @@
<div class="input-container landing-mode" id="inputContainer">
<div class="input-wrapper">
<textarea id="chatInput" class="chat-input" placeholder="Ask samosaChaat anything..." rows="1" onkeydown="handleKeyDown(event)"></textarea>
<button id="thinkToggle" class="think-toggle" onclick="toggleReasoning()" title="Reasoning mode (think step-by-step)" type="button">
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M9.5 2A2.5 2.5 0 0 1 12 4.5v15a2.5 2.5 0 0 1-4.96.44 2.5 2.5 0 0 1-2.96-3.08 3 3 0 0 1-.34-5.58 2.5 2.5 0 0 1 1.32-4.24 2.5 2.5 0 0 1 1.98-3A2.5 2.5 0 0 1 9.5 2Z"></path><path d="M14.5 2A2.5 2.5 0 0 0 12 4.5v15a2.5 2.5 0 0 0 4.96.44 2.5 2.5 0 0 0 2.96-3.08 3 3 0 0 0 .34-5.58 2.5 2.5 0 0 0-1.32-4.24 2.5 2.5 0 0 0-1.98-3A2.5 2.5 0 0 0 14.5 2Z"></path></svg>
<span>Think</span>
</button>
<button id="sendButton" class="send-button" onclick="sendMessage()" disabled>
<svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M22 2L11 13"></path><path d="M22 2l-7 20-4-9-9-4 20-7z"></path>
@ -639,6 +689,53 @@
let isChatMode = false;
let currentTemperature = 0.8;
let currentTopK = 50;
let reasoningMode = false;
const SYS_DIRECT = "You are samosaChaat, a helpful AI assistant. Answer directly and concisely.";
const SYS_THINK = "You are samosaChaat, a helpful AI assistant. Think step by step inside <think>...</think> tags, then give your final answer.";
function toggleReasoning() {
reasoningMode = !reasoningMode;
const btn = document.getElementById("thinkToggle");
if (btn) btn.classList.toggle("active", reasoningMode);
}
function buildApiMessages() {
const out = messages.map(m => ({ role: m.role, content: m.content }));
if (out.length && out[0].role === "user") {
const sys = reasoningMode ? SYS_THINK : SYS_DIRECT;
out[0].content = sys + "
" + out[0].content;
}
return out;
}
function renderAssistantContent(fullText, container) {
// Parse <think>...</think> blocks and render specially
const openIdx = fullText.indexOf("<think>");
if (openIdx === -1) { container.textContent = fullText; return; }
const closeIdx = fullText.indexOf("</think>", openIdx);
container.innerHTML = "";
if (openIdx > 0) {
const pre = document.createElement("div");
pre.className = "answer-block";
pre.textContent = fullText.slice(0, openIdx);
container.appendChild(pre);
}
const thinkText = closeIdx >= 0 ? fullText.slice(openIdx+7, closeIdx) : fullText.slice(openIdx+7);
const thinkDiv = document.createElement("div");
thinkDiv.className = "think-block";
thinkDiv.textContent = thinkText;
container.appendChild(thinkDiv);
if (closeIdx >= 0) {
const after = fullText.slice(closeIdx+8);
const ansDiv = document.createElement("div");
ansDiv.className = "answer-block";
ansDiv.textContent = after;
container.appendChild(ansDiv);
}
}
// ================================================================
// TRANSITION: Landing → Chat
@ -776,7 +873,7 @@
assistantContent.textContent = '';
for await (const token of window.samosaChaat.generateLocal(messages)) {
fullResponse += token;
assistantContent.textContent = fullResponse;
renderAssistantContent(fullResponse, assistantContent);
chatContainer.scrollTop = chatContainer.scrollHeight;
}
const idx = messages.length;
@ -788,7 +885,7 @@
const response = await fetch(`${API_URL}/chat/completions`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ messages, temperature: currentTemperature, top_k: currentTopK, max_tokens: 512 }),
body: JSON.stringify({ messages: buildApiMessages(), temperature: currentTemperature, top_k: currentTopK, max_tokens: 512 }),
});
if (!response.ok) throw new Error(`HTTP error! status: ${response.status}`);
const reader = response.body.getReader();
@ -802,7 +899,7 @@
if (line.startsWith('data: ')) {
try {
const data = JSON.parse(line.slice(6));
if (data.token) { fullResponse += data.token; assistantContent.textContent = fullResponse; chatContainer.scrollTop = chatContainer.scrollHeight; }
if (data.token) { fullResponse += data.token; renderAssistantContent(fullResponse, assistantContent); chatContainer.scrollTop = chatContainer.scrollHeight; }
} catch (e) {}
}
}

View File

@ -27,12 +27,44 @@ class SendMessageRequest(BaseModel):
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
max_tokens: int | None = Field(default=None, ge=1, le=4096)
top_k: int | None = Field(default=None, ge=0, le=200)
thinking_mode: bool = Field(default=False)
class RegenerateRequest(BaseModel):
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
max_tokens: int | None = Field(default=None, ge=1, le=4096)
top_k: int | None = Field(default=None, ge=0, le=200)
thinking_mode: bool = Field(default=False)
# System prompts: tools are always implicitly available via the model's SFT training.
# The toggle only affects whether the model is nudged into <think>...</think> mode.
_SYS_DEFAULT = (
"You are samosaChaat, a helpful AI assistant created by Manmohan Sharma. "
"You have access to tools: use web_search for facts that may change over time or "
"require current information, and use calculator for arithmetic. Otherwise answer directly and concisely."
)
_SYS_THINK = (
"You are samosaChaat, a helpful AI assistant created by Manmohan Sharma. "
"You have access to tools: use web_search for facts that may change over time or "
"require current information, and use calculator for arithmetic. "
"Think step by step inside <think>...</think> tags, then give your final answer after the closing tag."
)
def _inject_system_prompt(history: list[dict[str, str]], thinking_mode: bool) -> list[dict[str, str]]:
"""Merge a system prompt into the first user message. Upstream Modal serve
ignores role='system', so we prepend the system prompt inline to the first
user turn mirroring nanochat's tokenizer convention."""
if not history:
return history
sys_prompt = _SYS_THINK if thinking_mode else _SYS_DEFAULT
out = [dict(m) for m in history]
for m in out:
if m.get("role") == "user":
m["content"] = sys_prompt + "\n\n" + m.get("content", "")
break
return out
def _parse_uuid(raw: str) -> uuid.UUID:
@ -182,18 +214,23 @@ async def send_message(
db_session, conversation_id=conv_uuid
)
# Inject system prompt (direct or think mode) into the first user message,
# since upstream Modal serve ignores role='system'.
history_with_sys = _inject_system_prompt(history, body.thinking_mode)
logger.info(
"send_message",
conversation_id=str(conv_uuid),
history_len=len(history),
model_tag=model_tag,
thinking_mode=body.thinking_mode,
)
generator = _stream_and_persist(
request=request,
user_id=user_uuid,
conversation_id=conv_uuid,
history=history,
history=history_with_sys,
temperature=body.temperature,
max_tokens=body.max_tokens,
top_k=body.top_k,
@ -235,17 +272,20 @@ async def regenerate(
detail="conversation has no user messages to regenerate from",
)
history_with_sys = _inject_system_prompt(history, body.thinking_mode)
logger.info(
"regenerate_message",
conversation_id=str(conv_uuid),
history_len=len(history),
thinking_mode=body.thinking_mode,
)
generator = _stream_and_persist(
request=request,
user_id=user_uuid,
conversation_id=conv_uuid,
history=history,
history=history_with_sys,
temperature=body.temperature,
max_tokens=body.max_tokens,
top_k=body.top_k,

View File

@ -1,7 +1,7 @@
'use client';
import { useEffect, useRef } from 'react';
import { ArrowUp, Square } from 'lucide-react';
import { ArrowUp, Brain, Square } from 'lucide-react';
import clsx from 'clsx';
interface Props {
@ -11,9 +11,11 @@ interface Props {
onStop?: () => void;
isStreaming?: boolean;
disabled?: boolean;
thinkingMode?: boolean;
onToggleThinking?: () => void;
}
export default function ChatInput({ value, onChange, onSubmit, onStop, isStreaming, disabled }: Props) {
export default function ChatInput({ value, onChange, onSubmit, onStop, isStreaming, disabled, thinkingMode, onToggleThinking }: Props) {
const ref = useRef<HTMLTextAreaElement>(null);
useEffect(() => {
@ -58,6 +60,27 @@ export default function ChatInput({ value, onChange, onSubmit, onStop, isStreami
className="flex-1 resize-none bg-transparent px-5 py-4 pr-2 text-[0.95rem] leading-relaxed text-gray-900 dark:text-ink-text placeholder-gray-400 dark:placeholder-ink-text-soft focus:outline-none min-h-[52px] max-h-[200px]"
/>
{/* Think toggle */}
{onToggleThinking && (
<div className="self-end p-2">
<button
type="button"
onClick={onToggleThinking}
aria-pressed={!!thinkingMode}
title={thinkingMode ? 'Reasoning mode ON — model will think step-by-step' : 'Enable reasoning mode'}
className={clsx(
'h-10 px-3 rounded-full flex items-center gap-1.5 text-xs font-medium transition-all border',
thinkingMode
? 'bg-saffron/15 dark:bg-saffron/20 border-saffron/40 dark:border-saffron/50 text-saffron dark:text-saffron-soft shadow-[0_4px_14px_rgba(255,153,51,0.15)]'
: 'bg-transparent border-cream-border dark:border-ink-border text-gray-500 dark:text-ink-text-soft hover:bg-gray-50 dark:hover:bg-ink-elev',
)}
>
<Brain size={14} />
<span>Think</span>
</button>
</div>
)}
{/* Send / stop button — vertically centered with the textarea baseline */}
<div className="self-end p-2">
{isStreaming && onStop ? (

View File

@ -31,6 +31,7 @@ export default function ChatWindow() {
const [draft, setDraft] = useState('');
const [streamingMsgId, setStreamingMsgId] = useState<string | null>(null);
const [thinkingMode, setThinkingMode] = useState(false);
const streamingBufferRef = useRef('');
const scrollRef = useRef<HTMLDivElement>(null);
@ -60,7 +61,7 @@ export default function ChatWindow() {
setIsStreaming(false);
}, []);
const streamFromApi = useCallback(async (convId: string, assistantMsgId: string, content: string, temp?: number, topk?: number) => {
const streamFromApi = useCallback(async (convId: string, assistantMsgId: string, content: string, temp?: number, topk?: number, thinking?: boolean) => {
stop();
const ac = new AbortController();
abortRef.current = ac;
@ -76,7 +77,7 @@ export default function ChatWindow() {
const res = await fetch(`/api/conversations/${convId}/messages`, {
method: 'POST',
headers,
body: JSON.stringify({ content, temperature: temp, max_tokens: 512, top_k: topk }),
body: JSON.stringify({ content, temperature: temp, max_tokens: 512, top_k: topk, thinking_mode: !!thinking }),
signal: ac.signal,
});
@ -172,7 +173,7 @@ export default function ChatWindow() {
setStreamingMsgId(assistantId);
streamingBufferRef.current = '';
await streamFromApi(convId, assistantId, text, temperature, topK);
await streamFromApi(convId, assistantId, text, temperature, topK, thinkingMode);
},
[
draft,
@ -180,13 +181,13 @@ export default function ChatWindow() {
ensureConversation,
temperature,
topK,
thinkingMode,
appendMessage,
streamFromApi,
setTemperature,
setTopK,
createConversation,
newConversation,
// streamFromApi in deps via earlier line
],
);
@ -238,6 +239,8 @@ export default function ChatWindow() {
onSubmit={() => handleSend()}
onStop={stop}
isStreaming={isStreaming}
thinkingMode={thinkingMode}
onToggleThinking={() => setThinkingMode((v) => !v)}
/>
</section>
);

View File

@ -5,11 +5,114 @@ import ReactMarkdown from 'react-markdown';
import remarkGfm from 'remark-gfm';
import rehypeHighlight from 'rehype-highlight';
import 'highlight.js/styles/github-dark.css';
import { Check, Copy } from 'lucide-react';
import { Check, ChevronDown, ChevronRight, Copy, Search, Calculator, Sparkles } from 'lucide-react';
import clsx from 'clsx';
import type { Message } from '@/types/chat';
import SteamTyping from '@/components/svg/SteamTyping';
// ---- Content parser: split into text / think / tool_call / tool_result segments ----
type Segment =
| { kind: 'text'; content: string }
| { kind: 'think'; content: string; closed: boolean }
| { kind: 'tool_call'; content: string; closed: boolean }
| { kind: 'tool_result'; content: string; closed: boolean };
function parseSegments(raw: string): Segment[] {
const segs: Segment[] = [];
let i = 0;
// marker -> [openTag, closeTag, kind]
const markers: Array<[string, string, Segment['kind']]> = [
['<think>', '</think>', 'think'],
['<|python_start|>', '<|python_end|>', 'tool_call'],
['<|output_start|>', '<|output_end|>', 'tool_result'],
];
while (i < raw.length) {
// find the nearest opening marker
let bestOpen = -1;
let bestMarker: [string, string, Segment['kind']] | null = null;
for (const m of markers) {
const p = raw.indexOf(m[0], i);
if (p !== -1 && (bestOpen === -1 || p < bestOpen)) { bestOpen = p; bestMarker = m; }
}
if (bestOpen === -1) {
if (i < raw.length) segs.push({ kind: 'text', content: raw.slice(i) });
break;
}
if (bestOpen > i) segs.push({ kind: 'text', content: raw.slice(i, bestOpen) });
const [openTag, closeTag, kind] = bestMarker!;
const afterOpen = bestOpen + openTag.length;
const closeIdx = raw.indexOf(closeTag, afterOpen);
if (closeIdx === -1) {
segs.push({ kind, content: raw.slice(afterOpen), closed: false });
i = raw.length;
} else {
segs.push({ kind, content: raw.slice(afterOpen, closeIdx), closed: true });
i = closeIdx + closeTag.length;
}
}
return segs;
}
function ThinkBlock({ content, closed }: { content: string; closed: boolean }) {
const [open, setOpen] = useState(true);
return (
<div className="my-3 rounded-lg border border-gray-200 dark:border-ink-border bg-gray-50/60 dark:bg-ink-soft/60">
<button type="button" onClick={() => setOpen(!open)} className="w-full flex items-center gap-2 px-3 py-2 text-xs uppercase tracking-wider text-gray-500 dark:text-ink-text-soft hover:bg-gray-100 dark:hover:bg-ink-elev/50">
{open ? <ChevronDown size={14} /> : <ChevronRight size={14} />}
<Sparkles size={12} />
<span>Thinking{closed ? '' : '…'}</span>
</button>
{open && (
<div className="px-4 py-3 text-sm text-gray-600 dark:text-ink-text-soft whitespace-pre-wrap italic leading-relaxed border-t border-gray-200 dark:border-ink-border">
{content}
</div>
)}
</div>
);
}
function ToolCallBlock({ content, closed }: { content: string; closed: boolean }) {
let parsed: { tool?: string; arguments?: Record<string, unknown> } | null = null;
try { parsed = JSON.parse(content); } catch { /* streaming — partial JSON */ }
const toolName = parsed?.tool ?? 'tool';
const icon = toolName === 'web_search' ? <Search size={12} /> : toolName === 'calculator' ? <Calculator size={12} /> : <Sparkles size={12} />;
const query = parsed?.arguments ? JSON.stringify(parsed.arguments) : content;
return (
<div className="my-2 rounded-lg border border-saffron/30 dark:border-saffron/40 bg-saffron/5 dark:bg-saffron/10 px-3 py-2">
<div className="flex items-center gap-2 text-xs font-medium text-saffron dark:text-saffron-soft uppercase tracking-wider">
{icon}
<span>Calling {toolName}{closed ? '' : '…'}</span>
</div>
<div className="mt-1 text-xs font-mono text-gray-600 dark:text-ink-text-soft truncate">{query}</div>
</div>
);
}
function ToolResultBlock({ content, closed }: { content: string; closed: boolean }) {
const [open, setOpen] = useState(false);
let summary = content;
try {
const j = JSON.parse(content);
if (j?.output?.results?.[0]?.snippet) summary = String(j.output.results[0].snippet).slice(0, 160);
else if (j?.output?.value !== undefined) summary = `= ${j.output.value}`;
else if (j?.error) summary = `error: ${j.error}`;
} catch { /* partial */ }
return (
<div className="my-2 rounded-lg border border-gray-200 dark:border-ink-border bg-white/60 dark:bg-ink-elev/60">
<button type="button" onClick={() => setOpen(!open)} className="w-full flex items-center justify-between gap-2 px-3 py-2 text-xs text-gray-600 dark:text-ink-text-soft hover:bg-gray-50 dark:hover:bg-ink-soft/50">
<span className="flex items-center gap-2">
{open ? <ChevronDown size={14} /> : <ChevronRight size={14} />}
<span className="uppercase tracking-wider">Result{closed ? '' : '…'}</span>
<span className="ml-2 truncate text-gray-500 dark:text-ink-text-soft normal-case">{summary}</span>
</span>
</button>
{open && (
<pre className="px-3 py-2 text-xs overflow-x-auto border-t border-gray-200 dark:border-ink-border">{content}</pre>
)}
</div>
);
}
interface Props {
message: Message;
isStreaming?: boolean;
@ -91,13 +194,21 @@ export default function MessageBubble({ message, isStreaming }: Props) {
</div>
) : (
<div className="markdown-body text-[0.95rem] text-gray-900 dark:text-ink-text leading-relaxed">
<ReactMarkdown
remarkPlugins={[remarkGfm]}
rehypePlugins={[rehypeHighlight]}
components={{ code: CodeBlock as never }}
>
{message.content}
</ReactMarkdown>
{parseSegments(message.content).map((seg, idx) => {
if (seg.kind === 'think') return <ThinkBlock key={idx} content={seg.content} closed={seg.closed} />;
if (seg.kind === 'tool_call') return <ToolCallBlock key={idx} content={seg.content} closed={seg.closed} />;
if (seg.kind === 'tool_result') return <ToolResultBlock key={idx} content={seg.content} closed={seg.closed} />;
return (
<ReactMarkdown
key={idx}
remarkPlugins={[remarkGfm]}
rehypePlugins={[rehypeHighlight]}
components={{ code: CodeBlock as never }}
>
{seg.content}
</ReactMarkdown>
);
})}
</div>
)}
</div>