""" Shared tool definitions for nanochat. The current tokenizer only has python/output special tokens. To preserve checkpoint compatibility, we reuse those tokens as a generic tool-call and tool-result channel. Legacy "python" calculator payloads still work. """ from __future__ import annotations import ast import json import math import os import time from dataclasses import dataclass, field from typing import Any, Protocol import requests TOOL_CALL_START = "<|python_start|>" TOOL_CALL_END = "<|python_end|>" TOOL_RESULT_START = "<|output_start|>" TOOL_RESULT_END = "<|output_end|>" MAX_TOOL_PAYLOAD_CHARS = 4096 DEFAULT_TOOL_SCHEMA = [ { "name": "calculator", "description": "Deterministic scientific calculator for exact arithmetic and common finance formulas.", "arguments": { "expression": "String expression using numbers, operators, and supported functions.", }, }, { "name": "web_search", "description": "Search and fetch web content. Requires a search backend and optionally a page fetch client.", "arguments": { "query": "Search query string.", "top_k": "Maximum number of results to return.", "urls": "Optional explicit URLs to fetch instead of searching.", }, }, ] def _compact_json(data: Any) -> str: return json.dumps(data, ensure_ascii=True, separators=(",", ":"), sort_keys=True) @dataclass class ToolInvocation: tool_name: str arguments: dict[str, Any] raw_text: str = "" @dataclass class ToolResult: tool_name: str success: bool output: Any = None error: str | None = None metadata: dict[str, Any] = field(default_factory=dict) def to_payload(self) -> str: return _compact_json( { "tool": self.tool_name, "success": self.success, "output": self.output, "error": self.error, "metadata": self.metadata, } ) class BaseTool: name: str def run(self, arguments: dict[str, Any]) -> ToolResult: raise NotImplementedError class ToolRegistry: def __init__(self, tools: list[BaseTool] | tuple[BaseTool, ...]): self._tools = {tool.name: tool for tool in tools} def execute(self, tool_name: str, arguments: dict[str, Any]) -> ToolResult: tool = self._tools.get(tool_name) if tool is None: return ToolResult(tool_name=tool_name, success=False, error=f"Unknown tool: {tool_name}") try: return tool.run(arguments) except Exception as exc: # defensive: tool failures should become model-visible outputs return ToolResult(tool_name=tool_name, success=False, error=str(exc)) def schema(self) -> list[dict[str, Any]]: return [item for item in DEFAULT_TOOL_SCHEMA if item["name"] in self._tools] def serialize_tool_call(tool_name: str, arguments: dict[str, Any] | None = None) -> str: payload = { "tool": tool_name, "arguments": arguments or {}, } text = _compact_json(payload) return text[:MAX_TOOL_PAYLOAD_CHARS] def serialize_tool_result( tool_name: str, output: Any = None, *, success: bool = True, error: str | None = None, metadata: dict[str, Any] | None = None, ) -> str: return ToolResult( tool_name=tool_name, success=success, output=output, error=error, metadata=metadata or {}, ).to_payload()[:MAX_TOOL_PAYLOAD_CHARS] def parse_tool_call_payload(text: str) -> ToolInvocation: stripped = text.strip() if not stripped: return ToolInvocation(tool_name="calculator", arguments={"expression": ""}, raw_text=text) try: payload = json.loads(stripped) except json.JSONDecodeError: return ToolInvocation(tool_name="calculator", arguments={"expression": stripped}, raw_text=text) if isinstance(payload, dict): tool_name = payload.get("tool") or payload.get("tool_name") or payload.get("name") arguments = payload.get("arguments") or payload.get("args") or {} if isinstance(tool_name, str) and isinstance(arguments, dict): return ToolInvocation(tool_name=tool_name, arguments=arguments, raw_text=text) return ToolInvocation(tool_name="calculator", arguments={"expression": stripped}, raw_text=text) def parse_tool_result_payload(text: str) -> ToolResult | None: stripped = text.strip() try: payload = json.loads(stripped) except json.JSONDecodeError: return None if not isinstance(payload, dict): return None tool_name = payload.get("tool") if not isinstance(tool_name, str): return None return ToolResult( tool_name=tool_name, success=bool(payload.get("success", True)), output=payload.get("output"), error=payload.get("error"), metadata=payload.get("metadata") or {}, ) def _percent(value: float, rate: float) -> float: return value * rate / 100.0 def _percent_change(old: float, new: float) -> float: if old == 0: raise ValueError("percent_change old value cannot be zero") return ((new - old) / old) * 100.0 def _cagr(start: float, end: float, years: float) -> float: if start <= 0 or end <= 0 or years <= 0: raise ValueError("cagr inputs must be positive") return ((end / start) ** (1.0 / years) - 1.0) * 100.0 def _simple_interest(principal: float, annual_rate: float, years: float) -> float: return principal * annual_rate / 100.0 * years def _compound_interest(principal: float, annual_rate: float, periods_per_year: float, years: float) -> float: if periods_per_year <= 0: raise ValueError("periods_per_year must be positive") return principal * (1.0 + annual_rate / 100.0 / periods_per_year) ** (periods_per_year * years) def _emi(principal: float, annual_rate: float, months: float) -> float: if months <= 0: raise ValueError("months must be positive") monthly_rate = annual_rate / 100.0 / 12.0 if monthly_rate == 0: return principal / months growth = (1.0 + monthly_rate) ** months return principal * monthly_rate * growth / (growth - 1.0) ALLOWED_BINOPS = { ast.Add: lambda a, b: a + b, ast.Sub: lambda a, b: a - b, ast.Mult: lambda a, b: a * b, ast.Div: lambda a, b: a / b, ast.Pow: lambda a, b: a ** b, ast.Mod: lambda a, b: a % b, } ALLOWED_UNARYOPS = { ast.UAdd: lambda a: a, ast.USub: lambda a: -a, } ALLOWED_NAMES = { "pi": math.pi, "e": math.e, "tau": math.tau, } ALLOWED_FUNCTIONS = { "abs": abs, "round": round, "floor": math.floor, "ceil": math.ceil, "sqrt": math.sqrt, "log": math.log, "log10": math.log10, "exp": math.exp, "sin": math.sin, "cos": math.cos, "tan": math.tan, "asin": math.asin, "acos": math.acos, "atan": math.atan, "degrees": math.degrees, "radians": math.radians, "percent": _percent, "percent_change": _percent_change, "cagr": _cagr, "simple_interest": _simple_interest, "compound_interest": _compound_interest, "emi": _emi, } class _SafeMathEvaluator: def __init__(self, expression: str): self.expression = expression self.node_count = 0 def eval(self) -> float: if len(self.expression) > 512: raise ValueError("expression too long") parsed = ast.parse(self.expression, mode="eval") return self._visit(parsed.body) def _visit(self, node: ast.AST) -> Any: self.node_count += 1 if self.node_count > 128: raise ValueError("expression too complex") if isinstance(node, ast.Constant): if isinstance(node.value, (int, float)): return node.value raise ValueError(f"unsupported constant: {node.value!r}") if isinstance(node, ast.Num): # pragma: no cover - py<3.8 compatibility return node.n if isinstance(node, ast.BinOp): op = ALLOWED_BINOPS.get(type(node.op)) if op is None: raise ValueError(f"unsupported operator: {type(node.op).__name__}") return op(self._visit(node.left), self._visit(node.right)) if isinstance(node, ast.UnaryOp): op = ALLOWED_UNARYOPS.get(type(node.op)) if op is None: raise ValueError(f"unsupported unary operator: {type(node.op).__name__}") return op(self._visit(node.operand)) if isinstance(node, ast.Name): if node.id not in ALLOWED_NAMES: raise ValueError(f"unknown symbol: {node.id}") return ALLOWED_NAMES[node.id] if isinstance(node, ast.Call): if not isinstance(node.func, ast.Name): raise ValueError("only direct function calls are allowed") fn = ALLOWED_FUNCTIONS.get(node.func.id) if fn is None: raise ValueError(f"unsupported function: {node.func.id}") if node.keywords: raise ValueError("keyword arguments are not supported") args = [self._visit(arg) for arg in node.args] return fn(*args) raise ValueError(f"unsupported expression node: {type(node).__name__}") def _normalize_numeric_output(value: Any) -> Any: if isinstance(value, float): if not math.isfinite(value): raise ValueError("result is not finite") return float(f"{value:.12g}") return value class CalculatorTool(BaseTool): name = "calculator" def run(self, arguments: dict[str, Any]) -> ToolResult: expression = str(arguments.get("expression", "")).strip() if not expression: return ToolResult(tool_name=self.name, success=False, error="Missing expression") value = _SafeMathEvaluator(expression).eval() return ToolResult( tool_name=self.name, success=True, output={"expression": expression, "value": _normalize_numeric_output(value)}, ) @dataclass class SearchHit: url: str title: str = "" snippet: str = "" class SearchBackend(Protocol): def search(self, query: str, top_k: int) -> list[SearchHit]: raise NotImplementedError class MockSearchBackend: def __init__(self, canned_results: dict[str, list[dict[str, str]]] | None = None): self.canned_results = canned_results or { "browser rendering markdown": [ { "url": "https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/", "title": "Cloudflare markdown endpoint", "snippet": "Extract markdown from a webpage using Cloudflare Browser Rendering.", } ], "nanochat gpt2 speedrun": [ { "url": "https://github.com/karpathy/nanochat", "title": "karpathy/nanochat", "snippet": "Minimal LLM training harness with pretraining, SFT, RL, and chat UI.", } ], } def search(self, query: str, top_k: int) -> list[SearchHit]: normalized = query.strip().lower() rows = self.canned_results.get(normalized, []) return [SearchHit(**row) for row in rows[:top_k]] class TavilySearchBackend: """LLM-optimized web search via Tavily. Falls back silently on errors.""" def __init__(self, api_key: str | None = None, timeout: float = 15.0): self.api_key = api_key or os.environ.get('TAVILY_API_KEY') if not self.api_key: raise ValueError('TavilySearchBackend requires TAVILY_API_KEY') self.timeout = timeout def search(self, query: str, top_k: int) -> list[SearchHit]: import requests try: r = requests.post( 'https://api.tavily.com/search', json={ 'api_key': self.api_key, 'query': query, 'max_results': max(1, min(int(top_k), 8)), 'include_answer': True, 'include_raw_content': False, 'search_depth': 'advanced', }, timeout=self.timeout, ) r.raise_for_status() data = r.json() except Exception: return [] direct_answer = (data.get('answer') or '').strip() hits: list[SearchHit] = [] # Surface Tavily's synthesized answer as the first hit so a 1.4B model # can parrot a clean, grounded sentence instead of fighting with noisy snippets. if direct_answer: hits.append(SearchHit( url='https://tavily.com/answer', title='Tavily direct answer', snippet=direct_answer, )) for h in data.get('results', [])[: max(0, top_k - (1 if direct_answer else 0))]: hits.append(SearchHit( url=h.get('url', ''), title=h.get('title', ''), snippet=h.get('content', ''), )) return hits class CloudflareBrowserRenderingClient: def __init__( self, *, api_token: str | None = None, account_id: str | None = None, base_url: str = "https://api.cloudflare.com/client/v4", timeout: float = 30.0, max_retries: int = 3, ): self.api_token = api_token or os.environ.get("CLOUDFLARE_API_TOKEN") self.account_id = account_id or os.environ.get("CLOUDFLARE_ACCOUNT_ID") if not self.api_token or not self.account_id: raise ValueError("Cloudflare Browser Rendering requires CLOUDFLARE_API_TOKEN and CLOUDFLARE_ACCOUNT_ID") self.base_url = base_url.rstrip("/") self.timeout = timeout self.max_retries = max_retries self.session = requests.Session() self.session.headers.update( { "Authorization": f"Bearer {self.api_token}", "Content-Type": "application/json", } ) def _post(self, endpoint: str, body: dict[str, Any]) -> Any: url = f"{self.base_url}/accounts/{self.account_id}/browser-rendering/{endpoint}" last_error = None for attempt in range(1, self.max_retries + 1): response = self.session.post(url, json=body, timeout=self.timeout) if response.status_code == 429: retry_after = response.headers.get("Retry-After") sleep_seconds = float(retry_after) if retry_after else float(attempt) last_error = RuntimeError(f"Cloudflare Browser Rendering rate limited on {endpoint}") time.sleep(min(sleep_seconds, 5.0)) continue response.raise_for_status() payload = response.json() if not payload.get("success", False): errors = payload.get("errors", []) last_error = RuntimeError(f"Cloudflare Browser Rendering request failed: {errors}") break return payload.get("result") if last_error is not None: raise last_error raise RuntimeError(f"Cloudflare Browser Rendering request failed for {endpoint}") def markdown(self, url: str, **options: Any) -> str: body = {"url": url} body.update(options) return self._post("markdown", body) def links(self, url: str, **options: Any) -> list[str]: body = {"url": url} body.update(options) return self._post("links", body) def json_extract(self, url: str, *, prompt: str | None = None, schema: dict[str, Any] | None = None, **options: Any) -> dict[str, Any]: body: dict[str, Any] = {"url": url} if prompt is not None: body["prompt"] = prompt if schema is not None: body["schema"] = schema body.update(options) return self._post("json", body) class WebSearchTool(BaseTool): name = "web_search" def __init__( self, *, search_backend: SearchBackend | None = None, fetch_client: CloudflareBrowserRenderingClient | None = None, max_results: int = 3, ): self.search_backend = search_backend self.fetch_client = fetch_client self.max_results = max_results def run(self, arguments: dict[str, Any]) -> ToolResult: query = str(arguments.get("query", "")).strip() requested_urls = arguments.get("urls") or [] if isinstance(requested_urls, str): requested_urls = [requested_urls] top_k = int(arguments.get("top_k", self.max_results) or self.max_results) top_k = max(1, min(top_k, 8)) if not query and not requested_urls: return ToolResult(tool_name=self.name, success=False, error="Missing query or urls") hits: list[SearchHit] if requested_urls: hits = [SearchHit(url=str(url)) for url in requested_urls[:top_k]] else: if self.search_backend is None: return ToolResult( tool_name=self.name, success=False, error="No search backend configured. Cloudflare Browser Rendering can fetch pages but does not provide public web search by itself.", ) hits = self.search_backend.search(query, top_k) results = [] for hit in hits[:top_k]: entry: dict[str, Any] = {"url": hit.url} if hit.title: entry["title"] = hit.title if hit.snippet: entry["snippet"] = hit.snippet if self.fetch_client is not None: try: markdown = self.fetch_client.markdown(hit.url) links = self.fetch_client.links(hit.url) entry["markdown"] = markdown[:4000] entry["links"] = links[:10] except Exception as exc: entry["fetch_error"] = str(exc) results.append(entry) return ToolResult( tool_name=self.name, success=True, output={"query": query, "results": results}, metadata={ "search_backend": type(self.search_backend).__name__ if self.search_backend is not None else None, "fetch_backend": type(self.fetch_client).__name__ if self.fetch_client is not None else None, "num_results": len(results), }, ) def build_default_tool_registry( *, cloudflare_token: str | None = None, cloudflare_account_id: str | None = None, search_backend: SearchBackend | None = None, ) -> ToolRegistry: fetch_client = None if cloudflare_token or os.environ.get("CLOUDFLARE_API_TOKEN"): try: fetch_client = CloudflareBrowserRenderingClient( api_token=cloudflare_token, account_id=cloudflare_account_id, ) except Exception: fetch_client = None if search_backend is None: if os.environ.get('TAVILY_API_KEY'): try: search_backend = TavilySearchBackend() except Exception: search_backend = MockSearchBackend() else: search_backend = MockSearchBackend() registry = ToolRegistry( [ CalculatorTool(), WebSearchTool( search_backend=search_backend, fetch_client=fetch_client, ), ] ) return registry