mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-09 01:10:10 +00:00
Merge pull request #1 from manmohan659/codex/pre-gpu-readiness
Add pre-GPU tool training and checkpoint plumbing
This commit is contained in:
commit
baef0a3d66
97
docs/pre_gpu_runbook.md
Normal file
97
docs/pre_gpu_runbook.md
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
# Pre-GPU Runbook
|
||||
|
||||
This runbook is the minimum operational checklist before spending GPU time.
|
||||
|
||||
## 1. Local Prep
|
||||
|
||||
1. Build the seed tool datasets:
|
||||
|
||||
```bash
|
||||
python -m scripts.build_tool_datasets
|
||||
```
|
||||
|
||||
2. Import the starting checkpoint from Hugging Face into native nanochat format:
|
||||
|
||||
```bash
|
||||
python -m scripts.import_hf_checkpoint \
|
||||
--repo-id ManmohanSharma/nanochat-d24 \
|
||||
--model-tag d24_hf_import
|
||||
```
|
||||
|
||||
3. Validate tool tokenization and mock tool execution with local tests:
|
||||
|
||||
```bash
|
||||
python -m pytest tests/test_engine.py tests/test_tools.py -v
|
||||
```
|
||||
|
||||
4. Dry-run tool evaluation on CPU:
|
||||
|
||||
```bash
|
||||
python -m scripts.chat_eval \
|
||||
-i sft \
|
||||
-a ToolJSON \
|
||||
--tool-jsonl seed_data/tool_eval_seed.jsonl \
|
||||
--device-type cpu \
|
||||
-x 3
|
||||
```
|
||||
|
||||
## 2. 48-Hour GPU Schedule
|
||||
|
||||
1. Pilot CPT
|
||||
- Run a short continuation test from the imported base checkpoint.
|
||||
- Confirm loss is moving, checkpoint save works, and HF sync works.
|
||||
|
||||
2. Full CPT
|
||||
- Run the main continuation stage on ClimbMix backbone.
|
||||
- Save staged checkpoints at planned intervals.
|
||||
|
||||
3. SFT
|
||||
- Include the local tool SFT JSONL via `--extra-train-jsonl`.
|
||||
- Validate that calculator/web_search traces render correctly.
|
||||
|
||||
4. RL / tool tuning
|
||||
- Keep this stage narrow and short.
|
||||
- Focus on tool-choice correctness and grounded answers.
|
||||
|
||||
5. Eval
|
||||
- Run ARC, MMLU, GSM8K, HumanEval, and ToolJSON checks.
|
||||
- Do not ship if tool behavior regresses or citations are missing.
|
||||
|
||||
## 3. Checkpoint Upload Cadence
|
||||
|
||||
Upload every stage boundary and any explicit resume point:
|
||||
|
||||
```bash
|
||||
python -m scripts.hf_sync_checkpoint \
|
||||
--repo-id ManmohanSharma/nanochat-d24 \
|
||||
--source base \
|
||||
--model-tag d24_hf_import \
|
||||
--step 0
|
||||
```
|
||||
|
||||
If a whole checkpoint directory should be mirrored:
|
||||
|
||||
```bash
|
||||
python -m scripts.hf_sync_checkpoint \
|
||||
--repo-id ManmohanSharma/nanochat-d24 \
|
||||
--source base \
|
||||
--model-tag d24_hf_import
|
||||
```
|
||||
|
||||
## 4. Go / No-Go
|
||||
|
||||
Go only if:
|
||||
|
||||
- HF import works.
|
||||
- HF sync works.
|
||||
- Mock tool execution works.
|
||||
- Tool seed datasets are generated.
|
||||
- Tool eval runs locally.
|
||||
- The search backend plan is explicit: search provider plus Cloudflare fetch/crawl.
|
||||
|
||||
No-Go if:
|
||||
|
||||
- Any tokenizer mismatch appears during HF import.
|
||||
- Tool blocks fail to render.
|
||||
- `web_search` still has no backend plan beyond fetch-only Cloudflare Browser Rendering.
|
||||
- Local tool eval is missing or failing.
|
||||
31
manifests/cpt_recipe_v1.json
Normal file
31
manifests/cpt_recipe_v1.json
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
{
|
||||
"name": "cpt_recipe_v1",
|
||||
"goal": "Continue pretraining nanochat-d24 on a safe, high-value mixture for a 48-hour run.",
|
||||
"backbone_dataset": {
|
||||
"repo_id": "karpathy/climbmix-400b-shuffle",
|
||||
"role": "primary",
|
||||
"notes": "Use ClimbMix as the main pretraining backbone because it is already the best-performing dataset in this fork."
|
||||
},
|
||||
"optional_augmentations": [
|
||||
{
|
||||
"repo_id": "nvidia/Nemotron-Pretraining-Code-v2",
|
||||
"role": "code",
|
||||
"notes": "Selective code augmentation only."
|
||||
},
|
||||
{
|
||||
"repo_id": "nvidia/Nemotron-CC-Math-v1",
|
||||
"role": "math_science",
|
||||
"notes": "Selective math/science augmentation only."
|
||||
},
|
||||
{
|
||||
"repo_id": "nvidia/Nemotron-Pretraining-Specialized-v1.1",
|
||||
"role": "specialized_web",
|
||||
"notes": "Use sparingly for targeted specialization."
|
||||
}
|
||||
],
|
||||
"do_not_do": [
|
||||
"Do not replace ClimbMix wholesale.",
|
||||
"Do not attempt architecture changes during this 48-hour continuation run.",
|
||||
"Do not rely on continued pretraining for current-events freshness."
|
||||
]
|
||||
}
|
||||
26
manifests/rl_recipe_v1.json
Normal file
26
manifests/rl_recipe_v1.json
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
{
|
||||
"name": "rl_recipe_v1",
|
||||
"goal": "Lightweight tool-tuning stage after SFT.",
|
||||
"seed_eval_data": "seed_data/tool_eval_seed.jsonl",
|
||||
"reward_components": [
|
||||
"valid tool call",
|
||||
"correct tool choice",
|
||||
"answer contains expected answer fragment",
|
||||
"citation required when using web_search",
|
||||
"penalty for unnecessary tool calls"
|
||||
],
|
||||
"external_templates": [
|
||||
{
|
||||
"repo_id": "nvidia/Nemotron-RL-Agentic-Conversational-Tool-Use-Pivot-v1",
|
||||
"role": "tool_use_structure"
|
||||
},
|
||||
{
|
||||
"repo_id": "nvidia/Nemotron-RL-Agentic-Function-Calling-Pivot-v1",
|
||||
"role": "function_call_structure"
|
||||
}
|
||||
],
|
||||
"notes": [
|
||||
"Keep RL narrow; do not turn this into broad RLHF.",
|
||||
"Use the local tool eval dataset for dry runs before GPU."
|
||||
]
|
||||
}
|
||||
32
manifests/sft_recipe_v1.json
Normal file
32
manifests/sft_recipe_v1.json
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
{
|
||||
"name": "sft_recipe_v1",
|
||||
"goal": "Teach instruction following, tool routing, grounded web answers, and calculator usage.",
|
||||
"base_mixture": [
|
||||
{
|
||||
"source": "existing_nanochat",
|
||||
"notes": "Keep the existing chat mixture in scripts/chat_sft.py as the base."
|
||||
},
|
||||
{
|
||||
"source": "local_seed_data",
|
||||
"path": "seed_data/tool_sft_seed.jsonl",
|
||||
"notes": "Seed examples for direct answers, calculator, and web_search."
|
||||
}
|
||||
],
|
||||
"external_templates": [
|
||||
{
|
||||
"repo_id": "nvidia/Nemotron-Instruction-Following-Chat-v1",
|
||||
"role": "instruction_following"
|
||||
},
|
||||
{
|
||||
"collection": "nvidia/nemotron-post-training-v3",
|
||||
"role": "agentic_templates"
|
||||
}
|
||||
],
|
||||
"target_behaviors": [
|
||||
"Answer directly when no tool is needed.",
|
||||
"Use calculator for exact numeric computation.",
|
||||
"Use web_search for recency and external facts.",
|
||||
"Cite URLs when the response depends on retrieved web content.",
|
||||
"Avoid unnecessary tool calls."
|
||||
]
|
||||
}
|
||||
|
|
@ -13,70 +13,10 @@ The whole thing is made as efficient as possible.
|
|||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import signal
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from collections import deque
|
||||
from nanochat.common import compute_init, autodetect_device_type, COMPUTE_DTYPE
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Calculator tool helpers
|
||||
@contextmanager
|
||||
def timeout(duration, formula):
|
||||
def timeout_handler(signum, frame):
|
||||
raise Exception(f"'{formula}': timed out after {duration} seconds")
|
||||
|
||||
signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(duration)
|
||||
yield
|
||||
signal.alarm(0)
|
||||
|
||||
def eval_with_timeout(formula, max_time=3):
|
||||
try:
|
||||
with timeout(max_time, formula):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", SyntaxWarning)
|
||||
return eval(formula, {"__builtins__": {}}, {})
|
||||
except Exception as e:
|
||||
signal.alarm(0)
|
||||
# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
|
||||
return None
|
||||
|
||||
def use_calculator(expr):
|
||||
"""
|
||||
Evaluate a Python expression safely.
|
||||
Supports both math expressions and string operations like .count()
|
||||
"""
|
||||
# Remove commas from numbers
|
||||
expr = expr.replace(",", "")
|
||||
|
||||
# Check if it's a pure math expression (old behavior)
|
||||
if all([x in "0123456789*+-/.() " for x in expr]):
|
||||
if "**" in expr: # disallow power operator
|
||||
return None
|
||||
return eval_with_timeout(expr)
|
||||
|
||||
# Check if it's a string operation we support
|
||||
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
|
||||
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
|
||||
if not all([x in allowed_chars for x in expr]):
|
||||
return None
|
||||
|
||||
# Disallow dangerous patterns
|
||||
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
|
||||
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
|
||||
'getattr', 'setattr', 'delattr', 'hasattr']
|
||||
expr_lower = expr.lower()
|
||||
if any(pattern in expr_lower for pattern in dangerous_patterns):
|
||||
return None
|
||||
|
||||
# Only allow .count() method for now (can expand later)
|
||||
if '.count(' not in expr:
|
||||
return None
|
||||
|
||||
# Evaluate with timeout
|
||||
return eval_with_timeout(expr)
|
||||
from nanochat.tools import build_default_tool_registry, parse_tool_call_payload
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class KVCache:
|
||||
|
|
@ -162,15 +102,16 @@ class RowState:
|
|||
def __init__(self, current_tokens=None):
|
||||
self.current_tokens = current_tokens or [] # Current token sequence for this row
|
||||
self.forced_tokens = deque() # Queue of tokens to force inject
|
||||
self.in_python_block = False # Whether we are inside a python block
|
||||
self.python_expr_tokens = [] # Tokens of the current python expression
|
||||
self.in_tool_block = False # Whether we are inside a tool-call block
|
||||
self.tool_payload_tokens = [] # Tokens of the current tool-call payload
|
||||
self.completed = False # Whether this row has completed generation
|
||||
|
||||
class Engine:
|
||||
|
||||
def __init__(self, model, tokenizer):
|
||||
def __init__(self, model, tokenizer, tools=None):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer # needed for tool use
|
||||
self.tools = tools or build_default_tool_registry()
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
|
||||
|
|
@ -255,21 +196,34 @@ class Engine:
|
|||
state.completed = True
|
||||
# Handle tool logic
|
||||
if next_token == python_start:
|
||||
state.in_python_block = True
|
||||
state.python_expr_tokens = []
|
||||
elif next_token == python_end and state.in_python_block:
|
||||
state.in_python_block = False
|
||||
if state.python_expr_tokens:
|
||||
expr = self.tokenizer.decode(state.python_expr_tokens)
|
||||
result = use_calculator(expr)
|
||||
if result is not None:
|
||||
result_tokens = self.tokenizer.encode(str(result))
|
||||
state.in_tool_block = True
|
||||
state.tool_payload_tokens = []
|
||||
elif next_token == python_end and state.in_tool_block:
|
||||
state.in_tool_block = False
|
||||
if state.tool_payload_tokens:
|
||||
payload_text = self.tokenizer.decode(state.tool_payload_tokens)
|
||||
invocation = parse_tool_call_payload(payload_text)
|
||||
result = self.tools.execute(invocation.tool_name, invocation.arguments)
|
||||
legacy_calculator_call = (
|
||||
invocation.tool_name == "calculator"
|
||||
and invocation.raw_text.strip()
|
||||
and not invocation.raw_text.strip().startswith("{")
|
||||
)
|
||||
if legacy_calculator_call:
|
||||
if result.success and isinstance(result.output, dict) and "value" in result.output:
|
||||
result_text = str(result.output["value"])
|
||||
else:
|
||||
result_text = result.error or ""
|
||||
else:
|
||||
result_text = result.to_payload()[:4096]
|
||||
if result_text:
|
||||
result_tokens = self.tokenizer.encode(result_text)
|
||||
state.forced_tokens.append(output_start)
|
||||
state.forced_tokens.extend(result_tokens)
|
||||
state.forced_tokens.append(output_end)
|
||||
state.python_expr_tokens = []
|
||||
elif state.in_python_block:
|
||||
state.python_expr_tokens.append(next_token)
|
||||
state.tool_payload_tokens = []
|
||||
elif state.in_tool_block:
|
||||
state.tool_payload_tokens.append(next_token)
|
||||
|
||||
# Yield the token column
|
||||
yield token_column, token_masks
|
||||
|
|
|
|||
|
|
@ -270,6 +270,7 @@ class RustBPETokenizer:
|
|||
- ids: list[int] is a list of token ids of this rendered conversation
|
||||
- mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on.
|
||||
"""
|
||||
from nanochat.tools import serialize_tool_call, serialize_tool_result
|
||||
# ids, masks that we will return and a helper function to help build them up.
|
||||
ids, mask = [], []
|
||||
def add_tokens(token_ids, mask_val):
|
||||
|
|
@ -327,14 +328,29 @@ class RustBPETokenizer:
|
|||
if part["type"] == "text":
|
||||
# string part => simply add the tokens
|
||||
add_tokens(value_ids, 1)
|
||||
elif part["type"] == "python":
|
||||
elif part["type"] in {"python", "tool_call"}:
|
||||
# python tool call => add the tokens inside <|python_start|> and <|python_end|>
|
||||
if part["type"] == "tool_call":
|
||||
payload = serialize_tool_call(
|
||||
tool_name=part["tool_name"],
|
||||
arguments=part.get("arguments") or {},
|
||||
)
|
||||
value_ids = self.encode(payload)
|
||||
add_tokens(python_start, 1)
|
||||
add_tokens(value_ids, 1)
|
||||
add_tokens(python_end, 1)
|
||||
elif part["type"] == "python_output":
|
||||
elif part["type"] in {"python_output", "tool_result"}:
|
||||
# python output => add the tokens inside <|output_start|> and <|output_end|>
|
||||
# none of these tokens are supervised because the tokens come from Python at test time
|
||||
if part["type"] == "tool_result":
|
||||
payload = serialize_tool_result(
|
||||
tool_name=part["tool_name"],
|
||||
output=part.get("output"),
|
||||
success=part.get("success", True),
|
||||
error=part.get("error"),
|
||||
metadata=part.get("metadata"),
|
||||
)
|
||||
value_ids = self.encode(payload)
|
||||
add_tokens(output_start, 0)
|
||||
add_tokens(value_ids, 0)
|
||||
add_tokens(output_end, 0)
|
||||
|
|
|
|||
509
nanochat/tools.py
Normal file
509
nanochat/tools.py
Normal file
|
|
@ -0,0 +1,509 @@
|
|||
"""
|
||||
Shared tool definitions for nanochat.
|
||||
|
||||
The current tokenizer only has python/output special tokens. To preserve
|
||||
checkpoint compatibility, we reuse those tokens as a generic tool-call and
|
||||
tool-result channel. Legacy "python" calculator payloads still work.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol
|
||||
|
||||
import requests
|
||||
|
||||
TOOL_CALL_START = "<|python_start|>"
|
||||
TOOL_CALL_END = "<|python_end|>"
|
||||
TOOL_RESULT_START = "<|output_start|>"
|
||||
TOOL_RESULT_END = "<|output_end|>"
|
||||
MAX_TOOL_PAYLOAD_CHARS = 4096
|
||||
|
||||
DEFAULT_TOOL_SCHEMA = [
|
||||
{
|
||||
"name": "calculator",
|
||||
"description": "Deterministic scientific calculator for exact arithmetic and common finance formulas.",
|
||||
"arguments": {
|
||||
"expression": "String expression using numbers, operators, and supported functions.",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "web_search",
|
||||
"description": "Search and fetch web content. Requires a search backend and optionally a page fetch client.",
|
||||
"arguments": {
|
||||
"query": "Search query string.",
|
||||
"top_k": "Maximum number of results to return.",
|
||||
"urls": "Optional explicit URLs to fetch instead of searching.",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _compact_json(data: Any) -> str:
|
||||
return json.dumps(data, ensure_ascii=True, separators=(",", ":"), sort_keys=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolInvocation:
|
||||
tool_name: str
|
||||
arguments: dict[str, Any]
|
||||
raw_text: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
tool_name: str
|
||||
success: bool
|
||||
output: Any = None
|
||||
error: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_payload(self) -> str:
|
||||
return _compact_json(
|
||||
{
|
||||
"tool": self.tool_name,
|
||||
"success": self.success,
|
||||
"output": self.output,
|
||||
"error": self.error,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class BaseTool:
|
||||
name: str
|
||||
|
||||
def run(self, arguments: dict[str, Any]) -> ToolResult:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
def __init__(self, tools: list[BaseTool] | tuple[BaseTool, ...]):
|
||||
self._tools = {tool.name: tool for tool in tools}
|
||||
|
||||
def execute(self, tool_name: str, arguments: dict[str, Any]) -> ToolResult:
|
||||
tool = self._tools.get(tool_name)
|
||||
if tool is None:
|
||||
return ToolResult(tool_name=tool_name, success=False, error=f"Unknown tool: {tool_name}")
|
||||
try:
|
||||
return tool.run(arguments)
|
||||
except Exception as exc: # defensive: tool failures should become model-visible outputs
|
||||
return ToolResult(tool_name=tool_name, success=False, error=str(exc))
|
||||
|
||||
def schema(self) -> list[dict[str, Any]]:
|
||||
return [item for item in DEFAULT_TOOL_SCHEMA if item["name"] in self._tools]
|
||||
|
||||
|
||||
def serialize_tool_call(tool_name: str, arguments: dict[str, Any] | None = None) -> str:
|
||||
payload = {
|
||||
"tool": tool_name,
|
||||
"arguments": arguments or {},
|
||||
}
|
||||
text = _compact_json(payload)
|
||||
return text[:MAX_TOOL_PAYLOAD_CHARS]
|
||||
|
||||
|
||||
def serialize_tool_result(
|
||||
tool_name: str,
|
||||
output: Any = None,
|
||||
*,
|
||||
success: bool = True,
|
||||
error: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
return ToolResult(
|
||||
tool_name=tool_name,
|
||||
success=success,
|
||||
output=output,
|
||||
error=error,
|
||||
metadata=metadata or {},
|
||||
).to_payload()[:MAX_TOOL_PAYLOAD_CHARS]
|
||||
|
||||
|
||||
def parse_tool_call_payload(text: str) -> ToolInvocation:
|
||||
stripped = text.strip()
|
||||
if not stripped:
|
||||
return ToolInvocation(tool_name="calculator", arguments={"expression": ""}, raw_text=text)
|
||||
try:
|
||||
payload = json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
return ToolInvocation(tool_name="calculator", arguments={"expression": stripped}, raw_text=text)
|
||||
if isinstance(payload, dict):
|
||||
tool_name = payload.get("tool") or payload.get("tool_name") or payload.get("name")
|
||||
arguments = payload.get("arguments") or payload.get("args") or {}
|
||||
if isinstance(tool_name, str) and isinstance(arguments, dict):
|
||||
return ToolInvocation(tool_name=tool_name, arguments=arguments, raw_text=text)
|
||||
return ToolInvocation(tool_name="calculator", arguments={"expression": stripped}, raw_text=text)
|
||||
|
||||
|
||||
def parse_tool_result_payload(text: str) -> ToolResult | None:
|
||||
stripped = text.strip()
|
||||
try:
|
||||
payload = json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
tool_name = payload.get("tool")
|
||||
if not isinstance(tool_name, str):
|
||||
return None
|
||||
return ToolResult(
|
||||
tool_name=tool_name,
|
||||
success=bool(payload.get("success", True)),
|
||||
output=payload.get("output"),
|
||||
error=payload.get("error"),
|
||||
metadata=payload.get("metadata") or {},
|
||||
)
|
||||
|
||||
|
||||
def _percent(value: float, rate: float) -> float:
|
||||
return value * rate / 100.0
|
||||
|
||||
|
||||
def _percent_change(old: float, new: float) -> float:
|
||||
if old == 0:
|
||||
raise ValueError("percent_change old value cannot be zero")
|
||||
return ((new - old) / old) * 100.0
|
||||
|
||||
|
||||
def _cagr(start: float, end: float, years: float) -> float:
|
||||
if start <= 0 or end <= 0 or years <= 0:
|
||||
raise ValueError("cagr inputs must be positive")
|
||||
return ((end / start) ** (1.0 / years) - 1.0) * 100.0
|
||||
|
||||
|
||||
def _simple_interest(principal: float, annual_rate: float, years: float) -> float:
|
||||
return principal * annual_rate / 100.0 * years
|
||||
|
||||
|
||||
def _compound_interest(principal: float, annual_rate: float, periods_per_year: float, years: float) -> float:
|
||||
if periods_per_year <= 0:
|
||||
raise ValueError("periods_per_year must be positive")
|
||||
return principal * (1.0 + annual_rate / 100.0 / periods_per_year) ** (periods_per_year * years)
|
||||
|
||||
|
||||
def _emi(principal: float, annual_rate: float, months: float) -> float:
|
||||
if months <= 0:
|
||||
raise ValueError("months must be positive")
|
||||
monthly_rate = annual_rate / 100.0 / 12.0
|
||||
if monthly_rate == 0:
|
||||
return principal / months
|
||||
growth = (1.0 + monthly_rate) ** months
|
||||
return principal * monthly_rate * growth / (growth - 1.0)
|
||||
|
||||
|
||||
ALLOWED_BINOPS = {
|
||||
ast.Add: lambda a, b: a + b,
|
||||
ast.Sub: lambda a, b: a - b,
|
||||
ast.Mult: lambda a, b: a * b,
|
||||
ast.Div: lambda a, b: a / b,
|
||||
ast.Pow: lambda a, b: a ** b,
|
||||
ast.Mod: lambda a, b: a % b,
|
||||
}
|
||||
ALLOWED_UNARYOPS = {
|
||||
ast.UAdd: lambda a: a,
|
||||
ast.USub: lambda a: -a,
|
||||
}
|
||||
ALLOWED_NAMES = {
|
||||
"pi": math.pi,
|
||||
"e": math.e,
|
||||
"tau": math.tau,
|
||||
}
|
||||
ALLOWED_FUNCTIONS = {
|
||||
"abs": abs,
|
||||
"round": round,
|
||||
"floor": math.floor,
|
||||
"ceil": math.ceil,
|
||||
"sqrt": math.sqrt,
|
||||
"log": math.log,
|
||||
"log10": math.log10,
|
||||
"exp": math.exp,
|
||||
"sin": math.sin,
|
||||
"cos": math.cos,
|
||||
"tan": math.tan,
|
||||
"asin": math.asin,
|
||||
"acos": math.acos,
|
||||
"atan": math.atan,
|
||||
"degrees": math.degrees,
|
||||
"radians": math.radians,
|
||||
"percent": _percent,
|
||||
"percent_change": _percent_change,
|
||||
"cagr": _cagr,
|
||||
"simple_interest": _simple_interest,
|
||||
"compound_interest": _compound_interest,
|
||||
"emi": _emi,
|
||||
}
|
||||
|
||||
|
||||
class _SafeMathEvaluator:
|
||||
def __init__(self, expression: str):
|
||||
self.expression = expression
|
||||
self.node_count = 0
|
||||
|
||||
def eval(self) -> float:
|
||||
if len(self.expression) > 512:
|
||||
raise ValueError("expression too long")
|
||||
parsed = ast.parse(self.expression, mode="eval")
|
||||
return self._visit(parsed.body)
|
||||
|
||||
def _visit(self, node: ast.AST) -> Any:
|
||||
self.node_count += 1
|
||||
if self.node_count > 128:
|
||||
raise ValueError("expression too complex")
|
||||
|
||||
if isinstance(node, ast.Constant):
|
||||
if isinstance(node.value, (int, float)):
|
||||
return node.value
|
||||
raise ValueError(f"unsupported constant: {node.value!r}")
|
||||
if isinstance(node, ast.Num): # pragma: no cover - py<3.8 compatibility
|
||||
return node.n
|
||||
if isinstance(node, ast.BinOp):
|
||||
op = ALLOWED_BINOPS.get(type(node.op))
|
||||
if op is None:
|
||||
raise ValueError(f"unsupported operator: {type(node.op).__name__}")
|
||||
return op(self._visit(node.left), self._visit(node.right))
|
||||
if isinstance(node, ast.UnaryOp):
|
||||
op = ALLOWED_UNARYOPS.get(type(node.op))
|
||||
if op is None:
|
||||
raise ValueError(f"unsupported unary operator: {type(node.op).__name__}")
|
||||
return op(self._visit(node.operand))
|
||||
if isinstance(node, ast.Name):
|
||||
if node.id not in ALLOWED_NAMES:
|
||||
raise ValueError(f"unknown symbol: {node.id}")
|
||||
return ALLOWED_NAMES[node.id]
|
||||
if isinstance(node, ast.Call):
|
||||
if not isinstance(node.func, ast.Name):
|
||||
raise ValueError("only direct function calls are allowed")
|
||||
fn = ALLOWED_FUNCTIONS.get(node.func.id)
|
||||
if fn is None:
|
||||
raise ValueError(f"unsupported function: {node.func.id}")
|
||||
if node.keywords:
|
||||
raise ValueError("keyword arguments are not supported")
|
||||
args = [self._visit(arg) for arg in node.args]
|
||||
return fn(*args)
|
||||
raise ValueError(f"unsupported expression node: {type(node).__name__}")
|
||||
|
||||
|
||||
def _normalize_numeric_output(value: Any) -> Any:
|
||||
if isinstance(value, float):
|
||||
if not math.isfinite(value):
|
||||
raise ValueError("result is not finite")
|
||||
return float(f"{value:.12g}")
|
||||
return value
|
||||
|
||||
|
||||
class CalculatorTool(BaseTool):
|
||||
name = "calculator"
|
||||
|
||||
def run(self, arguments: dict[str, Any]) -> ToolResult:
|
||||
expression = str(arguments.get("expression", "")).strip()
|
||||
if not expression:
|
||||
return ToolResult(tool_name=self.name, success=False, error="Missing expression")
|
||||
value = _SafeMathEvaluator(expression).eval()
|
||||
return ToolResult(
|
||||
tool_name=self.name,
|
||||
success=True,
|
||||
output={"expression": expression, "value": _normalize_numeric_output(value)},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchHit:
|
||||
url: str
|
||||
title: str = ""
|
||||
snippet: str = ""
|
||||
|
||||
|
||||
class SearchBackend(Protocol):
|
||||
def search(self, query: str, top_k: int) -> list[SearchHit]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MockSearchBackend:
|
||||
def __init__(self, canned_results: dict[str, list[dict[str, str]]] | None = None):
|
||||
self.canned_results = canned_results or {
|
||||
"browser rendering markdown": [
|
||||
{
|
||||
"url": "https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/",
|
||||
"title": "Cloudflare markdown endpoint",
|
||||
"snippet": "Extract markdown from a webpage using Cloudflare Browser Rendering.",
|
||||
}
|
||||
],
|
||||
"nanochat gpt2 speedrun": [
|
||||
{
|
||||
"url": "https://github.com/karpathy/nanochat",
|
||||
"title": "karpathy/nanochat",
|
||||
"snippet": "Minimal LLM training harness with pretraining, SFT, RL, and chat UI.",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def search(self, query: str, top_k: int) -> list[SearchHit]:
|
||||
normalized = query.strip().lower()
|
||||
rows = self.canned_results.get(normalized, [])
|
||||
return [SearchHit(**row) for row in rows[:top_k]]
|
||||
|
||||
|
||||
class CloudflareBrowserRenderingClient:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_token: str | None = None,
|
||||
account_id: str | None = None,
|
||||
base_url: str = "https://api.cloudflare.com/client/v4",
|
||||
timeout: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
):
|
||||
self.api_token = api_token or os.environ.get("CLOUDFLARE_API_TOKEN")
|
||||
self.account_id = account_id or os.environ.get("CLOUDFLARE_ACCOUNT_ID")
|
||||
if not self.api_token or not self.account_id:
|
||||
raise ValueError("Cloudflare Browser Rendering requires CLOUDFLARE_API_TOKEN and CLOUDFLARE_ACCOUNT_ID")
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {self.api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
)
|
||||
|
||||
def _post(self, endpoint: str, body: dict[str, Any]) -> Any:
|
||||
url = f"{self.base_url}/accounts/{self.account_id}/browser-rendering/{endpoint}"
|
||||
last_error = None
|
||||
for attempt in range(1, self.max_retries + 1):
|
||||
response = self.session.post(url, json=body, timeout=self.timeout)
|
||||
if response.status_code == 429:
|
||||
retry_after = response.headers.get("Retry-After")
|
||||
sleep_seconds = float(retry_after) if retry_after else float(attempt)
|
||||
last_error = RuntimeError(f"Cloudflare Browser Rendering rate limited on {endpoint}")
|
||||
time.sleep(min(sleep_seconds, 5.0))
|
||||
continue
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
if not payload.get("success", False):
|
||||
errors = payload.get("errors", [])
|
||||
last_error = RuntimeError(f"Cloudflare Browser Rendering request failed: {errors}")
|
||||
break
|
||||
return payload.get("result")
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise RuntimeError(f"Cloudflare Browser Rendering request failed for {endpoint}")
|
||||
|
||||
def markdown(self, url: str, **options: Any) -> str:
|
||||
body = {"url": url}
|
||||
body.update(options)
|
||||
return self._post("markdown", body)
|
||||
|
||||
def links(self, url: str, **options: Any) -> list[str]:
|
||||
body = {"url": url}
|
||||
body.update(options)
|
||||
return self._post("links", body)
|
||||
|
||||
def json_extract(self, url: str, *, prompt: str | None = None, schema: dict[str, Any] | None = None, **options: Any) -> dict[str, Any]:
|
||||
body: dict[str, Any] = {"url": url}
|
||||
if prompt is not None:
|
||||
body["prompt"] = prompt
|
||||
if schema is not None:
|
||||
body["schema"] = schema
|
||||
body.update(options)
|
||||
return self._post("json", body)
|
||||
|
||||
|
||||
class WebSearchTool(BaseTool):
|
||||
name = "web_search"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
search_backend: SearchBackend | None = None,
|
||||
fetch_client: CloudflareBrowserRenderingClient | None = None,
|
||||
max_results: int = 3,
|
||||
):
|
||||
self.search_backend = search_backend
|
||||
self.fetch_client = fetch_client
|
||||
self.max_results = max_results
|
||||
|
||||
def run(self, arguments: dict[str, Any]) -> ToolResult:
|
||||
query = str(arguments.get("query", "")).strip()
|
||||
requested_urls = arguments.get("urls") or []
|
||||
if isinstance(requested_urls, str):
|
||||
requested_urls = [requested_urls]
|
||||
top_k = int(arguments.get("top_k", self.max_results) or self.max_results)
|
||||
top_k = max(1, min(top_k, 8))
|
||||
|
||||
if not query and not requested_urls:
|
||||
return ToolResult(tool_name=self.name, success=False, error="Missing query or urls")
|
||||
|
||||
hits: list[SearchHit]
|
||||
if requested_urls:
|
||||
hits = [SearchHit(url=str(url)) for url in requested_urls[:top_k]]
|
||||
else:
|
||||
if self.search_backend is None:
|
||||
return ToolResult(
|
||||
tool_name=self.name,
|
||||
success=False,
|
||||
error="No search backend configured. Cloudflare Browser Rendering can fetch pages but does not provide public web search by itself.",
|
||||
)
|
||||
hits = self.search_backend.search(query, top_k)
|
||||
|
||||
results = []
|
||||
for hit in hits[:top_k]:
|
||||
entry: dict[str, Any] = {"url": hit.url}
|
||||
if hit.title:
|
||||
entry["title"] = hit.title
|
||||
if hit.snippet:
|
||||
entry["snippet"] = hit.snippet
|
||||
if self.fetch_client is not None:
|
||||
try:
|
||||
markdown = self.fetch_client.markdown(hit.url)
|
||||
links = self.fetch_client.links(hit.url)
|
||||
entry["markdown"] = markdown[:4000]
|
||||
entry["links"] = links[:10]
|
||||
except Exception as exc:
|
||||
entry["fetch_error"] = str(exc)
|
||||
results.append(entry)
|
||||
|
||||
return ToolResult(
|
||||
tool_name=self.name,
|
||||
success=True,
|
||||
output={"query": query, "results": results},
|
||||
metadata={
|
||||
"search_backend": type(self.search_backend).__name__ if self.search_backend is not None else None,
|
||||
"fetch_backend": type(self.fetch_client).__name__ if self.fetch_client is not None else None,
|
||||
"num_results": len(results),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def build_default_tool_registry(
|
||||
*,
|
||||
cloudflare_token: str | None = None,
|
||||
cloudflare_account_id: str | None = None,
|
||||
search_backend: SearchBackend | None = None,
|
||||
) -> ToolRegistry:
|
||||
fetch_client = None
|
||||
if cloudflare_token or os.environ.get("CLOUDFLARE_API_TOKEN"):
|
||||
try:
|
||||
fetch_client = CloudflareBrowserRenderingClient(
|
||||
api_token=cloudflare_token,
|
||||
account_id=cloudflare_account_id,
|
||||
)
|
||||
except Exception:
|
||||
fetch_client = None
|
||||
registry = ToolRegistry(
|
||||
[
|
||||
CalculatorTool(),
|
||||
WebSearchTool(
|
||||
search_backend=search_backend or MockSearchBackend(),
|
||||
fetch_client=fetch_client,
|
||||
),
|
||||
]
|
||||
)
|
||||
return registry
|
||||
|
|
@ -7,6 +7,7 @@ requires-python = ">=3.10"
|
|||
dependencies = [
|
||||
"datasets>=4.0.0",
|
||||
"fastapi>=0.117.1",
|
||||
"huggingface_hub>=0.34.4",
|
||||
"ipykernel>=7.1.0",
|
||||
"kernels>=0.11.7",
|
||||
"matplotlib>=3.10.8",
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
|||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.flash_attention import HAS_FA3
|
||||
from nanochat.tools import DEFAULT_TOOL_SCHEMA
|
||||
from scripts.base_eval import evaluate_core
|
||||
print_banner()
|
||||
|
||||
|
|
@ -481,9 +482,12 @@ while True:
|
|||
optimizer.state_dict(), # optimizer state
|
||||
{ # metadata saved as json
|
||||
"step": step,
|
||||
"stage": "base",
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": model_config_kwargs,
|
||||
"user_config": user_config, # inputs to the training script
|
||||
"tool_schema": DEFAULT_TOOL_SCHEMA,
|
||||
"source_hf_repo": os.environ.get("NANOCHAT_SOURCE_HF_REPO"),
|
||||
"device_batch_size": args.device_batch_size,
|
||||
"max_seq_len": args.max_seq_len,
|
||||
"total_batch_size": total_batch_size,
|
||||
|
|
|
|||
167
scripts/build_tool_datasets.py
Normal file
167
scripts/build_tool_datasets.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
"""
|
||||
Build small local tool-use datasets for SFT/eval dry runs.
|
||||
|
||||
This is not intended to replace larger curated corpora. It creates schema-valid
|
||||
seed data so tool routing, tokenization, and evaluation can be tested locally.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
def write_jsonl(path, rows):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
for row in rows:
|
||||
f.write(json.dumps(row, ensure_ascii=True) + "\n")
|
||||
|
||||
|
||||
def calculator_conversation(user_text, expression, value, final_text):
|
||||
return [
|
||||
{"role": "user", "content": user_text},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Let me compute that exactly. "},
|
||||
{"type": "tool_call", "tool_name": "calculator", "arguments": {"expression": expression}},
|
||||
{"type": "tool_result", "tool_name": "calculator", "output": {"expression": expression, "value": value}},
|
||||
{"type": "text", "text": final_text},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def web_search_conversation(user_text, query, url, final_text):
|
||||
return [
|
||||
{"role": "user", "content": user_text},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "I should verify this with a fresh web lookup. "},
|
||||
{"type": "tool_call", "tool_name": "web_search", "arguments": {"query": query, "top_k": 1}},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_name": "web_search",
|
||||
"output": {
|
||||
"query": query,
|
||||
"results": [
|
||||
{
|
||||
"url": url,
|
||||
"title": "Source result",
|
||||
"snippet": "Fresh page content retrieved for grounding.",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": final_text},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def direct_conversation(user_text, answer_text):
|
||||
return [
|
||||
{"role": "user", "content": user_text},
|
||||
{"role": "assistant", "content": answer_text},
|
||||
]
|
||||
|
||||
|
||||
def build_train_rows():
|
||||
return [
|
||||
direct_conversation(
|
||||
"Explain what gradient descent is in one short paragraph.",
|
||||
"Gradient descent is an optimization method that repeatedly nudges model parameters in the direction that reduces error, using the gradient to decide both direction and magnitude of each update.",
|
||||
),
|
||||
calculator_conversation(
|
||||
"What is 18% of 250?",
|
||||
"percent(250,18)",
|
||||
45.0,
|
||||
"18% of 250 is 45.",
|
||||
),
|
||||
calculator_conversation(
|
||||
"If revenue grew from 120 to 150, what is the percent change?",
|
||||
"percent_change(120,150)",
|
||||
25.0,
|
||||
"The percent change from 120 to 150 is 25%.",
|
||||
),
|
||||
calculator_conversation(
|
||||
"What is the monthly EMI for a 500000 loan at 8% annual interest over 240 months?",
|
||||
"emi(500000,8,240)",
|
||||
4182.198594391402,
|
||||
"The monthly EMI is about 4182.2.",
|
||||
),
|
||||
web_search_conversation(
|
||||
"Find the official nanochat repository and give me the link.",
|
||||
"official nanochat repository",
|
||||
"https://github.com/karpathy/nanochat",
|
||||
"The official nanochat repository is https://github.com/karpathy/nanochat",
|
||||
),
|
||||
web_search_conversation(
|
||||
"What do the Cloudflare Browser Rendering docs say about the markdown endpoint?",
|
||||
"Cloudflare Browser Rendering markdown endpoint docs",
|
||||
"https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/",
|
||||
"Cloudflare documents a markdown endpoint for extracting rendered page content: https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def build_eval_rows():
|
||||
return [
|
||||
{
|
||||
"conversation": {
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is 12% of 250?"},
|
||||
{"role": "assistant", "content": "45"},
|
||||
]
|
||||
},
|
||||
"checks": {
|
||||
"must_call": "calculator",
|
||||
"must_not_call": ["web_search"],
|
||||
"answer_contains": ["45"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"conversation": {
|
||||
"messages": [
|
||||
{"role": "user", "content": "Find the official nanochat repository and cite the link."},
|
||||
{"role": "assistant", "content": "https://github.com/karpathy/nanochat"},
|
||||
]
|
||||
},
|
||||
"checks": {
|
||||
"must_call": "web_search",
|
||||
"citation_required": True,
|
||||
"answer_contains": ["github.com/karpathy/nanochat"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"conversation": {
|
||||
"messages": [
|
||||
{"role": "user", "content": "Explain what overfitting means in one sentence."},
|
||||
{"role": "assistant", "content": "Overfitting means a model memorizes training patterns too closely and generalizes poorly to new data."},
|
||||
]
|
||||
},
|
||||
"checks": {
|
||||
"must_not_call": ["web_search"],
|
||||
"answer_contains": ["generalizes poorly"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Build seed tool-use datasets")
|
||||
parser.add_argument("--train-out", default="seed_data/tool_sft_seed.jsonl", help="Output JSONL for SFT conversations")
|
||||
parser.add_argument("--eval-out", default="seed_data/tool_eval_seed.jsonl", help="Output JSONL for eval/reward objects")
|
||||
args = parser.parse_args()
|
||||
|
||||
train_rows = build_train_rows()
|
||||
eval_rows = build_eval_rows()
|
||||
write_jsonl(args.train_out, train_rows)
|
||||
write_jsonl(args.eval_out, eval_rows)
|
||||
print(f"Wrote {len(train_rows)} SFT rows to {args.train_out}")
|
||||
print(f"Wrote {len(eval_rows)} eval rows to {args.eval_out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -22,6 +22,7 @@ from tasks.mmlu import MMLU
|
|||
from tasks.arc import ARC
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.spellingbee import SpellingBee
|
||||
from tasks.tool_json import ToolJSON
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Generative evaluation loop (we go one problem at a time, sample, evaluate)
|
||||
|
|
@ -156,8 +157,10 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
|
|||
|
||||
def run_chat_eval(task_name, model, tokenizer, engine,
|
||||
batch_size=1, num_samples=1, max_new_tokens=512, temperature=0.0, top_k=50,
|
||||
max_problems=None):
|
||||
max_problems=None, tool_jsonl=None):
|
||||
# Create the evaluation object
|
||||
if task_name == "ToolJSON" and not tool_jsonl:
|
||||
raise ValueError("ToolJSON evaluation requires --tool-jsonl")
|
||||
task_module = {
|
||||
'HumanEval': HumanEval,
|
||||
'MMLU': partial(MMLU, subset="all", split="test"),
|
||||
|
|
@ -165,6 +168,7 @@ def run_chat_eval(task_name, model, tokenizer, engine,
|
|||
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
|
||||
'GSM8K': partial(GSM8K, subset="main", split="test"),
|
||||
'SpellingBee': partial(SpellingBee, size=256, split="test"),
|
||||
'ToolJSON': partial(ToolJSON, filepath=tool_jsonl),
|
||||
}[task_name]
|
||||
task_object = task_module()
|
||||
# Run the evaluation
|
||||
|
|
@ -191,6 +195,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
|
||||
parser.add_argument('--tool-jsonl', type=str, default=None, help='Tool JSONL dataset path used when task-name includes ToolJSON')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -224,6 +229,7 @@ if __name__ == "__main__":
|
|||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
max_problems=args.max_problems,
|
||||
tool_jsonl=args.tool_jsonl,
|
||||
)
|
||||
results[task_name] = acc
|
||||
print0(f"{task_name} accuracy: {100 * acc:.2f}%")
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ import torch.distributed as dist
|
|||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.tools import DEFAULT_TOOL_SCHEMA
|
||||
from tasks.gsm8k import GSM8K
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -317,7 +318,12 @@ for step in range(num_steps):
|
|||
model.state_dict(),
|
||||
None, # note: we don't bother to save the optimizer state
|
||||
{
|
||||
"step": step,
|
||||
"stage": "rl",
|
||||
"model_config": model_config_kwargs,
|
||||
"user_config": user_config,
|
||||
"tool_schema": DEFAULT_TOOL_SCHEMA,
|
||||
"source_hf_repo": meta.get("source_hf_repo"),
|
||||
}
|
||||
)
|
||||
print(f"✅ Saved model checkpoint to {checkpoint_dir}")
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from tasks.mmlu import MMLU
|
|||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
||||
from nanochat.tools import DEFAULT_TOOL_SCHEMA
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# CLI arguments
|
||||
|
|
@ -66,6 +67,8 @@ parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max pro
|
|||
# Data mixture
|
||||
parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)")
|
||||
parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)")
|
||||
parser.add_argument("--extra-train-jsonl", type=str, default="", help="comma-separated JSONL conversation files to append to the SFT train mixture")
|
||||
parser.add_argument("--extra-val-jsonl", type=str, default="", help="comma-separated JSONL conversation files to append to the SFT val mixture")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -162,6 +165,8 @@ for group in optimizer.param_groups:
|
|||
|
||||
# SFT data mixture and DataLoader
|
||||
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
|
||||
extra_train_jsonl = [path for path in args.extra_train_jsonl.split(",") if path]
|
||||
extra_val_jsonl = [path for path in args.extra_val_jsonl.split(",") if path]
|
||||
train_tasks = [
|
||||
SmolTalk(split="train"), # 460K rows of general conversations
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
||||
|
|
@ -171,13 +176,17 @@ train_tasks = [
|
|||
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]
|
||||
train_tasks.extend(CustomJSON(filepath=path) for path in extra_train_jsonl)
|
||||
train_dataset = TaskMixture(train_tasks)
|
||||
print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})")
|
||||
print0(
|
||||
f"Training mixture: {len(train_dataset):,} rows "
|
||||
f"(MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs}, extra_jsonl={len(extra_train_jsonl)})"
|
||||
)
|
||||
val_dataset = TaskMixture([
|
||||
SmolTalk(split="test"), # 24K rows in test set
|
||||
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
||||
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
|
||||
]) # total: 24K + 14K + 1.32K ~= 39K rows
|
||||
] + [CustomJSON(filepath=path) for path in extra_val_jsonl]) # total: 24K + 14K + 1.32K ~= 39K rows plus custom validation
|
||||
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
|
||||
# A big problem is that we don't know the final num_iterations in advance. So we create
|
||||
# these two global variables and update them from within the data generator.
|
||||
|
|
@ -406,6 +415,7 @@ while True:
|
|||
optimizer.state_dict(),
|
||||
{
|
||||
"step": step,
|
||||
"stage": "sft",
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": {
|
||||
"sequence_len": args.max_seq_len,
|
||||
|
|
@ -417,6 +427,10 @@ while True:
|
|||
"window_pattern": model.config.window_pattern,
|
||||
},
|
||||
"user_config": user_config, # inputs to the training script
|
||||
"tool_schema": DEFAULT_TOOL_SCHEMA,
|
||||
"source_hf_repo": meta.get("source_hf_repo"),
|
||||
"extra_train_jsonl": extra_train_jsonl,
|
||||
"extra_val_jsonl": extra_val_jsonl,
|
||||
},
|
||||
rank=ddp_rank,
|
||||
)
|
||||
|
|
|
|||
219
scripts/chat_tool_rl.py
Normal file
219
scripts/chat_tool_rl.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
"""
|
||||
Lightweight RL stage for tool-use tuning on local ToolJSON datasets.
|
||||
|
||||
1 GPU:
|
||||
python -m scripts.chat_tool_rl --train-jsonl seed_data/tool_eval_seed.jsonl --eval-jsonl seed_data/tool_eval_seed.jsonl
|
||||
|
||||
8 GPUs:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_tool_rl -- --train-jsonl seed_data/tool_eval_seed.jsonl --eval-jsonl seed_data/tool_eval_seed.jsonl
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
|
||||
from nanochat.checkpoint_manager import load_model, save_checkpoint
|
||||
from nanochat.common import DummyWandb, autodetect_device_type, compute_cleanup, compute_init, get_base_dir, print0
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.tools import DEFAULT_TOOL_SCHEMA
|
||||
from tasks.tool_json import ToolJSON
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="RL tuning on local tool-use JSONL tasks")
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="SFT model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="SFT model step to load from")
|
||||
parser.add_argument("--train-jsonl", type=str, required=True, help="Training ToolJSONL file")
|
||||
parser.add_argument("--eval-jsonl", type=str, default=None, help="Evaluation ToolJSONL file")
|
||||
parser.add_argument("--num-epochs", type=int, default=1, help="Number of epochs over ToolJSON")
|
||||
parser.add_argument("--device-batch-size", type=int, default=8, help="Max batch size per forward pass")
|
||||
parser.add_argument("--examples-per-step", type=int, default=8, help="Examples per optimization step across all ranks")
|
||||
parser.add_argument("--num-samples", type=int, default=8, help="Number of samples per example")
|
||||
parser.add_argument("--max-new-tokens", type=int, default=256, help="Max tokens to generate")
|
||||
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
|
||||
parser.add_argument("--top-k", type=int, default=50, help="Top-k sampling (0 disables)")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.2, help="Embedding LR")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="Unembedding LR")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="Matrix LR")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.0, help="Weight decay")
|
||||
parser.add_argument("--init-lr-frac", type=float, default=0.05, help="Initial LR multiplier")
|
||||
parser.add_argument("--eval-every", type=int, default=40, help="Evaluate every N steps")
|
||||
parser.add_argument("--eval-examples", type=int, default=64, help="Maximum eval examples")
|
||||
parser.add_argument("--save-every", type=int, default=40, help="Save every N steps")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-tool-rl", name=args.run, config=user_config)
|
||||
|
||||
model, tokenizer, meta = load_model("sft", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
||||
engine = Engine(model, tokenizer)
|
||||
|
||||
train_task = ToolJSON(filepath=args.train_jsonl)
|
||||
eval_task = ToolJSON(filepath=args.eval_jsonl or args.train_jsonl)
|
||||
num_steps = (len(train_task) // args.examples_per_step) * args.num_epochs
|
||||
print0(f"Calculated number of steps: {num_steps}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_batch():
|
||||
assistant_end = tokenizer.encode_special("<|assistant_end|>")
|
||||
rank_indices = range(ddp_rank, len(train_task), ddp_world_size)
|
||||
for example_idx in itertools.cycle(rank_indices):
|
||||
conversation = train_task[example_idx]
|
||||
tokens = tokenizer.render_for_completion(conversation)
|
||||
prefix_length = len(tokens)
|
||||
model.eval()
|
||||
generated_token_sequences = []
|
||||
masks = []
|
||||
num_sampling_steps = args.num_samples // args.device_batch_size
|
||||
if args.num_samples % args.device_batch_size != 0:
|
||||
raise ValueError("num_samples must be divisible by device_batch_size")
|
||||
for sampling_step in range(num_sampling_steps):
|
||||
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF
|
||||
generated_batch, mask_batch = engine.generate_batch(
|
||||
tokens,
|
||||
num_samples=args.device_batch_size,
|
||||
max_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
seed=seed,
|
||||
)
|
||||
generated_token_sequences.extend(generated_batch)
|
||||
masks.extend(mask_batch)
|
||||
|
||||
rewards = []
|
||||
for sample_tokens in generated_token_sequences:
|
||||
generated_text = tokenizer.decode(sample_tokens[prefix_length:])
|
||||
rewards.append(train_task.reward(conversation, generated_text))
|
||||
|
||||
max_length = max(len(seq) for seq in generated_token_sequences)
|
||||
padded_sequences = [seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences]
|
||||
padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks]
|
||||
ids = torch.tensor(padded_sequences, dtype=torch.long, device=device)
|
||||
mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device)
|
||||
inputs = ids[:, :-1]
|
||||
targets = ids[:, 1:].clone()
|
||||
targets[mask_ids[:, 1:] == 0] = -1
|
||||
rewards_tensor = torch.tensor(rewards, dtype=torch.float, device=device)
|
||||
advantages = rewards_tensor - rewards_tensor.mean()
|
||||
yield generated_token_sequences, inputs, targets, rewards_tensor, advantages
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_tool_eval(task, max_examples):
|
||||
max_examples = min(max_examples, len(task))
|
||||
total = 0
|
||||
reward_sum = 0.0
|
||||
passed = 0
|
||||
for idx in range(ddp_rank, max_examples, ddp_world_size):
|
||||
conversation = task[idx]
|
||||
tokens = tokenizer.render_for_completion(conversation)
|
||||
generated_sequences, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=args.max_new_tokens, temperature=0.0, top_k=args.top_k)
|
||||
generated_text = tokenizer.decode(generated_sequences[0][len(tokens):])
|
||||
reward_sum += task.reward(conversation, generated_text)
|
||||
passed += task.evaluate(conversation, generated_text)
|
||||
total += 1
|
||||
|
||||
reward_tensor = torch.tensor([reward_sum], dtype=torch.float, device=device)
|
||||
passed_tensor = torch.tensor([passed], dtype=torch.long, device=device)
|
||||
total_tensor = torch.tensor([total], dtype=torch.long, device=device)
|
||||
if ddp:
|
||||
dist.all_reduce(reward_tensor, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(passed_tensor, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
|
||||
total = max(total_tensor.item(), 1)
|
||||
return reward_tensor.item() / total, passed_tensor.item() / total
|
||||
|
||||
|
||||
optimizer = model.setup_optimizer(
|
||||
unembedding_lr=args.unembedding_lr,
|
||||
embedding_lr=args.embedding_lr,
|
||||
matrix_lr=args.matrix_lr,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
|
||||
def get_lr_multiplier(it):
|
||||
return 1.0 - it / max(num_steps, 1)
|
||||
|
||||
|
||||
assert args.examples_per_step % ddp_world_size == 0, "examples_per_step must be divisible by number of ranks"
|
||||
examples_per_rank = args.examples_per_step // ddp_world_size
|
||||
batch_iterator = get_batch()
|
||||
|
||||
for step in range(num_steps):
|
||||
if step % args.eval_every == 0:
|
||||
model.eval()
|
||||
mean_reward, pass_rate = run_tool_eval(eval_task, args.eval_examples)
|
||||
print0(f"Step {step} | tool_eval_reward={mean_reward:.4f} | tool_eval_pass={pass_rate:.4f}")
|
||||
wandb_run.log({"step": step, "tool_eval_reward": mean_reward, "tool_eval_pass": pass_rate})
|
||||
|
||||
rewards_list = []
|
||||
for example_step in range(examples_per_rank):
|
||||
_, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator)
|
||||
model.train()
|
||||
assert inputs_all.size(0) % args.device_batch_size == 0
|
||||
num_passes = inputs_all.size(0) // args.device_batch_size
|
||||
for pass_idx in range(num_passes):
|
||||
b0, b1 = pass_idx * args.device_batch_size, (pass_idx + 1) * args.device_batch_size
|
||||
inputs = inputs_all[b0:b1]
|
||||
targets = targets_all[b0:b1]
|
||||
rewards = rewards_all[b0:b1]
|
||||
advantages = advantages_all[b0:b1]
|
||||
logp = -model(inputs, targets, loss_reduction="none").view_as(inputs)
|
||||
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
|
||||
num_valid = (targets >= 0).sum().clamp(min=1)
|
||||
pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank)
|
||||
loss = -pg_obj
|
||||
loss.backward()
|
||||
print0(
|
||||
f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} "
|
||||
f"| loss={loss.item():.6f} | reward={rewards.mean().item():.4f}"
|
||||
)
|
||||
rewards_list.append(rewards_all.mean().item())
|
||||
|
||||
lrm = get_lr_multiplier(step)
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
wandb_run.log({"step": step, "lrm": lrm, "mean_reward": sum(rewards_list) / max(len(rewards_list), 1)})
|
||||
|
||||
if master_process and ((step > 0 and step % args.save_every == 0) or step == num_steps - 1):
|
||||
base_dir = get_base_dir()
|
||||
output_dirname = args.model_tag if args.model_tag else f"d{model.config.n_layer}"
|
||||
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname)
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
model.state_dict(),
|
||||
None,
|
||||
{
|
||||
"step": step,
|
||||
"stage": "tool_rl",
|
||||
"model_config": model.config.__dict__,
|
||||
"user_config": user_config,
|
||||
"tool_schema": DEFAULT_TOOL_SCHEMA,
|
||||
"source_hf_repo": meta.get("source_hf_repo"),
|
||||
"train_jsonl": args.train_jsonl,
|
||||
"eval_jsonl": args.eval_jsonl,
|
||||
},
|
||||
)
|
||||
print0(f"Saved tool RL checkpoint to {checkpoint_dir}")
|
||||
|
||||
wandb_run.finish()
|
||||
compute_cleanup()
|
||||
86
scripts/hf_sync_checkpoint.py
Normal file
86
scripts/hf_sync_checkpoint.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
"""
|
||||
Upload native nanochat checkpoints and related artifacts to a Hugging Face repo.
|
||||
|
||||
Examples:
|
||||
python -m scripts.hf_sync_checkpoint --repo-id ManmohanSharma/nanochat-d24 --source base --model-tag d24_hf_import
|
||||
python -m scripts.hf_sync_checkpoint --repo-id ManmohanSharma/nanochat-d24 --source base --model-tag d24_hf_import --step 0
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from nanochat.common import get_base_dir
|
||||
|
||||
|
||||
def resolve_checkpoint_dir(source, model_tag):
|
||||
phase_dir = {
|
||||
"base": "base_checkpoints",
|
||||
"sft": "chatsft_checkpoints",
|
||||
"rl": "chatrl_checkpoints",
|
||||
}[source]
|
||||
return os.path.join(get_base_dir(), phase_dir, model_tag)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Upload native nanochat checkpoints to Hugging Face")
|
||||
parser.add_argument("--repo-id", required=True, help="Destination HF model repo")
|
||||
parser.add_argument("--source", choices=["base", "sft", "rl"], required=True, help="Checkpoint phase")
|
||||
parser.add_argument("--model-tag", required=True, help="Local nanochat model tag")
|
||||
parser.add_argument("--step", type=int, default=None, help="Optional specific step to upload")
|
||||
parser.add_argument("--token-env", default="HF_TOKEN", help="Environment variable containing the HF token")
|
||||
parser.add_argument("--private", type=int, default=0, help="Create the repo as private if it does not exist")
|
||||
parser.add_argument("--repo-subdir", default="native_checkpoints", help="Subdirectory inside the repo")
|
||||
args = parser.parse_args()
|
||||
|
||||
token = os.environ.get(args.token_env)
|
||||
if not token:
|
||||
raise ValueError(f"Missing Hugging Face token in {args.token_env}")
|
||||
|
||||
checkpoint_dir = resolve_checkpoint_dir(args.source, args.model_tag)
|
||||
if not os.path.isdir(checkpoint_dir):
|
||||
raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}")
|
||||
|
||||
api = HfApi(token=token)
|
||||
api.create_repo(repo_id=args.repo_id, repo_type="model", private=bool(args.private), exist_ok=True)
|
||||
|
||||
if args.step is None:
|
||||
path_in_repo = f"{args.repo_subdir}/{args.source}/{args.model_tag}"
|
||||
api.upload_folder(
|
||||
folder_path=checkpoint_dir,
|
||||
repo_id=args.repo_id,
|
||||
repo_type="model",
|
||||
path_in_repo=path_in_repo,
|
||||
commit_message=f"Upload native {args.source} checkpoint folder for {args.model_tag}",
|
||||
)
|
||||
print(f"Uploaded {checkpoint_dir} to {args.repo_id}:{path_in_repo}")
|
||||
return
|
||||
|
||||
step_str = f"{args.step:06d}"
|
||||
files = [
|
||||
f"model_{step_str}.pt",
|
||||
f"meta_{step_str}.json",
|
||||
]
|
||||
optimizer_pattern = f"optim_{step_str}_"
|
||||
for filename in sorted(os.listdir(checkpoint_dir)):
|
||||
if filename.startswith(optimizer_pattern) and filename.endswith(".pt"):
|
||||
files.append(filename)
|
||||
|
||||
for filename in files:
|
||||
local_path = os.path.join(checkpoint_dir, filename)
|
||||
if not os.path.exists(local_path):
|
||||
continue
|
||||
path_in_repo = f"{args.repo_subdir}/{args.source}/{args.model_tag}/{filename}"
|
||||
api.upload_file(
|
||||
path_or_fileobj=local_path,
|
||||
repo_id=args.repo_id,
|
||||
repo_type="model",
|
||||
path_in_repo=path_in_repo,
|
||||
commit_message=f"Upload {args.source} checkpoint {args.model_tag} step {step_str}",
|
||||
)
|
||||
print(f"Uploaded step {step_str} for {args.model_tag} to {args.repo_id}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
162
scripts/import_hf_checkpoint.py
Normal file
162
scripts/import_hf_checkpoint.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
"""
|
||||
Import a Hugging Face model repo into nanochat's native checkpoint format.
|
||||
|
||||
This is intended for base-model continuation before multi-stage nanochat runs.
|
||||
|
||||
Examples:
|
||||
python -m scripts.import_hf_checkpoint --repo-id ManmohanSharma/nanochat-d24 --model-tag d24_hf_import
|
||||
python -m scripts.import_hf_checkpoint --repo-id ManmohanSharma/nanochat-d24 --local-dir /path/to/snapshot
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.common import get_base_dir
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
from nanochat.tools import DEFAULT_TOOL_SCHEMA
|
||||
|
||||
|
||||
def normalize_hf_state_dict_keys(state_dict):
|
||||
normalized = {}
|
||||
prefixes = ("_orig_mod.", "module.", "model.")
|
||||
for key, value in state_dict.items():
|
||||
normalized_key = key
|
||||
for prefix in prefixes:
|
||||
if normalized_key.startswith(prefix):
|
||||
normalized_key = normalized_key[len(prefix):]
|
||||
normalized[normalized_key] = value
|
||||
return normalized
|
||||
|
||||
|
||||
def infer_gpt_config(hf_config):
|
||||
kwargs = {
|
||||
"sequence_len": getattr(
|
||||
hf_config,
|
||||
"sequence_len",
|
||||
getattr(hf_config, "max_position_embeddings", getattr(hf_config, "n_positions", 2048)),
|
||||
),
|
||||
"vocab_size": getattr(hf_config, "vocab_size"),
|
||||
"n_layer": getattr(
|
||||
hf_config,
|
||||
"n_layer",
|
||||
getattr(hf_config, "num_hidden_layers", getattr(hf_config, "num_layers")),
|
||||
),
|
||||
"n_head": getattr(
|
||||
hf_config,
|
||||
"n_head",
|
||||
getattr(hf_config, "num_attention_heads", None),
|
||||
),
|
||||
"n_kv_head": getattr(
|
||||
hf_config,
|
||||
"n_kv_head",
|
||||
getattr(hf_config, "num_key_value_heads", getattr(hf_config, "num_attention_heads", None)),
|
||||
),
|
||||
"n_embd": getattr(
|
||||
hf_config,
|
||||
"n_embd",
|
||||
getattr(hf_config, "hidden_size", getattr(hf_config, "d_model", None)),
|
||||
),
|
||||
"window_pattern": getattr(hf_config, "window_pattern", "L"),
|
||||
}
|
||||
missing = [key for key, value in kwargs.items() if value is None]
|
||||
if missing:
|
||||
raise ValueError(f"Could not infer nanochat GPTConfig fields from HF config: {missing}")
|
||||
return GPTConfig(**kwargs)
|
||||
|
||||
|
||||
def verify_tokenizer_compatibility(hf_tokenizer, nanochat_tokenizer):
|
||||
hf_vocab = hf_tokenizer.vocab_size
|
||||
local_vocab = nanochat_tokenizer.get_vocab_size()
|
||||
if hf_vocab != local_vocab:
|
||||
raise ValueError(
|
||||
f"Tokenizer vocab mismatch: HF repo has vocab_size={hf_vocab}, "
|
||||
f"local nanochat tokenizer has vocab_size={local_vocab}"
|
||||
)
|
||||
|
||||
|
||||
def load_hf_snapshot(repo_id, revision, token, local_dir):
|
||||
if local_dir is not None:
|
||||
return local_dir
|
||||
return snapshot_download(
|
||||
repo_id=repo_id,
|
||||
revision=revision,
|
||||
token=token,
|
||||
repo_type="model",
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Import HF repo into native nanochat checkpoints")
|
||||
parser.add_argument("--repo-id", required=True, help="HF model repo id, e.g. ManmohanSharma/nanochat-d24")
|
||||
parser.add_argument("--revision", default=None, help="Optional HF revision")
|
||||
parser.add_argument("--local-dir", default=None, help="Use an already-downloaded HF snapshot instead of downloading")
|
||||
parser.add_argument("--token-env", default="HF_TOKEN", help="Environment variable containing the HF token")
|
||||
parser.add_argument("--model-tag", default=None, help="Destination model tag. Defaults to repo name slug")
|
||||
parser.add_argument("--step", type=int, default=0, help="Checkpoint step number to write")
|
||||
parser.add_argument("--source", choices=["base", "sft", "rl"], default="base", help="Destination checkpoint phase")
|
||||
parser.add_argument("--trust-remote-code", type=int, default=1, help="Pass trust_remote_code to Transformers loaders")
|
||||
args = parser.parse_args()
|
||||
|
||||
token = os.environ.get(args.token_env)
|
||||
snapshot_path = load_hf_snapshot(args.repo_id, args.revision, token, args.local_dir)
|
||||
trust_remote_code = bool(args.trust_remote_code)
|
||||
|
||||
hf_config = AutoConfig.from_pretrained(snapshot_path, token=token, trust_remote_code=trust_remote_code)
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(snapshot_path, token=token, trust_remote_code=trust_remote_code)
|
||||
nanochat_tokenizer = get_tokenizer()
|
||||
verify_tokenizer_compatibility(hf_tokenizer, nanochat_tokenizer)
|
||||
|
||||
local_config = infer_gpt_config(hf_config)
|
||||
with torch.no_grad():
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
snapshot_path,
|
||||
token=token,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
hf_state_dict = normalize_hf_state_dict_keys(model.state_dict())
|
||||
|
||||
with torch.device("meta"):
|
||||
local_model = GPT(local_config)
|
||||
expected_keys = set(local_model.state_dict().keys())
|
||||
provided_keys = set(hf_state_dict.keys())
|
||||
missing = sorted(expected_keys - provided_keys)
|
||||
extra = sorted(provided_keys - expected_keys)
|
||||
if missing or extra:
|
||||
message = [
|
||||
"HF checkpoint keys do not match native nanochat keys after normalization.",
|
||||
f"Missing keys: {missing[:12]}",
|
||||
f"Extra keys: {extra[:12]}",
|
||||
]
|
||||
raise ValueError("\n".join(message))
|
||||
|
||||
model_data = {key: value.detach().cpu() for key, value in hf_state_dict.items()}
|
||||
meta_data = {
|
||||
"model_config": asdict(local_config),
|
||||
"imported_from_hf": True,
|
||||
"source_hf_repo": args.repo_id,
|
||||
"source_hf_revision": args.revision,
|
||||
"tool_schema": DEFAULT_TOOL_SCHEMA,
|
||||
"tokenizer_vocab_size": nanochat_tokenizer.get_vocab_size(),
|
||||
}
|
||||
|
||||
model_tag = args.model_tag or args.repo_id.split("/")[-1].replace("-", "_")
|
||||
base_dir = get_base_dir()
|
||||
phase_dir = {
|
||||
"base": "base_checkpoints",
|
||||
"sft": "chatsft_checkpoints",
|
||||
"rl": "chatrl_checkpoints",
|
||||
}[args.source]
|
||||
checkpoint_dir = os.path.join(base_dir, phase_dir, model_tag)
|
||||
save_checkpoint(checkpoint_dir, args.step, model_data, optimizer_data=None, meta_data=meta_data, rank=0)
|
||||
print(f"Imported {args.repo_id} into {checkpoint_dir} at step {args.step}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
60
scripts/verify_external_access.py
Normal file
60
scripts/verify_external_access.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
"""
|
||||
Verify external service access from environment variables.
|
||||
|
||||
Expected env vars:
|
||||
- HF_TOKEN
|
||||
- CLOUDFLARE_API_TOKEN
|
||||
- CLOUDFLARE_ACCOUNT_ID
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
|
||||
def verify_hf(repo_id, token_env):
|
||||
token = os.environ.get(token_env)
|
||||
if not token:
|
||||
return False, f"Missing {token_env}"
|
||||
api = HfApi(token=token)
|
||||
info = api.model_info(repo_id)
|
||||
return True, f"HF access OK: {info.id}"
|
||||
|
||||
|
||||
def verify_cloudflare(token_env):
|
||||
token = os.environ.get(token_env)
|
||||
if not token:
|
||||
return False, f"Missing {token_env}"
|
||||
response = requests.get(
|
||||
"https://api.cloudflare.com/client/v4/user/tokens/verify",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=20,
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
if not payload.get("success", False):
|
||||
return False, f"Cloudflare verify failed: {payload.get('errors')}"
|
||||
status = payload.get("result", {}).get("status")
|
||||
return True, f"Cloudflare token OK: status={status}"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Verify Hugging Face and Cloudflare access from env vars")
|
||||
parser.add_argument("--hf-repo-id", default="ManmohanSharma/nanochat-d24", help="HF repo id to verify")
|
||||
parser.add_argument("--hf-token-env", default="HF_TOKEN", help="HF token env var")
|
||||
parser.add_argument("--cloudflare-token-env", default="CLOUDFLARE_API_TOKEN", help="Cloudflare token env var")
|
||||
args = parser.parse_args()
|
||||
|
||||
hf_ok, hf_message = verify_hf(args.hf_repo_id, args.hf_token_env)
|
||||
cf_ok, cf_message = verify_cloudflare(args.cloudflare_token_env)
|
||||
|
||||
print(hf_message)
|
||||
print(cf_message)
|
||||
if not (hf_ok and cf_ok):
|
||||
raise SystemExit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
3
seed_data/tool_eval_seed.jsonl
Normal file
3
seed_data/tool_eval_seed.jsonl
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
{"conversation": {"messages": [{"role": "user", "content": "What is 12% of 250?"}, {"role": "assistant", "content": "45"}]}, "checks": {"must_call": "calculator", "must_not_call": ["web_search"], "answer_contains": ["45"]}}
|
||||
{"conversation": {"messages": [{"role": "user", "content": "Find the official nanochat repository and cite the link."}, {"role": "assistant", "content": "https://github.com/karpathy/nanochat"}]}, "checks": {"must_call": "web_search", "citation_required": true, "answer_contains": ["github.com/karpathy/nanochat"]}}
|
||||
{"conversation": {"messages": [{"role": "user", "content": "Explain what overfitting means in one sentence."}, {"role": "assistant", "content": "Overfitting means a model memorizes training patterns too closely and generalizes poorly to new data."}]}, "checks": {"must_not_call": ["web_search"], "answer_contains": ["generalizes poorly"]}}
|
||||
6
seed_data/tool_sft_seed.jsonl
Normal file
6
seed_data/tool_sft_seed.jsonl
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
[{"role": "user", "content": "Explain what gradient descent is in one short paragraph."}, {"role": "assistant", "content": "Gradient descent is an optimization method that repeatedly nudges model parameters in the direction that reduces error, using the gradient to decide both direction and magnitude of each update."}]
|
||||
[{"role": "user", "content": "What is 18% of 250?"}, {"role": "assistant", "content": [{"type": "text", "text": "Let me compute that exactly. "}, {"type": "tool_call", "tool_name": "calculator", "arguments": {"expression": "percent(250,18)"}}, {"type": "tool_result", "tool_name": "calculator", "output": {"expression": "percent(250,18)", "value": 45.0}}, {"type": "text", "text": "18% of 250 is 45."}]}]
|
||||
[{"role": "user", "content": "If revenue grew from 120 to 150, what is the percent change?"}, {"role": "assistant", "content": [{"type": "text", "text": "Let me compute that exactly. "}, {"type": "tool_call", "tool_name": "calculator", "arguments": {"expression": "percent_change(120,150)"}}, {"type": "tool_result", "tool_name": "calculator", "output": {"expression": "percent_change(120,150)", "value": 25.0}}, {"type": "text", "text": "The percent change from 120 to 150 is 25%."}]}]
|
||||
[{"role": "user", "content": "What is the monthly EMI for a 500000 loan at 8% annual interest over 240 months?"}, {"role": "assistant", "content": [{"type": "text", "text": "Let me compute that exactly. "}, {"type": "tool_call", "tool_name": "calculator", "arguments": {"expression": "emi(500000,8,240)"}}, {"type": "tool_result", "tool_name": "calculator", "output": {"expression": "emi(500000,8,240)", "value": 4182.198594391402}}, {"type": "text", "text": "The monthly EMI is about 4182.2."}]}]
|
||||
[{"role": "user", "content": "Find the official nanochat repository and give me the link."}, {"role": "assistant", "content": [{"type": "text", "text": "I should verify this with a fresh web lookup. "}, {"type": "tool_call", "tool_name": "web_search", "arguments": {"query": "official nanochat repository", "top_k": 1}}, {"type": "tool_result", "tool_name": "web_search", "output": {"query": "official nanochat repository", "results": [{"url": "https://github.com/karpathy/nanochat", "title": "Source result", "snippet": "Fresh page content retrieved for grounding."}]}}, {"type": "text", "text": "The official nanochat repository is https://github.com/karpathy/nanochat"}]}]
|
||||
[{"role": "user", "content": "What do the Cloudflare Browser Rendering docs say about the markdown endpoint?"}, {"role": "assistant", "content": [{"type": "text", "text": "I should verify this with a fresh web lookup. "}, {"type": "tool_call", "tool_name": "web_search", "arguments": {"query": "Cloudflare Browser Rendering markdown endpoint docs", "top_k": 1}}, {"type": "tool_result", "tool_name": "web_search", "output": {"query": "Cloudflare Browser Rendering markdown endpoint docs", "results": [{"url": "https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/", "title": "Source result", "snippet": "Fresh page content retrieved for grounding."}]}}, {"type": "text", "text": "Cloudflare documents a markdown endpoint for extracting rendered page content: https://developers.cloudflare.com/browser-rendering/rest-api/markdown-endpoint/"}]}]
|
||||
101
tasks/tool_json.py
Normal file
101
tasks/tool_json.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
"""
|
||||
Local JSONL task for tool-use evaluation and lightweight RL reward shaping.
|
||||
|
||||
Each line should be a JSON object with:
|
||||
{
|
||||
"conversation": {"messages": [...]},
|
||||
"checks": {
|
||||
"must_call": "calculator",
|
||||
"must_not_call": ["web_search"],
|
||||
"answer_contains": ["42"],
|
||||
"citation_required": false
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
from nanochat.tools import TOOL_CALL_END, TOOL_CALL_START, parse_tool_call_payload
|
||||
from tasks.common import Task
|
||||
|
||||
|
||||
TOOL_BLOCK_RE = re.compile(re.escape(TOOL_CALL_START) + r"(.*?)" + re.escape(TOOL_CALL_END), re.DOTALL)
|
||||
|
||||
|
||||
class ToolJSON(Task):
|
||||
def __init__(self, filepath, split="eval", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.filepath = filepath
|
||||
self.split = split
|
||||
self.rows = []
|
||||
if not os.path.exists(filepath):
|
||||
raise FileNotFoundError(f"Tool JSONL dataset not found: {filepath}")
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
if "conversation" not in row:
|
||||
raise ValueError(f"Row missing conversation field: {row}")
|
||||
row.setdefault("checks", {})
|
||||
self.rows.append(row)
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
return "generative"
|
||||
|
||||
def num_examples(self):
|
||||
return len(self.rows)
|
||||
|
||||
def get_example(self, index):
|
||||
row = self.rows[index]
|
||||
conversation = dict(row["conversation"])
|
||||
conversation["checks"] = row.get("checks", {})
|
||||
return conversation
|
||||
|
||||
def _tool_calls(self, assistant_response):
|
||||
calls = []
|
||||
for payload in TOOL_BLOCK_RE.findall(assistant_response):
|
||||
invocation = parse_tool_call_payload(payload)
|
||||
calls.append(invocation.tool_name)
|
||||
return calls
|
||||
|
||||
def evaluate(self, conversation, assistant_response):
|
||||
checks = conversation.get("checks", {})
|
||||
score = self.reward(conversation, assistant_response)
|
||||
return int(score >= 0.999)
|
||||
|
||||
def reward(self, conversation, assistant_response):
|
||||
checks = conversation.get("checks", {})
|
||||
total = 0.0
|
||||
passed = 0.0
|
||||
tool_calls = self._tool_calls(assistant_response)
|
||||
|
||||
must_call = checks.get("must_call")
|
||||
if must_call:
|
||||
total += 1.0
|
||||
passed += float(must_call in tool_calls)
|
||||
|
||||
for tool_name in checks.get("must_not_call", []):
|
||||
total += 1.0
|
||||
passed += float(tool_name not in tool_calls)
|
||||
|
||||
for needle in checks.get("answer_contains", []):
|
||||
total += 1.0
|
||||
passed += float(needle in assistant_response)
|
||||
|
||||
answer_regex = checks.get("answer_regex")
|
||||
if answer_regex:
|
||||
total += 1.0
|
||||
passed += float(re.search(answer_regex, assistant_response) is not None)
|
||||
|
||||
if checks.get("citation_required", False):
|
||||
total += 1.0
|
||||
passed += float(("http://" in assistant_response) or ("https://" in assistant_response))
|
||||
|
||||
if total == 0.0:
|
||||
return 0.0
|
||||
return passed / total
|
||||
133
tests/test_tools.py
Normal file
133
tests/test_tools.py
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
import json
|
||||
|
||||
import torch
|
||||
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.tools import CalculatorTool, MockSearchBackend, ToolRegistry, WebSearchTool, parse_tool_call_payload, serialize_tool_call
|
||||
from tasks.tool_json import ToolJSON
|
||||
|
||||
|
||||
class MockConfig:
|
||||
n_kv_head = 4
|
||||
n_head = 4
|
||||
n_embd = 64
|
||||
n_layer = 2
|
||||
sequence_len = 128
|
||||
|
||||
|
||||
class ByteTokenizer:
|
||||
def __init__(self):
|
||||
self._special_tokens = {
|
||||
"<|python_start|>": 256,
|
||||
"<|python_end|>": 257,
|
||||
"<|output_start|>": 258,
|
||||
"<|output_end|>": 259,
|
||||
"<|assistant_end|>": 260,
|
||||
"<|bos|>": 261,
|
||||
}
|
||||
self._bos = 261
|
||||
|
||||
def encode_special(self, s):
|
||||
return self._special_tokens[s]
|
||||
|
||||
def get_bos_token_id(self):
|
||||
return self._bos
|
||||
|
||||
def encode(self, s, prepend=None):
|
||||
tokens = list(s.encode("utf-8"))
|
||||
if prepend is not None:
|
||||
tokens = [prepend] + tokens
|
||||
return tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
byte_tokens = [t for t in tokens if t < 256]
|
||||
return bytes(byte_tokens).decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
class SequencedModel:
|
||||
def __init__(self, sequence, vocab_size=262):
|
||||
self.sequence = sequence
|
||||
self.vocab_size = vocab_size
|
||||
self.config = MockConfig()
|
||||
self._device = torch.device("cpu")
|
||||
self.call_index = 0
|
||||
|
||||
def get_device(self):
|
||||
return self._device
|
||||
|
||||
def forward(self, ids, kv_cache=None):
|
||||
if kv_cache is not None:
|
||||
kv_cache.advance(ids.shape[1])
|
||||
target = self.sequence[min(self.call_index, len(self.sequence) - 1)]
|
||||
logits = torch.full((ids.shape[0], ids.shape[1], self.vocab_size), -1e9)
|
||||
logits[:, -1, target] = 0.0
|
||||
self.call_index += 1
|
||||
return logits
|
||||
|
||||
|
||||
def test_parse_tool_call_payload_supports_json_and_legacy():
|
||||
legacy = parse_tool_call_payload("6*7")
|
||||
assert legacy.tool_name == "calculator"
|
||||
assert legacy.arguments == {"expression": "6*7"}
|
||||
|
||||
payload = parse_tool_call_payload(serialize_tool_call("calculator", {"expression": "6*7"}))
|
||||
assert payload.tool_name == "calculator"
|
||||
assert payload.arguments == {"expression": "6*7"}
|
||||
|
||||
|
||||
def test_calculator_tool_supports_scientific_and_business_functions():
|
||||
tool = CalculatorTool()
|
||||
trig = tool.run({"expression": "round(sin(pi/2), 4)"})
|
||||
assert trig.success
|
||||
assert trig.output["value"] == 1.0
|
||||
|
||||
emi = tool.run({"expression": "round(emi(500000,8,240), 2)"})
|
||||
assert emi.success
|
||||
assert emi.output["value"] == 4182.2
|
||||
|
||||
|
||||
def test_web_search_tool_returns_mock_results():
|
||||
tool = WebSearchTool(search_backend=MockSearchBackend())
|
||||
result = tool.run({"query": "nanochat gpt2 speedrun", "top_k": 1})
|
||||
assert result.success
|
||||
assert result.output["results"][0]["url"] == "https://github.com/karpathy/nanochat"
|
||||
|
||||
|
||||
def test_engine_executes_json_calculator_tool_call():
|
||||
tokenizer = ByteTokenizer()
|
||||
payload = serialize_tool_call("calculator", {"expression": "6*7"})
|
||||
sequence = [
|
||||
tokenizer.encode_special("<|python_start|>"),
|
||||
*tokenizer.encode(payload),
|
||||
tokenizer.encode_special("<|python_end|>"),
|
||||
tokenizer.encode_special("<|assistant_end|>"),
|
||||
]
|
||||
model = SequencedModel(sequence)
|
||||
registry = ToolRegistry([CalculatorTool()])
|
||||
engine = Engine(model, tokenizer, tools=registry)
|
||||
prompt = [tokenizer.get_bos_token_id(), 72]
|
||||
results, _ = engine.generate_batch(prompt, num_samples=1, max_tokens=128, temperature=0.0)
|
||||
completion = tokenizer.decode(results[0][len(prompt):])
|
||||
assert "42" in completion
|
||||
|
||||
|
||||
def test_tool_json_reward_and_eval(tmp_path):
|
||||
path = tmp_path / "tool_eval.jsonl"
|
||||
row = {
|
||||
"conversation": {
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is 6 times 7?"},
|
||||
{"role": "assistant", "content": "42"},
|
||||
]
|
||||
},
|
||||
"checks": {
|
||||
"must_call": "calculator",
|
||||
"answer_contains": ["42"],
|
||||
},
|
||||
}
|
||||
path.write_text(json.dumps(row) + "\n", encoding="utf-8")
|
||||
task = ToolJSON(filepath=str(path))
|
||||
conversation = task[0]
|
||||
response = '<|python_start|>{"tool":"calculator","arguments":{"expression":"6*7"}}<|python_end|>42'
|
||||
assert task.reward(conversation, response) == 1.0
|
||||
assert task.evaluate(conversation, response) == 1
|
||||
Loading…
Reference in New Issue
Block a user