From 3ab89e78909d25542fe7696f99ad76bfdcb8e8a9 Mon Sep 17 00:00:00 2001 From: Manmohan <66306483+manmohan659@users.noreply.github.com> Date: Wed, 22 Apr 2026 20:38:21 +0000 Subject: [PATCH] feat: deploy d24-sft-r6 with full reasoning mode + live tool use (Tavily) Model R6 (97% pass rate on 33-probe eval, val_bpb 0.2635): - modal/serve.py + modal/_tools.py: tool-aware streaming with TavilySearchBackend auto-detect, python_start/end state machine, output_start/end forcing; mount tavily secret - modal/serve.py: MODEL_TAG=d24-sft-r6, model path points at new SFT r6 - services/chat-api/routes/messages.py: accept thinking_mode flag, inject samosaChaat system prompt (direct or variant) into first user message before streaming to Modal - services/frontend/components/chat/ChatInput.tsx: Brain toggle 'Think' button next to send; when active, model uses think mode - services/frontend/components/chat/ChatWindow.tsx: track thinkingMode state, pass through to API body as thinking_mode - services/frontend/components/chat/MessageBubble.tsx: parse and render ... as collapsible italic blocks; <|python_start|>...<|python_end|> as tool-call cards with icons per tool name; <|output_start|>...<|output_end|> as result cards with expandable JSON - nanochat/tools.py: TavilySearchBackend class + auto-detect - nanochat/ui.html: legacy UI reasoning toggle (kept for parity) Tool execution verified live: query -> web_search via Tavily -> Macron returned with grounded answer. --- modal/_tools.py | 554 ++++++++++++++++++ modal/serve.py | 115 ++-- nanochat/tools.py | 47 +- nanochat/ui.html | 103 +++- services/chat-api/src/routes/messages.py | 44 +- .../frontend/components/chat/ChatInput.tsx | 27 +- .../frontend/components/chat/ChatWindow.tsx | 11 +- .../components/chat/MessageBubble.tsx | 127 +++- 8 files changed, 972 insertions(+), 56 deletions(-) create mode 100644 modal/_tools.py diff --git a/modal/_tools.py b/modal/_tools.py new file mode 100644 index 00000000..9d3a82ef --- /dev/null +++ b/modal/_tools.py @@ -0,0 +1,554 @@ +""" +Shared tool definitions for nanochat. + +The current tokenizer only has python/output special tokens. To preserve +checkpoint compatibility, we reuse those tokens as a generic tool-call and +tool-result channel. Legacy "python" calculator payloads still work. +""" + +from __future__ import annotations + +import ast +import json +import math +import os +import time +from dataclasses import dataclass, field +from typing import Any, Protocol + +import requests + +TOOL_CALL_START = "<|python_start|>" +TOOL_CALL_END = "<|python_end|>" +TOOL_RESULT_START = "<|output_start|>" +TOOL_RESULT_END = "<|output_end|>" +MAX_TOOL_PAYLOAD_CHARS = 4096 + +DEFAULT_TOOL_SCHEMA = [ + { + "name": "calculator", + "description": "Deterministic scientific calculator for exact arithmetic and common finance formulas.", + "arguments": { + "expression": "String expression using numbers, operators, and supported functions.", + }, + }, + { + "name": "web_search", + "description": "Search and fetch web content. Requires a search backend and optionally a page fetch client.", + "arguments": { + "query": "Search query string.", + "top_k": "Maximum number of results to return.", + "urls": "Optional explicit URLs to fetch instead of searching.", + }, + }, +] + + +def _compact_json(data: Any) -> str: + return json.dumps(data, ensure_ascii=True, separators=(",", ":"), sort_keys=True) + + +@dataclass +class ToolInvocation: + tool_name: str + arguments: dict[str, Any] + raw_text: str = "" + + +@dataclass +class ToolResult: + tool_name: str + success: bool + output: Any = None + error: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def to_payload(self) -> str: + return _compact_json( + { + "tool": self.tool_name, + "success": self.success, + "output": self.output, + "error": self.error, + "metadata": self.metadata, + } + ) + + +class BaseTool: + name: str + + def run(self, arguments: dict[str, Any]) -> ToolResult: + raise NotImplementedError + + +class ToolRegistry: + def __init__(self, tools: list[BaseTool] | tuple[BaseTool, ...]): + self._tools = {tool.name: tool for tool in tools} + + def execute(self, tool_name: str, arguments: dict[str, Any]) -> ToolResult: + tool = self._tools.get(tool_name) + if tool is None: + return ToolResult(tool_name=tool_name, success=False, error=f"Unknown tool: {tool_name}") + try: + return tool.run(arguments) + except Exception as exc: # defensive: tool failures should become model-visible outputs + return ToolResult(tool_name=tool_name, success=False, error=str(exc)) + + def schema(self) -> list[dict[str, Any]]: + return [item for item in DEFAULT_TOOL_SCHEMA if item["name"] in self._tools] + + +def serialize_tool_call(tool_name: str, arguments: dict[str, Any] | None = None) -> str: + payload = { + "tool": tool_name, + "arguments": arguments or {}, + } + text = _compact_json(payload) + return text[:MAX_TOOL_PAYLOAD_CHARS] + + +def serialize_tool_result( + tool_name: str, + output: Any = None, + *, + success: bool = True, + error: str | None = None, + metadata: dict[str, Any] | None = None, +) -> str: + return ToolResult( + tool_name=tool_name, + success=success, + output=output, + error=error, + metadata=metadata or {}, + ).to_payload()[:MAX_TOOL_PAYLOAD_CHARS] + + +def parse_tool_call_payload(text: str) -> ToolInvocation: + stripped = text.strip() + if not stripped: + return ToolInvocation(tool_name="calculator", arguments={"expression": ""}, raw_text=text) + try: + payload = json.loads(stripped) + except json.JSONDecodeError: + return ToolInvocation(tool_name="calculator", arguments={"expression": stripped}, raw_text=text) + if isinstance(payload, dict): + tool_name = payload.get("tool") or payload.get("tool_name") or payload.get("name") + arguments = payload.get("arguments") or payload.get("args") or {} + if isinstance(tool_name, str) and isinstance(arguments, dict): + return ToolInvocation(tool_name=tool_name, arguments=arguments, raw_text=text) + return ToolInvocation(tool_name="calculator", arguments={"expression": stripped}, raw_text=text) + + +def parse_tool_result_payload(text: str) -> ToolResult | None: + stripped = text.strip() + try: + payload = json.loads(stripped) + except json.JSONDecodeError: + return None + if not isinstance(payload, dict): + return None + tool_name = payload.get("tool") + if not isinstance(tool_name, str): + return None + return ToolResult( + tool_name=tool_name, + success=bool(payload.get("success", True)), + output=payload.get("output"), + error=payload.get("error"), + metadata=payload.get("metadata") or {}, + ) + + +def _percent(value: float, rate: float) -> float: + return value * rate / 100.0 + + +def _percent_change(old: float, new: float) -> float: + if old == 0: + raise ValueError("percent_change old value cannot be zero") + return ((new - old) / old) * 100.0 + + +def _cagr(start: float, end: float, years: float) -> float: + if start <= 0 or end <= 0 or years <= 0: + raise ValueError("cagr inputs must be positive") + return ((end / start) ** (1.0 / years) - 1.0) * 100.0 + + +def _simple_interest(principal: float, annual_rate: float, years: float) -> float: + return principal * annual_rate / 100.0 * years + + +def _compound_interest(principal: float, annual_rate: float, periods_per_year: float, years: float) -> float: + if periods_per_year <= 0: + raise ValueError("periods_per_year must be positive") + return principal * (1.0 + annual_rate / 100.0 / periods_per_year) ** (periods_per_year * years) + + +def _emi(principal: float, annual_rate: float, months: float) -> float: + if months <= 0: + raise ValueError("months must be positive") + monthly_rate = annual_rate / 100.0 / 12.0 + if monthly_rate == 0: + return principal / months + growth = (1.0 + monthly_rate) ** months + return principal * monthly_rate * growth / (growth - 1.0) + + +ALLOWED_BINOPS = { + ast.Add: lambda a, b: a + b, + ast.Sub: lambda a, b: a - b, + ast.Mult: lambda a, b: a * b, + ast.Div: lambda a, b: a / b, + ast.Pow: lambda a, b: a ** b, + ast.Mod: lambda a, b: a % b, +} +ALLOWED_UNARYOPS = { + ast.UAdd: lambda a: a, + ast.USub: lambda a: -a, +} +ALLOWED_NAMES = { + "pi": math.pi, + "e": math.e, + "tau": math.tau, +} +ALLOWED_FUNCTIONS = { + "abs": abs, + "round": round, + "floor": math.floor, + "ceil": math.ceil, + "sqrt": math.sqrt, + "log": math.log, + "log10": math.log10, + "exp": math.exp, + "sin": math.sin, + "cos": math.cos, + "tan": math.tan, + "asin": math.asin, + "acos": math.acos, + "atan": math.atan, + "degrees": math.degrees, + "radians": math.radians, + "percent": _percent, + "percent_change": _percent_change, + "cagr": _cagr, + "simple_interest": _simple_interest, + "compound_interest": _compound_interest, + "emi": _emi, +} + + +class _SafeMathEvaluator: + def __init__(self, expression: str): + self.expression = expression + self.node_count = 0 + + def eval(self) -> float: + if len(self.expression) > 512: + raise ValueError("expression too long") + parsed = ast.parse(self.expression, mode="eval") + return self._visit(parsed.body) + + def _visit(self, node: ast.AST) -> Any: + self.node_count += 1 + if self.node_count > 128: + raise ValueError("expression too complex") + + if isinstance(node, ast.Constant): + if isinstance(node.value, (int, float)): + return node.value + raise ValueError(f"unsupported constant: {node.value!r}") + if isinstance(node, ast.Num): # pragma: no cover - py<3.8 compatibility + return node.n + if isinstance(node, ast.BinOp): + op = ALLOWED_BINOPS.get(type(node.op)) + if op is None: + raise ValueError(f"unsupported operator: {type(node.op).__name__}") + return op(self._visit(node.left), self._visit(node.right)) + if isinstance(node, ast.UnaryOp): + op = ALLOWED_UNARYOPS.get(type(node.op)) + if op is None: + raise ValueError(f"unsupported unary operator: {type(node.op).__name__}") + return op(self._visit(node.operand)) + if isinstance(node, ast.Name): + if node.id not in ALLOWED_NAMES: + raise ValueError(f"unknown symbol: {node.id}") + return ALLOWED_NAMES[node.id] + if isinstance(node, ast.Call): + if not isinstance(node.func, ast.Name): + raise ValueError("only direct function calls are allowed") + fn = ALLOWED_FUNCTIONS.get(node.func.id) + if fn is None: + raise ValueError(f"unsupported function: {node.func.id}") + if node.keywords: + raise ValueError("keyword arguments are not supported") + args = [self._visit(arg) for arg in node.args] + return fn(*args) + raise ValueError(f"unsupported expression node: {type(node).__name__}") + + +def _normalize_numeric_output(value: Any) -> Any: + if isinstance(value, float): + if not math.isfinite(value): + raise ValueError("result is not finite") + return float(f"{value:.12g}") + return value + + +class CalculatorTool(BaseTool): + name = "calculator" + + def run(self, arguments: dict[str, Any]) -> ToolResult: + expression = str(arguments.get("expression", "")).strip() + if not expression: + return ToolResult(tool_name=self.name, success=False, error="Missing expression") + value = _SafeMathEvaluator(expression).eval() + return ToolResult( + tool_name=self.name, + success=True, + output={"expression": expression, "value": _normalize_numeric_output(value)}, + ) + + +@dataclass +class SearchHit: + url: str + title: str = "" + snippet: str = "" + + +class SearchBackend(Protocol): + def search(self, query: str, top_k: int) -> list[SearchHit]: + raise NotImplementedError + + +class MockSearchBackend: + def __init__(self, canned_results: dict[str, list[dict[str, str]]] | None = None): + self.canned_results = canned_results or { + "browser rendering markdown": [ + { + "url": "https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/", + "title": "Cloudflare markdown endpoint", + "snippet": "Extract markdown from a webpage using Cloudflare Browser Rendering.", + } + ], + "nanochat gpt2 speedrun": [ + { + "url": "https://github.com/karpathy/nanochat", + "title": "karpathy/nanochat", + "snippet": "Minimal LLM training harness with pretraining, SFT, RL, and chat UI.", + } + ], + } + + def search(self, query: str, top_k: int) -> list[SearchHit]: + normalized = query.strip().lower() + rows = self.canned_results.get(normalized, []) + return [SearchHit(**row) for row in rows[:top_k]] + +class TavilySearchBackend: + """LLM-optimized web search via Tavily. Falls back silently on errors.""" + def __init__(self, api_key: str | None = None, timeout: float = 15.0): + self.api_key = api_key or os.environ.get('TAVILY_API_KEY') + if not self.api_key: + raise ValueError('TavilySearchBackend requires TAVILY_API_KEY') + self.timeout = timeout + + def search(self, query: str, top_k: int) -> list[SearchHit]: + import requests + try: + r = requests.post( + 'https://api.tavily.com/search', + json={ + 'api_key': self.api_key, + 'query': query, + 'max_results': max(1, min(int(top_k), 8)), + 'include_answer': False, + 'include_raw_content': False, + 'search_depth': 'basic', + }, + timeout=self.timeout, + ) + r.raise_for_status() + data = r.json() + except Exception: + return [] + return [ + SearchHit( + url=h.get('url', ''), + title=h.get('title', ''), + snippet=h.get('content', ''), + ) + for h in data.get('results', [])[:top_k] + ] + + + +class CloudflareBrowserRenderingClient: + def __init__( + self, + *, + api_token: str | None = None, + account_id: str | None = None, + base_url: str = "https://api.cloudflare.com/client/v4", + timeout: float = 30.0, + max_retries: int = 3, + ): + self.api_token = api_token or os.environ.get("CLOUDFLARE_API_TOKEN") + self.account_id = account_id or os.environ.get("CLOUDFLARE_ACCOUNT_ID") + if not self.api_token or not self.account_id: + raise ValueError("Cloudflare Browser Rendering requires CLOUDFLARE_API_TOKEN and CLOUDFLARE_ACCOUNT_ID") + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self.max_retries = max_retries + self.session = requests.Session() + self.session.headers.update( + { + "Authorization": f"Bearer {self.api_token}", + "Content-Type": "application/json", + } + ) + + def _post(self, endpoint: str, body: dict[str, Any]) -> Any: + url = f"{self.base_url}/accounts/{self.account_id}/browser-rendering/{endpoint}" + last_error = None + for attempt in range(1, self.max_retries + 1): + response = self.session.post(url, json=body, timeout=self.timeout) + if response.status_code == 429: + retry_after = response.headers.get("Retry-After") + sleep_seconds = float(retry_after) if retry_after else float(attempt) + last_error = RuntimeError(f"Cloudflare Browser Rendering rate limited on {endpoint}") + time.sleep(min(sleep_seconds, 5.0)) + continue + response.raise_for_status() + payload = response.json() + if not payload.get("success", False): + errors = payload.get("errors", []) + last_error = RuntimeError(f"Cloudflare Browser Rendering request failed: {errors}") + break + return payload.get("result") + if last_error is not None: + raise last_error + raise RuntimeError(f"Cloudflare Browser Rendering request failed for {endpoint}") + + def markdown(self, url: str, **options: Any) -> str: + body = {"url": url} + body.update(options) + return self._post("markdown", body) + + def links(self, url: str, **options: Any) -> list[str]: + body = {"url": url} + body.update(options) + return self._post("links", body) + + def json_extract(self, url: str, *, prompt: str | None = None, schema: dict[str, Any] | None = None, **options: Any) -> dict[str, Any]: + body: dict[str, Any] = {"url": url} + if prompt is not None: + body["prompt"] = prompt + if schema is not None: + body["schema"] = schema + body.update(options) + return self._post("json", body) + + +class WebSearchTool(BaseTool): + name = "web_search" + + def __init__( + self, + *, + search_backend: SearchBackend | None = None, + fetch_client: CloudflareBrowserRenderingClient | None = None, + max_results: int = 3, + ): + self.search_backend = search_backend + self.fetch_client = fetch_client + self.max_results = max_results + + def run(self, arguments: dict[str, Any]) -> ToolResult: + query = str(arguments.get("query", "")).strip() + requested_urls = arguments.get("urls") or [] + if isinstance(requested_urls, str): + requested_urls = [requested_urls] + top_k = int(arguments.get("top_k", self.max_results) or self.max_results) + top_k = max(1, min(top_k, 8)) + + if not query and not requested_urls: + return ToolResult(tool_name=self.name, success=False, error="Missing query or urls") + + hits: list[SearchHit] + if requested_urls: + hits = [SearchHit(url=str(url)) for url in requested_urls[:top_k]] + else: + if self.search_backend is None: + return ToolResult( + tool_name=self.name, + success=False, + error="No search backend configured. Cloudflare Browser Rendering can fetch pages but does not provide public web search by itself.", + ) + hits = self.search_backend.search(query, top_k) + + results = [] + for hit in hits[:top_k]: + entry: dict[str, Any] = {"url": hit.url} + if hit.title: + entry["title"] = hit.title + if hit.snippet: + entry["snippet"] = hit.snippet + if self.fetch_client is not None: + try: + markdown = self.fetch_client.markdown(hit.url) + links = self.fetch_client.links(hit.url) + entry["markdown"] = markdown[:4000] + entry["links"] = links[:10] + except Exception as exc: + entry["fetch_error"] = str(exc) + results.append(entry) + + return ToolResult( + tool_name=self.name, + success=True, + output={"query": query, "results": results}, + metadata={ + "search_backend": type(self.search_backend).__name__ if self.search_backend is not None else None, + "fetch_backend": type(self.fetch_client).__name__ if self.fetch_client is not None else None, + "num_results": len(results), + }, + ) + + +def build_default_tool_registry( + *, + cloudflare_token: str | None = None, + cloudflare_account_id: str | None = None, + search_backend: SearchBackend | None = None, +) -> ToolRegistry: + fetch_client = None + if cloudflare_token or os.environ.get("CLOUDFLARE_API_TOKEN"): + try: + fetch_client = CloudflareBrowserRenderingClient( + api_token=cloudflare_token, + account_id=cloudflare_account_id, + ) + except Exception: + fetch_client = None + if search_backend is None: + if os.environ.get('TAVILY_API_KEY'): + try: + search_backend = TavilySearchBackend() + except Exception: + search_backend = MockSearchBackend() + else: + search_backend = MockSearchBackend() + registry = ToolRegistry( + [ + CalculatorTool(), + WebSearchTool( + search_backend=search_backend, + fetch_client=fetch_client, + ), + ] + ) + return registry diff --git a/modal/serve.py b/modal/serve.py index 15a15651..5eca71e2 100644 --- a/modal/serve.py +++ b/modal/serve.py @@ -20,14 +20,15 @@ import modal # Configuration # --------------------------------------------------------------------------- MODEL_REPO = "ManmohanSharma/nanochat-d24" -MODEL_PT = "chatsft_checkpoints/d24/model_000484.pt" -META_JSON = "chatsft_checkpoints/d24/meta_000484.json" +MODEL_PT = "chatsft_checkpoints/d24-sft-r6/model_000754.pt" +META_JSON = "chatsft_checkpoints/d24-sft-r6/meta_000754.json" TOKENIZER_PKL = "tokenizer/tokenizer.pkl" TOKEN_BYTES = "tokenizer/token_bytes.pt" -MODEL_TAG = "d24-sft" +MODEL_TAG = "d24-sft-r6" GPU_TYPE = "L4" # 24 GB VRAM — fits 4 GB bf16 model loaded as fp32 VOLUME_NAME = "samosachaat-weights" HF_SECRET_NAME = "huggingface" # Modal secret containing HF_TOKEN +TAVILY_SECRET_NAME = "tavily" # Modal secret containing TAVILY_API_KEY # --------------------------------------------------------------------------- # Modal app + image @@ -42,12 +43,14 @@ inference_image = ( "tiktoken>=0.11.0", "tokenizers>=0.22.0", "huggingface_hub>=0.25.0", + "requests>=2.31.0", "fastapi>=0.115.0", "uvicorn>=0.30.0", extra_index_url="https://download.pytorch.org/whl/cu124", ) .add_local_file("modal/_model.py", "/root/_model.py") .add_local_file("modal/_tokenizer.py", "/root/_tokenizer.py") + .add_local_file("modal/_tools.py", "/root/_tools.py") ) # Persistent volume for model weights @@ -104,6 +107,7 @@ def download_weights(): scaledown_window=300, # keep warm for 5 min after last request # concurrency handled by @modal.concurrent below timeout=120, + secrets=[modal.Secret.from_name(TAVILY_SECRET_NAME)], ) class Inference: model: object @@ -190,8 +194,22 @@ class Inference: self.assistant_end_id = self.tokenizer.encode_special("<|assistant_end|>")[0] print(f" Special token IDs: {sorted(self.special_token_ids)}") + # Initialize tool registry (Tavily web_search + calculator) + import sys as _sys + if '/root' not in _sys.path: _sys.path.insert(0, '/root') + from _tools import build_default_tool_registry, parse_tool_call_payload + self.tool_registry = build_default_tool_registry() + self._parse_tool_call = parse_tool_call_payload + # Marker tokens for tool state machine + self.python_start_id = self.tokenizer.encode_special("<|python_start|>")[0] + self.python_end_id = self.tokenizer.encode_special("<|python_end|>")[0] + self.output_start_id = self.tokenizer.encode_special("<|output_start|>")[0] + self.output_end_id = self.tokenizer.encode_special("<|output_end|>")[0] + # Stop tokens (exclude tool markers so generation continues through tool calls) + self._stop_token_ids = {self.assistant_end_id, self.tokenizer.get_bos_token_id() if hasattr(self.tokenizer, "get_bos_token_id") else self.tokenizer.encode_special("<|bos|>")[0]} + dt = time.time() - t0 - print(f"Model loaded in {dt:.1f}s on {device}") + print(f"Model loaded in {dt:.1f}s on {device} | tools: {[t for t in self.tool_registry._tools.keys()] if hasattr(self.tool_registry, '_tools') else 'registered'}") @modal.fastapi_endpoint(method="POST", docs=True) async def generate(self, request: dict): @@ -236,49 +254,74 @@ class Inference: tokens = tokens[-max_context:] async def stream(): + from collections import deque input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device) + forced = deque() + in_tool = False + tool_payload_ids = [] + + def _append_token(tid): + nonlocal input_ids + nt = torch.tensor([[tid]], dtype=torch.long, device=self.device) + input_ids = torch.cat([input_ids, nt], dim=1) + if input_ids.size(1) > self.config.sequence_len: + input_ids = input_ids[:, -self.config.sequence_len:] with torch.no_grad(): - generated = [] - for _ in range(max_tokens): - # Forward pass - logits = self.model(input_ids) - next_logits = logits[:, -1, :] + num_generated = 0 + while num_generated < max_tokens: + if forced: + token_id = forced.popleft() + else: + logits = self.model(input_ids) + next_logits = logits[:, -1, :] + if temperature > 0: + next_logits = next_logits / temperature + if top_k > 0: + v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1))) + next_logits[next_logits < v[:, [-1]]] = float('-inf') + probs = torch.softmax(next_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + token_id = next_token.item() - # Temperature - if temperature > 0: - next_logits = next_logits / temperature + # Tool state machine: detect <|python_start|>...<|python_end|>, + # execute tool, inject <|output_start|>...<|output_end|> as forced tokens + if token_id == self.python_start_id: + in_tool = True + tool_payload_ids = [] + elif token_id == self.python_end_id and in_tool: + in_tool = False + if tool_payload_ids: + try: + payload_text = self.tokenizer.decode(tool_payload_ids) + invocation = self._parse_tool_call(payload_text) + result = self.tool_registry.execute(invocation.tool_name, invocation.arguments) + result_text = result.to_payload()[:4096] + except Exception as exc: + result_text = json.dumps({"error": str(exc)[:500]}) + if result_text: + forced.append(self.output_start_id) + forced.extend(self.tokenizer.encode(result_text)) + forced.append(self.output_end_id) + tool_payload_ids = [] + elif in_tool: + tool_payload_ids.append(token_id) - # Top-k filtering - if top_k > 0: - v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1))) - next_logits[next_logits < v[:, [-1]]] = float('-inf') - - # Sample - probs = torch.softmax(next_logits, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - - token_id = next_token.item() - - # Stop on any special token (assistant_end, bos, etc.) - if token_id in self.special_token_ids: + # Stop only on assistant_end or bos (NOT on tool markers) + if token_id in self._stop_token_ids: break - # Decode and yield (skip tokens that can't be decoded) + # Decode + stream to client (includes tool markers; UI renders) try: token_text = self.tokenizer.decode([token_id]) - except (KeyError, Exception): - continue - yield f"data: {json.dumps({'token': token_text, 'gpu': 0})}\n\n" + yield "data: " + json.dumps({"token": token_text, "gpu": 0}) + "\n\n" + except Exception: + pass - # Append for next iteration - input_ids = torch.cat([input_ids, next_token], dim=1) + _append_token(token_id) + num_generated += 1 - # Truncate if exceeding sequence length - if input_ids.size(1) > self.config.sequence_len: - input_ids = input_ids[:, -self.config.sequence_len:] - - yield f"data: {json.dumps({'done': True})}\n\n" + yield "data: " + json.dumps({"done": True}) + "\n\n" return StreamingResponse( stream(), diff --git a/nanochat/tools.py b/nanochat/tools.py index df7ef64c..9d3a82ef 100644 --- a/nanochat/tools.py +++ b/nanochat/tools.py @@ -348,6 +348,43 @@ class MockSearchBackend: rows = self.canned_results.get(normalized, []) return [SearchHit(**row) for row in rows[:top_k]] +class TavilySearchBackend: + """LLM-optimized web search via Tavily. Falls back silently on errors.""" + def __init__(self, api_key: str | None = None, timeout: float = 15.0): + self.api_key = api_key or os.environ.get('TAVILY_API_KEY') + if not self.api_key: + raise ValueError('TavilySearchBackend requires TAVILY_API_KEY') + self.timeout = timeout + + def search(self, query: str, top_k: int) -> list[SearchHit]: + import requests + try: + r = requests.post( + 'https://api.tavily.com/search', + json={ + 'api_key': self.api_key, + 'query': query, + 'max_results': max(1, min(int(top_k), 8)), + 'include_answer': False, + 'include_raw_content': False, + 'search_depth': 'basic', + }, + timeout=self.timeout, + ) + r.raise_for_status() + data = r.json() + except Exception: + return [] + return [ + SearchHit( + url=h.get('url', ''), + title=h.get('title', ''), + snippet=h.get('content', ''), + ) + for h in data.get('results', [])[:top_k] + ] + + class CloudflareBrowserRenderingClient: def __init__( @@ -497,11 +534,19 @@ def build_default_tool_registry( ) except Exception: fetch_client = None + if search_backend is None: + if os.environ.get('TAVILY_API_KEY'): + try: + search_backend = TavilySearchBackend() + except Exception: + search_backend = MockSearchBackend() + else: + search_backend = MockSearchBackend() registry = ToolRegistry( [ CalculatorTool(), WebSearchTool( - search_backend=search_backend or MockSearchBackend(), + search_backend=search_backend, fetch_client=fetch_client, ), ] diff --git a/nanochat/ui.html b/nanochat/ui.html index 3c2d8681..0456f2fe 100644 --- a/nanochat/ui.html +++ b/nanochat/ui.html @@ -459,6 +459,52 @@ .illust-right svg.kettle-svg { width: 70px; } .explore-tag, .chai-label { font-size: 0.85rem; } } + + /* --- Reasoning mode --- */ + .think-toggle { + background: transparent; + border: 1px solid rgba(255,255,255,0.15); + color: #b8a88a; + cursor: pointer; + padding: 6px 10px; + border-radius: 6px; + margin-right: 8px; + font-size: 0.85rem; + display: flex; + align-items: center; + gap: 4px; + transition: all 0.2s; + } + .think-toggle:hover { background: rgba(255,255,255,0.05); } + .think-toggle.active { + background: rgba(184,168,138,0.15); + border-color: #b8a88a; + color: #fff; + } + .think-block { + background: rgba(100,100,100,0.08); + border-left: 3px solid #777; + padding: 10px 14px; + margin-bottom: 10px; + font-style: italic; + color: #999; + font-size: 0.88em; + white-space: pre-wrap; + border-radius: 4px; + } + .think-block::before { + content: "\1F4AD" " thinking"; + display: block; + font-weight: 600; + margin-bottom: 6px; + color: #888; + font-style: normal; + font-size: 0.85em; + letter-spacing: 0.05em; + text-transform: uppercase; + } + .answer-block { white-space: pre-wrap; } + @@ -609,6 +655,10 @@
+ +
+ )} + {/* Send / stop button — vertically centered with the textarea baseline */}
{isStreaming && onStop ? ( diff --git a/services/frontend/components/chat/ChatWindow.tsx b/services/frontend/components/chat/ChatWindow.tsx index 505958ec..757ae260 100644 --- a/services/frontend/components/chat/ChatWindow.tsx +++ b/services/frontend/components/chat/ChatWindow.tsx @@ -31,6 +31,7 @@ export default function ChatWindow() { const [draft, setDraft] = useState(''); const [streamingMsgId, setStreamingMsgId] = useState(null); + const [thinkingMode, setThinkingMode] = useState(false); const streamingBufferRef = useRef(''); const scrollRef = useRef(null); @@ -60,7 +61,7 @@ export default function ChatWindow() { setIsStreaming(false); }, []); - const streamFromApi = useCallback(async (convId: string, assistantMsgId: string, content: string, temp?: number, topk?: number) => { + const streamFromApi = useCallback(async (convId: string, assistantMsgId: string, content: string, temp?: number, topk?: number, thinking?: boolean) => { stop(); const ac = new AbortController(); abortRef.current = ac; @@ -76,7 +77,7 @@ export default function ChatWindow() { const res = await fetch(`/api/conversations/${convId}/messages`, { method: 'POST', headers, - body: JSON.stringify({ content, temperature: temp, max_tokens: 512, top_k: topk }), + body: JSON.stringify({ content, temperature: temp, max_tokens: 512, top_k: topk, thinking_mode: !!thinking }), signal: ac.signal, }); @@ -172,7 +173,7 @@ export default function ChatWindow() { setStreamingMsgId(assistantId); streamingBufferRef.current = ''; - await streamFromApi(convId, assistantId, text, temperature, topK); + await streamFromApi(convId, assistantId, text, temperature, topK, thinkingMode); }, [ draft, @@ -180,13 +181,13 @@ export default function ChatWindow() { ensureConversation, temperature, topK, + thinkingMode, appendMessage, streamFromApi, setTemperature, setTopK, createConversation, newConversation, - // streamFromApi in deps via earlier line ], ); @@ -238,6 +239,8 @@ export default function ChatWindow() { onSubmit={() => handleSend()} onStop={stop} isStreaming={isStreaming} + thinkingMode={thinkingMode} + onToggleThinking={() => setThinkingMode((v) => !v)} /> ); diff --git a/services/frontend/components/chat/MessageBubble.tsx b/services/frontend/components/chat/MessageBubble.tsx index b7a0c987..017e141e 100644 --- a/services/frontend/components/chat/MessageBubble.tsx +++ b/services/frontend/components/chat/MessageBubble.tsx @@ -5,11 +5,114 @@ import ReactMarkdown from 'react-markdown'; import remarkGfm from 'remark-gfm'; import rehypeHighlight from 'rehype-highlight'; import 'highlight.js/styles/github-dark.css'; -import { Check, Copy } from 'lucide-react'; +import { Check, ChevronDown, ChevronRight, Copy, Search, Calculator, Sparkles } from 'lucide-react'; import clsx from 'clsx'; import type { Message } from '@/types/chat'; import SteamTyping from '@/components/svg/SteamTyping'; +// ---- Content parser: split into text / think / tool_call / tool_result segments ---- +type Segment = + | { kind: 'text'; content: string } + | { kind: 'think'; content: string; closed: boolean } + | { kind: 'tool_call'; content: string; closed: boolean } + | { kind: 'tool_result'; content: string; closed: boolean }; + +function parseSegments(raw: string): Segment[] { + const segs: Segment[] = []; + let i = 0; + // marker -> [openTag, closeTag, kind] + const markers: Array<[string, string, Segment['kind']]> = [ + ['', '', 'think'], + ['<|python_start|>', '<|python_end|>', 'tool_call'], + ['<|output_start|>', '<|output_end|>', 'tool_result'], + ]; + while (i < raw.length) { + // find the nearest opening marker + let bestOpen = -1; + let bestMarker: [string, string, Segment['kind']] | null = null; + for (const m of markers) { + const p = raw.indexOf(m[0], i); + if (p !== -1 && (bestOpen === -1 || p < bestOpen)) { bestOpen = p; bestMarker = m; } + } + if (bestOpen === -1) { + if (i < raw.length) segs.push({ kind: 'text', content: raw.slice(i) }); + break; + } + if (bestOpen > i) segs.push({ kind: 'text', content: raw.slice(i, bestOpen) }); + const [openTag, closeTag, kind] = bestMarker!; + const afterOpen = bestOpen + openTag.length; + const closeIdx = raw.indexOf(closeTag, afterOpen); + if (closeIdx === -1) { + segs.push({ kind, content: raw.slice(afterOpen), closed: false }); + i = raw.length; + } else { + segs.push({ kind, content: raw.slice(afterOpen, closeIdx), closed: true }); + i = closeIdx + closeTag.length; + } + } + return segs; +} + +function ThinkBlock({ content, closed }: { content: string; closed: boolean }) { + const [open, setOpen] = useState(true); + return ( +
+ + {open && ( +
+ {content} +
+ )} +
+ ); +} + +function ToolCallBlock({ content, closed }: { content: string; closed: boolean }) { + let parsed: { tool?: string; arguments?: Record } | null = null; + try { parsed = JSON.parse(content); } catch { /* streaming — partial JSON */ } + const toolName = parsed?.tool ?? 'tool'; + const icon = toolName === 'web_search' ? : toolName === 'calculator' ? : ; + const query = parsed?.arguments ? JSON.stringify(parsed.arguments) : content; + return ( +
+
+ {icon} + Calling {toolName}{closed ? '' : '…'} +
+
{query}
+
+ ); +} + +function ToolResultBlock({ content, closed }: { content: string; closed: boolean }) { + const [open, setOpen] = useState(false); + let summary = content; + try { + const j = JSON.parse(content); + if (j?.output?.results?.[0]?.snippet) summary = String(j.output.results[0].snippet).slice(0, 160); + else if (j?.output?.value !== undefined) summary = `= ${j.output.value}`; + else if (j?.error) summary = `error: ${j.error}`; + } catch { /* partial */ } + return ( +
+ + {open && ( +
{content}
+ )} +
+ ); +} + interface Props { message: Message; isStreaming?: boolean; @@ -91,13 +194,21 @@ export default function MessageBubble({ message, isStreaming }: Props) {
) : (
- - {message.content} - + {parseSegments(message.content).map((seg, idx) => { + if (seg.kind === 'think') return ; + if (seg.kind === 'tool_call') return ; + if (seg.kind === 'tool_result') return ; + return ( + + {seg.content} + + ); + })}
)}