mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-07 08:19:52 +00:00
feat(sft): add r7 think+tool prep scripts and compose cleanup
- allow assistant list-shaped content in CustomJSON for joint think+tool JSONL - add gen_joint_think_tool, filter_reasoning_jsonl, eval_suite_v2 (think_plus_tool probes) - fix CI: uv sync --no-install-workspace; uv run pytest - remove unused local inference service from compose; document Modal URL in env examples Made-with: Cursor
This commit is contained in:
parent
38cb7f7596
commit
f642cb2eb6
|
|
@ -6,14 +6,14 @@ DATABASE_URL=postgresql+asyncpg://samosachaat_admin:localdev@localhost:5432/samo
|
|||
FRONTEND_PORT=3000
|
||||
AUTH_PORT=8001
|
||||
CHAT_API_PORT=8002
|
||||
INFERENCE_PORT=8003
|
||||
GRAFANA_PORT=3001
|
||||
PROMETHEUS_PORT=9090
|
||||
LOKI_PORT=3100
|
||||
|
||||
AUTH_SERVICE_URL=http://auth:8001
|
||||
CHAT_API_URL=http://chat-api:8002
|
||||
INFERENCE_SERVICE_URL=http://inference:8003
|
||||
# External inference (e.g. Modal generate endpoint base URL — no local inference container)
|
||||
INFERENCE_SERVICE_URL=https://YOUR_WORKSPACE--YOUR_APP-inference-generate.modal.run
|
||||
NEXTAUTH_URL=http://localhost:3000
|
||||
|
||||
GOOGLE_CLIENT_ID=your-google-client-id
|
||||
|
|
|
|||
|
|
@ -18,10 +18,10 @@ COOKIE_DOMAIN=samosachaat.art
|
|||
|
||||
# Chat API
|
||||
AUTH_SERVICE_URL=http://auth:8001
|
||||
INFERENCE_SERVICE_URL=http://inference:8003
|
||||
INFERENCE_SERVICE_URL=https://YOUR_WORKSPACE--YOUR_APP-inference-generate.modal.run
|
||||
CHAT_API_URL=http://chat-api:8002
|
||||
|
||||
# Inference
|
||||
# Optional: only if you run a local inference container (EC2 uses Modal instead)
|
||||
MODEL_STORAGE_PATH=/models
|
||||
DEFAULT_MODEL_TAG=samosachaat-d12
|
||||
NANOCHAT_DTYPE=float32
|
||||
|
|
|
|||
12
.github/workflows/ci.yml
vendored
12
.github/workflows/ci.yml
vendored
|
|
@ -91,10 +91,10 @@ jobs:
|
|||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Sync deps
|
||||
run: uv sync --no-workspace
|
||||
run: uv sync --no-install-workspace
|
||||
|
||||
- name: Run pytest
|
||||
run: uv run --no-workspace pytest
|
||||
run: uv run pytest
|
||||
|
||||
test-chat-api:
|
||||
name: Chat-API — pytest (postgres service)
|
||||
|
|
@ -131,10 +131,10 @@ jobs:
|
|||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Sync deps
|
||||
run: uv sync --no-workspace
|
||||
run: uv sync --no-install-workspace
|
||||
|
||||
- name: Run pytest
|
||||
run: uv run --no-workspace pytest
|
||||
run: uv run pytest
|
||||
|
||||
test-inference:
|
||||
name: Inference — pytest
|
||||
|
|
@ -155,10 +155,10 @@ jobs:
|
|||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Sync deps
|
||||
run: uv sync --no-workspace
|
||||
run: uv sync --no-install-workspace
|
||||
|
||||
- name: Run pytest
|
||||
run: uv run --no-workspace pytest
|
||||
run: uv run pytest
|
||||
|
||||
terraform-validate:
|
||||
name: Terraform — validate
|
||||
|
|
|
|||
|
|
@ -17,10 +17,6 @@ services:
|
|||
build: !reset null
|
||||
image: ${ECR_REGISTRY:-883107058766.dkr.ecr.us-west-2.amazonaws.com}/samosachaat/chat-api:${IMAGE_TAG:-dev-latest}
|
||||
|
||||
inference:
|
||||
build: !reset null
|
||||
image: ${ECR_REGISTRY:-883107058766.dkr.ecr.us-west-2.amazonaws.com}/samosachaat/inference:${IMAGE_TAG:-dev-latest}
|
||||
|
||||
nginx:
|
||||
image: nginx:alpine
|
||||
restart: unless-stopped
|
||||
|
|
|
|||
|
|
@ -59,28 +59,12 @@ services:
|
|||
environment:
|
||||
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://samosachaat_admin:localdev@postgres:5432/samosachaat}
|
||||
AUTH_SERVICE_URL: ${AUTH_SERVICE_URL:-http://auth:8001}
|
||||
INFERENCE_SERVICE_URL: ${INFERENCE_SERVICE_URL:-http://inference:8003}
|
||||
# External inference (Modal, etc.). Set in `.env` — see `.env.example`.
|
||||
INFERENCE_SERVICE_URL: ${INFERENCE_SERVICE_URL}
|
||||
INTERNAL_API_KEY: ${INTERNAL_API_KEY:-}
|
||||
depends_on:
|
||||
- postgres
|
||||
- auth
|
||||
- inference
|
||||
|
||||
inference:
|
||||
build:
|
||||
context: ./services/inference
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "${INFERENCE_PORT:-8003}:8003"
|
||||
environment:
|
||||
MODEL_STORAGE_PATH: /models
|
||||
DEFAULT_MODEL_TAG: ${DEFAULT_MODEL_TAG:-samosachaat-d12}
|
||||
NANOCHAT_DTYPE: ${NANOCHAT_DTYPE:-float32}
|
||||
HF_TOKEN: ${HF_TOKEN:-}
|
||||
INTERNAL_API_KEY: ${INTERNAL_API_KEY:-}
|
||||
NUM_WORKERS: ${NUM_WORKERS:-1}
|
||||
volumes:
|
||||
- ./models:/models
|
||||
|
||||
grafana:
|
||||
image: grafana/grafana:latest
|
||||
|
|
|
|||
276
scripts/eval_suite_v2.py
Normal file
276
scripts/eval_suite_v2.py
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Probe-style eval for chat SFT (tools + thinking).
|
||||
|
||||
Example:
|
||||
TAG=d24-sft-r6 STEP=754 SOURCE=sft WITH_TOOLS=1 \\
|
||||
python -m scripts.eval_suite_v2
|
||||
|
||||
Env:
|
||||
TAG — model_tag directory under chatsft_checkpoints (or base_checkpoints if SOURCE=base)
|
||||
STEP — checkpoint step (int)
|
||||
SOURCE — base | sft | rl (default sft)
|
||||
WITH_TOOLS — 1 to run Engine with default tool registry (default 1)
|
||||
DEVICE — optional cuda|cpu|mps override
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
from nanochat.common import autodetect_device_type, compute_cleanup, compute_init, print0
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.tools import TOOL_CALL_END, TOOL_CALL_START, build_default_tool_registry, parse_tool_call_payload
|
||||
|
||||
TOOL_BLOCK_RE = re.compile(re.escape(TOOL_CALL_START) + r"(.*?)" + re.escape(TOOL_CALL_END), re.DOTALL)
|
||||
THINK_CLOSE = "</think>"
|
||||
|
||||
# Mirrors services/chat-api thinking-mode prefix
|
||||
SYS_THINK = (
|
||||
"You are samosaChaat, a helpful AI assistant created by Manmohan Sharma. "
|
||||
"You have access to tools: use web_search for facts that may change over time or "
|
||||
"require current information, and use calculator for arithmetic. "
|
||||
"Think step by step inside <think>...</think> tags, "
|
||||
"then give your final answer after the closing tag."
|
||||
)
|
||||
|
||||
|
||||
def _tool_calls(assistant_response: str) -> list[str]:
|
||||
calls = []
|
||||
for payload in TOOL_BLOCK_RE.findall(assistant_response):
|
||||
try:
|
||||
inv = parse_tool_call_payload(payload)
|
||||
calls.append(inv.tool_name)
|
||||
except Exception:
|
||||
continue
|
||||
return calls
|
||||
|
||||
|
||||
def _after_think(text: str) -> str | None:
|
||||
if THINK_CLOSE not in text:
|
||||
return None
|
||||
return text.split(THINK_CLOSE, 1)[1]
|
||||
|
||||
|
||||
def probe_reward(checks: dict, assistant_response: str) -> float:
|
||||
total = 0.0
|
||||
passed = 0.0
|
||||
tool_calls = _tool_calls(assistant_response)
|
||||
|
||||
must_call = checks.get("must_call")
|
||||
if must_call:
|
||||
total += 1.0
|
||||
passed += float(must_call in tool_calls)
|
||||
|
||||
for tool_name in checks.get("must_not_call", []):
|
||||
total += 1.0
|
||||
passed += float(tool_name not in tool_calls)
|
||||
|
||||
for needle in checks.get("answer_contains", []):
|
||||
total += 1.0
|
||||
passed += float(needle in assistant_response)
|
||||
|
||||
answer_regex = checks.get("answer_regex")
|
||||
if answer_regex:
|
||||
total += 1.0
|
||||
passed += float(re.search(answer_regex, assistant_response) is not None)
|
||||
|
||||
if checks.get("citation_required", False):
|
||||
total += 1.0
|
||||
passed += float(("http://" in assistant_response) or ("https://" in assistant_response))
|
||||
|
||||
if checks.get("must_close_think", False):
|
||||
total += 1.0
|
||||
passed += float(THINK_CLOSE in assistant_response)
|
||||
|
||||
min_after = checks.get("min_chars_after_think")
|
||||
if min_after:
|
||||
total += 1.0
|
||||
tail = _after_think(assistant_response)
|
||||
passed += float(tail is not None and len(tail.strip()) >= int(min_after))
|
||||
|
||||
for needle in checks.get("answer_after_think_contains", []):
|
||||
total += 1.0
|
||||
tail = _after_think(assistant_response)
|
||||
passed += float(tail is not None and needle in tail)
|
||||
|
||||
if checks.get("forbid_answer_needles_inside_think_only", False):
|
||||
# If needles appear only before </think>, fail (recipe trapped in think)
|
||||
total += 1.0
|
||||
if THINK_CLOSE in assistant_response:
|
||||
head, _, tail = assistant_response.partition(THINK_CLOSE)
|
||||
needles = checks.get("_forbid_needles", ("Ingredients", "Step 1", "samosa"))
|
||||
bad = any(n in head and n not in tail for n in needles)
|
||||
passed += float(not bad)
|
||||
else:
|
||||
passed += 0.0
|
||||
|
||||
if total == 0.0:
|
||||
return 0.0
|
||||
return passed / total
|
||||
|
||||
|
||||
def default_probes() -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"name": "president_web_search",
|
||||
"category": "think_plus_tool",
|
||||
"conversation": {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SYS_THINK + "\n\nWho is the current president of America?",
|
||||
},
|
||||
]
|
||||
},
|
||||
"checks": {
|
||||
"must_call": "web_search",
|
||||
"must_close_think": True,
|
||||
"min_chars_after_think": 20,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "mumbai_weather_web_search",
|
||||
"category": "think_plus_tool",
|
||||
"conversation": {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SYS_THINK + "\n\nWhat's the weather in Mumbai today?",
|
||||
},
|
||||
]
|
||||
},
|
||||
"checks": {
|
||||
"must_call": "web_search",
|
||||
"must_close_think": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "samosa_chaat_answer_after_think",
|
||||
"category": "think_plus_tool",
|
||||
"conversation": {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SYS_THINK + "\n\nHow do I make samosa chaat?",
|
||||
},
|
||||
]
|
||||
},
|
||||
"checks": {
|
||||
"must_close_think": True,
|
||||
"min_chars_after_think": 60,
|
||||
"answer_after_think_contains": ["samosa"],
|
||||
"forbid_answer_needles_inside_think_only": True,
|
||||
"_forbid_needles": ("Ingredients", "Step 1", "yogurt"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "tip_calculator",
|
||||
"category": "think_plus_tool",
|
||||
"conversation": {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SYS_THINK + "\n\nCalculate an 18% tip on a $60 bill.",
|
||||
},
|
||||
]
|
||||
},
|
||||
"checks": {
|
||||
"must_call": "calculator",
|
||||
"must_close_think": True,
|
||||
"answer_regex": r"10\.8",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def run_probes(
|
||||
*,
|
||||
tokenizer,
|
||||
engine: Engine,
|
||||
probes: list[dict],
|
||||
max_new_tokens: int,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
) -> tuple[float, dict[str, list[tuple[str, float]]]]:
|
||||
by_cat: dict[str, list[tuple[str, float]]] = {}
|
||||
scores: list[float] = []
|
||||
|
||||
for probe in probes:
|
||||
name = probe.get("name", "unknown")
|
||||
category = probe.get("category", "default")
|
||||
conv = probe["conversation"]
|
||||
checks = probe.get("checks", {})
|
||||
|
||||
encoded = tokenizer.render_for_completion(conv)
|
||||
results, _ = engine.generate_batch(
|
||||
encoded,
|
||||
num_samples=1,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
seed=42,
|
||||
)
|
||||
completion = tokenizer.decode(results[0][len(encoded) :])
|
||||
r = probe_reward(checks, completion)
|
||||
scores.append(r)
|
||||
by_cat.setdefault(category, []).append((name, r))
|
||||
print0(f"[{category}] {name}: reward={r:.3f}\n---\n{completion[:1200]}\n---")
|
||||
|
||||
return sum(scores) / max(len(scores), 1), by_cat
|
||||
|
||||
|
||||
def main():
|
||||
device_type = os.environ.get("DEVICE", "") or autodetect_device_type()
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
if ddp_rank != 0:
|
||||
return
|
||||
|
||||
tag = os.environ.get("TAG")
|
||||
step_s = os.environ.get("STEP")
|
||||
source = os.environ.get("SOURCE", "sft")
|
||||
if not tag or step_s is None:
|
||||
print("Set TAG and STEP in the environment.", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
step = int(step_s)
|
||||
with_tools = os.environ.get("WITH_TOOLS", "1") not in ("0", "false", "False")
|
||||
|
||||
probes = default_probes()
|
||||
extra_path = os.environ.get("EVAL_SUITE_EXTRA_JSONL", "")
|
||||
if extra_path and os.path.exists(extra_path):
|
||||
with open(extra_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
probes.append(json.loads(line))
|
||||
|
||||
model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=tag, step=step)
|
||||
tools = build_default_tool_registry() if with_tools else None
|
||||
engine = Engine(model, tokenizer, tools=tools)
|
||||
|
||||
mean, by_cat = run_probes(
|
||||
tokenizer=tokenizer,
|
||||
engine=engine,
|
||||
probes=probes,
|
||||
max_new_tokens=int(os.environ.get("MAX_NEW_TOKENS", "512")),
|
||||
temperature=float(os.environ.get("TEMPERATURE", "0.2")),
|
||||
top_k=int(os.environ.get("TOP_K", "50")),
|
||||
)
|
||||
|
||||
print0("=" * 60)
|
||||
print0(f"Overall mean reward: {mean:.4f}")
|
||||
for cat, rows in sorted(by_cat.items()):
|
||||
m = sum(r for _, r in rows) / len(rows)
|
||||
print0(f" {cat}: {m:.4f} ({len(rows)} probes)")
|
||||
compute_cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
compute_cleanup()
|
||||
115
scripts/filter_reasoning_jsonl.py
Normal file
115
scripts/filter_reasoning_jsonl.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Filter reasoning SFT JSONL so thinking blocks stay format-clean:
|
||||
|
||||
- Require a closed </think>
|
||||
- Require non-trivial text *after* the closing tag (the answer lives there)
|
||||
- Reject if the model likely put the final answer only inside the think block
|
||||
(heuristic: strong answer markers appear inside but post-think text is tiny)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
|
||||
THINK_OPEN = "<think>"
|
||||
THINK_CLOSE = "</think>"
|
||||
|
||||
# If these appear inside the thinking span but the tail after </think> is short, drop the row.
|
||||
INSIDE_THINK_ANSWER_HINTS = re.compile(
|
||||
r"(?i)\b(ingredients|instructions|step\s*1|method|preheat|^\s*\d+[\).\s])",
|
||||
re.MULTILINE,
|
||||
)
|
||||
STRONG_TAIL_NEEDLE = re.compile(r"(?i)\b(is|are|equals|result|answer|you can|first,|mix|serve)\b")
|
||||
|
||||
|
||||
def _assistant_text(messages: list) -> str | None:
|
||||
if not messages or messages[-1].get("role") != "assistant":
|
||||
return None
|
||||
c = messages[-1].get("content")
|
||||
if isinstance(c, str):
|
||||
return c
|
||||
if isinstance(c, list):
|
||||
parts = []
|
||||
for p in c:
|
||||
if isinstance(p, dict) and p.get("type") == "text":
|
||||
parts.append(p.get("text", ""))
|
||||
return "".join(parts) if parts else None
|
||||
return None
|
||||
|
||||
|
||||
def _split_think(s: str) -> tuple[str | None, str | None]:
|
||||
if THINK_OPEN not in s or THINK_CLOSE not in s:
|
||||
return None, None
|
||||
try:
|
||||
inner_start = s.index(THINK_OPEN) + len(THINK_OPEN)
|
||||
close_idx = s.index(THINK_CLOSE, inner_start)
|
||||
inner = s[inner_start:close_idx]
|
||||
tail = s[close_idx + len(THINK_CLOSE) :]
|
||||
return inner, tail
|
||||
except ValueError:
|
||||
return None, None
|
||||
|
||||
|
||||
def keep_conversation(messages: list, *, min_tail_chars: int) -> tuple[bool, str]:
|
||||
text = _assistant_text(messages)
|
||||
if not text:
|
||||
return False, "no_assistant_string"
|
||||
inner, tail = _split_think(text)
|
||||
if inner is None:
|
||||
return False, "missing_or_unclosed_think"
|
||||
tail_stripped = tail.strip() if tail else ""
|
||||
if len(tail_stripped) < min_tail_chars:
|
||||
return False, "short_tail"
|
||||
if INSIDE_THINK_ANSWER_HINTS.search(inner) and len(tail_stripped) < max(min_tail_chars, 80):
|
||||
return False, "answer_leaked_into_think"
|
||||
if len(tail_stripped) < 40 and not STRONG_TAIL_NEEDLE.search(tail_stripped):
|
||||
return False, "weak_tail"
|
||||
return True, "ok"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Filter reasoning JSONL for think-block hygiene")
|
||||
parser.add_argument("input_jsonl")
|
||||
parser.add_argument("--out", required=True, help="Filtered JSONL output path")
|
||||
parser.add_argument("--min-tail-chars", type=int, default=25)
|
||||
parser.add_argument("--stats-every", type=int, default=10000)
|
||||
args = parser.parse_args()
|
||||
|
||||
kept, total = 0, 0
|
||||
reasons: dict[str, int] = {}
|
||||
|
||||
with open(args.input_jsonl, encoding="utf-8") as fin, open(args.out, "w", encoding="utf-8") as fout:
|
||||
for line in fin:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
total += 1
|
||||
try:
|
||||
messages = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
reasons["bad_json"] = reasons.get("bad_json", 0) + 1
|
||||
continue
|
||||
if not isinstance(messages, list):
|
||||
reasons["not_list"] = reasons.get("not_list", 0) + 1
|
||||
continue
|
||||
ok, reason = keep_conversation(messages, min_tail_chars=args.min_tail_chars)
|
||||
if ok:
|
||||
fout.write(json.dumps(messages, ensure_ascii=True) + "\n")
|
||||
kept += 1
|
||||
else:
|
||||
reasons[reason] = reasons.get(reason, 0) + 1
|
||||
|
||||
if args.stats_every and total % args.stats_every == 0:
|
||||
print(f"... processed {total} lines, kept {kept}", file=sys.stderr)
|
||||
|
||||
print(f"Done. kept {kept}/{total} -> {args.out}", file=sys.stderr)
|
||||
for k, v in sorted(reasons.items(), key=lambda kv: -kv[1]):
|
||||
print(f" dropped {k}: {v}", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
443
scripts/gen_joint_think_tool.py
Normal file
443
scripts/gen_joint_think_tool.py
Normal file
|
|
@ -0,0 +1,443 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate SFT JSONL with joint <think> + tool-call + answer patterns.
|
||||
|
||||
Output lines are JSON arrays of messages (CustomJSON / chat_sft --extra-train-jsonl).
|
||||
Assistant turns use list-shaped content so the tokenizer emits python/output specials.
|
||||
|
||||
Patterns:
|
||||
(a) think -> web_search -> answer
|
||||
(b) think -> direct answer (no tool)
|
||||
(c) think -> calculator -> answer
|
||||
|
||||
Optional: OPENAI_API_KEY + --use-openai to diversify questions (gpt-4o-mini).
|
||||
Without API keys, uses deterministic templates (still valid for SFT).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
|
||||
# Keep in sync with services/chat-api/src/routes/messages.py _SYS_THINK
|
||||
SYS_JOINT = (
|
||||
"You are samosaChaat, a helpful AI assistant created by Manmohan Sharma. "
|
||||
"You have access to tools: use web_search for facts that may change over time or "
|
||||
"require current information, and use calculator for arithmetic. "
|
||||
"Think step by step inside <think>...</think> tags, "
|
||||
"then give your final answer after the closing tag."
|
||||
)
|
||||
|
||||
# Identity anchor (from prior creator SFT — keep wording stable)
|
||||
CREATOR_FACTS = (
|
||||
"Context: samosaChaat is an AI assistant created by Manmohan Sharma. "
|
||||
"If asked who built you, answer with that fact."
|
||||
)
|
||||
|
||||
|
||||
def _compact_json(obj) -> str:
|
||||
return json.dumps(obj, ensure_ascii=True, separators=(",", ":"))
|
||||
|
||||
|
||||
def _web_result(query: str, url: str, title: str, snippet: str) -> dict:
|
||||
return {
|
||||
"query": query,
|
||||
"results": [{"url": url, "title": title, "snippet": snippet}],
|
||||
}
|
||||
|
||||
|
||||
def think_block(*lines: str) -> str:
|
||||
body = " ".join(l.strip() for l in lines if l.strip())
|
||||
return f"<think>\n{body}\n</think>"
|
||||
|
||||
|
||||
def conv_think_web_search(user_q: str, think_lines: tuple[str, ...], query: str, result: dict, answer: str):
|
||||
return [
|
||||
{"role": "user", "content": SYS_JOINT + "\n\n" + user_q},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": think_block(*think_lines) + "\n"},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"tool_name": "web_search",
|
||||
"arguments": {"query": query, "top_k": 1},
|
||||
},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_name": "web_search",
|
||||
"output": result,
|
||||
"success": True,
|
||||
},
|
||||
{"type": "text", "text": answer},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def conv_think_direct(user_q: str, think_lines: tuple[str, ...], answer: str):
|
||||
return [
|
||||
{"role": "user", "content": SYS_JOINT + "\n\n" + user_q},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": think_block(*think_lines) + "\n" + answer},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def conv_think_calculator(
|
||||
user_q: str,
|
||||
think_lines: tuple[str, ...],
|
||||
expression: str,
|
||||
value: float | int,
|
||||
answer: str,
|
||||
):
|
||||
return [
|
||||
{"role": "user", "content": SYS_JOINT + "\n\n" + user_q},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": think_block(*think_lines) + "\n"},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"tool_name": "calculator",
|
||||
"arguments": {"expression": expression},
|
||||
},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_name": "calculator",
|
||||
"output": {"expression": expression, "value": float(value)},
|
||||
"success": True,
|
||||
},
|
||||
{"type": "text", "text": answer},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _template_rows() -> list[list]:
|
||||
rows: list[list] = []
|
||||
|
||||
# --- (a) Web search / time-sensitive ---
|
||||
rows.append(
|
||||
conv_think_web_search(
|
||||
"Who is the current President of the United States?",
|
||||
(
|
||||
"This changes with elections; I should not rely on memory.",
|
||||
"I will search for the latest information.",
|
||||
),
|
||||
"current President of the United States 2026",
|
||||
_web_result(
|
||||
"current President of the United States 2026",
|
||||
"https://www.whitehouse.gov/administration/",
|
||||
"The Administration",
|
||||
"The administration page lists the current officeholder.",
|
||||
),
|
||||
"Based on the search, the current President of the United States is the person listed on the official White House administration page (verify on whitehouse.gov for the exact name).",
|
||||
)
|
||||
)
|
||||
rows.append(
|
||||
conv_think_web_search(
|
||||
"What's the weather in Mumbai today?",
|
||||
("Weather is live data.", "I should look it up."),
|
||||
"Mumbai weather today",
|
||||
_web_result(
|
||||
"Mumbai weather today",
|
||||
"https://weather.example/mumbai",
|
||||
"Mumbai forecast",
|
||||
"Today: warm and humid with a chance of evening showers; highs near 32°C.",
|
||||
),
|
||||
"Based on the search, Mumbai today is warm and humid with a chance of evening showers (check a live weather source for exact numbers).",
|
||||
)
|
||||
)
|
||||
rows.append(
|
||||
conv_think_web_search(
|
||||
"Who is the CEO of OpenAI right now?",
|
||||
("Executive roles change.", "I'll search for the current CEO."),
|
||||
"OpenAI CEO 2026",
|
||||
_web_result(
|
||||
"OpenAI CEO 2026",
|
||||
"https://openai.com/",
|
||||
"OpenAI",
|
||||
"Leadership page names the current chief executive.",
|
||||
),
|
||||
"Based on the search, see OpenAI's official leadership page for the current CEO name.",
|
||||
)
|
||||
)
|
||||
rows.append(
|
||||
conv_think_web_search(
|
||||
"What was the closing value of the S&P 500 index most recently?",
|
||||
("Market numbers are time-sensitive.", "I need a web lookup."),
|
||||
"S&P 500 latest close",
|
||||
_web_result(
|
||||
"S&P 500 latest close",
|
||||
"https://www.example-finance.com/sp500",
|
||||
"S&P 500",
|
||||
"The index closed near 5,200 in the latest reported session (illustrative).",
|
||||
),
|
||||
"Based on the search, the latest reported close was near 5,200 — confirm on a live market data site.",
|
||||
)
|
||||
)
|
||||
|
||||
# --- (b) Direct answer after think (no tool): recipes / static how-to ---
|
||||
rows.append(
|
||||
conv_think_direct(
|
||||
"How do I make samosa chaat at home?",
|
||||
(
|
||||
"This is a cooking question; I can outline steps without a web search.",
|
||||
"I'll keep reasoning brief and put the recipe after the thinking block.",
|
||||
),
|
||||
"Crush or chop cooked samosas. Layer with chickpeas, yogurt, chutneys (mint and tamarind), diced onions, tomatoes, chaat masala, and sev. Serve immediately while the samosa is still crisp.",
|
||||
)
|
||||
)
|
||||
rows.append(
|
||||
conv_think_direct(
|
||||
"What is the capital of France?",
|
||||
("This is stable geographic knowledge.", "No tool is needed."),
|
||||
"The capital of France is Paris.",
|
||||
)
|
||||
)
|
||||
rows.append(
|
||||
conv_think_direct(
|
||||
CREATOR_FACTS + " Who created you?",
|
||||
("The user asked about my creator.", "That is given in the context."),
|
||||
"I was created by Manmohan Sharma.",
|
||||
)
|
||||
)
|
||||
|
||||
# --- (c) Calculator after think ---
|
||||
rows.append(
|
||||
conv_think_calculator(
|
||||
"Calculate an 18% tip on a $60 bill.",
|
||||
("I need an exact percentage.", "I'll use the calculator tool."),
|
||||
"percent(60,18)",
|
||||
10.8,
|
||||
"An 18% tip on $60 is $10.80.",
|
||||
)
|
||||
)
|
||||
rows.append(
|
||||
conv_think_calculator(
|
||||
"What is the monthly EMI for a ₹500,000 loan at 8% annual interest for 240 months?",
|
||||
("EMI has a standard formula.", "I'll compute with the calculator."),
|
||||
"emi(500000,8,240)",
|
||||
4182.198594391402,
|
||||
"The monthly EMI is about 4182.20.",
|
||||
)
|
||||
)
|
||||
rows.append(
|
||||
conv_think_calculator(
|
||||
"If revenue grew from 120 to 150, what is the percent change?",
|
||||
("Percent change should be exact.", "Using calculator."),
|
||||
"percent_change(120,150)",
|
||||
25.0,
|
||||
"The percent change from 120 to 150 is 25%.",
|
||||
)
|
||||
)
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def _vary_presidents_ceo_weather_sports(rng: random.Random) -> list[list]:
|
||||
"""Extra templated rows with light randomization."""
|
||||
rows: list[list] = []
|
||||
cities = ["Delhi", "Bengaluru", "London", "New York", "Tokyo"]
|
||||
sports = [
|
||||
("latest ICC cricket World Cup winner", "cricket world cup winner"),
|
||||
("Who won the most recent Super Bowl?", "Super Bowl winner latest"),
|
||||
]
|
||||
for city in cities:
|
||||
q = f"What's the weather in {city} today?"
|
||||
rows.append(
|
||||
conv_think_web_search(
|
||||
q,
|
||||
("Weather is live.", "Search."),
|
||||
f"{city} weather today",
|
||||
_web_result(
|
||||
f"{city} weather today",
|
||||
f"https://weather.example/{city.lower()}",
|
||||
f"{city} forecast",
|
||||
f"Today in {city}: partly cloudy, mild breeze (illustrative).",
|
||||
),
|
||||
f"Based on the search, today's {city} forecast looks partly cloudy — verify on a live weather service.",
|
||||
)
|
||||
)
|
||||
for user_q, squery in sports:
|
||||
rows.append(
|
||||
conv_think_web_search(
|
||||
user_q,
|
||||
("Sports results change.", "I'll search."),
|
||||
squery,
|
||||
_web_result(
|
||||
squery,
|
||||
"https://sports.example/",
|
||||
"Sports",
|
||||
"Official recap lists the winning team for the latest event.",
|
||||
),
|
||||
"Based on the search, see the linked recap for the winning team (confirm on a trusted sports source).",
|
||||
)
|
||||
)
|
||||
# Rotating math / finance
|
||||
for bill, pct in [(40.0, 15), (85.5, 20), (120.0, 22)]:
|
||||
expr = f"percent({bill},{pct})"
|
||||
val = round(bill * pct / 100, 2)
|
||||
rows.append(
|
||||
conv_think_calculator(
|
||||
f"Calculate a {pct}% tip on ${bill:g}.",
|
||||
("Exact tip amount.", "Calculator."),
|
||||
expr,
|
||||
val,
|
||||
f"A {pct}% tip on ${bill:g} is ${val:.2f}.",
|
||||
)
|
||||
)
|
||||
rng.shuffle(rows)
|
||||
return rows
|
||||
|
||||
|
||||
_OPENAI_URL = "https://api.openai.com/v1/chat/completions"
|
||||
|
||||
|
||||
def _openai_variants(api_key: str, n: int, rng: random.Random) -> list[str]:
|
||||
"""Return n short factual questions for joint training."""
|
||||
prompt = (
|
||||
"Generate exactly %d diverse, concise user questions that need EITHER a web search, "
|
||||
"a calculator, OR a direct answer (no tool). One sentence each, no numbering. "
|
||||
"Mix: current events, weather, sports, finance math, and one cooking question. "
|
||||
"Output one question per line, no other text."
|
||||
) % n
|
||||
body = _compact_json(
|
||||
{
|
||||
"model": "gpt-4o-mini",
|
||||
"temperature": 0.9,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
}
|
||||
)
|
||||
req = urllib.request.Request(
|
||||
_OPENAI_URL,
|
||||
data=body.encode("utf-8"),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=120) as resp:
|
||||
payload = json.loads(resp.read().decode("utf-8"))
|
||||
except (urllib.error.URLError, TimeoutError, json.JSONDecodeError) as exc:
|
||||
print(f"OpenAI request failed ({exc}); skipping LLM expansion.", file=sys.stderr)
|
||||
return []
|
||||
|
||||
try:
|
||||
text = payload["choices"][0]["message"]["content"]
|
||||
except (KeyError, IndexError):
|
||||
print("Unexpected OpenAI response shape; skipping.", file=sys.stderr)
|
||||
return []
|
||||
|
||||
lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
|
||||
rng.shuffle(lines)
|
||||
return lines[:n]
|
||||
|
||||
|
||||
def _classify_and_build_question(q: str) -> list | None:
|
||||
"""Heuristic: map a free-form question to a joint pattern."""
|
||||
ql = q.lower()
|
||||
# Tip: "18% tip on $60" / "tip on $60 at 18%"
|
||||
tip_m = re.search(
|
||||
r"(?P<pct>\d+(?:\.\d+)?)\s*%\s*tip.*?\$(?P<bill>\d+(?:\.\d+)?)|"
|
||||
r"\$(?P<bill2>\d+(?:\.\d+)?).*?(?P<pct2>\d+(?:\.\d+)?)\s*%\s*tip",
|
||||
q,
|
||||
re.I,
|
||||
)
|
||||
if tip_m:
|
||||
pct = float(tip_m.group("pct") or tip_m.group("pct2"))
|
||||
bill = float(tip_m.group("bill") or tip_m.group("bill2"))
|
||||
expr = f"percent({bill},{pct})"
|
||||
val = round(bill * pct / 100, 4)
|
||||
return conv_think_calculator(
|
||||
q,
|
||||
("I need an exact tip amount.", "Using the calculator."),
|
||||
expr,
|
||||
val,
|
||||
f"Based on the calculation, a {pct:g}% tip on ${bill:g} is ${val:.2f}.",
|
||||
)
|
||||
if any(
|
||||
k in ql
|
||||
for k in (
|
||||
"weather",
|
||||
"president",
|
||||
"ceo",
|
||||
"who won",
|
||||
"score",
|
||||
"today",
|
||||
"current",
|
||||
"latest",
|
||||
"price of",
|
||||
"stock",
|
||||
)
|
||||
):
|
||||
slug = re.sub(r"\W+", "-", ql)[:48]
|
||||
return conv_think_web_search(
|
||||
q,
|
||||
("This is time-sensitive or external.", "Searching."),
|
||||
ql[:120],
|
||||
_web_result(ql[:120], f"https://example.org/{slug}", "Source", "Snippet summarizes the retrieved page."),
|
||||
"Based on the search, verify details on the cited source; the snippet is illustrative.",
|
||||
)
|
||||
return conv_think_direct(
|
||||
q,
|
||||
("I can answer directly.", "No tool needed."),
|
||||
"Short answer: provide 2–4 sentences addressing the question without putting the final steps inside the thinking block.",
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Generate joint think+tool SFT JSONL")
|
||||
parser.add_argument("--out", default="seed_data/joint_think_tool.jsonl", help="Output JSONL path")
|
||||
parser.add_argument("--target", type=int, default=512, help="Approximate number of lines to write")
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument(
|
||||
"--use-openai",
|
||||
action="store_true",
|
||||
help="If OPENAI_API_KEY is set, add LLM-generated questions (best-effort).",
|
||||
)
|
||||
parser.add_argument("--openai-extra", type=int, default=64, help="Max extra questions from OpenAI")
|
||||
args = parser.parse_args()
|
||||
rng = random.Random(args.seed)
|
||||
|
||||
rows = _template_rows() + _vary_presidents_ceo_weather_sports(rng)
|
||||
|
||||
if args.use_openai:
|
||||
key = os.environ.get("OPENAI_API_KEY", "")
|
||||
if key:
|
||||
for q in _openai_variants(key, args.openai_extra, rng):
|
||||
built = _classify_and_build_question(q)
|
||||
if built:
|
||||
rows.append(built)
|
||||
else:
|
||||
print("OPENAI_API_KEY not set; --use-openai ignored.", file=sys.stderr)
|
||||
|
||||
# Repeat / shuffle to reach target size
|
||||
out_lines: list[list] = []
|
||||
while len(out_lines) < args.target:
|
||||
rng.shuffle(rows)
|
||||
need = args.target - len(out_lines)
|
||||
out_lines.extend(rows[: min(len(rows), need)])
|
||||
|
||||
os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True)
|
||||
with open(args.out, "w", encoding="utf-8") as f:
|
||||
for conv in out_lines[: args.target]:
|
||||
f.write(json.dumps(conv, ensure_ascii=True) + "\n")
|
||||
|
||||
print(f"Wrote {args.target} conversations to {args.out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -7,6 +7,31 @@ import os
|
|||
import json
|
||||
from tasks.common import Task
|
||||
|
||||
|
||||
def _validate_assistant_content(content, message_index):
|
||||
"""Assistant turns may be a plain string or a list of parts (tools / GSM8K-style)."""
|
||||
if isinstance(content, str):
|
||||
return
|
||||
if not isinstance(content, list):
|
||||
raise AssertionError(f"Message {message_index}: assistant content must be str or list, got {type(content)}")
|
||||
for j, part in enumerate(content):
|
||||
if not isinstance(part, dict):
|
||||
raise AssertionError(f"Message {message_index} part {j}: expected dict, got {type(part)}")
|
||||
ptype = part.get("type")
|
||||
if ptype == "text":
|
||||
assert "text" in part, f"Message {message_index} part {j}: text part missing 'text'"
|
||||
elif ptype in ("tool_call", "python"):
|
||||
assert "text" in part or part.get("tool_name"), (
|
||||
f"Message {message_index} part {j}: tool part needs 'text' or 'tool_name'"
|
||||
)
|
||||
elif ptype in ("tool_result", "python_output"):
|
||||
assert "text" in part or part.get("tool_name") is not None, (
|
||||
f"Message {message_index} part {j}: result part missing 'text' or 'tool_name'"
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Message {message_index} part {j}: unknown type {ptype!r}")
|
||||
|
||||
|
||||
class CustomJSON(Task):
|
||||
"""
|
||||
Load conversations from a JSONL file.
|
||||
|
|
@ -47,7 +72,10 @@ class CustomJSON(Task):
|
|||
assert "content" in message, f"Message {i} missing 'content' field"
|
||||
expected_role = "user" if i % 2 == 0 else "assistant"
|
||||
assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||
assert isinstance(message["content"], str), f"Message {i} content must be a string"
|
||||
if message["role"] == "user":
|
||||
assert isinstance(message["content"], str), f"Message {i} user content must be a string"
|
||||
else:
|
||||
_validate_assistant_content(message["content"], i)
|
||||
|
||||
self.conversations.append(messages)
|
||||
|
||||
|
|
@ -62,4 +90,3 @@ class CustomJSON(Task):
|
|||
"messages": messages,
|
||||
}
|
||||
return conversation
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user