mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-09 17:30:14 +00:00
feat: deploy d24-sft-r6 with full reasoning mode + live tool use (Tavily)
Model R6 (97% pass rate on 33-probe eval, val_bpb 0.2635): - modal/serve.py + modal/_tools.py: tool-aware streaming with TavilySearchBackend auto-detect, python_start/end state machine, output_start/end forcing; mount tavily secret - modal/serve.py: MODEL_TAG=d24-sft-r6, model path points at new SFT r6 - services/chat-api/routes/messages.py: accept thinking_mode flag, inject samosaChaat system prompt (direct or <think> variant) into first user message before streaming to Modal - services/frontend/components/chat/ChatInput.tsx: Brain toggle 'Think' button next to send; when active, model uses think mode - services/frontend/components/chat/ChatWindow.tsx: track thinkingMode state, pass through to API body as thinking_mode - services/frontend/components/chat/MessageBubble.tsx: parse and render <think>...</think> as collapsible italic blocks; <|python_start|>...<|python_end|> as tool-call cards with icons per tool name; <|output_start|>...<|output_end|> as result cards with expandable JSON - nanochat/tools.py: TavilySearchBackend class + auto-detect - nanochat/ui.html: legacy UI reasoning toggle (kept for parity) Tool execution verified live: query -> web_search via Tavily -> Macron returned with grounded answer.
This commit is contained in:
parent
67f568a4f2
commit
3ab89e7890
554
modal/_tools.py
Normal file
554
modal/_tools.py
Normal file
|
|
@ -0,0 +1,554 @@
|
|||
"""
|
||||
Shared tool definitions for nanochat.
|
||||
|
||||
The current tokenizer only has python/output special tokens. To preserve
|
||||
checkpoint compatibility, we reuse those tokens as a generic tool-call and
|
||||
tool-result channel. Legacy "python" calculator payloads still work.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol
|
||||
|
||||
import requests
|
||||
|
||||
TOOL_CALL_START = "<|python_start|>"
|
||||
TOOL_CALL_END = "<|python_end|>"
|
||||
TOOL_RESULT_START = "<|output_start|>"
|
||||
TOOL_RESULT_END = "<|output_end|>"
|
||||
MAX_TOOL_PAYLOAD_CHARS = 4096
|
||||
|
||||
DEFAULT_TOOL_SCHEMA = [
|
||||
{
|
||||
"name": "calculator",
|
||||
"description": "Deterministic scientific calculator for exact arithmetic and common finance formulas.",
|
||||
"arguments": {
|
||||
"expression": "String expression using numbers, operators, and supported functions.",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "web_search",
|
||||
"description": "Search and fetch web content. Requires a search backend and optionally a page fetch client.",
|
||||
"arguments": {
|
||||
"query": "Search query string.",
|
||||
"top_k": "Maximum number of results to return.",
|
||||
"urls": "Optional explicit URLs to fetch instead of searching.",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _compact_json(data: Any) -> str:
|
||||
return json.dumps(data, ensure_ascii=True, separators=(",", ":"), sort_keys=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolInvocation:
|
||||
tool_name: str
|
||||
arguments: dict[str, Any]
|
||||
raw_text: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
tool_name: str
|
||||
success: bool
|
||||
output: Any = None
|
||||
error: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_payload(self) -> str:
|
||||
return _compact_json(
|
||||
{
|
||||
"tool": self.tool_name,
|
||||
"success": self.success,
|
||||
"output": self.output,
|
||||
"error": self.error,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class BaseTool:
|
||||
name: str
|
||||
|
||||
def run(self, arguments: dict[str, Any]) -> ToolResult:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
def __init__(self, tools: list[BaseTool] | tuple[BaseTool, ...]):
|
||||
self._tools = {tool.name: tool for tool in tools}
|
||||
|
||||
def execute(self, tool_name: str, arguments: dict[str, Any]) -> ToolResult:
|
||||
tool = self._tools.get(tool_name)
|
||||
if tool is None:
|
||||
return ToolResult(tool_name=tool_name, success=False, error=f"Unknown tool: {tool_name}")
|
||||
try:
|
||||
return tool.run(arguments)
|
||||
except Exception as exc: # defensive: tool failures should become model-visible outputs
|
||||
return ToolResult(tool_name=tool_name, success=False, error=str(exc))
|
||||
|
||||
def schema(self) -> list[dict[str, Any]]:
|
||||
return [item for item in DEFAULT_TOOL_SCHEMA if item["name"] in self._tools]
|
||||
|
||||
|
||||
def serialize_tool_call(tool_name: str, arguments: dict[str, Any] | None = None) -> str:
|
||||
payload = {
|
||||
"tool": tool_name,
|
||||
"arguments": arguments or {},
|
||||
}
|
||||
text = _compact_json(payload)
|
||||
return text[:MAX_TOOL_PAYLOAD_CHARS]
|
||||
|
||||
|
||||
def serialize_tool_result(
|
||||
tool_name: str,
|
||||
output: Any = None,
|
||||
*,
|
||||
success: bool = True,
|
||||
error: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
return ToolResult(
|
||||
tool_name=tool_name,
|
||||
success=success,
|
||||
output=output,
|
||||
error=error,
|
||||
metadata=metadata or {},
|
||||
).to_payload()[:MAX_TOOL_PAYLOAD_CHARS]
|
||||
|
||||
|
||||
def parse_tool_call_payload(text: str) -> ToolInvocation:
|
||||
stripped = text.strip()
|
||||
if not stripped:
|
||||
return ToolInvocation(tool_name="calculator", arguments={"expression": ""}, raw_text=text)
|
||||
try:
|
||||
payload = json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
return ToolInvocation(tool_name="calculator", arguments={"expression": stripped}, raw_text=text)
|
||||
if isinstance(payload, dict):
|
||||
tool_name = payload.get("tool") or payload.get("tool_name") or payload.get("name")
|
||||
arguments = payload.get("arguments") or payload.get("args") or {}
|
||||
if isinstance(tool_name, str) and isinstance(arguments, dict):
|
||||
return ToolInvocation(tool_name=tool_name, arguments=arguments, raw_text=text)
|
||||
return ToolInvocation(tool_name="calculator", arguments={"expression": stripped}, raw_text=text)
|
||||
|
||||
|
||||
def parse_tool_result_payload(text: str) -> ToolResult | None:
|
||||
stripped = text.strip()
|
||||
try:
|
||||
payload = json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
tool_name = payload.get("tool")
|
||||
if not isinstance(tool_name, str):
|
||||
return None
|
||||
return ToolResult(
|
||||
tool_name=tool_name,
|
||||
success=bool(payload.get("success", True)),
|
||||
output=payload.get("output"),
|
||||
error=payload.get("error"),
|
||||
metadata=payload.get("metadata") or {},
|
||||
)
|
||||
|
||||
|
||||
def _percent(value: float, rate: float) -> float:
|
||||
return value * rate / 100.0
|
||||
|
||||
|
||||
def _percent_change(old: float, new: float) -> float:
|
||||
if old == 0:
|
||||
raise ValueError("percent_change old value cannot be zero")
|
||||
return ((new - old) / old) * 100.0
|
||||
|
||||
|
||||
def _cagr(start: float, end: float, years: float) -> float:
|
||||
if start <= 0 or end <= 0 or years <= 0:
|
||||
raise ValueError("cagr inputs must be positive")
|
||||
return ((end / start) ** (1.0 / years) - 1.0) * 100.0
|
||||
|
||||
|
||||
def _simple_interest(principal: float, annual_rate: float, years: float) -> float:
|
||||
return principal * annual_rate / 100.0 * years
|
||||
|
||||
|
||||
def _compound_interest(principal: float, annual_rate: float, periods_per_year: float, years: float) -> float:
|
||||
if periods_per_year <= 0:
|
||||
raise ValueError("periods_per_year must be positive")
|
||||
return principal * (1.0 + annual_rate / 100.0 / periods_per_year) ** (periods_per_year * years)
|
||||
|
||||
|
||||
def _emi(principal: float, annual_rate: float, months: float) -> float:
|
||||
if months <= 0:
|
||||
raise ValueError("months must be positive")
|
||||
monthly_rate = annual_rate / 100.0 / 12.0
|
||||
if monthly_rate == 0:
|
||||
return principal / months
|
||||
growth = (1.0 + monthly_rate) ** months
|
||||
return principal * monthly_rate * growth / (growth - 1.0)
|
||||
|
||||
|
||||
ALLOWED_BINOPS = {
|
||||
ast.Add: lambda a, b: a + b,
|
||||
ast.Sub: lambda a, b: a - b,
|
||||
ast.Mult: lambda a, b: a * b,
|
||||
ast.Div: lambda a, b: a / b,
|
||||
ast.Pow: lambda a, b: a ** b,
|
||||
ast.Mod: lambda a, b: a % b,
|
||||
}
|
||||
ALLOWED_UNARYOPS = {
|
||||
ast.UAdd: lambda a: a,
|
||||
ast.USub: lambda a: -a,
|
||||
}
|
||||
ALLOWED_NAMES = {
|
||||
"pi": math.pi,
|
||||
"e": math.e,
|
||||
"tau": math.tau,
|
||||
}
|
||||
ALLOWED_FUNCTIONS = {
|
||||
"abs": abs,
|
||||
"round": round,
|
||||
"floor": math.floor,
|
||||
"ceil": math.ceil,
|
||||
"sqrt": math.sqrt,
|
||||
"log": math.log,
|
||||
"log10": math.log10,
|
||||
"exp": math.exp,
|
||||
"sin": math.sin,
|
||||
"cos": math.cos,
|
||||
"tan": math.tan,
|
||||
"asin": math.asin,
|
||||
"acos": math.acos,
|
||||
"atan": math.atan,
|
||||
"degrees": math.degrees,
|
||||
"radians": math.radians,
|
||||
"percent": _percent,
|
||||
"percent_change": _percent_change,
|
||||
"cagr": _cagr,
|
||||
"simple_interest": _simple_interest,
|
||||
"compound_interest": _compound_interest,
|
||||
"emi": _emi,
|
||||
}
|
||||
|
||||
|
||||
class _SafeMathEvaluator:
|
||||
def __init__(self, expression: str):
|
||||
self.expression = expression
|
||||
self.node_count = 0
|
||||
|
||||
def eval(self) -> float:
|
||||
if len(self.expression) > 512:
|
||||
raise ValueError("expression too long")
|
||||
parsed = ast.parse(self.expression, mode="eval")
|
||||
return self._visit(parsed.body)
|
||||
|
||||
def _visit(self, node: ast.AST) -> Any:
|
||||
self.node_count += 1
|
||||
if self.node_count > 128:
|
||||
raise ValueError("expression too complex")
|
||||
|
||||
if isinstance(node, ast.Constant):
|
||||
if isinstance(node.value, (int, float)):
|
||||
return node.value
|
||||
raise ValueError(f"unsupported constant: {node.value!r}")
|
||||
if isinstance(node, ast.Num): # pragma: no cover - py<3.8 compatibility
|
||||
return node.n
|
||||
if isinstance(node, ast.BinOp):
|
||||
op = ALLOWED_BINOPS.get(type(node.op))
|
||||
if op is None:
|
||||
raise ValueError(f"unsupported operator: {type(node.op).__name__}")
|
||||
return op(self._visit(node.left), self._visit(node.right))
|
||||
if isinstance(node, ast.UnaryOp):
|
||||
op = ALLOWED_UNARYOPS.get(type(node.op))
|
||||
if op is None:
|
||||
raise ValueError(f"unsupported unary operator: {type(node.op).__name__}")
|
||||
return op(self._visit(node.operand))
|
||||
if isinstance(node, ast.Name):
|
||||
if node.id not in ALLOWED_NAMES:
|
||||
raise ValueError(f"unknown symbol: {node.id}")
|
||||
return ALLOWED_NAMES[node.id]
|
||||
if isinstance(node, ast.Call):
|
||||
if not isinstance(node.func, ast.Name):
|
||||
raise ValueError("only direct function calls are allowed")
|
||||
fn = ALLOWED_FUNCTIONS.get(node.func.id)
|
||||
if fn is None:
|
||||
raise ValueError(f"unsupported function: {node.func.id}")
|
||||
if node.keywords:
|
||||
raise ValueError("keyword arguments are not supported")
|
||||
args = [self._visit(arg) for arg in node.args]
|
||||
return fn(*args)
|
||||
raise ValueError(f"unsupported expression node: {type(node).__name__}")
|
||||
|
||||
|
||||
def _normalize_numeric_output(value: Any) -> Any:
|
||||
if isinstance(value, float):
|
||||
if not math.isfinite(value):
|
||||
raise ValueError("result is not finite")
|
||||
return float(f"{value:.12g}")
|
||||
return value
|
||||
|
||||
|
||||
class CalculatorTool(BaseTool):
|
||||
name = "calculator"
|
||||
|
||||
def run(self, arguments: dict[str, Any]) -> ToolResult:
|
||||
expression = str(arguments.get("expression", "")).strip()
|
||||
if not expression:
|
||||
return ToolResult(tool_name=self.name, success=False, error="Missing expression")
|
||||
value = _SafeMathEvaluator(expression).eval()
|
||||
return ToolResult(
|
||||
tool_name=self.name,
|
||||
success=True,
|
||||
output={"expression": expression, "value": _normalize_numeric_output(value)},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchHit:
|
||||
url: str
|
||||
title: str = ""
|
||||
snippet: str = ""
|
||||
|
||||
|
||||
class SearchBackend(Protocol):
|
||||
def search(self, query: str, top_k: int) -> list[SearchHit]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MockSearchBackend:
|
||||
def __init__(self, canned_results: dict[str, list[dict[str, str]]] | None = None):
|
||||
self.canned_results = canned_results or {
|
||||
"browser rendering markdown": [
|
||||
{
|
||||
"url": "https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/",
|
||||
"title": "Cloudflare markdown endpoint",
|
||||
"snippet": "Extract markdown from a webpage using Cloudflare Browser Rendering.",
|
||||
}
|
||||
],
|
||||
"nanochat gpt2 speedrun": [
|
||||
{
|
||||
"url": "https://github.com/karpathy/nanochat",
|
||||
"title": "karpathy/nanochat",
|
||||
"snippet": "Minimal LLM training harness with pretraining, SFT, RL, and chat UI.",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def search(self, query: str, top_k: int) -> list[SearchHit]:
|
||||
normalized = query.strip().lower()
|
||||
rows = self.canned_results.get(normalized, [])
|
||||
return [SearchHit(**row) for row in rows[:top_k]]
|
||||
|
||||
class TavilySearchBackend:
|
||||
"""LLM-optimized web search via Tavily. Falls back silently on errors."""
|
||||
def __init__(self, api_key: str | None = None, timeout: float = 15.0):
|
||||
self.api_key = api_key or os.environ.get('TAVILY_API_KEY')
|
||||
if not self.api_key:
|
||||
raise ValueError('TavilySearchBackend requires TAVILY_API_KEY')
|
||||
self.timeout = timeout
|
||||
|
||||
def search(self, query: str, top_k: int) -> list[SearchHit]:
|
||||
import requests
|
||||
try:
|
||||
r = requests.post(
|
||||
'https://api.tavily.com/search',
|
||||
json={
|
||||
'api_key': self.api_key,
|
||||
'query': query,
|
||||
'max_results': max(1, min(int(top_k), 8)),
|
||||
'include_answer': False,
|
||||
'include_raw_content': False,
|
||||
'search_depth': 'basic',
|
||||
},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
except Exception:
|
||||
return []
|
||||
return [
|
||||
SearchHit(
|
||||
url=h.get('url', ''),
|
||||
title=h.get('title', ''),
|
||||
snippet=h.get('content', ''),
|
||||
)
|
||||
for h in data.get('results', [])[:top_k]
|
||||
]
|
||||
|
||||
|
||||
|
||||
class CloudflareBrowserRenderingClient:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_token: str | None = None,
|
||||
account_id: str | None = None,
|
||||
base_url: str = "https://api.cloudflare.com/client/v4",
|
||||
timeout: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
):
|
||||
self.api_token = api_token or os.environ.get("CLOUDFLARE_API_TOKEN")
|
||||
self.account_id = account_id or os.environ.get("CLOUDFLARE_ACCOUNT_ID")
|
||||
if not self.api_token or not self.account_id:
|
||||
raise ValueError("Cloudflare Browser Rendering requires CLOUDFLARE_API_TOKEN and CLOUDFLARE_ACCOUNT_ID")
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {self.api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
)
|
||||
|
||||
def _post(self, endpoint: str, body: dict[str, Any]) -> Any:
|
||||
url = f"{self.base_url}/accounts/{self.account_id}/browser-rendering/{endpoint}"
|
||||
last_error = None
|
||||
for attempt in range(1, self.max_retries + 1):
|
||||
response = self.session.post(url, json=body, timeout=self.timeout)
|
||||
if response.status_code == 429:
|
||||
retry_after = response.headers.get("Retry-After")
|
||||
sleep_seconds = float(retry_after) if retry_after else float(attempt)
|
||||
last_error = RuntimeError(f"Cloudflare Browser Rendering rate limited on {endpoint}")
|
||||
time.sleep(min(sleep_seconds, 5.0))
|
||||
continue
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
if not payload.get("success", False):
|
||||
errors = payload.get("errors", [])
|
||||
last_error = RuntimeError(f"Cloudflare Browser Rendering request failed: {errors}")
|
||||
break
|
||||
return payload.get("result")
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise RuntimeError(f"Cloudflare Browser Rendering request failed for {endpoint}")
|
||||
|
||||
def markdown(self, url: str, **options: Any) -> str:
|
||||
body = {"url": url}
|
||||
body.update(options)
|
||||
return self._post("markdown", body)
|
||||
|
||||
def links(self, url: str, **options: Any) -> list[str]:
|
||||
body = {"url": url}
|
||||
body.update(options)
|
||||
return self._post("links", body)
|
||||
|
||||
def json_extract(self, url: str, *, prompt: str | None = None, schema: dict[str, Any] | None = None, **options: Any) -> dict[str, Any]:
|
||||
body: dict[str, Any] = {"url": url}
|
||||
if prompt is not None:
|
||||
body["prompt"] = prompt
|
||||
if schema is not None:
|
||||
body["schema"] = schema
|
||||
body.update(options)
|
||||
return self._post("json", body)
|
||||
|
||||
|
||||
class WebSearchTool(BaseTool):
|
||||
name = "web_search"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
search_backend: SearchBackend | None = None,
|
||||
fetch_client: CloudflareBrowserRenderingClient | None = None,
|
||||
max_results: int = 3,
|
||||
):
|
||||
self.search_backend = search_backend
|
||||
self.fetch_client = fetch_client
|
||||
self.max_results = max_results
|
||||
|
||||
def run(self, arguments: dict[str, Any]) -> ToolResult:
|
||||
query = str(arguments.get("query", "")).strip()
|
||||
requested_urls = arguments.get("urls") or []
|
||||
if isinstance(requested_urls, str):
|
||||
requested_urls = [requested_urls]
|
||||
top_k = int(arguments.get("top_k", self.max_results) or self.max_results)
|
||||
top_k = max(1, min(top_k, 8))
|
||||
|
||||
if not query and not requested_urls:
|
||||
return ToolResult(tool_name=self.name, success=False, error="Missing query or urls")
|
||||
|
||||
hits: list[SearchHit]
|
||||
if requested_urls:
|
||||
hits = [SearchHit(url=str(url)) for url in requested_urls[:top_k]]
|
||||
else:
|
||||
if self.search_backend is None:
|
||||
return ToolResult(
|
||||
tool_name=self.name,
|
||||
success=False,
|
||||
error="No search backend configured. Cloudflare Browser Rendering can fetch pages but does not provide public web search by itself.",
|
||||
)
|
||||
hits = self.search_backend.search(query, top_k)
|
||||
|
||||
results = []
|
||||
for hit in hits[:top_k]:
|
||||
entry: dict[str, Any] = {"url": hit.url}
|
||||
if hit.title:
|
||||
entry["title"] = hit.title
|
||||
if hit.snippet:
|
||||
entry["snippet"] = hit.snippet
|
||||
if self.fetch_client is not None:
|
||||
try:
|
||||
markdown = self.fetch_client.markdown(hit.url)
|
||||
links = self.fetch_client.links(hit.url)
|
||||
entry["markdown"] = markdown[:4000]
|
||||
entry["links"] = links[:10]
|
||||
except Exception as exc:
|
||||
entry["fetch_error"] = str(exc)
|
||||
results.append(entry)
|
||||
|
||||
return ToolResult(
|
||||
tool_name=self.name,
|
||||
success=True,
|
||||
output={"query": query, "results": results},
|
||||
metadata={
|
||||
"search_backend": type(self.search_backend).__name__ if self.search_backend is not None else None,
|
||||
"fetch_backend": type(self.fetch_client).__name__ if self.fetch_client is not None else None,
|
||||
"num_results": len(results),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def build_default_tool_registry(
|
||||
*,
|
||||
cloudflare_token: str | None = None,
|
||||
cloudflare_account_id: str | None = None,
|
||||
search_backend: SearchBackend | None = None,
|
||||
) -> ToolRegistry:
|
||||
fetch_client = None
|
||||
if cloudflare_token or os.environ.get("CLOUDFLARE_API_TOKEN"):
|
||||
try:
|
||||
fetch_client = CloudflareBrowserRenderingClient(
|
||||
api_token=cloudflare_token,
|
||||
account_id=cloudflare_account_id,
|
||||
)
|
||||
except Exception:
|
||||
fetch_client = None
|
||||
if search_backend is None:
|
||||
if os.environ.get('TAVILY_API_KEY'):
|
||||
try:
|
||||
search_backend = TavilySearchBackend()
|
||||
except Exception:
|
||||
search_backend = MockSearchBackend()
|
||||
else:
|
||||
search_backend = MockSearchBackend()
|
||||
registry = ToolRegistry(
|
||||
[
|
||||
CalculatorTool(),
|
||||
WebSearchTool(
|
||||
search_backend=search_backend,
|
||||
fetch_client=fetch_client,
|
||||
),
|
||||
]
|
||||
)
|
||||
return registry
|
||||
115
modal/serve.py
115
modal/serve.py
|
|
@ -20,14 +20,15 @@ import modal
|
|||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
MODEL_REPO = "ManmohanSharma/nanochat-d24"
|
||||
MODEL_PT = "chatsft_checkpoints/d24/model_000484.pt"
|
||||
META_JSON = "chatsft_checkpoints/d24/meta_000484.json"
|
||||
MODEL_PT = "chatsft_checkpoints/d24-sft-r6/model_000754.pt"
|
||||
META_JSON = "chatsft_checkpoints/d24-sft-r6/meta_000754.json"
|
||||
TOKENIZER_PKL = "tokenizer/tokenizer.pkl"
|
||||
TOKEN_BYTES = "tokenizer/token_bytes.pt"
|
||||
MODEL_TAG = "d24-sft"
|
||||
MODEL_TAG = "d24-sft-r6"
|
||||
GPU_TYPE = "L4" # 24 GB VRAM — fits 4 GB bf16 model loaded as fp32
|
||||
VOLUME_NAME = "samosachaat-weights"
|
||||
HF_SECRET_NAME = "huggingface" # Modal secret containing HF_TOKEN
|
||||
TAVILY_SECRET_NAME = "tavily" # Modal secret containing TAVILY_API_KEY
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Modal app + image
|
||||
|
|
@ -42,12 +43,14 @@ inference_image = (
|
|||
"tiktoken>=0.11.0",
|
||||
"tokenizers>=0.22.0",
|
||||
"huggingface_hub>=0.25.0",
|
||||
"requests>=2.31.0",
|
||||
"fastapi>=0.115.0",
|
||||
"uvicorn>=0.30.0",
|
||||
extra_index_url="https://download.pytorch.org/whl/cu124",
|
||||
)
|
||||
.add_local_file("modal/_model.py", "/root/_model.py")
|
||||
.add_local_file("modal/_tokenizer.py", "/root/_tokenizer.py")
|
||||
.add_local_file("modal/_tools.py", "/root/_tools.py")
|
||||
)
|
||||
|
||||
# Persistent volume for model weights
|
||||
|
|
@ -104,6 +107,7 @@ def download_weights():
|
|||
scaledown_window=300, # keep warm for 5 min after last request
|
||||
# concurrency handled by @modal.concurrent below
|
||||
timeout=120,
|
||||
secrets=[modal.Secret.from_name(TAVILY_SECRET_NAME)],
|
||||
)
|
||||
class Inference:
|
||||
model: object
|
||||
|
|
@ -190,8 +194,22 @@ class Inference:
|
|||
self.assistant_end_id = self.tokenizer.encode_special("<|assistant_end|>")[0]
|
||||
print(f" Special token IDs: {sorted(self.special_token_ids)}")
|
||||
|
||||
# Initialize tool registry (Tavily web_search + calculator)
|
||||
import sys as _sys
|
||||
if '/root' not in _sys.path: _sys.path.insert(0, '/root')
|
||||
from _tools import build_default_tool_registry, parse_tool_call_payload
|
||||
self.tool_registry = build_default_tool_registry()
|
||||
self._parse_tool_call = parse_tool_call_payload
|
||||
# Marker tokens for tool state machine
|
||||
self.python_start_id = self.tokenizer.encode_special("<|python_start|>")[0]
|
||||
self.python_end_id = self.tokenizer.encode_special("<|python_end|>")[0]
|
||||
self.output_start_id = self.tokenizer.encode_special("<|output_start|>")[0]
|
||||
self.output_end_id = self.tokenizer.encode_special("<|output_end|>")[0]
|
||||
# Stop tokens (exclude tool markers so generation continues through tool calls)
|
||||
self._stop_token_ids = {self.assistant_end_id, self.tokenizer.get_bos_token_id() if hasattr(self.tokenizer, "get_bos_token_id") else self.tokenizer.encode_special("<|bos|>")[0]}
|
||||
|
||||
dt = time.time() - t0
|
||||
print(f"Model loaded in {dt:.1f}s on {device}")
|
||||
print(f"Model loaded in {dt:.1f}s on {device} | tools: {[t for t in self.tool_registry._tools.keys()] if hasattr(self.tool_registry, '_tools') else 'registered'}")
|
||||
|
||||
@modal.fastapi_endpoint(method="POST", docs=True)
|
||||
async def generate(self, request: dict):
|
||||
|
|
@ -236,49 +254,74 @@ class Inference:
|
|||
tokens = tokens[-max_context:]
|
||||
|
||||
async def stream():
|
||||
from collections import deque
|
||||
input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device)
|
||||
forced = deque()
|
||||
in_tool = False
|
||||
tool_payload_ids = []
|
||||
|
||||
def _append_token(tid):
|
||||
nonlocal input_ids
|
||||
nt = torch.tensor([[tid]], dtype=torch.long, device=self.device)
|
||||
input_ids = torch.cat([input_ids, nt], dim=1)
|
||||
if input_ids.size(1) > self.config.sequence_len:
|
||||
input_ids = input_ids[:, -self.config.sequence_len:]
|
||||
|
||||
with torch.no_grad():
|
||||
generated = []
|
||||
for _ in range(max_tokens):
|
||||
# Forward pass
|
||||
logits = self.model(input_ids)
|
||||
next_logits = logits[:, -1, :]
|
||||
num_generated = 0
|
||||
while num_generated < max_tokens:
|
||||
if forced:
|
||||
token_id = forced.popleft()
|
||||
else:
|
||||
logits = self.model(input_ids)
|
||||
next_logits = logits[:, -1, :]
|
||||
if temperature > 0:
|
||||
next_logits = next_logits / temperature
|
||||
if top_k > 0:
|
||||
v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
|
||||
next_logits[next_logits < v[:, [-1]]] = float('-inf')
|
||||
probs = torch.softmax(next_logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
token_id = next_token.item()
|
||||
|
||||
# Temperature
|
||||
if temperature > 0:
|
||||
next_logits = next_logits / temperature
|
||||
# Tool state machine: detect <|python_start|>...<|python_end|>,
|
||||
# execute tool, inject <|output_start|>...<|output_end|> as forced tokens
|
||||
if token_id == self.python_start_id:
|
||||
in_tool = True
|
||||
tool_payload_ids = []
|
||||
elif token_id == self.python_end_id and in_tool:
|
||||
in_tool = False
|
||||
if tool_payload_ids:
|
||||
try:
|
||||
payload_text = self.tokenizer.decode(tool_payload_ids)
|
||||
invocation = self._parse_tool_call(payload_text)
|
||||
result = self.tool_registry.execute(invocation.tool_name, invocation.arguments)
|
||||
result_text = result.to_payload()[:4096]
|
||||
except Exception as exc:
|
||||
result_text = json.dumps({"error": str(exc)[:500]})
|
||||
if result_text:
|
||||
forced.append(self.output_start_id)
|
||||
forced.extend(self.tokenizer.encode(result_text))
|
||||
forced.append(self.output_end_id)
|
||||
tool_payload_ids = []
|
||||
elif in_tool:
|
||||
tool_payload_ids.append(token_id)
|
||||
|
||||
# Top-k filtering
|
||||
if top_k > 0:
|
||||
v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
|
||||
next_logits[next_logits < v[:, [-1]]] = float('-inf')
|
||||
|
||||
# Sample
|
||||
probs = torch.softmax(next_logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
token_id = next_token.item()
|
||||
|
||||
# Stop on any special token (assistant_end, bos, etc.)
|
||||
if token_id in self.special_token_ids:
|
||||
# Stop only on assistant_end or bos (NOT on tool markers)
|
||||
if token_id in self._stop_token_ids:
|
||||
break
|
||||
|
||||
# Decode and yield (skip tokens that can't be decoded)
|
||||
# Decode + stream to client (includes tool markers; UI renders)
|
||||
try:
|
||||
token_text = self.tokenizer.decode([token_id])
|
||||
except (KeyError, Exception):
|
||||
continue
|
||||
yield f"data: {json.dumps({'token': token_text, 'gpu': 0})}\n\n"
|
||||
yield "data: " + json.dumps({"token": token_text, "gpu": 0}) + "\n\n"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Append for next iteration
|
||||
input_ids = torch.cat([input_ids, next_token], dim=1)
|
||||
_append_token(token_id)
|
||||
num_generated += 1
|
||||
|
||||
# Truncate if exceeding sequence length
|
||||
if input_ids.size(1) > self.config.sequence_len:
|
||||
input_ids = input_ids[:, -self.config.sequence_len:]
|
||||
|
||||
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||
yield "data: " + json.dumps({"done": True}) + "\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
stream(),
|
||||
|
|
|
|||
|
|
@ -348,6 +348,43 @@ class MockSearchBackend:
|
|||
rows = self.canned_results.get(normalized, [])
|
||||
return [SearchHit(**row) for row in rows[:top_k]]
|
||||
|
||||
class TavilySearchBackend:
|
||||
"""LLM-optimized web search via Tavily. Falls back silently on errors."""
|
||||
def __init__(self, api_key: str | None = None, timeout: float = 15.0):
|
||||
self.api_key = api_key or os.environ.get('TAVILY_API_KEY')
|
||||
if not self.api_key:
|
||||
raise ValueError('TavilySearchBackend requires TAVILY_API_KEY')
|
||||
self.timeout = timeout
|
||||
|
||||
def search(self, query: str, top_k: int) -> list[SearchHit]:
|
||||
import requests
|
||||
try:
|
||||
r = requests.post(
|
||||
'https://api.tavily.com/search',
|
||||
json={
|
||||
'api_key': self.api_key,
|
||||
'query': query,
|
||||
'max_results': max(1, min(int(top_k), 8)),
|
||||
'include_answer': False,
|
||||
'include_raw_content': False,
|
||||
'search_depth': 'basic',
|
||||
},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
except Exception:
|
||||
return []
|
||||
return [
|
||||
SearchHit(
|
||||
url=h.get('url', ''),
|
||||
title=h.get('title', ''),
|
||||
snippet=h.get('content', ''),
|
||||
)
|
||||
for h in data.get('results', [])[:top_k]
|
||||
]
|
||||
|
||||
|
||||
|
||||
class CloudflareBrowserRenderingClient:
|
||||
def __init__(
|
||||
|
|
@ -497,11 +534,19 @@ def build_default_tool_registry(
|
|||
)
|
||||
except Exception:
|
||||
fetch_client = None
|
||||
if search_backend is None:
|
||||
if os.environ.get('TAVILY_API_KEY'):
|
||||
try:
|
||||
search_backend = TavilySearchBackend()
|
||||
except Exception:
|
||||
search_backend = MockSearchBackend()
|
||||
else:
|
||||
search_backend = MockSearchBackend()
|
||||
registry = ToolRegistry(
|
||||
[
|
||||
CalculatorTool(),
|
||||
WebSearchTool(
|
||||
search_backend=search_backend or MockSearchBackend(),
|
||||
search_backend=search_backend,
|
||||
fetch_client=fetch_client,
|
||||
),
|
||||
]
|
||||
|
|
|
|||
103
nanochat/ui.html
103
nanochat/ui.html
|
|
@ -459,6 +459,52 @@
|
|||
.illust-right svg.kettle-svg { width: 70px; }
|
||||
.explore-tag, .chai-label { font-size: 0.85rem; }
|
||||
}
|
||||
|
||||
/* --- Reasoning mode --- */
|
||||
.think-toggle {
|
||||
background: transparent;
|
||||
border: 1px solid rgba(255,255,255,0.15);
|
||||
color: #b8a88a;
|
||||
cursor: pointer;
|
||||
padding: 6px 10px;
|
||||
border-radius: 6px;
|
||||
margin-right: 8px;
|
||||
font-size: 0.85rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
.think-toggle:hover { background: rgba(255,255,255,0.05); }
|
||||
.think-toggle.active {
|
||||
background: rgba(184,168,138,0.15);
|
||||
border-color: #b8a88a;
|
||||
color: #fff;
|
||||
}
|
||||
.think-block {
|
||||
background: rgba(100,100,100,0.08);
|
||||
border-left: 3px solid #777;
|
||||
padding: 10px 14px;
|
||||
margin-bottom: 10px;
|
||||
font-style: italic;
|
||||
color: #999;
|
||||
font-size: 0.88em;
|
||||
white-space: pre-wrap;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.think-block::before {
|
||||
content: "\1F4AD" " thinking";
|
||||
display: block;
|
||||
font-weight: 600;
|
||||
margin-bottom: 6px;
|
||||
color: #888;
|
||||
font-style: normal;
|
||||
font-size: 0.85em;
|
||||
letter-spacing: 0.05em;
|
||||
text-transform: uppercase;
|
||||
}
|
||||
.answer-block { white-space: pre-wrap; }
|
||||
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
|
|
@ -609,6 +655,10 @@
|
|||
<div class="input-container landing-mode" id="inputContainer">
|
||||
<div class="input-wrapper">
|
||||
<textarea id="chatInput" class="chat-input" placeholder="Ask samosaChaat anything..." rows="1" onkeydown="handleKeyDown(event)"></textarea>
|
||||
<button id="thinkToggle" class="think-toggle" onclick="toggleReasoning()" title="Reasoning mode (think step-by-step)" type="button">
|
||||
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M9.5 2A2.5 2.5 0 0 1 12 4.5v15a2.5 2.5 0 0 1-4.96.44 2.5 2.5 0 0 1-2.96-3.08 3 3 0 0 1-.34-5.58 2.5 2.5 0 0 1 1.32-4.24 2.5 2.5 0 0 1 1.98-3A2.5 2.5 0 0 1 9.5 2Z"></path><path d="M14.5 2A2.5 2.5 0 0 0 12 4.5v15a2.5 2.5 0 0 0 4.96.44 2.5 2.5 0 0 0 2.96-3.08 3 3 0 0 0 .34-5.58 2.5 2.5 0 0 0-1.32-4.24 2.5 2.5 0 0 0-1.98-3A2.5 2.5 0 0 0 14.5 2Z"></path></svg>
|
||||
<span>Think</span>
|
||||
</button>
|
||||
<button id="sendButton" class="send-button" onclick="sendMessage()" disabled>
|
||||
<svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M22 2L11 13"></path><path d="M22 2l-7 20-4-9-9-4 20-7z"></path>
|
||||
|
|
@ -639,6 +689,53 @@
|
|||
let isChatMode = false;
|
||||
let currentTemperature = 0.8;
|
||||
let currentTopK = 50;
|
||||
let reasoningMode = false;
|
||||
|
||||
const SYS_DIRECT = "You are samosaChaat, a helpful AI assistant. Answer directly and concisely.";
|
||||
const SYS_THINK = "You are samosaChaat, a helpful AI assistant. Think step by step inside <think>...</think> tags, then give your final answer.";
|
||||
|
||||
function toggleReasoning() {
|
||||
reasoningMode = !reasoningMode;
|
||||
const btn = document.getElementById("thinkToggle");
|
||||
if (btn) btn.classList.toggle("active", reasoningMode);
|
||||
}
|
||||
|
||||
function buildApiMessages() {
|
||||
const out = messages.map(m => ({ role: m.role, content: m.content }));
|
||||
if (out.length && out[0].role === "user") {
|
||||
const sys = reasoningMode ? SYS_THINK : SYS_DIRECT;
|
||||
out[0].content = sys + "
|
||||
|
||||
" + out[0].content;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function renderAssistantContent(fullText, container) {
|
||||
// Parse <think>...</think> blocks and render specially
|
||||
const openIdx = fullText.indexOf("<think>");
|
||||
if (openIdx === -1) { container.textContent = fullText; return; }
|
||||
const closeIdx = fullText.indexOf("</think>", openIdx);
|
||||
container.innerHTML = "";
|
||||
if (openIdx > 0) {
|
||||
const pre = document.createElement("div");
|
||||
pre.className = "answer-block";
|
||||
pre.textContent = fullText.slice(0, openIdx);
|
||||
container.appendChild(pre);
|
||||
}
|
||||
const thinkText = closeIdx >= 0 ? fullText.slice(openIdx+7, closeIdx) : fullText.slice(openIdx+7);
|
||||
const thinkDiv = document.createElement("div");
|
||||
thinkDiv.className = "think-block";
|
||||
thinkDiv.textContent = thinkText;
|
||||
container.appendChild(thinkDiv);
|
||||
if (closeIdx >= 0) {
|
||||
const after = fullText.slice(closeIdx+8);
|
||||
const ansDiv = document.createElement("div");
|
||||
ansDiv.className = "answer-block";
|
||||
ansDiv.textContent = after;
|
||||
container.appendChild(ansDiv);
|
||||
}
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// TRANSITION: Landing → Chat
|
||||
|
|
@ -776,7 +873,7 @@
|
|||
assistantContent.textContent = '';
|
||||
for await (const token of window.samosaChaat.generateLocal(messages)) {
|
||||
fullResponse += token;
|
||||
assistantContent.textContent = fullResponse;
|
||||
renderAssistantContent(fullResponse, assistantContent);
|
||||
chatContainer.scrollTop = chatContainer.scrollHeight;
|
||||
}
|
||||
const idx = messages.length;
|
||||
|
|
@ -788,7 +885,7 @@
|
|||
const response = await fetch(`${API_URL}/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ messages, temperature: currentTemperature, top_k: currentTopK, max_tokens: 512 }),
|
||||
body: JSON.stringify({ messages: buildApiMessages(), temperature: currentTemperature, top_k: currentTopK, max_tokens: 512 }),
|
||||
});
|
||||
if (!response.ok) throw new Error(`HTTP error! status: ${response.status}`);
|
||||
const reader = response.body.getReader();
|
||||
|
|
@ -802,7 +899,7 @@
|
|||
if (line.startsWith('data: ')) {
|
||||
try {
|
||||
const data = JSON.parse(line.slice(6));
|
||||
if (data.token) { fullResponse += data.token; assistantContent.textContent = fullResponse; chatContainer.scrollTop = chatContainer.scrollHeight; }
|
||||
if (data.token) { fullResponse += data.token; renderAssistantContent(fullResponse, assistantContent); chatContainer.scrollTop = chatContainer.scrollHeight; }
|
||||
} catch (e) {}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,12 +27,44 @@ class SendMessageRequest(BaseModel):
|
|||
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
|
||||
max_tokens: int | None = Field(default=None, ge=1, le=4096)
|
||||
top_k: int | None = Field(default=None, ge=0, le=200)
|
||||
thinking_mode: bool = Field(default=False)
|
||||
|
||||
|
||||
class RegenerateRequest(BaseModel):
|
||||
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
|
||||
max_tokens: int | None = Field(default=None, ge=1, le=4096)
|
||||
top_k: int | None = Field(default=None, ge=0, le=200)
|
||||
thinking_mode: bool = Field(default=False)
|
||||
|
||||
|
||||
# System prompts: tools are always implicitly available via the model's SFT training.
|
||||
# The toggle only affects whether the model is nudged into <think>...</think> mode.
|
||||
_SYS_DEFAULT = (
|
||||
"You are samosaChaat, a helpful AI assistant created by Manmohan Sharma. "
|
||||
"You have access to tools: use web_search for facts that may change over time or "
|
||||
"require current information, and use calculator for arithmetic. Otherwise answer directly and concisely."
|
||||
)
|
||||
_SYS_THINK = (
|
||||
"You are samosaChaat, a helpful AI assistant created by Manmohan Sharma. "
|
||||
"You have access to tools: use web_search for facts that may change over time or "
|
||||
"require current information, and use calculator for arithmetic. "
|
||||
"Think step by step inside <think>...</think> tags, then give your final answer after the closing tag."
|
||||
)
|
||||
|
||||
|
||||
def _inject_system_prompt(history: list[dict[str, str]], thinking_mode: bool) -> list[dict[str, str]]:
|
||||
"""Merge a system prompt into the first user message. Upstream Modal serve
|
||||
ignores role='system', so we prepend the system prompt inline to the first
|
||||
user turn — mirroring nanochat's tokenizer convention."""
|
||||
if not history:
|
||||
return history
|
||||
sys_prompt = _SYS_THINK if thinking_mode else _SYS_DEFAULT
|
||||
out = [dict(m) for m in history]
|
||||
for m in out:
|
||||
if m.get("role") == "user":
|
||||
m["content"] = sys_prompt + "\n\n" + m.get("content", "")
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def _parse_uuid(raw: str) -> uuid.UUID:
|
||||
|
|
@ -182,18 +214,23 @@ async def send_message(
|
|||
db_session, conversation_id=conv_uuid
|
||||
)
|
||||
|
||||
# Inject system prompt (direct or think mode) into the first user message,
|
||||
# since upstream Modal serve ignores role='system'.
|
||||
history_with_sys = _inject_system_prompt(history, body.thinking_mode)
|
||||
|
||||
logger.info(
|
||||
"send_message",
|
||||
conversation_id=str(conv_uuid),
|
||||
history_len=len(history),
|
||||
model_tag=model_tag,
|
||||
thinking_mode=body.thinking_mode,
|
||||
)
|
||||
|
||||
generator = _stream_and_persist(
|
||||
request=request,
|
||||
user_id=user_uuid,
|
||||
conversation_id=conv_uuid,
|
||||
history=history,
|
||||
history=history_with_sys,
|
||||
temperature=body.temperature,
|
||||
max_tokens=body.max_tokens,
|
||||
top_k=body.top_k,
|
||||
|
|
@ -235,17 +272,20 @@ async def regenerate(
|
|||
detail="conversation has no user messages to regenerate from",
|
||||
)
|
||||
|
||||
history_with_sys = _inject_system_prompt(history, body.thinking_mode)
|
||||
|
||||
logger.info(
|
||||
"regenerate_message",
|
||||
conversation_id=str(conv_uuid),
|
||||
history_len=len(history),
|
||||
thinking_mode=body.thinking_mode,
|
||||
)
|
||||
|
||||
generator = _stream_and_persist(
|
||||
request=request,
|
||||
user_id=user_uuid,
|
||||
conversation_id=conv_uuid,
|
||||
history=history,
|
||||
history=history_with_sys,
|
||||
temperature=body.temperature,
|
||||
max_tokens=body.max_tokens,
|
||||
top_k=body.top_k,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
'use client';
|
||||
|
||||
import { useEffect, useRef } from 'react';
|
||||
import { ArrowUp, Square } from 'lucide-react';
|
||||
import { ArrowUp, Brain, Square } from 'lucide-react';
|
||||
import clsx from 'clsx';
|
||||
|
||||
interface Props {
|
||||
|
|
@ -11,9 +11,11 @@ interface Props {
|
|||
onStop?: () => void;
|
||||
isStreaming?: boolean;
|
||||
disabled?: boolean;
|
||||
thinkingMode?: boolean;
|
||||
onToggleThinking?: () => void;
|
||||
}
|
||||
|
||||
export default function ChatInput({ value, onChange, onSubmit, onStop, isStreaming, disabled }: Props) {
|
||||
export default function ChatInput({ value, onChange, onSubmit, onStop, isStreaming, disabled, thinkingMode, onToggleThinking }: Props) {
|
||||
const ref = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
|
|
@ -58,6 +60,27 @@ export default function ChatInput({ value, onChange, onSubmit, onStop, isStreami
|
|||
className="flex-1 resize-none bg-transparent px-5 py-4 pr-2 text-[0.95rem] leading-relaxed text-gray-900 dark:text-ink-text placeholder-gray-400 dark:placeholder-ink-text-soft focus:outline-none min-h-[52px] max-h-[200px]"
|
||||
/>
|
||||
|
||||
{/* Think toggle */}
|
||||
{onToggleThinking && (
|
||||
<div className="self-end p-2">
|
||||
<button
|
||||
type="button"
|
||||
onClick={onToggleThinking}
|
||||
aria-pressed={!!thinkingMode}
|
||||
title={thinkingMode ? 'Reasoning mode ON — model will think step-by-step' : 'Enable reasoning mode'}
|
||||
className={clsx(
|
||||
'h-10 px-3 rounded-full flex items-center gap-1.5 text-xs font-medium transition-all border',
|
||||
thinkingMode
|
||||
? 'bg-saffron/15 dark:bg-saffron/20 border-saffron/40 dark:border-saffron/50 text-saffron dark:text-saffron-soft shadow-[0_4px_14px_rgba(255,153,51,0.15)]'
|
||||
: 'bg-transparent border-cream-border dark:border-ink-border text-gray-500 dark:text-ink-text-soft hover:bg-gray-50 dark:hover:bg-ink-elev',
|
||||
)}
|
||||
>
|
||||
<Brain size={14} />
|
||||
<span>Think</span>
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Send / stop button — vertically centered with the textarea baseline */}
|
||||
<div className="self-end p-2">
|
||||
{isStreaming && onStop ? (
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ export default function ChatWindow() {
|
|||
|
||||
const [draft, setDraft] = useState('');
|
||||
const [streamingMsgId, setStreamingMsgId] = useState<string | null>(null);
|
||||
const [thinkingMode, setThinkingMode] = useState(false);
|
||||
const streamingBufferRef = useRef('');
|
||||
const scrollRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
|
|
@ -60,7 +61,7 @@ export default function ChatWindow() {
|
|||
setIsStreaming(false);
|
||||
}, []);
|
||||
|
||||
const streamFromApi = useCallback(async (convId: string, assistantMsgId: string, content: string, temp?: number, topk?: number) => {
|
||||
const streamFromApi = useCallback(async (convId: string, assistantMsgId: string, content: string, temp?: number, topk?: number, thinking?: boolean) => {
|
||||
stop();
|
||||
const ac = new AbortController();
|
||||
abortRef.current = ac;
|
||||
|
|
@ -76,7 +77,7 @@ export default function ChatWindow() {
|
|||
const res = await fetch(`/api/conversations/${convId}/messages`, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({ content, temperature: temp, max_tokens: 512, top_k: topk }),
|
||||
body: JSON.stringify({ content, temperature: temp, max_tokens: 512, top_k: topk, thinking_mode: !!thinking }),
|
||||
signal: ac.signal,
|
||||
});
|
||||
|
||||
|
|
@ -172,7 +173,7 @@ export default function ChatWindow() {
|
|||
setStreamingMsgId(assistantId);
|
||||
streamingBufferRef.current = '';
|
||||
|
||||
await streamFromApi(convId, assistantId, text, temperature, topK);
|
||||
await streamFromApi(convId, assistantId, text, temperature, topK, thinkingMode);
|
||||
},
|
||||
[
|
||||
draft,
|
||||
|
|
@ -180,13 +181,13 @@ export default function ChatWindow() {
|
|||
ensureConversation,
|
||||
temperature,
|
||||
topK,
|
||||
thinkingMode,
|
||||
appendMessage,
|
||||
streamFromApi,
|
||||
setTemperature,
|
||||
setTopK,
|
||||
createConversation,
|
||||
newConversation,
|
||||
// streamFromApi in deps via earlier line
|
||||
],
|
||||
);
|
||||
|
||||
|
|
@ -238,6 +239,8 @@ export default function ChatWindow() {
|
|||
onSubmit={() => handleSend()}
|
||||
onStop={stop}
|
||||
isStreaming={isStreaming}
|
||||
thinkingMode={thinkingMode}
|
||||
onToggleThinking={() => setThinkingMode((v) => !v)}
|
||||
/>
|
||||
</section>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -5,11 +5,114 @@ import ReactMarkdown from 'react-markdown';
|
|||
import remarkGfm from 'remark-gfm';
|
||||
import rehypeHighlight from 'rehype-highlight';
|
||||
import 'highlight.js/styles/github-dark.css';
|
||||
import { Check, Copy } from 'lucide-react';
|
||||
import { Check, ChevronDown, ChevronRight, Copy, Search, Calculator, Sparkles } from 'lucide-react';
|
||||
import clsx from 'clsx';
|
||||
import type { Message } from '@/types/chat';
|
||||
import SteamTyping from '@/components/svg/SteamTyping';
|
||||
|
||||
// ---- Content parser: split into text / think / tool_call / tool_result segments ----
|
||||
type Segment =
|
||||
| { kind: 'text'; content: string }
|
||||
| { kind: 'think'; content: string; closed: boolean }
|
||||
| { kind: 'tool_call'; content: string; closed: boolean }
|
||||
| { kind: 'tool_result'; content: string; closed: boolean };
|
||||
|
||||
function parseSegments(raw: string): Segment[] {
|
||||
const segs: Segment[] = [];
|
||||
let i = 0;
|
||||
// marker -> [openTag, closeTag, kind]
|
||||
const markers: Array<[string, string, Segment['kind']]> = [
|
||||
['<think>', '</think>', 'think'],
|
||||
['<|python_start|>', '<|python_end|>', 'tool_call'],
|
||||
['<|output_start|>', '<|output_end|>', 'tool_result'],
|
||||
];
|
||||
while (i < raw.length) {
|
||||
// find the nearest opening marker
|
||||
let bestOpen = -1;
|
||||
let bestMarker: [string, string, Segment['kind']] | null = null;
|
||||
for (const m of markers) {
|
||||
const p = raw.indexOf(m[0], i);
|
||||
if (p !== -1 && (bestOpen === -1 || p < bestOpen)) { bestOpen = p; bestMarker = m; }
|
||||
}
|
||||
if (bestOpen === -1) {
|
||||
if (i < raw.length) segs.push({ kind: 'text', content: raw.slice(i) });
|
||||
break;
|
||||
}
|
||||
if (bestOpen > i) segs.push({ kind: 'text', content: raw.slice(i, bestOpen) });
|
||||
const [openTag, closeTag, kind] = bestMarker!;
|
||||
const afterOpen = bestOpen + openTag.length;
|
||||
const closeIdx = raw.indexOf(closeTag, afterOpen);
|
||||
if (closeIdx === -1) {
|
||||
segs.push({ kind, content: raw.slice(afterOpen), closed: false });
|
||||
i = raw.length;
|
||||
} else {
|
||||
segs.push({ kind, content: raw.slice(afterOpen, closeIdx), closed: true });
|
||||
i = closeIdx + closeTag.length;
|
||||
}
|
||||
}
|
||||
return segs;
|
||||
}
|
||||
|
||||
function ThinkBlock({ content, closed }: { content: string; closed: boolean }) {
|
||||
const [open, setOpen] = useState(true);
|
||||
return (
|
||||
<div className="my-3 rounded-lg border border-gray-200 dark:border-ink-border bg-gray-50/60 dark:bg-ink-soft/60">
|
||||
<button type="button" onClick={() => setOpen(!open)} className="w-full flex items-center gap-2 px-3 py-2 text-xs uppercase tracking-wider text-gray-500 dark:text-ink-text-soft hover:bg-gray-100 dark:hover:bg-ink-elev/50">
|
||||
{open ? <ChevronDown size={14} /> : <ChevronRight size={14} />}
|
||||
<Sparkles size={12} />
|
||||
<span>Thinking{closed ? '' : '…'}</span>
|
||||
</button>
|
||||
{open && (
|
||||
<div className="px-4 py-3 text-sm text-gray-600 dark:text-ink-text-soft whitespace-pre-wrap italic leading-relaxed border-t border-gray-200 dark:border-ink-border">
|
||||
{content}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ToolCallBlock({ content, closed }: { content: string; closed: boolean }) {
|
||||
let parsed: { tool?: string; arguments?: Record<string, unknown> } | null = null;
|
||||
try { parsed = JSON.parse(content); } catch { /* streaming — partial JSON */ }
|
||||
const toolName = parsed?.tool ?? 'tool';
|
||||
const icon = toolName === 'web_search' ? <Search size={12} /> : toolName === 'calculator' ? <Calculator size={12} /> : <Sparkles size={12} />;
|
||||
const query = parsed?.arguments ? JSON.stringify(parsed.arguments) : content;
|
||||
return (
|
||||
<div className="my-2 rounded-lg border border-saffron/30 dark:border-saffron/40 bg-saffron/5 dark:bg-saffron/10 px-3 py-2">
|
||||
<div className="flex items-center gap-2 text-xs font-medium text-saffron dark:text-saffron-soft uppercase tracking-wider">
|
||||
{icon}
|
||||
<span>Calling {toolName}{closed ? '' : '…'}</span>
|
||||
</div>
|
||||
<div className="mt-1 text-xs font-mono text-gray-600 dark:text-ink-text-soft truncate">{query}</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ToolResultBlock({ content, closed }: { content: string; closed: boolean }) {
|
||||
const [open, setOpen] = useState(false);
|
||||
let summary = content;
|
||||
try {
|
||||
const j = JSON.parse(content);
|
||||
if (j?.output?.results?.[0]?.snippet) summary = String(j.output.results[0].snippet).slice(0, 160);
|
||||
else if (j?.output?.value !== undefined) summary = `= ${j.output.value}`;
|
||||
else if (j?.error) summary = `error: ${j.error}`;
|
||||
} catch { /* partial */ }
|
||||
return (
|
||||
<div className="my-2 rounded-lg border border-gray-200 dark:border-ink-border bg-white/60 dark:bg-ink-elev/60">
|
||||
<button type="button" onClick={() => setOpen(!open)} className="w-full flex items-center justify-between gap-2 px-3 py-2 text-xs text-gray-600 dark:text-ink-text-soft hover:bg-gray-50 dark:hover:bg-ink-soft/50">
|
||||
<span className="flex items-center gap-2">
|
||||
{open ? <ChevronDown size={14} /> : <ChevronRight size={14} />}
|
||||
<span className="uppercase tracking-wider">Result{closed ? '' : '…'}</span>
|
||||
<span className="ml-2 truncate text-gray-500 dark:text-ink-text-soft normal-case">{summary}</span>
|
||||
</span>
|
||||
</button>
|
||||
{open && (
|
||||
<pre className="px-3 py-2 text-xs overflow-x-auto border-t border-gray-200 dark:border-ink-border">{content}</pre>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface Props {
|
||||
message: Message;
|
||||
isStreaming?: boolean;
|
||||
|
|
@ -91,13 +194,21 @@ export default function MessageBubble({ message, isStreaming }: Props) {
|
|||
</div>
|
||||
) : (
|
||||
<div className="markdown-body text-[0.95rem] text-gray-900 dark:text-ink-text leading-relaxed">
|
||||
<ReactMarkdown
|
||||
remarkPlugins={[remarkGfm]}
|
||||
rehypePlugins={[rehypeHighlight]}
|
||||
components={{ code: CodeBlock as never }}
|
||||
>
|
||||
{message.content}
|
||||
</ReactMarkdown>
|
||||
{parseSegments(message.content).map((seg, idx) => {
|
||||
if (seg.kind === 'think') return <ThinkBlock key={idx} content={seg.content} closed={seg.closed} />;
|
||||
if (seg.kind === 'tool_call') return <ToolCallBlock key={idx} content={seg.content} closed={seg.closed} />;
|
||||
if (seg.kind === 'tool_result') return <ToolResultBlock key={idx} content={seg.content} closed={seg.closed} />;
|
||||
return (
|
||||
<ReactMarkdown
|
||||
key={idx}
|
||||
remarkPlugins={[remarkGfm]}
|
||||
rehypePlugins={[rehypeHighlight]}
|
||||
components={{ code: CodeBlock as never }}
|
||||
>
|
||||
{seg.content}
|
||||
</ReactMarkdown>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user