mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-07 00:09:50 +00:00
563 lines
19 KiB
Python
563 lines
19 KiB
Python
"""
|
|
Shared tool definitions for nanochat.
|
|
|
|
The current tokenizer only has python/output special tokens. To preserve
|
|
checkpoint compatibility, we reuse those tokens as a generic tool-call and
|
|
tool-result channel. Legacy "python" calculator payloads still work.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import ast
|
|
import json
|
|
import math
|
|
import os
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Protocol
|
|
|
|
import requests
|
|
|
|
TOOL_CALL_START = "<|python_start|>"
|
|
TOOL_CALL_END = "<|python_end|>"
|
|
TOOL_RESULT_START = "<|output_start|>"
|
|
TOOL_RESULT_END = "<|output_end|>"
|
|
MAX_TOOL_PAYLOAD_CHARS = 4096
|
|
|
|
DEFAULT_TOOL_SCHEMA = [
|
|
{
|
|
"name": "calculator",
|
|
"description": "Deterministic scientific calculator for exact arithmetic and common finance formulas.",
|
|
"arguments": {
|
|
"expression": "String expression using numbers, operators, and supported functions.",
|
|
},
|
|
},
|
|
{
|
|
"name": "web_search",
|
|
"description": "Search and fetch web content. Requires a search backend and optionally a page fetch client.",
|
|
"arguments": {
|
|
"query": "Search query string.",
|
|
"top_k": "Maximum number of results to return.",
|
|
"urls": "Optional explicit URLs to fetch instead of searching.",
|
|
},
|
|
},
|
|
]
|
|
|
|
|
|
def _compact_json(data: Any) -> str:
|
|
return json.dumps(data, ensure_ascii=True, separators=(",", ":"), sort_keys=True)
|
|
|
|
|
|
@dataclass
|
|
class ToolInvocation:
|
|
tool_name: str
|
|
arguments: dict[str, Any]
|
|
raw_text: str = ""
|
|
|
|
|
|
@dataclass
|
|
class ToolResult:
|
|
tool_name: str
|
|
success: bool
|
|
output: Any = None
|
|
error: str | None = None
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
def to_payload(self) -> str:
|
|
return _compact_json(
|
|
{
|
|
"tool": self.tool_name,
|
|
"success": self.success,
|
|
"output": self.output,
|
|
"error": self.error,
|
|
"metadata": self.metadata,
|
|
}
|
|
)
|
|
|
|
|
|
class BaseTool:
|
|
name: str
|
|
|
|
def run(self, arguments: dict[str, Any]) -> ToolResult:
|
|
raise NotImplementedError
|
|
|
|
|
|
class ToolRegistry:
|
|
def __init__(self, tools: list[BaseTool] | tuple[BaseTool, ...]):
|
|
self._tools = {tool.name: tool for tool in tools}
|
|
|
|
def execute(self, tool_name: str, arguments: dict[str, Any]) -> ToolResult:
|
|
tool = self._tools.get(tool_name)
|
|
if tool is None:
|
|
return ToolResult(tool_name=tool_name, success=False, error=f"Unknown tool: {tool_name}")
|
|
try:
|
|
return tool.run(arguments)
|
|
except Exception as exc: # defensive: tool failures should become model-visible outputs
|
|
return ToolResult(tool_name=tool_name, success=False, error=str(exc))
|
|
|
|
def schema(self) -> list[dict[str, Any]]:
|
|
return [item for item in DEFAULT_TOOL_SCHEMA if item["name"] in self._tools]
|
|
|
|
|
|
def serialize_tool_call(tool_name: str, arguments: dict[str, Any] | None = None) -> str:
|
|
payload = {
|
|
"tool": tool_name,
|
|
"arguments": arguments or {},
|
|
}
|
|
text = _compact_json(payload)
|
|
return text[:MAX_TOOL_PAYLOAD_CHARS]
|
|
|
|
|
|
def serialize_tool_result(
|
|
tool_name: str,
|
|
output: Any = None,
|
|
*,
|
|
success: bool = True,
|
|
error: str | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> str:
|
|
return ToolResult(
|
|
tool_name=tool_name,
|
|
success=success,
|
|
output=output,
|
|
error=error,
|
|
metadata=metadata or {},
|
|
).to_payload()[:MAX_TOOL_PAYLOAD_CHARS]
|
|
|
|
|
|
def parse_tool_call_payload(text: str) -> ToolInvocation:
|
|
stripped = text.strip()
|
|
if not stripped:
|
|
return ToolInvocation(tool_name="calculator", arguments={"expression": ""}, raw_text=text)
|
|
try:
|
|
payload = json.loads(stripped)
|
|
except json.JSONDecodeError:
|
|
return ToolInvocation(tool_name="calculator", arguments={"expression": stripped}, raw_text=text)
|
|
if isinstance(payload, dict):
|
|
tool_name = payload.get("tool") or payload.get("tool_name") or payload.get("name")
|
|
arguments = payload.get("arguments") or payload.get("args") or {}
|
|
if isinstance(tool_name, str) and isinstance(arguments, dict):
|
|
return ToolInvocation(tool_name=tool_name, arguments=arguments, raw_text=text)
|
|
return ToolInvocation(tool_name="calculator", arguments={"expression": stripped}, raw_text=text)
|
|
|
|
|
|
def parse_tool_result_payload(text: str) -> ToolResult | None:
|
|
stripped = text.strip()
|
|
try:
|
|
payload = json.loads(stripped)
|
|
except json.JSONDecodeError:
|
|
return None
|
|
if not isinstance(payload, dict):
|
|
return None
|
|
tool_name = payload.get("tool")
|
|
if not isinstance(tool_name, str):
|
|
return None
|
|
return ToolResult(
|
|
tool_name=tool_name,
|
|
success=bool(payload.get("success", True)),
|
|
output=payload.get("output"),
|
|
error=payload.get("error"),
|
|
metadata=payload.get("metadata") or {},
|
|
)
|
|
|
|
|
|
def _percent(value: float, rate: float) -> float:
|
|
return value * rate / 100.0
|
|
|
|
|
|
def _percent_change(old: float, new: float) -> float:
|
|
if old == 0:
|
|
raise ValueError("percent_change old value cannot be zero")
|
|
return ((new - old) / old) * 100.0
|
|
|
|
|
|
def _cagr(start: float, end: float, years: float) -> float:
|
|
if start <= 0 or end <= 0 or years <= 0:
|
|
raise ValueError("cagr inputs must be positive")
|
|
return ((end / start) ** (1.0 / years) - 1.0) * 100.0
|
|
|
|
|
|
def _simple_interest(principal: float, annual_rate: float, years: float) -> float:
|
|
return principal * annual_rate / 100.0 * years
|
|
|
|
|
|
def _compound_interest(principal: float, annual_rate: float, periods_per_year: float, years: float) -> float:
|
|
if periods_per_year <= 0:
|
|
raise ValueError("periods_per_year must be positive")
|
|
return principal * (1.0 + annual_rate / 100.0 / periods_per_year) ** (periods_per_year * years)
|
|
|
|
|
|
def _emi(principal: float, annual_rate: float, months: float) -> float:
|
|
if months <= 0:
|
|
raise ValueError("months must be positive")
|
|
monthly_rate = annual_rate / 100.0 / 12.0
|
|
if monthly_rate == 0:
|
|
return principal / months
|
|
growth = (1.0 + monthly_rate) ** months
|
|
return principal * monthly_rate * growth / (growth - 1.0)
|
|
|
|
|
|
ALLOWED_BINOPS = {
|
|
ast.Add: lambda a, b: a + b,
|
|
ast.Sub: lambda a, b: a - b,
|
|
ast.Mult: lambda a, b: a * b,
|
|
ast.Div: lambda a, b: a / b,
|
|
ast.Pow: lambda a, b: a ** b,
|
|
ast.Mod: lambda a, b: a % b,
|
|
}
|
|
ALLOWED_UNARYOPS = {
|
|
ast.UAdd: lambda a: a,
|
|
ast.USub: lambda a: -a,
|
|
}
|
|
ALLOWED_NAMES = {
|
|
"pi": math.pi,
|
|
"e": math.e,
|
|
"tau": math.tau,
|
|
}
|
|
ALLOWED_FUNCTIONS = {
|
|
"abs": abs,
|
|
"round": round,
|
|
"floor": math.floor,
|
|
"ceil": math.ceil,
|
|
"sqrt": math.sqrt,
|
|
"log": math.log,
|
|
"log10": math.log10,
|
|
"exp": math.exp,
|
|
"sin": math.sin,
|
|
"cos": math.cos,
|
|
"tan": math.tan,
|
|
"asin": math.asin,
|
|
"acos": math.acos,
|
|
"atan": math.atan,
|
|
"degrees": math.degrees,
|
|
"radians": math.radians,
|
|
"percent": _percent,
|
|
"percent_change": _percent_change,
|
|
"cagr": _cagr,
|
|
"simple_interest": _simple_interest,
|
|
"compound_interest": _compound_interest,
|
|
"emi": _emi,
|
|
}
|
|
|
|
|
|
class _SafeMathEvaluator:
|
|
def __init__(self, expression: str):
|
|
self.expression = expression
|
|
self.node_count = 0
|
|
|
|
def eval(self) -> float:
|
|
if len(self.expression) > 512:
|
|
raise ValueError("expression too long")
|
|
parsed = ast.parse(self.expression, mode="eval")
|
|
return self._visit(parsed.body)
|
|
|
|
def _visit(self, node: ast.AST) -> Any:
|
|
self.node_count += 1
|
|
if self.node_count > 128:
|
|
raise ValueError("expression too complex")
|
|
|
|
if isinstance(node, ast.Constant):
|
|
if isinstance(node.value, (int, float)):
|
|
return node.value
|
|
raise ValueError(f"unsupported constant: {node.value!r}")
|
|
if isinstance(node, ast.Num): # pragma: no cover - py<3.8 compatibility
|
|
return node.n
|
|
if isinstance(node, ast.BinOp):
|
|
op = ALLOWED_BINOPS.get(type(node.op))
|
|
if op is None:
|
|
raise ValueError(f"unsupported operator: {type(node.op).__name__}")
|
|
return op(self._visit(node.left), self._visit(node.right))
|
|
if isinstance(node, ast.UnaryOp):
|
|
op = ALLOWED_UNARYOPS.get(type(node.op))
|
|
if op is None:
|
|
raise ValueError(f"unsupported unary operator: {type(node.op).__name__}")
|
|
return op(self._visit(node.operand))
|
|
if isinstance(node, ast.Name):
|
|
if node.id not in ALLOWED_NAMES:
|
|
raise ValueError(f"unknown symbol: {node.id}")
|
|
return ALLOWED_NAMES[node.id]
|
|
if isinstance(node, ast.Call):
|
|
if not isinstance(node.func, ast.Name):
|
|
raise ValueError("only direct function calls are allowed")
|
|
fn = ALLOWED_FUNCTIONS.get(node.func.id)
|
|
if fn is None:
|
|
raise ValueError(f"unsupported function: {node.func.id}")
|
|
if node.keywords:
|
|
raise ValueError("keyword arguments are not supported")
|
|
args = [self._visit(arg) for arg in node.args]
|
|
return fn(*args)
|
|
raise ValueError(f"unsupported expression node: {type(node).__name__}")
|
|
|
|
|
|
def _normalize_numeric_output(value: Any) -> Any:
|
|
if isinstance(value, float):
|
|
if not math.isfinite(value):
|
|
raise ValueError("result is not finite")
|
|
return float(f"{value:.12g}")
|
|
return value
|
|
|
|
|
|
class CalculatorTool(BaseTool):
|
|
name = "calculator"
|
|
|
|
def run(self, arguments: dict[str, Any]) -> ToolResult:
|
|
expression = str(arguments.get("expression", "")).strip()
|
|
if not expression:
|
|
return ToolResult(tool_name=self.name, success=False, error="Missing expression")
|
|
value = _SafeMathEvaluator(expression).eval()
|
|
return ToolResult(
|
|
tool_name=self.name,
|
|
success=True,
|
|
output={"expression": expression, "value": _normalize_numeric_output(value)},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class SearchHit:
|
|
url: str
|
|
title: str = ""
|
|
snippet: str = ""
|
|
|
|
|
|
class SearchBackend(Protocol):
|
|
def search(self, query: str, top_k: int) -> list[SearchHit]:
|
|
raise NotImplementedError
|
|
|
|
|
|
class MockSearchBackend:
|
|
def __init__(self, canned_results: dict[str, list[dict[str, str]]] | None = None):
|
|
self.canned_results = canned_results or {
|
|
"browser rendering markdown": [
|
|
{
|
|
"url": "https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/",
|
|
"title": "Cloudflare markdown endpoint",
|
|
"snippet": "Extract markdown from a webpage using Cloudflare Browser Rendering.",
|
|
}
|
|
],
|
|
"nanochat gpt2 speedrun": [
|
|
{
|
|
"url": "https://github.com/karpathy/nanochat",
|
|
"title": "karpathy/nanochat",
|
|
"snippet": "Minimal LLM training harness with pretraining, SFT, RL, and chat UI.",
|
|
}
|
|
],
|
|
}
|
|
|
|
def search(self, query: str, top_k: int) -> list[SearchHit]:
|
|
normalized = query.strip().lower()
|
|
rows = self.canned_results.get(normalized, [])
|
|
return [SearchHit(**row) for row in rows[:top_k]]
|
|
|
|
class TavilySearchBackend:
|
|
"""LLM-optimized web search via Tavily. Falls back silently on errors."""
|
|
def __init__(self, api_key: str | None = None, timeout: float = 15.0):
|
|
self.api_key = api_key or os.environ.get('TAVILY_API_KEY')
|
|
if not self.api_key:
|
|
raise ValueError('TavilySearchBackend requires TAVILY_API_KEY')
|
|
self.timeout = timeout
|
|
|
|
def search(self, query: str, top_k: int) -> list[SearchHit]:
|
|
import requests
|
|
try:
|
|
r = requests.post(
|
|
'https://api.tavily.com/search',
|
|
json={
|
|
'api_key': self.api_key,
|
|
'query': query,
|
|
'max_results': max(1, min(int(top_k), 8)),
|
|
'include_answer': True,
|
|
'include_raw_content': False,
|
|
'search_depth': 'advanced',
|
|
},
|
|
timeout=self.timeout,
|
|
)
|
|
r.raise_for_status()
|
|
data = r.json()
|
|
except Exception:
|
|
return []
|
|
direct_answer = (data.get('answer') or '').strip()
|
|
hits: list[SearchHit] = []
|
|
# Surface Tavily's synthesized answer as the first hit so a 1.4B model
|
|
# can parrot a clean, grounded sentence instead of fighting with noisy snippets.
|
|
if direct_answer:
|
|
hits.append(SearchHit(
|
|
url='https://tavily.com/answer',
|
|
title='Tavily direct answer',
|
|
snippet=direct_answer,
|
|
))
|
|
for h in data.get('results', [])[: max(0, top_k - (1 if direct_answer else 0))]:
|
|
hits.append(SearchHit(
|
|
url=h.get('url', ''),
|
|
title=h.get('title', ''),
|
|
snippet=h.get('content', ''),
|
|
))
|
|
return hits
|
|
|
|
|
|
class CloudflareBrowserRenderingClient:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
api_token: str | None = None,
|
|
account_id: str | None = None,
|
|
base_url: str = "https://api.cloudflare.com/client/v4",
|
|
timeout: float = 30.0,
|
|
max_retries: int = 3,
|
|
):
|
|
self.api_token = api_token or os.environ.get("CLOUDFLARE_API_TOKEN")
|
|
self.account_id = account_id or os.environ.get("CLOUDFLARE_ACCOUNT_ID")
|
|
if not self.api_token or not self.account_id:
|
|
raise ValueError("Cloudflare Browser Rendering requires CLOUDFLARE_API_TOKEN and CLOUDFLARE_ACCOUNT_ID")
|
|
self.base_url = base_url.rstrip("/")
|
|
self.timeout = timeout
|
|
self.max_retries = max_retries
|
|
self.session = requests.Session()
|
|
self.session.headers.update(
|
|
{
|
|
"Authorization": f"Bearer {self.api_token}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
)
|
|
|
|
def _post(self, endpoint: str, body: dict[str, Any]) -> Any:
|
|
url = f"{self.base_url}/accounts/{self.account_id}/browser-rendering/{endpoint}"
|
|
last_error = None
|
|
for attempt in range(1, self.max_retries + 1):
|
|
response = self.session.post(url, json=body, timeout=self.timeout)
|
|
if response.status_code == 429:
|
|
retry_after = response.headers.get("Retry-After")
|
|
sleep_seconds = float(retry_after) if retry_after else float(attempt)
|
|
last_error = RuntimeError(f"Cloudflare Browser Rendering rate limited on {endpoint}")
|
|
time.sleep(min(sleep_seconds, 5.0))
|
|
continue
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
if not payload.get("success", False):
|
|
errors = payload.get("errors", [])
|
|
last_error = RuntimeError(f"Cloudflare Browser Rendering request failed: {errors}")
|
|
break
|
|
return payload.get("result")
|
|
if last_error is not None:
|
|
raise last_error
|
|
raise RuntimeError(f"Cloudflare Browser Rendering request failed for {endpoint}")
|
|
|
|
def markdown(self, url: str, **options: Any) -> str:
|
|
body = {"url": url}
|
|
body.update(options)
|
|
return self._post("markdown", body)
|
|
|
|
def links(self, url: str, **options: Any) -> list[str]:
|
|
body = {"url": url}
|
|
body.update(options)
|
|
return self._post("links", body)
|
|
|
|
def json_extract(self, url: str, *, prompt: str | None = None, schema: dict[str, Any] | None = None, **options: Any) -> dict[str, Any]:
|
|
body: dict[str, Any] = {"url": url}
|
|
if prompt is not None:
|
|
body["prompt"] = prompt
|
|
if schema is not None:
|
|
body["schema"] = schema
|
|
body.update(options)
|
|
return self._post("json", body)
|
|
|
|
|
|
class WebSearchTool(BaseTool):
|
|
name = "web_search"
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
search_backend: SearchBackend | None = None,
|
|
fetch_client: CloudflareBrowserRenderingClient | None = None,
|
|
max_results: int = 3,
|
|
):
|
|
self.search_backend = search_backend
|
|
self.fetch_client = fetch_client
|
|
self.max_results = max_results
|
|
|
|
def run(self, arguments: dict[str, Any]) -> ToolResult:
|
|
query = str(arguments.get("query", "")).strip()
|
|
requested_urls = arguments.get("urls") or []
|
|
if isinstance(requested_urls, str):
|
|
requested_urls = [requested_urls]
|
|
top_k = int(arguments.get("top_k", self.max_results) or self.max_results)
|
|
top_k = max(1, min(top_k, 8))
|
|
|
|
if not query and not requested_urls:
|
|
return ToolResult(tool_name=self.name, success=False, error="Missing query or urls")
|
|
|
|
hits: list[SearchHit]
|
|
if requested_urls:
|
|
hits = [SearchHit(url=str(url)) for url in requested_urls[:top_k]]
|
|
else:
|
|
if self.search_backend is None:
|
|
return ToolResult(
|
|
tool_name=self.name,
|
|
success=False,
|
|
error="No search backend configured. Cloudflare Browser Rendering can fetch pages but does not provide public web search by itself.",
|
|
)
|
|
hits = self.search_backend.search(query, top_k)
|
|
|
|
results = []
|
|
for hit in hits[:top_k]:
|
|
entry: dict[str, Any] = {"url": hit.url}
|
|
if hit.title:
|
|
entry["title"] = hit.title
|
|
if hit.snippet:
|
|
entry["snippet"] = hit.snippet
|
|
if self.fetch_client is not None:
|
|
try:
|
|
markdown = self.fetch_client.markdown(hit.url)
|
|
links = self.fetch_client.links(hit.url)
|
|
entry["markdown"] = markdown[:4000]
|
|
entry["links"] = links[:10]
|
|
except Exception as exc:
|
|
entry["fetch_error"] = str(exc)
|
|
results.append(entry)
|
|
|
|
return ToolResult(
|
|
tool_name=self.name,
|
|
success=True,
|
|
output={"query": query, "results": results},
|
|
metadata={
|
|
"search_backend": type(self.search_backend).__name__ if self.search_backend is not None else None,
|
|
"fetch_backend": type(self.fetch_client).__name__ if self.fetch_client is not None else None,
|
|
"num_results": len(results),
|
|
},
|
|
)
|
|
|
|
|
|
def build_default_tool_registry(
|
|
*,
|
|
cloudflare_token: str | None = None,
|
|
cloudflare_account_id: str | None = None,
|
|
search_backend: SearchBackend | None = None,
|
|
) -> ToolRegistry:
|
|
fetch_client = None
|
|
if cloudflare_token or os.environ.get("CLOUDFLARE_API_TOKEN"):
|
|
try:
|
|
fetch_client = CloudflareBrowserRenderingClient(
|
|
api_token=cloudflare_token,
|
|
account_id=cloudflare_account_id,
|
|
)
|
|
except Exception:
|
|
fetch_client = None
|
|
if search_backend is None:
|
|
if os.environ.get('TAVILY_API_KEY'):
|
|
try:
|
|
search_backend = TavilySearchBackend()
|
|
except Exception:
|
|
search_backend = MockSearchBackend()
|
|
else:
|
|
search_backend = MockSearchBackend()
|
|
registry = ToolRegistry(
|
|
[
|
|
CalculatorTool(),
|
|
WebSearchTool(
|
|
search_backend=search_backend,
|
|
fetch_client=fetch_client,
|
|
),
|
|
]
|
|
)
|
|
return registry
|