diff --git a/docs/pre_gpu_runbook.md b/docs/pre_gpu_runbook.md new file mode 100644 index 00000000..b54f7c3a --- /dev/null +++ b/docs/pre_gpu_runbook.md @@ -0,0 +1,97 @@ +# Pre-GPU Runbook + +This runbook is the minimum operational checklist before spending GPU time. + +## 1. Local Prep + +1. Build the seed tool datasets: + +```bash +python -m scripts.build_tool_datasets +``` + +2. Import the starting checkpoint from Hugging Face into native nanochat format: + +```bash +python -m scripts.import_hf_checkpoint \ + --repo-id ManmohanSharma/nanochat-d24 \ + --model-tag d24_hf_import +``` + +3. Validate tool tokenization and mock tool execution with local tests: + +```bash +python -m pytest tests/test_engine.py tests/test_tools.py -v +``` + +4. Dry-run tool evaluation on CPU: + +```bash +python -m scripts.chat_eval \ + -i sft \ + -a ToolJSON \ + --tool-jsonl seed_data/tool_eval_seed.jsonl \ + --device-type cpu \ + -x 3 +``` + +## 2. 48-Hour GPU Schedule + +1. Pilot CPT + - Run a short continuation test from the imported base checkpoint. + - Confirm loss is moving, checkpoint save works, and HF sync works. + +2. Full CPT + - Run the main continuation stage on ClimbMix backbone. + - Save staged checkpoints at planned intervals. + +3. SFT + - Include the local tool SFT JSONL via `--extra-train-jsonl`. + - Validate that calculator/web_search traces render correctly. + +4. RL / tool tuning + - Keep this stage narrow and short. + - Focus on tool-choice correctness and grounded answers. + +5. Eval + - Run ARC, MMLU, GSM8K, HumanEval, and ToolJSON checks. + - Do not ship if tool behavior regresses or citations are missing. + +## 3. Checkpoint Upload Cadence + +Upload every stage boundary and any explicit resume point: + +```bash +python -m scripts.hf_sync_checkpoint \ + --repo-id ManmohanSharma/nanochat-d24 \ + --source base \ + --model-tag d24_hf_import \ + --step 0 +``` + +If a whole checkpoint directory should be mirrored: + +```bash +python -m scripts.hf_sync_checkpoint \ + --repo-id ManmohanSharma/nanochat-d24 \ + --source base \ + --model-tag d24_hf_import +``` + +## 4. Go / No-Go + +Go only if: + +- HF import works. +- HF sync works. +- Mock tool execution works. +- Tool seed datasets are generated. +- Tool eval runs locally. +- The search backend plan is explicit: search provider plus Cloudflare fetch/crawl. + +No-Go if: + +- Any tokenizer mismatch appears during HF import. +- Tool blocks fail to render. +- `web_search` still has no backend plan beyond fetch-only Cloudflare Browser Rendering. +- Local tool eval is missing or failing. diff --git a/manifests/cpt_recipe_v1.json b/manifests/cpt_recipe_v1.json new file mode 100644 index 00000000..84fe9414 --- /dev/null +++ b/manifests/cpt_recipe_v1.json @@ -0,0 +1,31 @@ +{ + "name": "cpt_recipe_v1", + "goal": "Continue pretraining nanochat-d24 on a safe, high-value mixture for a 48-hour run.", + "backbone_dataset": { + "repo_id": "karpathy/climbmix-400b-shuffle", + "role": "primary", + "notes": "Use ClimbMix as the main pretraining backbone because it is already the best-performing dataset in this fork." + }, + "optional_augmentations": [ + { + "repo_id": "nvidia/Nemotron-Pretraining-Code-v2", + "role": "code", + "notes": "Selective code augmentation only." + }, + { + "repo_id": "nvidia/Nemotron-CC-Math-v1", + "role": "math_science", + "notes": "Selective math/science augmentation only." + }, + { + "repo_id": "nvidia/Nemotron-Pretraining-Specialized-v1.1", + "role": "specialized_web", + "notes": "Use sparingly for targeted specialization." + } + ], + "do_not_do": [ + "Do not replace ClimbMix wholesale.", + "Do not attempt architecture changes during this 48-hour continuation run.", + "Do not rely on continued pretraining for current-events freshness." + ] +} diff --git a/manifests/rl_recipe_v1.json b/manifests/rl_recipe_v1.json new file mode 100644 index 00000000..323ea319 --- /dev/null +++ b/manifests/rl_recipe_v1.json @@ -0,0 +1,26 @@ +{ + "name": "rl_recipe_v1", + "goal": "Lightweight tool-tuning stage after SFT.", + "seed_eval_data": "seed_data/tool_eval_seed.jsonl", + "reward_components": [ + "valid tool call", + "correct tool choice", + "answer contains expected answer fragment", + "citation required when using web_search", + "penalty for unnecessary tool calls" + ], + "external_templates": [ + { + "repo_id": "nvidia/Nemotron-RL-Agentic-Conversational-Tool-Use-Pivot-v1", + "role": "tool_use_structure" + }, + { + "repo_id": "nvidia/Nemotron-RL-Agentic-Function-Calling-Pivot-v1", + "role": "function_call_structure" + } + ], + "notes": [ + "Keep RL narrow; do not turn this into broad RLHF.", + "Use the local tool eval dataset for dry runs before GPU." + ] +} diff --git a/manifests/sft_recipe_v1.json b/manifests/sft_recipe_v1.json new file mode 100644 index 00000000..ac45e050 --- /dev/null +++ b/manifests/sft_recipe_v1.json @@ -0,0 +1,32 @@ +{ + "name": "sft_recipe_v1", + "goal": "Teach instruction following, tool routing, grounded web answers, and calculator usage.", + "base_mixture": [ + { + "source": "existing_nanochat", + "notes": "Keep the existing chat mixture in scripts/chat_sft.py as the base." + }, + { + "source": "local_seed_data", + "path": "seed_data/tool_sft_seed.jsonl", + "notes": "Seed examples for direct answers, calculator, and web_search." + } + ], + "external_templates": [ + { + "repo_id": "nvidia/Nemotron-Instruction-Following-Chat-v1", + "role": "instruction_following" + }, + { + "collection": "nvidia/nemotron-post-training-v3", + "role": "agentic_templates" + } + ], + "target_behaviors": [ + "Answer directly when no tool is needed.", + "Use calculator for exact numeric computation.", + "Use web_search for recency and external facts.", + "Cite URLs when the response depends on retrieved web content.", + "Avoid unnecessary tool calls." + ] +} diff --git a/nanochat/engine.py b/nanochat/engine.py index 4bdfd654..a0c74475 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -13,70 +13,10 @@ The whole thing is made as efficient as possible. import torch import torch.nn.functional as F -import signal -import warnings -from contextlib import contextmanager from collections import deque from nanochat.common import compute_init, autodetect_device_type, COMPUTE_DTYPE from nanochat.checkpoint_manager import load_model - -# ----------------------------------------------------------------------------- -# Calculator tool helpers -@contextmanager -def timeout(duration, formula): - def timeout_handler(signum, frame): - raise Exception(f"'{formula}': timed out after {duration} seconds") - - signal.signal(signal.SIGALRM, timeout_handler) - signal.alarm(duration) - yield - signal.alarm(0) - -def eval_with_timeout(formula, max_time=3): - try: - with timeout(max_time, formula): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", SyntaxWarning) - return eval(formula, {"__builtins__": {}}, {}) - except Exception as e: - signal.alarm(0) - # print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage - return None - -def use_calculator(expr): - """ - Evaluate a Python expression safely. - Supports both math expressions and string operations like .count() - """ - # Remove commas from numbers - expr = expr.replace(",", "") - - # Check if it's a pure math expression (old behavior) - if all([x in "0123456789*+-/.() " for x in expr]): - if "**" in expr: # disallow power operator - return None - return eval_with_timeout(expr) - - # Check if it's a string operation we support - # Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens - allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ " - if not all([x in allowed_chars for x in expr]): - return None - - # Disallow dangerous patterns - dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file', - 'input', 'raw_input', 'globals', 'locals', 'vars', 'dir', - 'getattr', 'setattr', 'delattr', 'hasattr'] - expr_lower = expr.lower() - if any(pattern in expr_lower for pattern in dangerous_patterns): - return None - - # Only allow .count() method for now (can expand later) - if '.count(' not in expr: - return None - - # Evaluate with timeout - return eval_with_timeout(expr) +from nanochat.tools import build_default_tool_registry, parse_tool_call_payload # ----------------------------------------------------------------------------- class KVCache: @@ -162,15 +102,16 @@ class RowState: def __init__(self, current_tokens=None): self.current_tokens = current_tokens or [] # Current token sequence for this row self.forced_tokens = deque() # Queue of tokens to force inject - self.in_python_block = False # Whether we are inside a python block - self.python_expr_tokens = [] # Tokens of the current python expression + self.in_tool_block = False # Whether we are inside a tool-call block + self.tool_payload_tokens = [] # Tokens of the current tool-call payload self.completed = False # Whether this row has completed generation class Engine: - def __init__(self, model, tokenizer): + def __init__(self, model, tokenizer, tools=None): self.model = model self.tokenizer = tokenizer # needed for tool use + self.tools = tools or build_default_tool_registry() @torch.inference_mode() def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42): @@ -255,21 +196,34 @@ class Engine: state.completed = True # Handle tool logic if next_token == python_start: - state.in_python_block = True - state.python_expr_tokens = [] - elif next_token == python_end and state.in_python_block: - state.in_python_block = False - if state.python_expr_tokens: - expr = self.tokenizer.decode(state.python_expr_tokens) - result = use_calculator(expr) - if result is not None: - result_tokens = self.tokenizer.encode(str(result)) + state.in_tool_block = True + state.tool_payload_tokens = [] + elif next_token == python_end and state.in_tool_block: + state.in_tool_block = False + if state.tool_payload_tokens: + payload_text = self.tokenizer.decode(state.tool_payload_tokens) + invocation = parse_tool_call_payload(payload_text) + result = self.tools.execute(invocation.tool_name, invocation.arguments) + legacy_calculator_call = ( + invocation.tool_name == "calculator" + and invocation.raw_text.strip() + and not invocation.raw_text.strip().startswith("{") + ) + if legacy_calculator_call: + if result.success and isinstance(result.output, dict) and "value" in result.output: + result_text = str(result.output["value"]) + else: + result_text = result.error or "" + else: + result_text = result.to_payload()[:4096] + if result_text: + result_tokens = self.tokenizer.encode(result_text) state.forced_tokens.append(output_start) state.forced_tokens.extend(result_tokens) state.forced_tokens.append(output_end) - state.python_expr_tokens = [] - elif state.in_python_block: - state.python_expr_tokens.append(next_token) + state.tool_payload_tokens = [] + elif state.in_tool_block: + state.tool_payload_tokens.append(next_token) # Yield the token column yield token_column, token_masks diff --git a/nanochat/tokenizer.py b/nanochat/tokenizer.py index a2146c2e..2b4080d4 100644 --- a/nanochat/tokenizer.py +++ b/nanochat/tokenizer.py @@ -270,6 +270,7 @@ class RustBPETokenizer: - ids: list[int] is a list of token ids of this rendered conversation - mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on. """ + from nanochat.tools import serialize_tool_call, serialize_tool_result # ids, masks that we will return and a helper function to help build them up. ids, mask = [], [] def add_tokens(token_ids, mask_val): @@ -327,14 +328,29 @@ class RustBPETokenizer: if part["type"] == "text": # string part => simply add the tokens add_tokens(value_ids, 1) - elif part["type"] == "python": + elif part["type"] in {"python", "tool_call"}: # python tool call => add the tokens inside <|python_start|> and <|python_end|> + if part["type"] == "tool_call": + payload = serialize_tool_call( + tool_name=part["tool_name"], + arguments=part.get("arguments") or {}, + ) + value_ids = self.encode(payload) add_tokens(python_start, 1) add_tokens(value_ids, 1) add_tokens(python_end, 1) - elif part["type"] == "python_output": + elif part["type"] in {"python_output", "tool_result"}: # python output => add the tokens inside <|output_start|> and <|output_end|> # none of these tokens are supervised because the tokens come from Python at test time + if part["type"] == "tool_result": + payload = serialize_tool_result( + tool_name=part["tool_name"], + output=part.get("output"), + success=part.get("success", True), + error=part.get("error"), + metadata=part.get("metadata"), + ) + value_ids = self.encode(payload) add_tokens(output_start, 0) add_tokens(value_ids, 0) add_tokens(output_end, 0) diff --git a/nanochat/tools.py b/nanochat/tools.py new file mode 100644 index 00000000..df7ef64c --- /dev/null +++ b/nanochat/tools.py @@ -0,0 +1,509 @@ +""" +Shared tool definitions for nanochat. + +The current tokenizer only has python/output special tokens. To preserve +checkpoint compatibility, we reuse those tokens as a generic tool-call and +tool-result channel. Legacy "python" calculator payloads still work. +""" + +from __future__ import annotations + +import ast +import json +import math +import os +import time +from dataclasses import dataclass, field +from typing import Any, Protocol + +import requests + +TOOL_CALL_START = "<|python_start|>" +TOOL_CALL_END = "<|python_end|>" +TOOL_RESULT_START = "<|output_start|>" +TOOL_RESULT_END = "<|output_end|>" +MAX_TOOL_PAYLOAD_CHARS = 4096 + +DEFAULT_TOOL_SCHEMA = [ + { + "name": "calculator", + "description": "Deterministic scientific calculator for exact arithmetic and common finance formulas.", + "arguments": { + "expression": "String expression using numbers, operators, and supported functions.", + }, + }, + { + "name": "web_search", + "description": "Search and fetch web content. Requires a search backend and optionally a page fetch client.", + "arguments": { + "query": "Search query string.", + "top_k": "Maximum number of results to return.", + "urls": "Optional explicit URLs to fetch instead of searching.", + }, + }, +] + + +def _compact_json(data: Any) -> str: + return json.dumps(data, ensure_ascii=True, separators=(",", ":"), sort_keys=True) + + +@dataclass +class ToolInvocation: + tool_name: str + arguments: dict[str, Any] + raw_text: str = "" + + +@dataclass +class ToolResult: + tool_name: str + success: bool + output: Any = None + error: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def to_payload(self) -> str: + return _compact_json( + { + "tool": self.tool_name, + "success": self.success, + "output": self.output, + "error": self.error, + "metadata": self.metadata, + } + ) + + +class BaseTool: + name: str + + def run(self, arguments: dict[str, Any]) -> ToolResult: + raise NotImplementedError + + +class ToolRegistry: + def __init__(self, tools: list[BaseTool] | tuple[BaseTool, ...]): + self._tools = {tool.name: tool for tool in tools} + + def execute(self, tool_name: str, arguments: dict[str, Any]) -> ToolResult: + tool = self._tools.get(tool_name) + if tool is None: + return ToolResult(tool_name=tool_name, success=False, error=f"Unknown tool: {tool_name}") + try: + return tool.run(arguments) + except Exception as exc: # defensive: tool failures should become model-visible outputs + return ToolResult(tool_name=tool_name, success=False, error=str(exc)) + + def schema(self) -> list[dict[str, Any]]: + return [item for item in DEFAULT_TOOL_SCHEMA if item["name"] in self._tools] + + +def serialize_tool_call(tool_name: str, arguments: dict[str, Any] | None = None) -> str: + payload = { + "tool": tool_name, + "arguments": arguments or {}, + } + text = _compact_json(payload) + return text[:MAX_TOOL_PAYLOAD_CHARS] + + +def serialize_tool_result( + tool_name: str, + output: Any = None, + *, + success: bool = True, + error: str | None = None, + metadata: dict[str, Any] | None = None, +) -> str: + return ToolResult( + tool_name=tool_name, + success=success, + output=output, + error=error, + metadata=metadata or {}, + ).to_payload()[:MAX_TOOL_PAYLOAD_CHARS] + + +def parse_tool_call_payload(text: str) -> ToolInvocation: + stripped = text.strip() + if not stripped: + return ToolInvocation(tool_name="calculator", arguments={"expression": ""}, raw_text=text) + try: + payload = json.loads(stripped) + except json.JSONDecodeError: + return ToolInvocation(tool_name="calculator", arguments={"expression": stripped}, raw_text=text) + if isinstance(payload, dict): + tool_name = payload.get("tool") or payload.get("tool_name") or payload.get("name") + arguments = payload.get("arguments") or payload.get("args") or {} + if isinstance(tool_name, str) and isinstance(arguments, dict): + return ToolInvocation(tool_name=tool_name, arguments=arguments, raw_text=text) + return ToolInvocation(tool_name="calculator", arguments={"expression": stripped}, raw_text=text) + + +def parse_tool_result_payload(text: str) -> ToolResult | None: + stripped = text.strip() + try: + payload = json.loads(stripped) + except json.JSONDecodeError: + return None + if not isinstance(payload, dict): + return None + tool_name = payload.get("tool") + if not isinstance(tool_name, str): + return None + return ToolResult( + tool_name=tool_name, + success=bool(payload.get("success", True)), + output=payload.get("output"), + error=payload.get("error"), + metadata=payload.get("metadata") or {}, + ) + + +def _percent(value: float, rate: float) -> float: + return value * rate / 100.0 + + +def _percent_change(old: float, new: float) -> float: + if old == 0: + raise ValueError("percent_change old value cannot be zero") + return ((new - old) / old) * 100.0 + + +def _cagr(start: float, end: float, years: float) -> float: + if start <= 0 or end <= 0 or years <= 0: + raise ValueError("cagr inputs must be positive") + return ((end / start) ** (1.0 / years) - 1.0) * 100.0 + + +def _simple_interest(principal: float, annual_rate: float, years: float) -> float: + return principal * annual_rate / 100.0 * years + + +def _compound_interest(principal: float, annual_rate: float, periods_per_year: float, years: float) -> float: + if periods_per_year <= 0: + raise ValueError("periods_per_year must be positive") + return principal * (1.0 + annual_rate / 100.0 / periods_per_year) ** (periods_per_year * years) + + +def _emi(principal: float, annual_rate: float, months: float) -> float: + if months <= 0: + raise ValueError("months must be positive") + monthly_rate = annual_rate / 100.0 / 12.0 + if monthly_rate == 0: + return principal / months + growth = (1.0 + monthly_rate) ** months + return principal * monthly_rate * growth / (growth - 1.0) + + +ALLOWED_BINOPS = { + ast.Add: lambda a, b: a + b, + ast.Sub: lambda a, b: a - b, + ast.Mult: lambda a, b: a * b, + ast.Div: lambda a, b: a / b, + ast.Pow: lambda a, b: a ** b, + ast.Mod: lambda a, b: a % b, +} +ALLOWED_UNARYOPS = { + ast.UAdd: lambda a: a, + ast.USub: lambda a: -a, +} +ALLOWED_NAMES = { + "pi": math.pi, + "e": math.e, + "tau": math.tau, +} +ALLOWED_FUNCTIONS = { + "abs": abs, + "round": round, + "floor": math.floor, + "ceil": math.ceil, + "sqrt": math.sqrt, + "log": math.log, + "log10": math.log10, + "exp": math.exp, + "sin": math.sin, + "cos": math.cos, + "tan": math.tan, + "asin": math.asin, + "acos": math.acos, + "atan": math.atan, + "degrees": math.degrees, + "radians": math.radians, + "percent": _percent, + "percent_change": _percent_change, + "cagr": _cagr, + "simple_interest": _simple_interest, + "compound_interest": _compound_interest, + "emi": _emi, +} + + +class _SafeMathEvaluator: + def __init__(self, expression: str): + self.expression = expression + self.node_count = 0 + + def eval(self) -> float: + if len(self.expression) > 512: + raise ValueError("expression too long") + parsed = ast.parse(self.expression, mode="eval") + return self._visit(parsed.body) + + def _visit(self, node: ast.AST) -> Any: + self.node_count += 1 + if self.node_count > 128: + raise ValueError("expression too complex") + + if isinstance(node, ast.Constant): + if isinstance(node.value, (int, float)): + return node.value + raise ValueError(f"unsupported constant: {node.value!r}") + if isinstance(node, ast.Num): # pragma: no cover - py<3.8 compatibility + return node.n + if isinstance(node, ast.BinOp): + op = ALLOWED_BINOPS.get(type(node.op)) + if op is None: + raise ValueError(f"unsupported operator: {type(node.op).__name__}") + return op(self._visit(node.left), self._visit(node.right)) + if isinstance(node, ast.UnaryOp): + op = ALLOWED_UNARYOPS.get(type(node.op)) + if op is None: + raise ValueError(f"unsupported unary operator: {type(node.op).__name__}") + return op(self._visit(node.operand)) + if isinstance(node, ast.Name): + if node.id not in ALLOWED_NAMES: + raise ValueError(f"unknown symbol: {node.id}") + return ALLOWED_NAMES[node.id] + if isinstance(node, ast.Call): + if not isinstance(node.func, ast.Name): + raise ValueError("only direct function calls are allowed") + fn = ALLOWED_FUNCTIONS.get(node.func.id) + if fn is None: + raise ValueError(f"unsupported function: {node.func.id}") + if node.keywords: + raise ValueError("keyword arguments are not supported") + args = [self._visit(arg) for arg in node.args] + return fn(*args) + raise ValueError(f"unsupported expression node: {type(node).__name__}") + + +def _normalize_numeric_output(value: Any) -> Any: + if isinstance(value, float): + if not math.isfinite(value): + raise ValueError("result is not finite") + return float(f"{value:.12g}") + return value + + +class CalculatorTool(BaseTool): + name = "calculator" + + def run(self, arguments: dict[str, Any]) -> ToolResult: + expression = str(arguments.get("expression", "")).strip() + if not expression: + return ToolResult(tool_name=self.name, success=False, error="Missing expression") + value = _SafeMathEvaluator(expression).eval() + return ToolResult( + tool_name=self.name, + success=True, + output={"expression": expression, "value": _normalize_numeric_output(value)}, + ) + + +@dataclass +class SearchHit: + url: str + title: str = "" + snippet: str = "" + + +class SearchBackend(Protocol): + def search(self, query: str, top_k: int) -> list[SearchHit]: + raise NotImplementedError + + +class MockSearchBackend: + def __init__(self, canned_results: dict[str, list[dict[str, str]]] | None = None): + self.canned_results = canned_results or { + "browser rendering markdown": [ + { + "url": "https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/", + "title": "Cloudflare markdown endpoint", + "snippet": "Extract markdown from a webpage using Cloudflare Browser Rendering.", + } + ], + "nanochat gpt2 speedrun": [ + { + "url": "https://github.com/karpathy/nanochat", + "title": "karpathy/nanochat", + "snippet": "Minimal LLM training harness with pretraining, SFT, RL, and chat UI.", + } + ], + } + + def search(self, query: str, top_k: int) -> list[SearchHit]: + normalized = query.strip().lower() + rows = self.canned_results.get(normalized, []) + return [SearchHit(**row) for row in rows[:top_k]] + + +class CloudflareBrowserRenderingClient: + def __init__( + self, + *, + api_token: str | None = None, + account_id: str | None = None, + base_url: str = "https://api.cloudflare.com/client/v4", + timeout: float = 30.0, + max_retries: int = 3, + ): + self.api_token = api_token or os.environ.get("CLOUDFLARE_API_TOKEN") + self.account_id = account_id or os.environ.get("CLOUDFLARE_ACCOUNT_ID") + if not self.api_token or not self.account_id: + raise ValueError("Cloudflare Browser Rendering requires CLOUDFLARE_API_TOKEN and CLOUDFLARE_ACCOUNT_ID") + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self.max_retries = max_retries + self.session = requests.Session() + self.session.headers.update( + { + "Authorization": f"Bearer {self.api_token}", + "Content-Type": "application/json", + } + ) + + def _post(self, endpoint: str, body: dict[str, Any]) -> Any: + url = f"{self.base_url}/accounts/{self.account_id}/browser-rendering/{endpoint}" + last_error = None + for attempt in range(1, self.max_retries + 1): + response = self.session.post(url, json=body, timeout=self.timeout) + if response.status_code == 429: + retry_after = response.headers.get("Retry-After") + sleep_seconds = float(retry_after) if retry_after else float(attempt) + last_error = RuntimeError(f"Cloudflare Browser Rendering rate limited on {endpoint}") + time.sleep(min(sleep_seconds, 5.0)) + continue + response.raise_for_status() + payload = response.json() + if not payload.get("success", False): + errors = payload.get("errors", []) + last_error = RuntimeError(f"Cloudflare Browser Rendering request failed: {errors}") + break + return payload.get("result") + if last_error is not None: + raise last_error + raise RuntimeError(f"Cloudflare Browser Rendering request failed for {endpoint}") + + def markdown(self, url: str, **options: Any) -> str: + body = {"url": url} + body.update(options) + return self._post("markdown", body) + + def links(self, url: str, **options: Any) -> list[str]: + body = {"url": url} + body.update(options) + return self._post("links", body) + + def json_extract(self, url: str, *, prompt: str | None = None, schema: dict[str, Any] | None = None, **options: Any) -> dict[str, Any]: + body: dict[str, Any] = {"url": url} + if prompt is not None: + body["prompt"] = prompt + if schema is not None: + body["schema"] = schema + body.update(options) + return self._post("json", body) + + +class WebSearchTool(BaseTool): + name = "web_search" + + def __init__( + self, + *, + search_backend: SearchBackend | None = None, + fetch_client: CloudflareBrowserRenderingClient | None = None, + max_results: int = 3, + ): + self.search_backend = search_backend + self.fetch_client = fetch_client + self.max_results = max_results + + def run(self, arguments: dict[str, Any]) -> ToolResult: + query = str(arguments.get("query", "")).strip() + requested_urls = arguments.get("urls") or [] + if isinstance(requested_urls, str): + requested_urls = [requested_urls] + top_k = int(arguments.get("top_k", self.max_results) or self.max_results) + top_k = max(1, min(top_k, 8)) + + if not query and not requested_urls: + return ToolResult(tool_name=self.name, success=False, error="Missing query or urls") + + hits: list[SearchHit] + if requested_urls: + hits = [SearchHit(url=str(url)) for url in requested_urls[:top_k]] + else: + if self.search_backend is None: + return ToolResult( + tool_name=self.name, + success=False, + error="No search backend configured. Cloudflare Browser Rendering can fetch pages but does not provide public web search by itself.", + ) + hits = self.search_backend.search(query, top_k) + + results = [] + for hit in hits[:top_k]: + entry: dict[str, Any] = {"url": hit.url} + if hit.title: + entry["title"] = hit.title + if hit.snippet: + entry["snippet"] = hit.snippet + if self.fetch_client is not None: + try: + markdown = self.fetch_client.markdown(hit.url) + links = self.fetch_client.links(hit.url) + entry["markdown"] = markdown[:4000] + entry["links"] = links[:10] + except Exception as exc: + entry["fetch_error"] = str(exc) + results.append(entry) + + return ToolResult( + tool_name=self.name, + success=True, + output={"query": query, "results": results}, + metadata={ + "search_backend": type(self.search_backend).__name__ if self.search_backend is not None else None, + "fetch_backend": type(self.fetch_client).__name__ if self.fetch_client is not None else None, + "num_results": len(results), + }, + ) + + +def build_default_tool_registry( + *, + cloudflare_token: str | None = None, + cloudflare_account_id: str | None = None, + search_backend: SearchBackend | None = None, +) -> ToolRegistry: + fetch_client = None + if cloudflare_token or os.environ.get("CLOUDFLARE_API_TOKEN"): + try: + fetch_client = CloudflareBrowserRenderingClient( + api_token=cloudflare_token, + account_id=cloudflare_account_id, + ) + except Exception: + fetch_client = None + registry = ToolRegistry( + [ + CalculatorTool(), + WebSearchTool( + search_backend=search_backend or MockSearchBackend(), + fetch_client=fetch_client, + ), + ] + ) + return registry diff --git a/pyproject.toml b/pyproject.toml index 8b6fd954..4c7fc56c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ requires-python = ">=3.10" dependencies = [ "datasets>=4.0.0", "fastapi>=0.117.1", + "huggingface_hub>=0.34.4", "ipykernel>=7.1.0", "kernels>=0.11.7", "matplotlib>=3.10.8", diff --git a/scripts/base_train.py b/scripts/base_train.py index 86aa770b..cb8071eb 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -33,6 +33,7 @@ from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine from nanochat.flash_attention import HAS_FA3 +from nanochat.tools import DEFAULT_TOOL_SCHEMA from scripts.base_eval import evaluate_core print_banner() @@ -481,9 +482,12 @@ while True: optimizer.state_dict(), # optimizer state { # metadata saved as json "step": step, + "stage": "base", "val_bpb": val_bpb, # loss at last step "model_config": model_config_kwargs, "user_config": user_config, # inputs to the training script + "tool_schema": DEFAULT_TOOL_SCHEMA, + "source_hf_repo": os.environ.get("NANOCHAT_SOURCE_HF_REPO"), "device_batch_size": args.device_batch_size, "max_seq_len": args.max_seq_len, "total_batch_size": total_batch_size, diff --git a/scripts/build_tool_datasets.py b/scripts/build_tool_datasets.py new file mode 100644 index 00000000..75ec4e6f --- /dev/null +++ b/scripts/build_tool_datasets.py @@ -0,0 +1,167 @@ +""" +Build small local tool-use datasets for SFT/eval dry runs. + +This is not intended to replace larger curated corpora. It creates schema-valid +seed data so tool routing, tokenization, and evaluation can be tested locally. +""" + +import argparse +import json +import os + + +def write_jsonl(path, rows): + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, ensure_ascii=True) + "\n") + + +def calculator_conversation(user_text, expression, value, final_text): + return [ + {"role": "user", "content": user_text}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me compute that exactly. "}, + {"type": "tool_call", "tool_name": "calculator", "arguments": {"expression": expression}}, + {"type": "tool_result", "tool_name": "calculator", "output": {"expression": expression, "value": value}}, + {"type": "text", "text": final_text}, + ], + }, + ] + + +def web_search_conversation(user_text, query, url, final_text): + return [ + {"role": "user", "content": user_text}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I should verify this with a fresh web lookup. "}, + {"type": "tool_call", "tool_name": "web_search", "arguments": {"query": query, "top_k": 1}}, + { + "type": "tool_result", + "tool_name": "web_search", + "output": { + "query": query, + "results": [ + { + "url": url, + "title": "Source result", + "snippet": "Fresh page content retrieved for grounding.", + } + ], + }, + }, + {"type": "text", "text": final_text}, + ], + }, + ] + + +def direct_conversation(user_text, answer_text): + return [ + {"role": "user", "content": user_text}, + {"role": "assistant", "content": answer_text}, + ] + + +def build_train_rows(): + return [ + direct_conversation( + "Explain what gradient descent is in one short paragraph.", + "Gradient descent is an optimization method that repeatedly nudges model parameters in the direction that reduces error, using the gradient to decide both direction and magnitude of each update.", + ), + calculator_conversation( + "What is 18% of 250?", + "percent(250,18)", + 45.0, + "18% of 250 is 45.", + ), + calculator_conversation( + "If revenue grew from 120 to 150, what is the percent change?", + "percent_change(120,150)", + 25.0, + "The percent change from 120 to 150 is 25%.", + ), + calculator_conversation( + "What is the monthly EMI for a 500000 loan at 8% annual interest over 240 months?", + "emi(500000,8,240)", + 4182.198594391402, + "The monthly EMI is about 4182.2.", + ), + web_search_conversation( + "Find the official nanochat repository and give me the link.", + "official nanochat repository", + "https://github.com/karpathy/nanochat", + "The official nanochat repository is https://github.com/karpathy/nanochat", + ), + web_search_conversation( + "What do the Cloudflare Browser Rendering docs say about the markdown endpoint?", + "Cloudflare Browser Rendering markdown endpoint docs", + "https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/", + "Cloudflare documents a markdown endpoint for extracting rendered page content: https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/", + ), + ] + + +def build_eval_rows(): + return [ + { + "conversation": { + "messages": [ + {"role": "user", "content": "What is 12% of 250?"}, + {"role": "assistant", "content": "45"}, + ] + }, + "checks": { + "must_call": "calculator", + "must_not_call": ["web_search"], + "answer_contains": ["45"], + }, + }, + { + "conversation": { + "messages": [ + {"role": "user", "content": "Find the official nanochat repository and cite the link."}, + {"role": "assistant", "content": "https://github.com/karpathy/nanochat"}, + ] + }, + "checks": { + "must_call": "web_search", + "citation_required": True, + "answer_contains": ["github.com/karpathy/nanochat"], + }, + }, + { + "conversation": { + "messages": [ + {"role": "user", "content": "Explain what overfitting means in one sentence."}, + {"role": "assistant", "content": "Overfitting means a model memorizes training patterns too closely and generalizes poorly to new data."}, + ] + }, + "checks": { + "must_not_call": ["web_search"], + "answer_contains": ["generalizes poorly"], + }, + }, + ] + + +def main(): + parser = argparse.ArgumentParser(description="Build seed tool-use datasets") + parser.add_argument("--train-out", default="seed_data/tool_sft_seed.jsonl", help="Output JSONL for SFT conversations") + parser.add_argument("--eval-out", default="seed_data/tool_eval_seed.jsonl", help="Output JSONL for eval/reward objects") + args = parser.parse_args() + + train_rows = build_train_rows() + eval_rows = build_eval_rows() + write_jsonl(args.train_out, train_rows) + write_jsonl(args.eval_out, eval_rows) + print(f"Wrote {len(train_rows)} SFT rows to {args.train_out}") + print(f"Wrote {len(eval_rows)} eval rows to {args.eval_out}") + + +if __name__ == "__main__": + main() diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index 858d4c29..2646d4d4 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -22,6 +22,7 @@ from tasks.mmlu import MMLU from tasks.arc import ARC from tasks.gsm8k import GSM8K from tasks.spellingbee import SpellingBee +from tasks.tool_json import ToolJSON # ----------------------------------------------------------------------------- # Generative evaluation loop (we go one problem at a time, sample, evaluate) @@ -156,8 +157,10 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems def run_chat_eval(task_name, model, tokenizer, engine, batch_size=1, num_samples=1, max_new_tokens=512, temperature=0.0, top_k=50, - max_problems=None): + max_problems=None, tool_jsonl=None): # Create the evaluation object + if task_name == "ToolJSON" and not tool_jsonl: + raise ValueError("ToolJSON evaluation requires --tool-jsonl") task_module = { 'HumanEval': HumanEval, 'MMLU': partial(MMLU, subset="all", split="test"), @@ -165,6 +168,7 @@ def run_chat_eval(task_name, model, tokenizer, engine, 'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"), 'GSM8K': partial(GSM8K, subset="main", split="test"), 'SpellingBee': partial(SpellingBee, size=256, split="test"), + 'ToolJSON': partial(ToolJSON, filepath=tool_jsonl), }[task_name] task_object = task_module() # Run the evaluation @@ -191,6 +195,7 @@ if __name__ == "__main__": parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate') + parser.add_argument('--tool-jsonl', type=str, default=None, help='Tool JSONL dataset path used when task-name includes ToolJSON') parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') args = parser.parse_args() @@ -224,6 +229,7 @@ if __name__ == "__main__": temperature=args.temperature, top_k=args.top_k, max_problems=args.max_problems, + tool_jsonl=args.tool_jsonl, ) results[task_name] = acc print0(f"{task_name} accuracy: {100 * acc:.2f}%") diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index cb2cb0e0..b0909083 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -25,6 +25,7 @@ import torch.distributed as dist from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type from nanochat.checkpoint_manager import save_checkpoint, load_model from nanochat.engine import Engine +from nanochat.tools import DEFAULT_TOOL_SCHEMA from tasks.gsm8k import GSM8K # ----------------------------------------------------------------------------- @@ -317,7 +318,12 @@ for step in range(num_steps): model.state_dict(), None, # note: we don't bother to save the optimizer state { + "step": step, + "stage": "rl", "model_config": model_config_kwargs, + "user_config": user_config, + "tool_schema": DEFAULT_TOOL_SCHEMA, + "source_hf_repo": meta.get("source_hf_repo"), } ) print(f"✅ Saved model checkpoint to {checkpoint_dir}") diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index c1adbb69..fecc208b 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -31,6 +31,7 @@ from tasks.mmlu import MMLU from tasks.smoltalk import SmolTalk from tasks.customjson import CustomJSON from tasks.spellingbee import SimpleSpelling, SpellingBee +from nanochat.tools import DEFAULT_TOOL_SCHEMA # ----------------------------------------------------------------------------- # CLI arguments @@ -66,6 +67,8 @@ parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max pro # Data mixture parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)") parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)") +parser.add_argument("--extra-train-jsonl", type=str, default="", help="comma-separated JSONL conversation files to append to the SFT train mixture") +parser.add_argument("--extra-val-jsonl", type=str, default="", help="comma-separated JSONL conversation files to append to the SFT val mixture") args = parser.parse_args() user_config = vars(args).copy() # ----------------------------------------------------------------------------- @@ -162,6 +165,8 @@ for group in optimizer.param_groups: # SFT data mixture and DataLoader identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") +extra_train_jsonl = [path for path in args.extra_train_jsonl.split(",") if path] +extra_val_jsonl = [path for path in args.extra_val_jsonl.split(",") if path] train_tasks = [ SmolTalk(split="train"), # 460K rows of general conversations CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations @@ -171,13 +176,17 @@ train_tasks = [ SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) ] +train_tasks.extend(CustomJSON(filepath=path) for path in extra_train_jsonl) train_dataset = TaskMixture(train_tasks) -print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})") +print0( + f"Training mixture: {len(train_dataset):,} rows " + f"(MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs}, extra_jsonl={len(extra_train_jsonl)})" +) val_dataset = TaskMixture([ SmolTalk(split="test"), # 24K rows in test set MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios -]) # total: 24K + 14K + 1.32K ~= 39K rows +] + [CustomJSON(filepath=path) for path in extra_val_jsonl]) # total: 24K + 14K + 1.32K ~= 39K rows plus custom validation # DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len) # A big problem is that we don't know the final num_iterations in advance. So we create # these two global variables and update them from within the data generator. @@ -406,6 +415,7 @@ while True: optimizer.state_dict(), { "step": step, + "stage": "sft", "val_bpb": val_bpb, # loss at last step "model_config": { "sequence_len": args.max_seq_len, @@ -417,6 +427,10 @@ while True: "window_pattern": model.config.window_pattern, }, "user_config": user_config, # inputs to the training script + "tool_schema": DEFAULT_TOOL_SCHEMA, + "source_hf_repo": meta.get("source_hf_repo"), + "extra_train_jsonl": extra_train_jsonl, + "extra_val_jsonl": extra_val_jsonl, }, rank=ddp_rank, ) diff --git a/scripts/chat_tool_rl.py b/scripts/chat_tool_rl.py new file mode 100644 index 00000000..48995c6e --- /dev/null +++ b/scripts/chat_tool_rl.py @@ -0,0 +1,219 @@ +""" +Lightweight RL stage for tool-use tuning on local ToolJSON datasets. + +1 GPU: +python -m scripts.chat_tool_rl --train-jsonl seed_data/tool_eval_seed.jsonl --eval-jsonl seed_data/tool_eval_seed.jsonl + +8 GPUs: +torchrun --standalone --nproc_per_node=8 -m scripts.chat_tool_rl -- --train-jsonl seed_data/tool_eval_seed.jsonl --eval-jsonl seed_data/tool_eval_seed.jsonl +""" + +import argparse +import itertools +import os + +import torch +import torch.distributed as dist +import wandb + +from nanochat.checkpoint_manager import load_model, save_checkpoint +from nanochat.common import DummyWandb, autodetect_device_type, compute_cleanup, compute_init, get_base_dir, print0 +from nanochat.engine import Engine +from nanochat.tools import DEFAULT_TOOL_SCHEMA +from tasks.tool_json import ToolJSON + + +parser = argparse.ArgumentParser(description="RL tuning on local tool-use JSONL tasks") +parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") +parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") +parser.add_argument("--model-tag", type=str, default=None, help="SFT model tag to load from") +parser.add_argument("--model-step", type=int, default=None, help="SFT model step to load from") +parser.add_argument("--train-jsonl", type=str, required=True, help="Training ToolJSONL file") +parser.add_argument("--eval-jsonl", type=str, default=None, help="Evaluation ToolJSONL file") +parser.add_argument("--num-epochs", type=int, default=1, help="Number of epochs over ToolJSON") +parser.add_argument("--device-batch-size", type=int, default=8, help="Max batch size per forward pass") +parser.add_argument("--examples-per-step", type=int, default=8, help="Examples per optimization step across all ranks") +parser.add_argument("--num-samples", type=int, default=8, help="Number of samples per example") +parser.add_argument("--max-new-tokens", type=int, default=256, help="Max tokens to generate") +parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature") +parser.add_argument("--top-k", type=int, default=50, help="Top-k sampling (0 disables)") +parser.add_argument("--embedding-lr", type=float, default=0.2, help="Embedding LR") +parser.add_argument("--unembedding-lr", type=float, default=0.004, help="Unembedding LR") +parser.add_argument("--matrix-lr", type=float, default=0.02, help="Matrix LR") +parser.add_argument("--weight-decay", type=float, default=0.0, help="Weight decay") +parser.add_argument("--init-lr-frac", type=float, default=0.05, help="Initial LR multiplier") +parser.add_argument("--eval-every", type=int, default=40, help="Evaluate every N steps") +parser.add_argument("--eval-examples", type=int, default=64, help="Maximum eval examples") +parser.add_argument("--save-every", type=int, default=40, help="Save every N steps") +args = parser.parse_args() +user_config = vars(args).copy() + + +device_type = autodetect_device_type() if args.device_type == "" else args.device_type +ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) +master_process = ddp_rank == 0 + +use_dummy_wandb = args.run == "dummy" or not master_process +wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-tool-rl", name=args.run, config=user_config) + +model, tokenizer, meta = load_model("sft", device, phase="eval", model_tag=args.model_tag, step=args.model_step) +engine = Engine(model, tokenizer) + +train_task = ToolJSON(filepath=args.train_jsonl) +eval_task = ToolJSON(filepath=args.eval_jsonl or args.train_jsonl) +num_steps = (len(train_task) // args.examples_per_step) * args.num_epochs +print0(f"Calculated number of steps: {num_steps}") + + +@torch.no_grad() +def get_batch(): + assistant_end = tokenizer.encode_special("<|assistant_end|>") + rank_indices = range(ddp_rank, len(train_task), ddp_world_size) + for example_idx in itertools.cycle(rank_indices): + conversation = train_task[example_idx] + tokens = tokenizer.render_for_completion(conversation) + prefix_length = len(tokens) + model.eval() + generated_token_sequences = [] + masks = [] + num_sampling_steps = args.num_samples // args.device_batch_size + if args.num_samples % args.device_batch_size != 0: + raise ValueError("num_samples must be divisible by device_batch_size") + for sampling_step in range(num_sampling_steps): + seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF + generated_batch, mask_batch = engine.generate_batch( + tokens, + num_samples=args.device_batch_size, + max_tokens=args.max_new_tokens, + temperature=args.temperature, + top_k=args.top_k, + seed=seed, + ) + generated_token_sequences.extend(generated_batch) + masks.extend(mask_batch) + + rewards = [] + for sample_tokens in generated_token_sequences: + generated_text = tokenizer.decode(sample_tokens[prefix_length:]) + rewards.append(train_task.reward(conversation, generated_text)) + + max_length = max(len(seq) for seq in generated_token_sequences) + padded_sequences = [seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences] + padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks] + ids = torch.tensor(padded_sequences, dtype=torch.long, device=device) + mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device) + inputs = ids[:, :-1] + targets = ids[:, 1:].clone() + targets[mask_ids[:, 1:] == 0] = -1 + rewards_tensor = torch.tensor(rewards, dtype=torch.float, device=device) + advantages = rewards_tensor - rewards_tensor.mean() + yield generated_token_sequences, inputs, targets, rewards_tensor, advantages + + +@torch.no_grad() +def run_tool_eval(task, max_examples): + max_examples = min(max_examples, len(task)) + total = 0 + reward_sum = 0.0 + passed = 0 + for idx in range(ddp_rank, max_examples, ddp_world_size): + conversation = task[idx] + tokens = tokenizer.render_for_completion(conversation) + generated_sequences, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=args.max_new_tokens, temperature=0.0, top_k=args.top_k) + generated_text = tokenizer.decode(generated_sequences[0][len(tokens):]) + reward_sum += task.reward(conversation, generated_text) + passed += task.evaluate(conversation, generated_text) + total += 1 + + reward_tensor = torch.tensor([reward_sum], dtype=torch.float, device=device) + passed_tensor = torch.tensor([passed], dtype=torch.long, device=device) + total_tensor = torch.tensor([total], dtype=torch.long, device=device) + if ddp: + dist.all_reduce(reward_tensor, op=dist.ReduceOp.SUM) + dist.all_reduce(passed_tensor, op=dist.ReduceOp.SUM) + dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM) + total = max(total_tensor.item(), 1) + return reward_tensor.item() / total, passed_tensor.item() / total + + +optimizer = model.setup_optimizer( + unembedding_lr=args.unembedding_lr, + embedding_lr=args.embedding_lr, + matrix_lr=args.matrix_lr, + weight_decay=args.weight_decay, +) +for group in optimizer.param_groups: + group["lr"] = group["lr"] * args.init_lr_frac + group["initial_lr"] = group["lr"] + + +def get_lr_multiplier(it): + return 1.0 - it / max(num_steps, 1) + + +assert args.examples_per_step % ddp_world_size == 0, "examples_per_step must be divisible by number of ranks" +examples_per_rank = args.examples_per_step // ddp_world_size +batch_iterator = get_batch() + +for step in range(num_steps): + if step % args.eval_every == 0: + model.eval() + mean_reward, pass_rate = run_tool_eval(eval_task, args.eval_examples) + print0(f"Step {step} | tool_eval_reward={mean_reward:.4f} | tool_eval_pass={pass_rate:.4f}") + wandb_run.log({"step": step, "tool_eval_reward": mean_reward, "tool_eval_pass": pass_rate}) + + rewards_list = [] + for example_step in range(examples_per_rank): + _, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator) + model.train() + assert inputs_all.size(0) % args.device_batch_size == 0 + num_passes = inputs_all.size(0) // args.device_batch_size + for pass_idx in range(num_passes): + b0, b1 = pass_idx * args.device_batch_size, (pass_idx + 1) * args.device_batch_size + inputs = inputs_all[b0:b1] + targets = targets_all[b0:b1] + rewards = rewards_all[b0:b1] + advantages = advantages_all[b0:b1] + logp = -model(inputs, targets, loss_reduction="none").view_as(inputs) + pg_obj = (logp * advantages.unsqueeze(-1)).sum() + num_valid = (targets >= 0).sum().clamp(min=1) + pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank) + loss = -pg_obj + loss.backward() + print0( + f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} " + f"| loss={loss.item():.6f} | reward={rewards.mean().item():.4f}" + ) + rewards_list.append(rewards_all.mean().item()) + + lrm = get_lr_multiplier(step) + for group in optimizer.param_groups: + group["lr"] = group["initial_lr"] * lrm + optimizer.step() + model.zero_grad(set_to_none=True) + wandb_run.log({"step": step, "lrm": lrm, "mean_reward": sum(rewards_list) / max(len(rewards_list), 1)}) + + if master_process and ((step > 0 and step % args.save_every == 0) or step == num_steps - 1): + base_dir = get_base_dir() + output_dirname = args.model_tag if args.model_tag else f"d{model.config.n_layer}" + checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname) + save_checkpoint( + checkpoint_dir, + step, + model.state_dict(), + None, + { + "step": step, + "stage": "tool_rl", + "model_config": model.config.__dict__, + "user_config": user_config, + "tool_schema": DEFAULT_TOOL_SCHEMA, + "source_hf_repo": meta.get("source_hf_repo"), + "train_jsonl": args.train_jsonl, + "eval_jsonl": args.eval_jsonl, + }, + ) + print0(f"Saved tool RL checkpoint to {checkpoint_dir}") + +wandb_run.finish() +compute_cleanup() diff --git a/scripts/hf_sync_checkpoint.py b/scripts/hf_sync_checkpoint.py new file mode 100644 index 00000000..2f0aac7b --- /dev/null +++ b/scripts/hf_sync_checkpoint.py @@ -0,0 +1,86 @@ +""" +Upload native nanochat checkpoints and related artifacts to a Hugging Face repo. + +Examples: +python -m scripts.hf_sync_checkpoint --repo-id ManmohanSharma/nanochat-d24 --source base --model-tag d24_hf_import +python -m scripts.hf_sync_checkpoint --repo-id ManmohanSharma/nanochat-d24 --source base --model-tag d24_hf_import --step 0 +""" + +import argparse +import os + +from huggingface_hub import HfApi + +from nanochat.common import get_base_dir + + +def resolve_checkpoint_dir(source, model_tag): + phase_dir = { + "base": "base_checkpoints", + "sft": "chatsft_checkpoints", + "rl": "chatrl_checkpoints", + }[source] + return os.path.join(get_base_dir(), phase_dir, model_tag) + + +def main(): + parser = argparse.ArgumentParser(description="Upload native nanochat checkpoints to Hugging Face") + parser.add_argument("--repo-id", required=True, help="Destination HF model repo") + parser.add_argument("--source", choices=["base", "sft", "rl"], required=True, help="Checkpoint phase") + parser.add_argument("--model-tag", required=True, help="Local nanochat model tag") + parser.add_argument("--step", type=int, default=None, help="Optional specific step to upload") + parser.add_argument("--token-env", default="HF_TOKEN", help="Environment variable containing the HF token") + parser.add_argument("--private", type=int, default=0, help="Create the repo as private if it does not exist") + parser.add_argument("--repo-subdir", default="native_checkpoints", help="Subdirectory inside the repo") + args = parser.parse_args() + + token = os.environ.get(args.token_env) + if not token: + raise ValueError(f"Missing Hugging Face token in {args.token_env}") + + checkpoint_dir = resolve_checkpoint_dir(args.source, args.model_tag) + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}") + + api = HfApi(token=token) + api.create_repo(repo_id=args.repo_id, repo_type="model", private=bool(args.private), exist_ok=True) + + if args.step is None: + path_in_repo = f"{args.repo_subdir}/{args.source}/{args.model_tag}" + api.upload_folder( + folder_path=checkpoint_dir, + repo_id=args.repo_id, + repo_type="model", + path_in_repo=path_in_repo, + commit_message=f"Upload native {args.source} checkpoint folder for {args.model_tag}", + ) + print(f"Uploaded {checkpoint_dir} to {args.repo_id}:{path_in_repo}") + return + + step_str = f"{args.step:06d}" + files = [ + f"model_{step_str}.pt", + f"meta_{step_str}.json", + ] + optimizer_pattern = f"optim_{step_str}_" + for filename in sorted(os.listdir(checkpoint_dir)): + if filename.startswith(optimizer_pattern) and filename.endswith(".pt"): + files.append(filename) + + for filename in files: + local_path = os.path.join(checkpoint_dir, filename) + if not os.path.exists(local_path): + continue + path_in_repo = f"{args.repo_subdir}/{args.source}/{args.model_tag}/{filename}" + api.upload_file( + path_or_fileobj=local_path, + repo_id=args.repo_id, + repo_type="model", + path_in_repo=path_in_repo, + commit_message=f"Upload {args.source} checkpoint {args.model_tag} step {step_str}", + ) + print(f"Uploaded step {step_str} for {args.model_tag} to {args.repo_id}") + + +if __name__ == "__main__": + main() diff --git a/scripts/import_hf_checkpoint.py b/scripts/import_hf_checkpoint.py new file mode 100644 index 00000000..9f1d975f --- /dev/null +++ b/scripts/import_hf_checkpoint.py @@ -0,0 +1,162 @@ +""" +Import a Hugging Face model repo into nanochat's native checkpoint format. + +This is intended for base-model continuation before multi-stage nanochat runs. + +Examples: +python -m scripts.import_hf_checkpoint --repo-id ManmohanSharma/nanochat-d24 --model-tag d24_hf_import +python -m scripts.import_hf_checkpoint --repo-id ManmohanSharma/nanochat-d24 --local-dir /path/to/snapshot +""" + +import argparse +import os +from dataclasses import asdict + +import torch +from huggingface_hub import snapshot_download +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from nanochat.checkpoint_manager import save_checkpoint +from nanochat.common import get_base_dir +from nanochat.gpt import GPT, GPTConfig +from nanochat.tokenizer import get_tokenizer +from nanochat.tools import DEFAULT_TOOL_SCHEMA + + +def normalize_hf_state_dict_keys(state_dict): + normalized = {} + prefixes = ("_orig_mod.", "module.", "model.") + for key, value in state_dict.items(): + normalized_key = key + for prefix in prefixes: + if normalized_key.startswith(prefix): + normalized_key = normalized_key[len(prefix):] + normalized[normalized_key] = value + return normalized + + +def infer_gpt_config(hf_config): + kwargs = { + "sequence_len": getattr( + hf_config, + "sequence_len", + getattr(hf_config, "max_position_embeddings", getattr(hf_config, "n_positions", 2048)), + ), + "vocab_size": getattr(hf_config, "vocab_size"), + "n_layer": getattr( + hf_config, + "n_layer", + getattr(hf_config, "num_hidden_layers", getattr(hf_config, "num_layers")), + ), + "n_head": getattr( + hf_config, + "n_head", + getattr(hf_config, "num_attention_heads", None), + ), + "n_kv_head": getattr( + hf_config, + "n_kv_head", + getattr(hf_config, "num_key_value_heads", getattr(hf_config, "num_attention_heads", None)), + ), + "n_embd": getattr( + hf_config, + "n_embd", + getattr(hf_config, "hidden_size", getattr(hf_config, "d_model", None)), + ), + "window_pattern": getattr(hf_config, "window_pattern", "L"), + } + missing = [key for key, value in kwargs.items() if value is None] + if missing: + raise ValueError(f"Could not infer nanochat GPTConfig fields from HF config: {missing}") + return GPTConfig(**kwargs) + + +def verify_tokenizer_compatibility(hf_tokenizer, nanochat_tokenizer): + hf_vocab = hf_tokenizer.vocab_size + local_vocab = nanochat_tokenizer.get_vocab_size() + if hf_vocab != local_vocab: + raise ValueError( + f"Tokenizer vocab mismatch: HF repo has vocab_size={hf_vocab}, " + f"local nanochat tokenizer has vocab_size={local_vocab}" + ) + + +def load_hf_snapshot(repo_id, revision, token, local_dir): + if local_dir is not None: + return local_dir + return snapshot_download( + repo_id=repo_id, + revision=revision, + token=token, + repo_type="model", + ) + + +def main(): + parser = argparse.ArgumentParser(description="Import HF repo into native nanochat checkpoints") + parser.add_argument("--repo-id", required=True, help="HF model repo id, e.g. ManmohanSharma/nanochat-d24") + parser.add_argument("--revision", default=None, help="Optional HF revision") + parser.add_argument("--local-dir", default=None, help="Use an already-downloaded HF snapshot instead of downloading") + parser.add_argument("--token-env", default="HF_TOKEN", help="Environment variable containing the HF token") + parser.add_argument("--model-tag", default=None, help="Destination model tag. Defaults to repo name slug") + parser.add_argument("--step", type=int, default=0, help="Checkpoint step number to write") + parser.add_argument("--source", choices=["base", "sft", "rl"], default="base", help="Destination checkpoint phase") + parser.add_argument("--trust-remote-code", type=int, default=1, help="Pass trust_remote_code to Transformers loaders") + args = parser.parse_args() + + token = os.environ.get(args.token_env) + snapshot_path = load_hf_snapshot(args.repo_id, args.revision, token, args.local_dir) + trust_remote_code = bool(args.trust_remote_code) + + hf_config = AutoConfig.from_pretrained(snapshot_path, token=token, trust_remote_code=trust_remote_code) + hf_tokenizer = AutoTokenizer.from_pretrained(snapshot_path, token=token, trust_remote_code=trust_remote_code) + nanochat_tokenizer = get_tokenizer() + verify_tokenizer_compatibility(hf_tokenizer, nanochat_tokenizer) + + local_config = infer_gpt_config(hf_config) + with torch.no_grad(): + model = AutoModelForCausalLM.from_pretrained( + snapshot_path, + token=token, + trust_remote_code=trust_remote_code, + ) + hf_state_dict = normalize_hf_state_dict_keys(model.state_dict()) + + with torch.device("meta"): + local_model = GPT(local_config) + expected_keys = set(local_model.state_dict().keys()) + provided_keys = set(hf_state_dict.keys()) + missing = sorted(expected_keys - provided_keys) + extra = sorted(provided_keys - expected_keys) + if missing or extra: + message = [ + "HF checkpoint keys do not match native nanochat keys after normalization.", + f"Missing keys: {missing[:12]}", + f"Extra keys: {extra[:12]}", + ] + raise ValueError("\n".join(message)) + + model_data = {key: value.detach().cpu() for key, value in hf_state_dict.items()} + meta_data = { + "model_config": asdict(local_config), + "imported_from_hf": True, + "source_hf_repo": args.repo_id, + "source_hf_revision": args.revision, + "tool_schema": DEFAULT_TOOL_SCHEMA, + "tokenizer_vocab_size": nanochat_tokenizer.get_vocab_size(), + } + + model_tag = args.model_tag or args.repo_id.split("/")[-1].replace("-", "_") + base_dir = get_base_dir() + phase_dir = { + "base": "base_checkpoints", + "sft": "chatsft_checkpoints", + "rl": "chatrl_checkpoints", + }[args.source] + checkpoint_dir = os.path.join(base_dir, phase_dir, model_tag) + save_checkpoint(checkpoint_dir, args.step, model_data, optimizer_data=None, meta_data=meta_data, rank=0) + print(f"Imported {args.repo_id} into {checkpoint_dir} at step {args.step}") + + +if __name__ == "__main__": + main() diff --git a/scripts/verify_external_access.py b/scripts/verify_external_access.py new file mode 100644 index 00000000..d87e9ca5 --- /dev/null +++ b/scripts/verify_external_access.py @@ -0,0 +1,60 @@ +""" +Verify external service access from environment variables. + +Expected env vars: +- HF_TOKEN +- CLOUDFLARE_API_TOKEN +- CLOUDFLARE_ACCOUNT_ID +""" + +import argparse +import os + +import requests +from huggingface_hub import HfApi + + +def verify_hf(repo_id, token_env): + token = os.environ.get(token_env) + if not token: + return False, f"Missing {token_env}" + api = HfApi(token=token) + info = api.model_info(repo_id) + return True, f"HF access OK: {info.id}" + + +def verify_cloudflare(token_env): + token = os.environ.get(token_env) + if not token: + return False, f"Missing {token_env}" + response = requests.get( + "https://api.cloudflare.com/client/v4/user/tokens/verify", + headers={"Authorization": f"Bearer {token}"}, + timeout=20, + ) + response.raise_for_status() + payload = response.json() + if not payload.get("success", False): + return False, f"Cloudflare verify failed: {payload.get('errors')}" + status = payload.get("result", {}).get("status") + return True, f"Cloudflare token OK: status={status}" + + +def main(): + parser = argparse.ArgumentParser(description="Verify Hugging Face and Cloudflare access from env vars") + parser.add_argument("--hf-repo-id", default="ManmohanSharma/nanochat-d24", help="HF repo id to verify") + parser.add_argument("--hf-token-env", default="HF_TOKEN", help="HF token env var") + parser.add_argument("--cloudflare-token-env", default="CLOUDFLARE_API_TOKEN", help="Cloudflare token env var") + args = parser.parse_args() + + hf_ok, hf_message = verify_hf(args.hf_repo_id, args.hf_token_env) + cf_ok, cf_message = verify_cloudflare(args.cloudflare_token_env) + + print(hf_message) + print(cf_message) + if not (hf_ok and cf_ok): + raise SystemExit(1) + + +if __name__ == "__main__": + main() diff --git a/seed_data/tool_eval_seed.jsonl b/seed_data/tool_eval_seed.jsonl new file mode 100644 index 00000000..de4fe06a --- /dev/null +++ b/seed_data/tool_eval_seed.jsonl @@ -0,0 +1,3 @@ +{"conversation": {"messages": [{"role": "user", "content": "What is 12% of 250?"}, {"role": "assistant", "content": "45"}]}, "checks": {"must_call": "calculator", "must_not_call": ["web_search"], "answer_contains": ["45"]}} +{"conversation": {"messages": [{"role": "user", "content": "Find the official nanochat repository and cite the link."}, {"role": "assistant", "content": "https://github.com/karpathy/nanochat"}]}, "checks": {"must_call": "web_search", "citation_required": true, "answer_contains": ["github.com/karpathy/nanochat"]}} +{"conversation": {"messages": [{"role": "user", "content": "Explain what overfitting means in one sentence."}, {"role": "assistant", "content": "Overfitting means a model memorizes training patterns too closely and generalizes poorly to new data."}]}, "checks": {"must_not_call": ["web_search"], "answer_contains": ["generalizes poorly"]}} diff --git a/seed_data/tool_sft_seed.jsonl b/seed_data/tool_sft_seed.jsonl new file mode 100644 index 00000000..2a0e64bc --- /dev/null +++ b/seed_data/tool_sft_seed.jsonl @@ -0,0 +1,6 @@ +[{"role": "user", "content": "Explain what gradient descent is in one short paragraph."}, {"role": "assistant", "content": "Gradient descent is an optimization method that repeatedly nudges model parameters in the direction that reduces error, using the gradient to decide both direction and magnitude of each update."}] +[{"role": "user", "content": "What is 18% of 250?"}, {"role": "assistant", "content": [{"type": "text", "text": "Let me compute that exactly. "}, {"type": "tool_call", "tool_name": "calculator", "arguments": {"expression": "percent(250,18)"}}, {"type": "tool_result", "tool_name": "calculator", "output": {"expression": "percent(250,18)", "value": 45.0}}, {"type": "text", "text": "18% of 250 is 45."}]}] +[{"role": "user", "content": "If revenue grew from 120 to 150, what is the percent change?"}, {"role": "assistant", "content": [{"type": "text", "text": "Let me compute that exactly. "}, {"type": "tool_call", "tool_name": "calculator", "arguments": {"expression": "percent_change(120,150)"}}, {"type": "tool_result", "tool_name": "calculator", "output": {"expression": "percent_change(120,150)", "value": 25.0}}, {"type": "text", "text": "The percent change from 120 to 150 is 25%."}]}] +[{"role": "user", "content": "What is the monthly EMI for a 500000 loan at 8% annual interest over 240 months?"}, {"role": "assistant", "content": [{"type": "text", "text": "Let me compute that exactly. "}, {"type": "tool_call", "tool_name": "calculator", "arguments": {"expression": "emi(500000,8,240)"}}, {"type": "tool_result", "tool_name": "calculator", "output": {"expression": "emi(500000,8,240)", "value": 4182.198594391402}}, {"type": "text", "text": "The monthly EMI is about 4182.2."}]}] +[{"role": "user", "content": "Find the official nanochat repository and give me the link."}, {"role": "assistant", "content": [{"type": "text", "text": "I should verify this with a fresh web lookup. "}, {"type": "tool_call", "tool_name": "web_search", "arguments": {"query": "official nanochat repository", "top_k": 1}}, {"type": "tool_result", "tool_name": "web_search", "output": {"query": "official nanochat repository", "results": [{"url": "https://github.com/karpathy/nanochat", "title": "Source result", "snippet": "Fresh page content retrieved for grounding."}]}}, {"type": "text", "text": "The official nanochat repository is https://github.com/karpathy/nanochat"}]}] +[{"role": "user", "content": "What do the Cloudflare Browser Rendering docs say about the markdown endpoint?"}, {"role": "assistant", "content": [{"type": "text", "text": "I should verify this with a fresh web lookup. "}, {"type": "tool_call", "tool_name": "web_search", "arguments": {"query": "Cloudflare Browser Rendering markdown endpoint docs", "top_k": 1}}, {"type": "tool_result", "tool_name": "web_search", "output": {"query": "Cloudflare Browser Rendering markdown endpoint docs", "results": [{"url": "https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/", "title": "Source result", "snippet": "Fresh page content retrieved for grounding."}]}}, {"type": "text", "text": "Cloudflare documents a markdown endpoint for extracting rendered page content: https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/"}]}] diff --git a/tasks/tool_json.py b/tasks/tool_json.py new file mode 100644 index 00000000..edbc7f53 --- /dev/null +++ b/tasks/tool_json.py @@ -0,0 +1,101 @@ +""" +Local JSONL task for tool-use evaluation and lightweight RL reward shaping. + +Each line should be a JSON object with: +{ + "conversation": {"messages": [...]}, + "checks": { + "must_call": "calculator", + "must_not_call": ["web_search"], + "answer_contains": ["42"], + "citation_required": false + } +} +""" + +import json +import os +import re + +from nanochat.tools import TOOL_CALL_END, TOOL_CALL_START, parse_tool_call_payload +from tasks.common import Task + + +TOOL_BLOCK_RE = re.compile(re.escape(TOOL_CALL_START) + r"(.*?)" + re.escape(TOOL_CALL_END), re.DOTALL) + + +class ToolJSON(Task): + def __init__(self, filepath, split="eval", **kwargs): + super().__init__(**kwargs) + self.filepath = filepath + self.split = split + self.rows = [] + if not os.path.exists(filepath): + raise FileNotFoundError(f"Tool JSONL dataset not found: {filepath}") + with open(filepath, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + row = json.loads(line) + if "conversation" not in row: + raise ValueError(f"Row missing conversation field: {row}") + row.setdefault("checks", {}) + self.rows.append(row) + + @property + def eval_type(self): + return "generative" + + def num_examples(self): + return len(self.rows) + + def get_example(self, index): + row = self.rows[index] + conversation = dict(row["conversation"]) + conversation["checks"] = row.get("checks", {}) + return conversation + + def _tool_calls(self, assistant_response): + calls = [] + for payload in TOOL_BLOCK_RE.findall(assistant_response): + invocation = parse_tool_call_payload(payload) + calls.append(invocation.tool_name) + return calls + + def evaluate(self, conversation, assistant_response): + checks = conversation.get("checks", {}) + score = self.reward(conversation, assistant_response) + return int(score >= 0.999) + + def reward(self, conversation, assistant_response): + checks = conversation.get("checks", {}) + total = 0.0 + passed = 0.0 + tool_calls = self._tool_calls(assistant_response) + + must_call = checks.get("must_call") + if must_call: + total += 1.0 + passed += float(must_call in tool_calls) + + for tool_name in checks.get("must_not_call", []): + total += 1.0 + passed += float(tool_name not in tool_calls) + + for needle in checks.get("answer_contains", []): + total += 1.0 + passed += float(needle in assistant_response) + + answer_regex = checks.get("answer_regex") + if answer_regex: + total += 1.0 + passed += float(re.search(answer_regex, assistant_response) is not None) + + if checks.get("citation_required", False): + total += 1.0 + passed += float(("http://" in assistant_response) or ("https://" in assistant_response)) + + if total == 0.0: + return 0.0 + return passed / total diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 00000000..9a460b01 --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,133 @@ +import json + +import torch + +from nanochat.engine import Engine +from nanochat.tools import CalculatorTool, MockSearchBackend, ToolRegistry, WebSearchTool, parse_tool_call_payload, serialize_tool_call +from tasks.tool_json import ToolJSON + + +class MockConfig: + n_kv_head = 4 + n_head = 4 + n_embd = 64 + n_layer = 2 + sequence_len = 128 + + +class ByteTokenizer: + def __init__(self): + self._special_tokens = { + "<|python_start|>": 256, + "<|python_end|>": 257, + "<|output_start|>": 258, + "<|output_end|>": 259, + "<|assistant_end|>": 260, + "<|bos|>": 261, + } + self._bos = 261 + + def encode_special(self, s): + return self._special_tokens[s] + + def get_bos_token_id(self): + return self._bos + + def encode(self, s, prepend=None): + tokens = list(s.encode("utf-8")) + if prepend is not None: + tokens = [prepend] + tokens + return tokens + + def decode(self, tokens): + byte_tokens = [t for t in tokens if t < 256] + return bytes(byte_tokens).decode("utf-8", errors="replace") + + +class SequencedModel: + def __init__(self, sequence, vocab_size=262): + self.sequence = sequence + self.vocab_size = vocab_size + self.config = MockConfig() + self._device = torch.device("cpu") + self.call_index = 0 + + def get_device(self): + return self._device + + def forward(self, ids, kv_cache=None): + if kv_cache is not None: + kv_cache.advance(ids.shape[1]) + target = self.sequence[min(self.call_index, len(self.sequence) - 1)] + logits = torch.full((ids.shape[0], ids.shape[1], self.vocab_size), -1e9) + logits[:, -1, target] = 0.0 + self.call_index += 1 + return logits + + +def test_parse_tool_call_payload_supports_json_and_legacy(): + legacy = parse_tool_call_payload("6*7") + assert legacy.tool_name == "calculator" + assert legacy.arguments == {"expression": "6*7"} + + payload = parse_tool_call_payload(serialize_tool_call("calculator", {"expression": "6*7"})) + assert payload.tool_name == "calculator" + assert payload.arguments == {"expression": "6*7"} + + +def test_calculator_tool_supports_scientific_and_business_functions(): + tool = CalculatorTool() + trig = tool.run({"expression": "round(sin(pi/2), 4)"}) + assert trig.success + assert trig.output["value"] == 1.0 + + emi = tool.run({"expression": "round(emi(500000,8,240), 2)"}) + assert emi.success + assert emi.output["value"] == 4182.2 + + +def test_web_search_tool_returns_mock_results(): + tool = WebSearchTool(search_backend=MockSearchBackend()) + result = tool.run({"query": "nanochat gpt2 speedrun", "top_k": 1}) + assert result.success + assert result.output["results"][0]["url"] == "https://github.com/karpathy/nanochat" + + +def test_engine_executes_json_calculator_tool_call(): + tokenizer = ByteTokenizer() + payload = serialize_tool_call("calculator", {"expression": "6*7"}) + sequence = [ + tokenizer.encode_special("<|python_start|>"), + *tokenizer.encode(payload), + tokenizer.encode_special("<|python_end|>"), + tokenizer.encode_special("<|assistant_end|>"), + ] + model = SequencedModel(sequence) + registry = ToolRegistry([CalculatorTool()]) + engine = Engine(model, tokenizer, tools=registry) + prompt = [tokenizer.get_bos_token_id(), 72] + results, _ = engine.generate_batch(prompt, num_samples=1, max_tokens=128, temperature=0.0) + completion = tokenizer.decode(results[0][len(prompt):]) + assert "42" in completion + + +def test_tool_json_reward_and_eval(tmp_path): + path = tmp_path / "tool_eval.jsonl" + row = { + "conversation": { + "messages": [ + {"role": "user", "content": "What is 6 times 7?"}, + {"role": "assistant", "content": "42"}, + ] + }, + "checks": { + "must_call": "calculator", + "answer_contains": ["42"], + }, + } + path.write_text(json.dumps(row) + "\n", encoding="utf-8") + task = ToolJSON(filepath=str(path)) + conversation = task[0] + response = '<|python_start|>{"tool":"calculator","arguments":{"expression":"6*7"}}<|python_end|>42' + assert task.reward(conversation, response) == 1.0 + assert task.evaluate(conversation, response) == 1