Merge pull request #1 from manmohan659/codex/pre-gpu-readiness

Add pre-GPU tool training and checkpoint plumbing
This commit is contained in:
Manmohan 2026-03-24 20:53:39 -04:00 committed by GitHub
commit baef0a3d66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 1714 additions and 81 deletions

97
docs/pre_gpu_runbook.md Normal file
View File

@ -0,0 +1,97 @@
# Pre-GPU Runbook
This runbook is the minimum operational checklist before spending GPU time.
## 1. Local Prep
1. Build the seed tool datasets:
```bash
python -m scripts.build_tool_datasets
```
2. Import the starting checkpoint from Hugging Face into native nanochat format:
```bash
python -m scripts.import_hf_checkpoint \
--repo-id ManmohanSharma/nanochat-d24 \
--model-tag d24_hf_import
```
3. Validate tool tokenization and mock tool execution with local tests:
```bash
python -m pytest tests/test_engine.py tests/test_tools.py -v
```
4. Dry-run tool evaluation on CPU:
```bash
python -m scripts.chat_eval \
-i sft \
-a ToolJSON \
--tool-jsonl seed_data/tool_eval_seed.jsonl \
--device-type cpu \
-x 3
```
## 2. 48-Hour GPU Schedule
1. Pilot CPT
- Run a short continuation test from the imported base checkpoint.
- Confirm loss is moving, checkpoint save works, and HF sync works.
2. Full CPT
- Run the main continuation stage on ClimbMix backbone.
- Save staged checkpoints at planned intervals.
3. SFT
- Include the local tool SFT JSONL via `--extra-train-jsonl`.
- Validate that calculator/web_search traces render correctly.
4. RL / tool tuning
- Keep this stage narrow and short.
- Focus on tool-choice correctness and grounded answers.
5. Eval
- Run ARC, MMLU, GSM8K, HumanEval, and ToolJSON checks.
- Do not ship if tool behavior regresses or citations are missing.
## 3. Checkpoint Upload Cadence
Upload every stage boundary and any explicit resume point:
```bash
python -m scripts.hf_sync_checkpoint \
--repo-id ManmohanSharma/nanochat-d24 \
--source base \
--model-tag d24_hf_import \
--step 0
```
If a whole checkpoint directory should be mirrored:
```bash
python -m scripts.hf_sync_checkpoint \
--repo-id ManmohanSharma/nanochat-d24 \
--source base \
--model-tag d24_hf_import
```
## 4. Go / No-Go
Go only if:
- HF import works.
- HF sync works.
- Mock tool execution works.
- Tool seed datasets are generated.
- Tool eval runs locally.
- The search backend plan is explicit: search provider plus Cloudflare fetch/crawl.
No-Go if:
- Any tokenizer mismatch appears during HF import.
- Tool blocks fail to render.
- `web_search` still has no backend plan beyond fetch-only Cloudflare Browser Rendering.
- Local tool eval is missing or failing.

View File

@ -0,0 +1,31 @@
{
"name": "cpt_recipe_v1",
"goal": "Continue pretraining nanochat-d24 on a safe, high-value mixture for a 48-hour run.",
"backbone_dataset": {
"repo_id": "karpathy/climbmix-400b-shuffle",
"role": "primary",
"notes": "Use ClimbMix as the main pretraining backbone because it is already the best-performing dataset in this fork."
},
"optional_augmentations": [
{
"repo_id": "nvidia/Nemotron-Pretraining-Code-v2",
"role": "code",
"notes": "Selective code augmentation only."
},
{
"repo_id": "nvidia/Nemotron-CC-Math-v1",
"role": "math_science",
"notes": "Selective math/science augmentation only."
},
{
"repo_id": "nvidia/Nemotron-Pretraining-Specialized-v1.1",
"role": "specialized_web",
"notes": "Use sparingly for targeted specialization."
}
],
"do_not_do": [
"Do not replace ClimbMix wholesale.",
"Do not attempt architecture changes during this 48-hour continuation run.",
"Do not rely on continued pretraining for current-events freshness."
]
}

View File

@ -0,0 +1,26 @@
{
"name": "rl_recipe_v1",
"goal": "Lightweight tool-tuning stage after SFT.",
"seed_eval_data": "seed_data/tool_eval_seed.jsonl",
"reward_components": [
"valid tool call",
"correct tool choice",
"answer contains expected answer fragment",
"citation required when using web_search",
"penalty for unnecessary tool calls"
],
"external_templates": [
{
"repo_id": "nvidia/Nemotron-RL-Agentic-Conversational-Tool-Use-Pivot-v1",
"role": "tool_use_structure"
},
{
"repo_id": "nvidia/Nemotron-RL-Agentic-Function-Calling-Pivot-v1",
"role": "function_call_structure"
}
],
"notes": [
"Keep RL narrow; do not turn this into broad RLHF.",
"Use the local tool eval dataset for dry runs before GPU."
]
}

View File

@ -0,0 +1,32 @@
{
"name": "sft_recipe_v1",
"goal": "Teach instruction following, tool routing, grounded web answers, and calculator usage.",
"base_mixture": [
{
"source": "existing_nanochat",
"notes": "Keep the existing chat mixture in scripts/chat_sft.py as the base."
},
{
"source": "local_seed_data",
"path": "seed_data/tool_sft_seed.jsonl",
"notes": "Seed examples for direct answers, calculator, and web_search."
}
],
"external_templates": [
{
"repo_id": "nvidia/Nemotron-Instruction-Following-Chat-v1",
"role": "instruction_following"
},
{
"collection": "nvidia/nemotron-post-training-v3",
"role": "agentic_templates"
}
],
"target_behaviors": [
"Answer directly when no tool is needed.",
"Use calculator for exact numeric computation.",
"Use web_search for recency and external facts.",
"Cite URLs when the response depends on retrieved web content.",
"Avoid unnecessary tool calls."
]
}

View File

@ -13,70 +13,10 @@ The whole thing is made as efficient as possible.
import torch
import torch.nn.functional as F
import signal
import warnings
from contextlib import contextmanager
from collections import deque
from nanochat.common import compute_init, autodetect_device_type, COMPUTE_DTYPE
from nanochat.checkpoint_manager import load_model
# -----------------------------------------------------------------------------
# Calculator tool helpers
@contextmanager
def timeout(duration, formula):
def timeout_handler(signum, frame):
raise Exception(f"'{formula}': timed out after {duration} seconds")
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(duration)
yield
signal.alarm(0)
def eval_with_timeout(formula, max_time=3):
try:
with timeout(max_time, formula):
with warnings.catch_warnings():
warnings.simplefilter("ignore", SyntaxWarning)
return eval(formula, {"__builtins__": {}}, {})
except Exception as e:
signal.alarm(0)
# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
return None
def use_calculator(expr):
"""
Evaluate a Python expression safely.
Supports both math expressions and string operations like .count()
"""
# Remove commas from numbers
expr = expr.replace(",", "")
# Check if it's a pure math expression (old behavior)
if all([x in "0123456789*+-/.() " for x in expr]):
if "**" in expr: # disallow power operator
return None
return eval_with_timeout(expr)
# Check if it's a string operation we support
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
if not all([x in allowed_chars for x in expr]):
return None
# Disallow dangerous patterns
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
'getattr', 'setattr', 'delattr', 'hasattr']
expr_lower = expr.lower()
if any(pattern in expr_lower for pattern in dangerous_patterns):
return None
# Only allow .count() method for now (can expand later)
if '.count(' not in expr:
return None
# Evaluate with timeout
return eval_with_timeout(expr)
from nanochat.tools import build_default_tool_registry, parse_tool_call_payload
# -----------------------------------------------------------------------------
class KVCache:
@ -162,15 +102,16 @@ class RowState:
def __init__(self, current_tokens=None):
self.current_tokens = current_tokens or [] # Current token sequence for this row
self.forced_tokens = deque() # Queue of tokens to force inject
self.in_python_block = False # Whether we are inside a python block
self.python_expr_tokens = [] # Tokens of the current python expression
self.in_tool_block = False # Whether we are inside a tool-call block
self.tool_payload_tokens = [] # Tokens of the current tool-call payload
self.completed = False # Whether this row has completed generation
class Engine:
def __init__(self, model, tokenizer):
def __init__(self, model, tokenizer, tools=None):
self.model = model
self.tokenizer = tokenizer # needed for tool use
self.tools = tools or build_default_tool_registry()
@torch.inference_mode()
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
@ -255,21 +196,34 @@ class Engine:
state.completed = True
# Handle tool logic
if next_token == python_start:
state.in_python_block = True
state.python_expr_tokens = []
elif next_token == python_end and state.in_python_block:
state.in_python_block = False
if state.python_expr_tokens:
expr = self.tokenizer.decode(state.python_expr_tokens)
result = use_calculator(expr)
if result is not None:
result_tokens = self.tokenizer.encode(str(result))
state.in_tool_block = True
state.tool_payload_tokens = []
elif next_token == python_end and state.in_tool_block:
state.in_tool_block = False
if state.tool_payload_tokens:
payload_text = self.tokenizer.decode(state.tool_payload_tokens)
invocation = parse_tool_call_payload(payload_text)
result = self.tools.execute(invocation.tool_name, invocation.arguments)
legacy_calculator_call = (
invocation.tool_name == "calculator"
and invocation.raw_text.strip()
and not invocation.raw_text.strip().startswith("{")
)
if legacy_calculator_call:
if result.success and isinstance(result.output, dict) and "value" in result.output:
result_text = str(result.output["value"])
else:
result_text = result.error or ""
else:
result_text = result.to_payload()[:4096]
if result_text:
result_tokens = self.tokenizer.encode(result_text)
state.forced_tokens.append(output_start)
state.forced_tokens.extend(result_tokens)
state.forced_tokens.append(output_end)
state.python_expr_tokens = []
elif state.in_python_block:
state.python_expr_tokens.append(next_token)
state.tool_payload_tokens = []
elif state.in_tool_block:
state.tool_payload_tokens.append(next_token)
# Yield the token column
yield token_column, token_masks

View File

@ -270,6 +270,7 @@ class RustBPETokenizer:
- ids: list[int] is a list of token ids of this rendered conversation
- mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on.
"""
from nanochat.tools import serialize_tool_call, serialize_tool_result
# ids, masks that we will return and a helper function to help build them up.
ids, mask = [], []
def add_tokens(token_ids, mask_val):
@ -327,14 +328,29 @@ class RustBPETokenizer:
if part["type"] == "text":
# string part => simply add the tokens
add_tokens(value_ids, 1)
elif part["type"] == "python":
elif part["type"] in {"python", "tool_call"}:
# python tool call => add the tokens inside <|python_start|> and <|python_end|>
if part["type"] == "tool_call":
payload = serialize_tool_call(
tool_name=part["tool_name"],
arguments=part.get("arguments") or {},
)
value_ids = self.encode(payload)
add_tokens(python_start, 1)
add_tokens(value_ids, 1)
add_tokens(python_end, 1)
elif part["type"] == "python_output":
elif part["type"] in {"python_output", "tool_result"}:
# python output => add the tokens inside <|output_start|> and <|output_end|>
# none of these tokens are supervised because the tokens come from Python at test time
if part["type"] == "tool_result":
payload = serialize_tool_result(
tool_name=part["tool_name"],
output=part.get("output"),
success=part.get("success", True),
error=part.get("error"),
metadata=part.get("metadata"),
)
value_ids = self.encode(payload)
add_tokens(output_start, 0)
add_tokens(value_ids, 0)
add_tokens(output_end, 0)

509
nanochat/tools.py Normal file
View File

@ -0,0 +1,509 @@
"""
Shared tool definitions for nanochat.
The current tokenizer only has python/output special tokens. To preserve
checkpoint compatibility, we reuse those tokens as a generic tool-call and
tool-result channel. Legacy "python" calculator payloads still work.
"""
from __future__ import annotations
import ast
import json
import math
import os
import time
from dataclasses import dataclass, field
from typing import Any, Protocol
import requests
TOOL_CALL_START = "<|python_start|>"
TOOL_CALL_END = "<|python_end|>"
TOOL_RESULT_START = "<|output_start|>"
TOOL_RESULT_END = "<|output_end|>"
MAX_TOOL_PAYLOAD_CHARS = 4096
DEFAULT_TOOL_SCHEMA = [
{
"name": "calculator",
"description": "Deterministic scientific calculator for exact arithmetic and common finance formulas.",
"arguments": {
"expression": "String expression using numbers, operators, and supported functions.",
},
},
{
"name": "web_search",
"description": "Search and fetch web content. Requires a search backend and optionally a page fetch client.",
"arguments": {
"query": "Search query string.",
"top_k": "Maximum number of results to return.",
"urls": "Optional explicit URLs to fetch instead of searching.",
},
},
]
def _compact_json(data: Any) -> str:
return json.dumps(data, ensure_ascii=True, separators=(",", ":"), sort_keys=True)
@dataclass
class ToolInvocation:
tool_name: str
arguments: dict[str, Any]
raw_text: str = ""
@dataclass
class ToolResult:
tool_name: str
success: bool
output: Any = None
error: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
def to_payload(self) -> str:
return _compact_json(
{
"tool": self.tool_name,
"success": self.success,
"output": self.output,
"error": self.error,
"metadata": self.metadata,
}
)
class BaseTool:
name: str
def run(self, arguments: dict[str, Any]) -> ToolResult:
raise NotImplementedError
class ToolRegistry:
def __init__(self, tools: list[BaseTool] | tuple[BaseTool, ...]):
self._tools = {tool.name: tool for tool in tools}
def execute(self, tool_name: str, arguments: dict[str, Any]) -> ToolResult:
tool = self._tools.get(tool_name)
if tool is None:
return ToolResult(tool_name=tool_name, success=False, error=f"Unknown tool: {tool_name}")
try:
return tool.run(arguments)
except Exception as exc: # defensive: tool failures should become model-visible outputs
return ToolResult(tool_name=tool_name, success=False, error=str(exc))
def schema(self) -> list[dict[str, Any]]:
return [item for item in DEFAULT_TOOL_SCHEMA if item["name"] in self._tools]
def serialize_tool_call(tool_name: str, arguments: dict[str, Any] | None = None) -> str:
payload = {
"tool": tool_name,
"arguments": arguments or {},
}
text = _compact_json(payload)
return text[:MAX_TOOL_PAYLOAD_CHARS]
def serialize_tool_result(
tool_name: str,
output: Any = None,
*,
success: bool = True,
error: str | None = None,
metadata: dict[str, Any] | None = None,
) -> str:
return ToolResult(
tool_name=tool_name,
success=success,
output=output,
error=error,
metadata=metadata or {},
).to_payload()[:MAX_TOOL_PAYLOAD_CHARS]
def parse_tool_call_payload(text: str) -> ToolInvocation:
stripped = text.strip()
if not stripped:
return ToolInvocation(tool_name="calculator", arguments={"expression": ""}, raw_text=text)
try:
payload = json.loads(stripped)
except json.JSONDecodeError:
return ToolInvocation(tool_name="calculator", arguments={"expression": stripped}, raw_text=text)
if isinstance(payload, dict):
tool_name = payload.get("tool") or payload.get("tool_name") or payload.get("name")
arguments = payload.get("arguments") or payload.get("args") or {}
if isinstance(tool_name, str) and isinstance(arguments, dict):
return ToolInvocation(tool_name=tool_name, arguments=arguments, raw_text=text)
return ToolInvocation(tool_name="calculator", arguments={"expression": stripped}, raw_text=text)
def parse_tool_result_payload(text: str) -> ToolResult | None:
stripped = text.strip()
try:
payload = json.loads(stripped)
except json.JSONDecodeError:
return None
if not isinstance(payload, dict):
return None
tool_name = payload.get("tool")
if not isinstance(tool_name, str):
return None
return ToolResult(
tool_name=tool_name,
success=bool(payload.get("success", True)),
output=payload.get("output"),
error=payload.get("error"),
metadata=payload.get("metadata") or {},
)
def _percent(value: float, rate: float) -> float:
return value * rate / 100.0
def _percent_change(old: float, new: float) -> float:
if old == 0:
raise ValueError("percent_change old value cannot be zero")
return ((new - old) / old) * 100.0
def _cagr(start: float, end: float, years: float) -> float:
if start <= 0 or end <= 0 or years <= 0:
raise ValueError("cagr inputs must be positive")
return ((end / start) ** (1.0 / years) - 1.0) * 100.0
def _simple_interest(principal: float, annual_rate: float, years: float) -> float:
return principal * annual_rate / 100.0 * years
def _compound_interest(principal: float, annual_rate: float, periods_per_year: float, years: float) -> float:
if periods_per_year <= 0:
raise ValueError("periods_per_year must be positive")
return principal * (1.0 + annual_rate / 100.0 / periods_per_year) ** (periods_per_year * years)
def _emi(principal: float, annual_rate: float, months: float) -> float:
if months <= 0:
raise ValueError("months must be positive")
monthly_rate = annual_rate / 100.0 / 12.0
if monthly_rate == 0:
return principal / months
growth = (1.0 + monthly_rate) ** months
return principal * monthly_rate * growth / (growth - 1.0)
ALLOWED_BINOPS = {
ast.Add: lambda a, b: a + b,
ast.Sub: lambda a, b: a - b,
ast.Mult: lambda a, b: a * b,
ast.Div: lambda a, b: a / b,
ast.Pow: lambda a, b: a ** b,
ast.Mod: lambda a, b: a % b,
}
ALLOWED_UNARYOPS = {
ast.UAdd: lambda a: a,
ast.USub: lambda a: -a,
}
ALLOWED_NAMES = {
"pi": math.pi,
"e": math.e,
"tau": math.tau,
}
ALLOWED_FUNCTIONS = {
"abs": abs,
"round": round,
"floor": math.floor,
"ceil": math.ceil,
"sqrt": math.sqrt,
"log": math.log,
"log10": math.log10,
"exp": math.exp,
"sin": math.sin,
"cos": math.cos,
"tan": math.tan,
"asin": math.asin,
"acos": math.acos,
"atan": math.atan,
"degrees": math.degrees,
"radians": math.radians,
"percent": _percent,
"percent_change": _percent_change,
"cagr": _cagr,
"simple_interest": _simple_interest,
"compound_interest": _compound_interest,
"emi": _emi,
}
class _SafeMathEvaluator:
def __init__(self, expression: str):
self.expression = expression
self.node_count = 0
def eval(self) -> float:
if len(self.expression) > 512:
raise ValueError("expression too long")
parsed = ast.parse(self.expression, mode="eval")
return self._visit(parsed.body)
def _visit(self, node: ast.AST) -> Any:
self.node_count += 1
if self.node_count > 128:
raise ValueError("expression too complex")
if isinstance(node, ast.Constant):
if isinstance(node.value, (int, float)):
return node.value
raise ValueError(f"unsupported constant: {node.value!r}")
if isinstance(node, ast.Num): # pragma: no cover - py<3.8 compatibility
return node.n
if isinstance(node, ast.BinOp):
op = ALLOWED_BINOPS.get(type(node.op))
if op is None:
raise ValueError(f"unsupported operator: {type(node.op).__name__}")
return op(self._visit(node.left), self._visit(node.right))
if isinstance(node, ast.UnaryOp):
op = ALLOWED_UNARYOPS.get(type(node.op))
if op is None:
raise ValueError(f"unsupported unary operator: {type(node.op).__name__}")
return op(self._visit(node.operand))
if isinstance(node, ast.Name):
if node.id not in ALLOWED_NAMES:
raise ValueError(f"unknown symbol: {node.id}")
return ALLOWED_NAMES[node.id]
if isinstance(node, ast.Call):
if not isinstance(node.func, ast.Name):
raise ValueError("only direct function calls are allowed")
fn = ALLOWED_FUNCTIONS.get(node.func.id)
if fn is None:
raise ValueError(f"unsupported function: {node.func.id}")
if node.keywords:
raise ValueError("keyword arguments are not supported")
args = [self._visit(arg) for arg in node.args]
return fn(*args)
raise ValueError(f"unsupported expression node: {type(node).__name__}")
def _normalize_numeric_output(value: Any) -> Any:
if isinstance(value, float):
if not math.isfinite(value):
raise ValueError("result is not finite")
return float(f"{value:.12g}")
return value
class CalculatorTool(BaseTool):
name = "calculator"
def run(self, arguments: dict[str, Any]) -> ToolResult:
expression = str(arguments.get("expression", "")).strip()
if not expression:
return ToolResult(tool_name=self.name, success=False, error="Missing expression")
value = _SafeMathEvaluator(expression).eval()
return ToolResult(
tool_name=self.name,
success=True,
output={"expression": expression, "value": _normalize_numeric_output(value)},
)
@dataclass
class SearchHit:
url: str
title: str = ""
snippet: str = ""
class SearchBackend(Protocol):
def search(self, query: str, top_k: int) -> list[SearchHit]:
raise NotImplementedError
class MockSearchBackend:
def __init__(self, canned_results: dict[str, list[dict[str, str]]] | None = None):
self.canned_results = canned_results or {
"browser rendering markdown": [
{
"url": "https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/",
"title": "Cloudflare markdown endpoint",
"snippet": "Extract markdown from a webpage using Cloudflare Browser Rendering.",
}
],
"nanochat gpt2 speedrun": [
{
"url": "https://github.com/karpathy/nanochat",
"title": "karpathy/nanochat",
"snippet": "Minimal LLM training harness with pretraining, SFT, RL, and chat UI.",
}
],
}
def search(self, query: str, top_k: int) -> list[SearchHit]:
normalized = query.strip().lower()
rows = self.canned_results.get(normalized, [])
return [SearchHit(**row) for row in rows[:top_k]]
class CloudflareBrowserRenderingClient:
def __init__(
self,
*,
api_token: str | None = None,
account_id: str | None = None,
base_url: str = "https://api.cloudflare.com/client/v4",
timeout: float = 30.0,
max_retries: int = 3,
):
self.api_token = api_token or os.environ.get("CLOUDFLARE_API_TOKEN")
self.account_id = account_id or os.environ.get("CLOUDFLARE_ACCOUNT_ID")
if not self.api_token or not self.account_id:
raise ValueError("Cloudflare Browser Rendering requires CLOUDFLARE_API_TOKEN and CLOUDFLARE_ACCOUNT_ID")
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.max_retries = max_retries
self.session = requests.Session()
self.session.headers.update(
{
"Authorization": f"Bearer {self.api_token}",
"Content-Type": "application/json",
}
)
def _post(self, endpoint: str, body: dict[str, Any]) -> Any:
url = f"{self.base_url}/accounts/{self.account_id}/browser-rendering/{endpoint}"
last_error = None
for attempt in range(1, self.max_retries + 1):
response = self.session.post(url, json=body, timeout=self.timeout)
if response.status_code == 429:
retry_after = response.headers.get("Retry-After")
sleep_seconds = float(retry_after) if retry_after else float(attempt)
last_error = RuntimeError(f"Cloudflare Browser Rendering rate limited on {endpoint}")
time.sleep(min(sleep_seconds, 5.0))
continue
response.raise_for_status()
payload = response.json()
if not payload.get("success", False):
errors = payload.get("errors", [])
last_error = RuntimeError(f"Cloudflare Browser Rendering request failed: {errors}")
break
return payload.get("result")
if last_error is not None:
raise last_error
raise RuntimeError(f"Cloudflare Browser Rendering request failed for {endpoint}")
def markdown(self, url: str, **options: Any) -> str:
body = {"url": url}
body.update(options)
return self._post("markdown", body)
def links(self, url: str, **options: Any) -> list[str]:
body = {"url": url}
body.update(options)
return self._post("links", body)
def json_extract(self, url: str, *, prompt: str | None = None, schema: dict[str, Any] | None = None, **options: Any) -> dict[str, Any]:
body: dict[str, Any] = {"url": url}
if prompt is not None:
body["prompt"] = prompt
if schema is not None:
body["schema"] = schema
body.update(options)
return self._post("json", body)
class WebSearchTool(BaseTool):
name = "web_search"
def __init__(
self,
*,
search_backend: SearchBackend | None = None,
fetch_client: CloudflareBrowserRenderingClient | None = None,
max_results: int = 3,
):
self.search_backend = search_backend
self.fetch_client = fetch_client
self.max_results = max_results
def run(self, arguments: dict[str, Any]) -> ToolResult:
query = str(arguments.get("query", "")).strip()
requested_urls = arguments.get("urls") or []
if isinstance(requested_urls, str):
requested_urls = [requested_urls]
top_k = int(arguments.get("top_k", self.max_results) or self.max_results)
top_k = max(1, min(top_k, 8))
if not query and not requested_urls:
return ToolResult(tool_name=self.name, success=False, error="Missing query or urls")
hits: list[SearchHit]
if requested_urls:
hits = [SearchHit(url=str(url)) for url in requested_urls[:top_k]]
else:
if self.search_backend is None:
return ToolResult(
tool_name=self.name,
success=False,
error="No search backend configured. Cloudflare Browser Rendering can fetch pages but does not provide public web search by itself.",
)
hits = self.search_backend.search(query, top_k)
results = []
for hit in hits[:top_k]:
entry: dict[str, Any] = {"url": hit.url}
if hit.title:
entry["title"] = hit.title
if hit.snippet:
entry["snippet"] = hit.snippet
if self.fetch_client is not None:
try:
markdown = self.fetch_client.markdown(hit.url)
links = self.fetch_client.links(hit.url)
entry["markdown"] = markdown[:4000]
entry["links"] = links[:10]
except Exception as exc:
entry["fetch_error"] = str(exc)
results.append(entry)
return ToolResult(
tool_name=self.name,
success=True,
output={"query": query, "results": results},
metadata={
"search_backend": type(self.search_backend).__name__ if self.search_backend is not None else None,
"fetch_backend": type(self.fetch_client).__name__ if self.fetch_client is not None else None,
"num_results": len(results),
},
)
def build_default_tool_registry(
*,
cloudflare_token: str | None = None,
cloudflare_account_id: str | None = None,
search_backend: SearchBackend | None = None,
) -> ToolRegistry:
fetch_client = None
if cloudflare_token or os.environ.get("CLOUDFLARE_API_TOKEN"):
try:
fetch_client = CloudflareBrowserRenderingClient(
api_token=cloudflare_token,
account_id=cloudflare_account_id,
)
except Exception:
fetch_client = None
registry = ToolRegistry(
[
CalculatorTool(),
WebSearchTool(
search_backend=search_backend or MockSearchBackend(),
fetch_client=fetch_client,
),
]
)
return registry

View File

@ -7,6 +7,7 @@ requires-python = ">=3.10"
dependencies = [
"datasets>=4.0.0",
"fastapi>=0.117.1",
"huggingface_hub>=0.34.4",
"ipykernel>=7.1.0",
"kernels>=0.11.7",
"matplotlib>=3.10.8",

View File

@ -33,6 +33,7 @@ from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.flash_attention import HAS_FA3
from nanochat.tools import DEFAULT_TOOL_SCHEMA
from scripts.base_eval import evaluate_core
print_banner()
@ -481,9 +482,12 @@ while True:
optimizer.state_dict(), # optimizer state
{ # metadata saved as json
"step": step,
"stage": "base",
"val_bpb": val_bpb, # loss at last step
"model_config": model_config_kwargs,
"user_config": user_config, # inputs to the training script
"tool_schema": DEFAULT_TOOL_SCHEMA,
"source_hf_repo": os.environ.get("NANOCHAT_SOURCE_HF_REPO"),
"device_batch_size": args.device_batch_size,
"max_seq_len": args.max_seq_len,
"total_batch_size": total_batch_size,

View File

@ -0,0 +1,167 @@
"""
Build small local tool-use datasets for SFT/eval dry runs.
This is not intended to replace larger curated corpora. It creates schema-valid
seed data so tool routing, tokenization, and evaluation can be tested locally.
"""
import argparse
import json
import os
def write_jsonl(path, rows):
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
for row in rows:
f.write(json.dumps(row, ensure_ascii=True) + "\n")
def calculator_conversation(user_text, expression, value, final_text):
return [
{"role": "user", "content": user_text},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me compute that exactly. "},
{"type": "tool_call", "tool_name": "calculator", "arguments": {"expression": expression}},
{"type": "tool_result", "tool_name": "calculator", "output": {"expression": expression, "value": value}},
{"type": "text", "text": final_text},
],
},
]
def web_search_conversation(user_text, query, url, final_text):
return [
{"role": "user", "content": user_text},
{
"role": "assistant",
"content": [
{"type": "text", "text": "I should verify this with a fresh web lookup. "},
{"type": "tool_call", "tool_name": "web_search", "arguments": {"query": query, "top_k": 1}},
{
"type": "tool_result",
"tool_name": "web_search",
"output": {
"query": query,
"results": [
{
"url": url,
"title": "Source result",
"snippet": "Fresh page content retrieved for grounding.",
}
],
},
},
{"type": "text", "text": final_text},
],
},
]
def direct_conversation(user_text, answer_text):
return [
{"role": "user", "content": user_text},
{"role": "assistant", "content": answer_text},
]
def build_train_rows():
return [
direct_conversation(
"Explain what gradient descent is in one short paragraph.",
"Gradient descent is an optimization method that repeatedly nudges model parameters in the direction that reduces error, using the gradient to decide both direction and magnitude of each update.",
),
calculator_conversation(
"What is 18% of 250?",
"percent(250,18)",
45.0,
"18% of 250 is 45.",
),
calculator_conversation(
"If revenue grew from 120 to 150, what is the percent change?",
"percent_change(120,150)",
25.0,
"The percent change from 120 to 150 is 25%.",
),
calculator_conversation(
"What is the monthly EMI for a 500000 loan at 8% annual interest over 240 months?",
"emi(500000,8,240)",
4182.198594391402,
"The monthly EMI is about 4182.2.",
),
web_search_conversation(
"Find the official nanochat repository and give me the link.",
"official nanochat repository",
"https://github.com/karpathy/nanochat",
"The official nanochat repository is https://github.com/karpathy/nanochat",
),
web_search_conversation(
"What do the Cloudflare Browser Rendering docs say about the markdown endpoint?",
"Cloudflare Browser Rendering markdown endpoint docs",
"https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/",
"Cloudflare documents a markdown endpoint for extracting rendered page content: https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/",
),
]
def build_eval_rows():
return [
{
"conversation": {
"messages": [
{"role": "user", "content": "What is 12% of 250?"},
{"role": "assistant", "content": "45"},
]
},
"checks": {
"must_call": "calculator",
"must_not_call": ["web_search"],
"answer_contains": ["45"],
},
},
{
"conversation": {
"messages": [
{"role": "user", "content": "Find the official nanochat repository and cite the link."},
{"role": "assistant", "content": "https://github.com/karpathy/nanochat"},
]
},
"checks": {
"must_call": "web_search",
"citation_required": True,
"answer_contains": ["github.com/karpathy/nanochat"],
},
},
{
"conversation": {
"messages": [
{"role": "user", "content": "Explain what overfitting means in one sentence."},
{"role": "assistant", "content": "Overfitting means a model memorizes training patterns too closely and generalizes poorly to new data."},
]
},
"checks": {
"must_not_call": ["web_search"],
"answer_contains": ["generalizes poorly"],
},
},
]
def main():
parser = argparse.ArgumentParser(description="Build seed tool-use datasets")
parser.add_argument("--train-out", default="seed_data/tool_sft_seed.jsonl", help="Output JSONL for SFT conversations")
parser.add_argument("--eval-out", default="seed_data/tool_eval_seed.jsonl", help="Output JSONL for eval/reward objects")
args = parser.parse_args()
train_rows = build_train_rows()
eval_rows = build_eval_rows()
write_jsonl(args.train_out, train_rows)
write_jsonl(args.eval_out, eval_rows)
print(f"Wrote {len(train_rows)} SFT rows to {args.train_out}")
print(f"Wrote {len(eval_rows)} eval rows to {args.eval_out}")
if __name__ == "__main__":
main()

View File

@ -22,6 +22,7 @@ from tasks.mmlu import MMLU
from tasks.arc import ARC
from tasks.gsm8k import GSM8K
from tasks.spellingbee import SpellingBee
from tasks.tool_json import ToolJSON
# -----------------------------------------------------------------------------
# Generative evaluation loop (we go one problem at a time, sample, evaluate)
@ -156,8 +157,10 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
def run_chat_eval(task_name, model, tokenizer, engine,
batch_size=1, num_samples=1, max_new_tokens=512, temperature=0.0, top_k=50,
max_problems=None):
max_problems=None, tool_jsonl=None):
# Create the evaluation object
if task_name == "ToolJSON" and not tool_jsonl:
raise ValueError("ToolJSON evaluation requires --tool-jsonl")
task_module = {
'HumanEval': HumanEval,
'MMLU': partial(MMLU, subset="all", split="test"),
@ -165,6 +168,7 @@ def run_chat_eval(task_name, model, tokenizer, engine,
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
'GSM8K': partial(GSM8K, subset="main", split="test"),
'SpellingBee': partial(SpellingBee, size=256, split="test"),
'ToolJSON': partial(ToolJSON, filepath=tool_jsonl),
}[task_name]
task_object = task_module()
# Run the evaluation
@ -191,6 +195,7 @@ if __name__ == "__main__":
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
parser.add_argument('--tool-jsonl', type=str, default=None, help='Tool JSONL dataset path used when task-name includes ToolJSON')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
args = parser.parse_args()
@ -224,6 +229,7 @@ if __name__ == "__main__":
temperature=args.temperature,
top_k=args.top_k,
max_problems=args.max_problems,
tool_jsonl=args.tool_jsonl,
)
results[task_name] = acc
print0(f"{task_name} accuracy: {100 * acc:.2f}%")

View File

@ -25,6 +25,7 @@ import torch.distributed as dist
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
from nanochat.checkpoint_manager import save_checkpoint, load_model
from nanochat.engine import Engine
from nanochat.tools import DEFAULT_TOOL_SCHEMA
from tasks.gsm8k import GSM8K
# -----------------------------------------------------------------------------
@ -317,7 +318,12 @@ for step in range(num_steps):
model.state_dict(),
None, # note: we don't bother to save the optimizer state
{
"step": step,
"stage": "rl",
"model_config": model_config_kwargs,
"user_config": user_config,
"tool_schema": DEFAULT_TOOL_SCHEMA,
"source_hf_repo": meta.get("source_hf_repo"),
}
)
print(f"✅ Saved model checkpoint to {checkpoint_dir}")

View File

@ -31,6 +31,7 @@ from tasks.mmlu import MMLU
from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
from tasks.spellingbee import SimpleSpelling, SpellingBee
from nanochat.tools import DEFAULT_TOOL_SCHEMA
# -----------------------------------------------------------------------------
# CLI arguments
@ -66,6 +67,8 @@ parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max pro
# Data mixture
parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)")
parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)")
parser.add_argument("--extra-train-jsonl", type=str, default="", help="comma-separated JSONL conversation files to append to the SFT train mixture")
parser.add_argument("--extra-val-jsonl", type=str, default="", help="comma-separated JSONL conversation files to append to the SFT val mixture")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
@ -162,6 +165,8 @@ for group in optimizer.param_groups:
# SFT data mixture and DataLoader
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
extra_train_jsonl = [path for path in args.extra_train_jsonl.split(",") if path]
extra_val_jsonl = [path for path in args.extra_val_jsonl.split(",") if path]
train_tasks = [
SmolTalk(split="train"), # 460K rows of general conversations
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
@ -171,13 +176,17 @@ train_tasks = [
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]
train_tasks.extend(CustomJSON(filepath=path) for path in extra_train_jsonl)
train_dataset = TaskMixture(train_tasks)
print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})")
print0(
f"Training mixture: {len(train_dataset):,} rows "
f"(MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs}, extra_jsonl={len(extra_train_jsonl)})"
)
val_dataset = TaskMixture([
SmolTalk(split="test"), # 24K rows in test set
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
]) # total: 24K + 14K + 1.32K ~= 39K rows
] + [CustomJSON(filepath=path) for path in extra_val_jsonl]) # total: 24K + 14K + 1.32K ~= 39K rows plus custom validation
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
# A big problem is that we don't know the final num_iterations in advance. So we create
# these two global variables and update them from within the data generator.
@ -406,6 +415,7 @@ while True:
optimizer.state_dict(),
{
"step": step,
"stage": "sft",
"val_bpb": val_bpb, # loss at last step
"model_config": {
"sequence_len": args.max_seq_len,
@ -417,6 +427,10 @@ while True:
"window_pattern": model.config.window_pattern,
},
"user_config": user_config, # inputs to the training script
"tool_schema": DEFAULT_TOOL_SCHEMA,
"source_hf_repo": meta.get("source_hf_repo"),
"extra_train_jsonl": extra_train_jsonl,
"extra_val_jsonl": extra_val_jsonl,
},
rank=ddp_rank,
)

219
scripts/chat_tool_rl.py Normal file
View File

@ -0,0 +1,219 @@
"""
Lightweight RL stage for tool-use tuning on local ToolJSON datasets.
1 GPU:
python -m scripts.chat_tool_rl --train-jsonl seed_data/tool_eval_seed.jsonl --eval-jsonl seed_data/tool_eval_seed.jsonl
8 GPUs:
torchrun --standalone --nproc_per_node=8 -m scripts.chat_tool_rl -- --train-jsonl seed_data/tool_eval_seed.jsonl --eval-jsonl seed_data/tool_eval_seed.jsonl
"""
import argparse
import itertools
import os
import torch
import torch.distributed as dist
import wandb
from nanochat.checkpoint_manager import load_model, save_checkpoint
from nanochat.common import DummyWandb, autodetect_device_type, compute_cleanup, compute_init, get_base_dir, print0
from nanochat.engine import Engine
from nanochat.tools import DEFAULT_TOOL_SCHEMA
from tasks.tool_json import ToolJSON
parser = argparse.ArgumentParser(description="RL tuning on local tool-use JSONL tasks")
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--model-tag", type=str, default=None, help="SFT model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="SFT model step to load from")
parser.add_argument("--train-jsonl", type=str, required=True, help="Training ToolJSONL file")
parser.add_argument("--eval-jsonl", type=str, default=None, help="Evaluation ToolJSONL file")
parser.add_argument("--num-epochs", type=int, default=1, help="Number of epochs over ToolJSON")
parser.add_argument("--device-batch-size", type=int, default=8, help="Max batch size per forward pass")
parser.add_argument("--examples-per-step", type=int, default=8, help="Examples per optimization step across all ranks")
parser.add_argument("--num-samples", type=int, default=8, help="Number of samples per example")
parser.add_argument("--max-new-tokens", type=int, default=256, help="Max tokens to generate")
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
parser.add_argument("--top-k", type=int, default=50, help="Top-k sampling (0 disables)")
parser.add_argument("--embedding-lr", type=float, default=0.2, help="Embedding LR")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="Unembedding LR")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="Matrix LR")
parser.add_argument("--weight-decay", type=float, default=0.0, help="Weight decay")
parser.add_argument("--init-lr-frac", type=float, default=0.05, help="Initial LR multiplier")
parser.add_argument("--eval-every", type=int, default=40, help="Evaluate every N steps")
parser.add_argument("--eval-examples", type=int, default=64, help="Maximum eval examples")
parser.add_argument("--save-every", type=int, default=40, help="Save every N steps")
args = parser.parse_args()
user_config = vars(args).copy()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-tool-rl", name=args.run, config=user_config)
model, tokenizer, meta = load_model("sft", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
engine = Engine(model, tokenizer)
train_task = ToolJSON(filepath=args.train_jsonl)
eval_task = ToolJSON(filepath=args.eval_jsonl or args.train_jsonl)
num_steps = (len(train_task) // args.examples_per_step) * args.num_epochs
print0(f"Calculated number of steps: {num_steps}")
@torch.no_grad()
def get_batch():
assistant_end = tokenizer.encode_special("<|assistant_end|>")
rank_indices = range(ddp_rank, len(train_task), ddp_world_size)
for example_idx in itertools.cycle(rank_indices):
conversation = train_task[example_idx]
tokens = tokenizer.render_for_completion(conversation)
prefix_length = len(tokens)
model.eval()
generated_token_sequences = []
masks = []
num_sampling_steps = args.num_samples // args.device_batch_size
if args.num_samples % args.device_batch_size != 0:
raise ValueError("num_samples must be divisible by device_batch_size")
for sampling_step in range(num_sampling_steps):
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF
generated_batch, mask_batch = engine.generate_batch(
tokens,
num_samples=args.device_batch_size,
max_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
seed=seed,
)
generated_token_sequences.extend(generated_batch)
masks.extend(mask_batch)
rewards = []
for sample_tokens in generated_token_sequences:
generated_text = tokenizer.decode(sample_tokens[prefix_length:])
rewards.append(train_task.reward(conversation, generated_text))
max_length = max(len(seq) for seq in generated_token_sequences)
padded_sequences = [seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences]
padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks]
ids = torch.tensor(padded_sequences, dtype=torch.long, device=device)
mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device)
inputs = ids[:, :-1]
targets = ids[:, 1:].clone()
targets[mask_ids[:, 1:] == 0] = -1
rewards_tensor = torch.tensor(rewards, dtype=torch.float, device=device)
advantages = rewards_tensor - rewards_tensor.mean()
yield generated_token_sequences, inputs, targets, rewards_tensor, advantages
@torch.no_grad()
def run_tool_eval(task, max_examples):
max_examples = min(max_examples, len(task))
total = 0
reward_sum = 0.0
passed = 0
for idx in range(ddp_rank, max_examples, ddp_world_size):
conversation = task[idx]
tokens = tokenizer.render_for_completion(conversation)
generated_sequences, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=args.max_new_tokens, temperature=0.0, top_k=args.top_k)
generated_text = tokenizer.decode(generated_sequences[0][len(tokens):])
reward_sum += task.reward(conversation, generated_text)
passed += task.evaluate(conversation, generated_text)
total += 1
reward_tensor = torch.tensor([reward_sum], dtype=torch.float, device=device)
passed_tensor = torch.tensor([passed], dtype=torch.long, device=device)
total_tensor = torch.tensor([total], dtype=torch.long, device=device)
if ddp:
dist.all_reduce(reward_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(passed_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
total = max(total_tensor.item(), 1)
return reward_tensor.item() / total, passed_tensor.item() / total
optimizer = model.setup_optimizer(
unembedding_lr=args.unembedding_lr,
embedding_lr=args.embedding_lr,
matrix_lr=args.matrix_lr,
weight_decay=args.weight_decay,
)
for group in optimizer.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"]
def get_lr_multiplier(it):
return 1.0 - it / max(num_steps, 1)
assert args.examples_per_step % ddp_world_size == 0, "examples_per_step must be divisible by number of ranks"
examples_per_rank = args.examples_per_step // ddp_world_size
batch_iterator = get_batch()
for step in range(num_steps):
if step % args.eval_every == 0:
model.eval()
mean_reward, pass_rate = run_tool_eval(eval_task, args.eval_examples)
print0(f"Step {step} | tool_eval_reward={mean_reward:.4f} | tool_eval_pass={pass_rate:.4f}")
wandb_run.log({"step": step, "tool_eval_reward": mean_reward, "tool_eval_pass": pass_rate})
rewards_list = []
for example_step in range(examples_per_rank):
_, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator)
model.train()
assert inputs_all.size(0) % args.device_batch_size == 0
num_passes = inputs_all.size(0) // args.device_batch_size
for pass_idx in range(num_passes):
b0, b1 = pass_idx * args.device_batch_size, (pass_idx + 1) * args.device_batch_size
inputs = inputs_all[b0:b1]
targets = targets_all[b0:b1]
rewards = rewards_all[b0:b1]
advantages = advantages_all[b0:b1]
logp = -model(inputs, targets, loss_reduction="none").view_as(inputs)
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
num_valid = (targets >= 0).sum().clamp(min=1)
pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank)
loss = -pg_obj
loss.backward()
print0(
f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} "
f"| loss={loss.item():.6f} | reward={rewards.mean().item():.4f}"
)
rewards_list.append(rewards_all.mean().item())
lrm = get_lr_multiplier(step)
for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
optimizer.step()
model.zero_grad(set_to_none=True)
wandb_run.log({"step": step, "lrm": lrm, "mean_reward": sum(rewards_list) / max(len(rewards_list), 1)})
if master_process and ((step > 0 and step % args.save_every == 0) or step == num_steps - 1):
base_dir = get_base_dir()
output_dirname = args.model_tag if args.model_tag else f"d{model.config.n_layer}"
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname)
save_checkpoint(
checkpoint_dir,
step,
model.state_dict(),
None,
{
"step": step,
"stage": "tool_rl",
"model_config": model.config.__dict__,
"user_config": user_config,
"tool_schema": DEFAULT_TOOL_SCHEMA,
"source_hf_repo": meta.get("source_hf_repo"),
"train_jsonl": args.train_jsonl,
"eval_jsonl": args.eval_jsonl,
},
)
print0(f"Saved tool RL checkpoint to {checkpoint_dir}")
wandb_run.finish()
compute_cleanup()

View File

@ -0,0 +1,86 @@
"""
Upload native nanochat checkpoints and related artifacts to a Hugging Face repo.
Examples:
python -m scripts.hf_sync_checkpoint --repo-id ManmohanSharma/nanochat-d24 --source base --model-tag d24_hf_import
python -m scripts.hf_sync_checkpoint --repo-id ManmohanSharma/nanochat-d24 --source base --model-tag d24_hf_import --step 0
"""
import argparse
import os
from huggingface_hub import HfApi
from nanochat.common import get_base_dir
def resolve_checkpoint_dir(source, model_tag):
phase_dir = {
"base": "base_checkpoints",
"sft": "chatsft_checkpoints",
"rl": "chatrl_checkpoints",
}[source]
return os.path.join(get_base_dir(), phase_dir, model_tag)
def main():
parser = argparse.ArgumentParser(description="Upload native nanochat checkpoints to Hugging Face")
parser.add_argument("--repo-id", required=True, help="Destination HF model repo")
parser.add_argument("--source", choices=["base", "sft", "rl"], required=True, help="Checkpoint phase")
parser.add_argument("--model-tag", required=True, help="Local nanochat model tag")
parser.add_argument("--step", type=int, default=None, help="Optional specific step to upload")
parser.add_argument("--token-env", default="HF_TOKEN", help="Environment variable containing the HF token")
parser.add_argument("--private", type=int, default=0, help="Create the repo as private if it does not exist")
parser.add_argument("--repo-subdir", default="native_checkpoints", help="Subdirectory inside the repo")
args = parser.parse_args()
token = os.environ.get(args.token_env)
if not token:
raise ValueError(f"Missing Hugging Face token in {args.token_env}")
checkpoint_dir = resolve_checkpoint_dir(args.source, args.model_tag)
if not os.path.isdir(checkpoint_dir):
raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}")
api = HfApi(token=token)
api.create_repo(repo_id=args.repo_id, repo_type="model", private=bool(args.private), exist_ok=True)
if args.step is None:
path_in_repo = f"{args.repo_subdir}/{args.source}/{args.model_tag}"
api.upload_folder(
folder_path=checkpoint_dir,
repo_id=args.repo_id,
repo_type="model",
path_in_repo=path_in_repo,
commit_message=f"Upload native {args.source} checkpoint folder for {args.model_tag}",
)
print(f"Uploaded {checkpoint_dir} to {args.repo_id}:{path_in_repo}")
return
step_str = f"{args.step:06d}"
files = [
f"model_{step_str}.pt",
f"meta_{step_str}.json",
]
optimizer_pattern = f"optim_{step_str}_"
for filename in sorted(os.listdir(checkpoint_dir)):
if filename.startswith(optimizer_pattern) and filename.endswith(".pt"):
files.append(filename)
for filename in files:
local_path = os.path.join(checkpoint_dir, filename)
if not os.path.exists(local_path):
continue
path_in_repo = f"{args.repo_subdir}/{args.source}/{args.model_tag}/{filename}"
api.upload_file(
path_or_fileobj=local_path,
repo_id=args.repo_id,
repo_type="model",
path_in_repo=path_in_repo,
commit_message=f"Upload {args.source} checkpoint {args.model_tag} step {step_str}",
)
print(f"Uploaded step {step_str} for {args.model_tag} to {args.repo_id}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,162 @@
"""
Import a Hugging Face model repo into nanochat's native checkpoint format.
This is intended for base-model continuation before multi-stage nanochat runs.
Examples:
python -m scripts.import_hf_checkpoint --repo-id ManmohanSharma/nanochat-d24 --model-tag d24_hf_import
python -m scripts.import_hf_checkpoint --repo-id ManmohanSharma/nanochat-d24 --local-dir /path/to/snapshot
"""
import argparse
import os
from dataclasses import asdict
import torch
from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.common import get_base_dir
from nanochat.gpt import GPT, GPTConfig
from nanochat.tokenizer import get_tokenizer
from nanochat.tools import DEFAULT_TOOL_SCHEMA
def normalize_hf_state_dict_keys(state_dict):
normalized = {}
prefixes = ("_orig_mod.", "module.", "model.")
for key, value in state_dict.items():
normalized_key = key
for prefix in prefixes:
if normalized_key.startswith(prefix):
normalized_key = normalized_key[len(prefix):]
normalized[normalized_key] = value
return normalized
def infer_gpt_config(hf_config):
kwargs = {
"sequence_len": getattr(
hf_config,
"sequence_len",
getattr(hf_config, "max_position_embeddings", getattr(hf_config, "n_positions", 2048)),
),
"vocab_size": getattr(hf_config, "vocab_size"),
"n_layer": getattr(
hf_config,
"n_layer",
getattr(hf_config, "num_hidden_layers", getattr(hf_config, "num_layers")),
),
"n_head": getattr(
hf_config,
"n_head",
getattr(hf_config, "num_attention_heads", None),
),
"n_kv_head": getattr(
hf_config,
"n_kv_head",
getattr(hf_config, "num_key_value_heads", getattr(hf_config, "num_attention_heads", None)),
),
"n_embd": getattr(
hf_config,
"n_embd",
getattr(hf_config, "hidden_size", getattr(hf_config, "d_model", None)),
),
"window_pattern": getattr(hf_config, "window_pattern", "L"),
}
missing = [key for key, value in kwargs.items() if value is None]
if missing:
raise ValueError(f"Could not infer nanochat GPTConfig fields from HF config: {missing}")
return GPTConfig(**kwargs)
def verify_tokenizer_compatibility(hf_tokenizer, nanochat_tokenizer):
hf_vocab = hf_tokenizer.vocab_size
local_vocab = nanochat_tokenizer.get_vocab_size()
if hf_vocab != local_vocab:
raise ValueError(
f"Tokenizer vocab mismatch: HF repo has vocab_size={hf_vocab}, "
f"local nanochat tokenizer has vocab_size={local_vocab}"
)
def load_hf_snapshot(repo_id, revision, token, local_dir):
if local_dir is not None:
return local_dir
return snapshot_download(
repo_id=repo_id,
revision=revision,
token=token,
repo_type="model",
)
def main():
parser = argparse.ArgumentParser(description="Import HF repo into native nanochat checkpoints")
parser.add_argument("--repo-id", required=True, help="HF model repo id, e.g. ManmohanSharma/nanochat-d24")
parser.add_argument("--revision", default=None, help="Optional HF revision")
parser.add_argument("--local-dir", default=None, help="Use an already-downloaded HF snapshot instead of downloading")
parser.add_argument("--token-env", default="HF_TOKEN", help="Environment variable containing the HF token")
parser.add_argument("--model-tag", default=None, help="Destination model tag. Defaults to repo name slug")
parser.add_argument("--step", type=int, default=0, help="Checkpoint step number to write")
parser.add_argument("--source", choices=["base", "sft", "rl"], default="base", help="Destination checkpoint phase")
parser.add_argument("--trust-remote-code", type=int, default=1, help="Pass trust_remote_code to Transformers loaders")
args = parser.parse_args()
token = os.environ.get(args.token_env)
snapshot_path = load_hf_snapshot(args.repo_id, args.revision, token, args.local_dir)
trust_remote_code = bool(args.trust_remote_code)
hf_config = AutoConfig.from_pretrained(snapshot_path, token=token, trust_remote_code=trust_remote_code)
hf_tokenizer = AutoTokenizer.from_pretrained(snapshot_path, token=token, trust_remote_code=trust_remote_code)
nanochat_tokenizer = get_tokenizer()
verify_tokenizer_compatibility(hf_tokenizer, nanochat_tokenizer)
local_config = infer_gpt_config(hf_config)
with torch.no_grad():
model = AutoModelForCausalLM.from_pretrained(
snapshot_path,
token=token,
trust_remote_code=trust_remote_code,
)
hf_state_dict = normalize_hf_state_dict_keys(model.state_dict())
with torch.device("meta"):
local_model = GPT(local_config)
expected_keys = set(local_model.state_dict().keys())
provided_keys = set(hf_state_dict.keys())
missing = sorted(expected_keys - provided_keys)
extra = sorted(provided_keys - expected_keys)
if missing or extra:
message = [
"HF checkpoint keys do not match native nanochat keys after normalization.",
f"Missing keys: {missing[:12]}",
f"Extra keys: {extra[:12]}",
]
raise ValueError("\n".join(message))
model_data = {key: value.detach().cpu() for key, value in hf_state_dict.items()}
meta_data = {
"model_config": asdict(local_config),
"imported_from_hf": True,
"source_hf_repo": args.repo_id,
"source_hf_revision": args.revision,
"tool_schema": DEFAULT_TOOL_SCHEMA,
"tokenizer_vocab_size": nanochat_tokenizer.get_vocab_size(),
}
model_tag = args.model_tag or args.repo_id.split("/")[-1].replace("-", "_")
base_dir = get_base_dir()
phase_dir = {
"base": "base_checkpoints",
"sft": "chatsft_checkpoints",
"rl": "chatrl_checkpoints",
}[args.source]
checkpoint_dir = os.path.join(base_dir, phase_dir, model_tag)
save_checkpoint(checkpoint_dir, args.step, model_data, optimizer_data=None, meta_data=meta_data, rank=0)
print(f"Imported {args.repo_id} into {checkpoint_dir} at step {args.step}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,60 @@
"""
Verify external service access from environment variables.
Expected env vars:
- HF_TOKEN
- CLOUDFLARE_API_TOKEN
- CLOUDFLARE_ACCOUNT_ID
"""
import argparse
import os
import requests
from huggingface_hub import HfApi
def verify_hf(repo_id, token_env):
token = os.environ.get(token_env)
if not token:
return False, f"Missing {token_env}"
api = HfApi(token=token)
info = api.model_info(repo_id)
return True, f"HF access OK: {info.id}"
def verify_cloudflare(token_env):
token = os.environ.get(token_env)
if not token:
return False, f"Missing {token_env}"
response = requests.get(
"https://api.cloudflare.com/client/v4/user/tokens/verify",
headers={"Authorization": f"Bearer {token}"},
timeout=20,
)
response.raise_for_status()
payload = response.json()
if not payload.get("success", False):
return False, f"Cloudflare verify failed: {payload.get('errors')}"
status = payload.get("result", {}).get("status")
return True, f"Cloudflare token OK: status={status}"
def main():
parser = argparse.ArgumentParser(description="Verify Hugging Face and Cloudflare access from env vars")
parser.add_argument("--hf-repo-id", default="ManmohanSharma/nanochat-d24", help="HF repo id to verify")
parser.add_argument("--hf-token-env", default="HF_TOKEN", help="HF token env var")
parser.add_argument("--cloudflare-token-env", default="CLOUDFLARE_API_TOKEN", help="Cloudflare token env var")
args = parser.parse_args()
hf_ok, hf_message = verify_hf(args.hf_repo_id, args.hf_token_env)
cf_ok, cf_message = verify_cloudflare(args.cloudflare_token_env)
print(hf_message)
print(cf_message)
if not (hf_ok and cf_ok):
raise SystemExit(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,3 @@
{"conversation": {"messages": [{"role": "user", "content": "What is 12% of 250?"}, {"role": "assistant", "content": "45"}]}, "checks": {"must_call": "calculator", "must_not_call": ["web_search"], "answer_contains": ["45"]}}
{"conversation": {"messages": [{"role": "user", "content": "Find the official nanochat repository and cite the link."}, {"role": "assistant", "content": "https://github.com/karpathy/nanochat"}]}, "checks": {"must_call": "web_search", "citation_required": true, "answer_contains": ["github.com/karpathy/nanochat"]}}
{"conversation": {"messages": [{"role": "user", "content": "Explain what overfitting means in one sentence."}, {"role": "assistant", "content": "Overfitting means a model memorizes training patterns too closely and generalizes poorly to new data."}]}, "checks": {"must_not_call": ["web_search"], "answer_contains": ["generalizes poorly"]}}

View File

@ -0,0 +1,6 @@
[{"role": "user", "content": "Explain what gradient descent is in one short paragraph."}, {"role": "assistant", "content": "Gradient descent is an optimization method that repeatedly nudges model parameters in the direction that reduces error, using the gradient to decide both direction and magnitude of each update."}]
[{"role": "user", "content": "What is 18% of 250?"}, {"role": "assistant", "content": [{"type": "text", "text": "Let me compute that exactly. "}, {"type": "tool_call", "tool_name": "calculator", "arguments": {"expression": "percent(250,18)"}}, {"type": "tool_result", "tool_name": "calculator", "output": {"expression": "percent(250,18)", "value": 45.0}}, {"type": "text", "text": "18% of 250 is 45."}]}]
[{"role": "user", "content": "If revenue grew from 120 to 150, what is the percent change?"}, {"role": "assistant", "content": [{"type": "text", "text": "Let me compute that exactly. "}, {"type": "tool_call", "tool_name": "calculator", "arguments": {"expression": "percent_change(120,150)"}}, {"type": "tool_result", "tool_name": "calculator", "output": {"expression": "percent_change(120,150)", "value": 25.0}}, {"type": "text", "text": "The percent change from 120 to 150 is 25%."}]}]
[{"role": "user", "content": "What is the monthly EMI for a 500000 loan at 8% annual interest over 240 months?"}, {"role": "assistant", "content": [{"type": "text", "text": "Let me compute that exactly. "}, {"type": "tool_call", "tool_name": "calculator", "arguments": {"expression": "emi(500000,8,240)"}}, {"type": "tool_result", "tool_name": "calculator", "output": {"expression": "emi(500000,8,240)", "value": 4182.198594391402}}, {"type": "text", "text": "The monthly EMI is about 4182.2."}]}]
[{"role": "user", "content": "Find the official nanochat repository and give me the link."}, {"role": "assistant", "content": [{"type": "text", "text": "I should verify this with a fresh web lookup. "}, {"type": "tool_call", "tool_name": "web_search", "arguments": {"query": "official nanochat repository", "top_k": 1}}, {"type": "tool_result", "tool_name": "web_search", "output": {"query": "official nanochat repository", "results": [{"url": "https://github.com/karpathy/nanochat", "title": "Source result", "snippet": "Fresh page content retrieved for grounding."}]}}, {"type": "text", "text": "The official nanochat repository is https://github.com/karpathy/nanochat"}]}]
[{"role": "user", "content": "What do the Cloudflare Browser Rendering docs say about the markdown endpoint?"}, {"role": "assistant", "content": [{"type": "text", "text": "I should verify this with a fresh web lookup. "}, {"type": "tool_call", "tool_name": "web_search", "arguments": {"query": "Cloudflare Browser Rendering markdown endpoint docs", "top_k": 1}}, {"type": "tool_result", "tool_name": "web_search", "output": {"query": "Cloudflare Browser Rendering markdown endpoint docs", "results": [{"url": "https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/", "title": "Source result", "snippet": "Fresh page content retrieved for grounding."}]}}, {"type": "text", "text": "Cloudflare documents a markdown endpoint for extracting rendered page content: https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/"}]}]

101
tasks/tool_json.py Normal file
View File

@ -0,0 +1,101 @@
"""
Local JSONL task for tool-use evaluation and lightweight RL reward shaping.
Each line should be a JSON object with:
{
"conversation": {"messages": [...]},
"checks": {
"must_call": "calculator",
"must_not_call": ["web_search"],
"answer_contains": ["42"],
"citation_required": false
}
}
"""
import json
import os
import re
from nanochat.tools import TOOL_CALL_END, TOOL_CALL_START, parse_tool_call_payload
from tasks.common import Task
TOOL_BLOCK_RE = re.compile(re.escape(TOOL_CALL_START) + r"(.*?)" + re.escape(TOOL_CALL_END), re.DOTALL)
class ToolJSON(Task):
def __init__(self, filepath, split="eval", **kwargs):
super().__init__(**kwargs)
self.filepath = filepath
self.split = split
self.rows = []
if not os.path.exists(filepath):
raise FileNotFoundError(f"Tool JSONL dataset not found: {filepath}")
with open(filepath, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
row = json.loads(line)
if "conversation" not in row:
raise ValueError(f"Row missing conversation field: {row}")
row.setdefault("checks", {})
self.rows.append(row)
@property
def eval_type(self):
return "generative"
def num_examples(self):
return len(self.rows)
def get_example(self, index):
row = self.rows[index]
conversation = dict(row["conversation"])
conversation["checks"] = row.get("checks", {})
return conversation
def _tool_calls(self, assistant_response):
calls = []
for payload in TOOL_BLOCK_RE.findall(assistant_response):
invocation = parse_tool_call_payload(payload)
calls.append(invocation.tool_name)
return calls
def evaluate(self, conversation, assistant_response):
checks = conversation.get("checks", {})
score = self.reward(conversation, assistant_response)
return int(score >= 0.999)
def reward(self, conversation, assistant_response):
checks = conversation.get("checks", {})
total = 0.0
passed = 0.0
tool_calls = self._tool_calls(assistant_response)
must_call = checks.get("must_call")
if must_call:
total += 1.0
passed += float(must_call in tool_calls)
for tool_name in checks.get("must_not_call", []):
total += 1.0
passed += float(tool_name not in tool_calls)
for needle in checks.get("answer_contains", []):
total += 1.0
passed += float(needle in assistant_response)
answer_regex = checks.get("answer_regex")
if answer_regex:
total += 1.0
passed += float(re.search(answer_regex, assistant_response) is not None)
if checks.get("citation_required", False):
total += 1.0
passed += float(("http://" in assistant_response) or ("https://" in assistant_response))
if total == 0.0:
return 0.0
return passed / total

133
tests/test_tools.py Normal file
View File

@ -0,0 +1,133 @@
import json
import torch
from nanochat.engine import Engine
from nanochat.tools import CalculatorTool, MockSearchBackend, ToolRegistry, WebSearchTool, parse_tool_call_payload, serialize_tool_call
from tasks.tool_json import ToolJSON
class MockConfig:
n_kv_head = 4
n_head = 4
n_embd = 64
n_layer = 2
sequence_len = 128
class ByteTokenizer:
def __init__(self):
self._special_tokens = {
"<|python_start|>": 256,
"<|python_end|>": 257,
"<|output_start|>": 258,
"<|output_end|>": 259,
"<|assistant_end|>": 260,
"<|bos|>": 261,
}
self._bos = 261
def encode_special(self, s):
return self._special_tokens[s]
def get_bos_token_id(self):
return self._bos
def encode(self, s, prepend=None):
tokens = list(s.encode("utf-8"))
if prepend is not None:
tokens = [prepend] + tokens
return tokens
def decode(self, tokens):
byte_tokens = [t for t in tokens if t < 256]
return bytes(byte_tokens).decode("utf-8", errors="replace")
class SequencedModel:
def __init__(self, sequence, vocab_size=262):
self.sequence = sequence
self.vocab_size = vocab_size
self.config = MockConfig()
self._device = torch.device("cpu")
self.call_index = 0
def get_device(self):
return self._device
def forward(self, ids, kv_cache=None):
if kv_cache is not None:
kv_cache.advance(ids.shape[1])
target = self.sequence[min(self.call_index, len(self.sequence) - 1)]
logits = torch.full((ids.shape[0], ids.shape[1], self.vocab_size), -1e9)
logits[:, -1, target] = 0.0
self.call_index += 1
return logits
def test_parse_tool_call_payload_supports_json_and_legacy():
legacy = parse_tool_call_payload("6*7")
assert legacy.tool_name == "calculator"
assert legacy.arguments == {"expression": "6*7"}
payload = parse_tool_call_payload(serialize_tool_call("calculator", {"expression": "6*7"}))
assert payload.tool_name == "calculator"
assert payload.arguments == {"expression": "6*7"}
def test_calculator_tool_supports_scientific_and_business_functions():
tool = CalculatorTool()
trig = tool.run({"expression": "round(sin(pi/2), 4)"})
assert trig.success
assert trig.output["value"] == 1.0
emi = tool.run({"expression": "round(emi(500000,8,240), 2)"})
assert emi.success
assert emi.output["value"] == 4182.2
def test_web_search_tool_returns_mock_results():
tool = WebSearchTool(search_backend=MockSearchBackend())
result = tool.run({"query": "nanochat gpt2 speedrun", "top_k": 1})
assert result.success
assert result.output["results"][0]["url"] == "https://github.com/karpathy/nanochat"
def test_engine_executes_json_calculator_tool_call():
tokenizer = ByteTokenizer()
payload = serialize_tool_call("calculator", {"expression": "6*7"})
sequence = [
tokenizer.encode_special("<|python_start|>"),
*tokenizer.encode(payload),
tokenizer.encode_special("<|python_end|>"),
tokenizer.encode_special("<|assistant_end|>"),
]
model = SequencedModel(sequence)
registry = ToolRegistry([CalculatorTool()])
engine = Engine(model, tokenizer, tools=registry)
prompt = [tokenizer.get_bos_token_id(), 72]
results, _ = engine.generate_batch(prompt, num_samples=1, max_tokens=128, temperature=0.0)
completion = tokenizer.decode(results[0][len(prompt):])
assert "42" in completion
def test_tool_json_reward_and_eval(tmp_path):
path = tmp_path / "tool_eval.jsonl"
row = {
"conversation": {
"messages": [
{"role": "user", "content": "What is 6 times 7?"},
{"role": "assistant", "content": "42"},
]
},
"checks": {
"must_call": "calculator",
"answer_contains": ["42"],
},
}
path.write_text(json.dumps(row) + "\n", encoding="utf-8")
task = ToolJSON(filepath=str(path))
conversation = task[0]
response = '<|python_start|>{"tool":"calculator","arguments":{"expression":"6*7"}}<|python_end|>42'
assert task.reward(conversation, response) == 1.0
assert task.evaluate(conversation, response) == 1