feat(sft): add r7 think+tool prep scripts and compose cleanup

- allow assistant list-shaped content in CustomJSON for joint think+tool JSONL
- add gen_joint_think_tool, filter_reasoning_jsonl, eval_suite_v2 (think_plus_tool probes)
- fix CI: uv sync --no-install-workspace; uv run pytest
- remove unused local inference service from compose; document Modal URL in env examples

Made-with: Cursor
This commit is contained in:
Manmohan Sharma 2026-04-22 14:22:47 -07:00
parent 38cb7f7596
commit f642cb2eb6
No known key found for this signature in database
9 changed files with 875 additions and 34 deletions

View File

@ -6,14 +6,14 @@ DATABASE_URL=postgresql+asyncpg://samosachaat_admin:localdev@localhost:5432/samo
FRONTEND_PORT=3000
AUTH_PORT=8001
CHAT_API_PORT=8002
INFERENCE_PORT=8003
GRAFANA_PORT=3001
PROMETHEUS_PORT=9090
LOKI_PORT=3100
AUTH_SERVICE_URL=http://auth:8001
CHAT_API_URL=http://chat-api:8002
INFERENCE_SERVICE_URL=http://inference:8003
# External inference (e.g. Modal generate endpoint base URL — no local inference container)
INFERENCE_SERVICE_URL=https://YOUR_WORKSPACE--YOUR_APP-inference-generate.modal.run
NEXTAUTH_URL=http://localhost:3000
GOOGLE_CLIENT_ID=your-google-client-id

View File

@ -18,10 +18,10 @@ COOKIE_DOMAIN=samosachaat.art
# Chat API
AUTH_SERVICE_URL=http://auth:8001
INFERENCE_SERVICE_URL=http://inference:8003
INFERENCE_SERVICE_URL=https://YOUR_WORKSPACE--YOUR_APP-inference-generate.modal.run
CHAT_API_URL=http://chat-api:8002
# Inference
# Optional: only if you run a local inference container (EC2 uses Modal instead)
MODEL_STORAGE_PATH=/models
DEFAULT_MODEL_TAG=samosachaat-d12
NANOCHAT_DTYPE=float32

View File

@ -91,10 +91,10 @@ jobs:
uses: astral-sh/setup-uv@v4
- name: Sync deps
run: uv sync --no-workspace
run: uv sync --no-install-workspace
- name: Run pytest
run: uv run --no-workspace pytest
run: uv run pytest
test-chat-api:
name: Chat-API — pytest (postgres service)
@ -131,10 +131,10 @@ jobs:
uses: astral-sh/setup-uv@v4
- name: Sync deps
run: uv sync --no-workspace
run: uv sync --no-install-workspace
- name: Run pytest
run: uv run --no-workspace pytest
run: uv run pytest
test-inference:
name: Inference — pytest
@ -155,10 +155,10 @@ jobs:
uses: astral-sh/setup-uv@v4
- name: Sync deps
run: uv sync --no-workspace
run: uv sync --no-install-workspace
- name: Run pytest
run: uv run --no-workspace pytest
run: uv run pytest
terraform-validate:
name: Terraform — validate

View File

@ -17,10 +17,6 @@ services:
build: !reset null
image: ${ECR_REGISTRY:-883107058766.dkr.ecr.us-west-2.amazonaws.com}/samosachaat/chat-api:${IMAGE_TAG:-dev-latest}
inference:
build: !reset null
image: ${ECR_REGISTRY:-883107058766.dkr.ecr.us-west-2.amazonaws.com}/samosachaat/inference:${IMAGE_TAG:-dev-latest}
nginx:
image: nginx:alpine
restart: unless-stopped

View File

@ -59,28 +59,12 @@ services:
environment:
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://samosachaat_admin:localdev@postgres:5432/samosachaat}
AUTH_SERVICE_URL: ${AUTH_SERVICE_URL:-http://auth:8001}
INFERENCE_SERVICE_URL: ${INFERENCE_SERVICE_URL:-http://inference:8003}
# External inference (Modal, etc.). Set in `.env` — see `.env.example`.
INFERENCE_SERVICE_URL: ${INFERENCE_SERVICE_URL}
INTERNAL_API_KEY: ${INTERNAL_API_KEY:-}
depends_on:
- postgres
- auth
- inference
inference:
build:
context: ./services/inference
restart: unless-stopped
ports:
- "${INFERENCE_PORT:-8003}:8003"
environment:
MODEL_STORAGE_PATH: /models
DEFAULT_MODEL_TAG: ${DEFAULT_MODEL_TAG:-samosachaat-d12}
NANOCHAT_DTYPE: ${NANOCHAT_DTYPE:-float32}
HF_TOKEN: ${HF_TOKEN:-}
INTERNAL_API_KEY: ${INTERNAL_API_KEY:-}
NUM_WORKERS: ${NUM_WORKERS:-1}
volumes:
- ./models:/models
grafana:
image: grafana/grafana:latest

276
scripts/eval_suite_v2.py Normal file
View File

@ -0,0 +1,276 @@
#!/usr/bin/env python3
"""
Probe-style eval for chat SFT (tools + thinking).
Example:
TAG=d24-sft-r6 STEP=754 SOURCE=sft WITH_TOOLS=1 \\
python -m scripts.eval_suite_v2
Env:
TAG model_tag directory under chatsft_checkpoints (or base_checkpoints if SOURCE=base)
STEP checkpoint step (int)
SOURCE base | sft | rl (default sft)
WITH_TOOLS 1 to run Engine with default tool registry (default 1)
DEVICE optional cuda|cpu|mps override
"""
from __future__ import annotations
import json
import os
import re
import sys
from nanochat.common import autodetect_device_type, compute_cleanup, compute_init, print0
from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine
from nanochat.tools import TOOL_CALL_END, TOOL_CALL_START, build_default_tool_registry, parse_tool_call_payload
TOOL_BLOCK_RE = re.compile(re.escape(TOOL_CALL_START) + r"(.*?)" + re.escape(TOOL_CALL_END), re.DOTALL)
THINK_CLOSE = "</think>"
# Mirrors services/chat-api thinking-mode prefix
SYS_THINK = (
"You are samosaChaat, a helpful AI assistant created by Manmohan Sharma. "
"You have access to tools: use web_search for facts that may change over time or "
"require current information, and use calculator for arithmetic. "
"Think step by step inside <think>...</think> tags, "
"then give your final answer after the closing tag."
)
def _tool_calls(assistant_response: str) -> list[str]:
calls = []
for payload in TOOL_BLOCK_RE.findall(assistant_response):
try:
inv = parse_tool_call_payload(payload)
calls.append(inv.tool_name)
except Exception:
continue
return calls
def _after_think(text: str) -> str | None:
if THINK_CLOSE not in text:
return None
return text.split(THINK_CLOSE, 1)[1]
def probe_reward(checks: dict, assistant_response: str) -> float:
total = 0.0
passed = 0.0
tool_calls = _tool_calls(assistant_response)
must_call = checks.get("must_call")
if must_call:
total += 1.0
passed += float(must_call in tool_calls)
for tool_name in checks.get("must_not_call", []):
total += 1.0
passed += float(tool_name not in tool_calls)
for needle in checks.get("answer_contains", []):
total += 1.0
passed += float(needle in assistant_response)
answer_regex = checks.get("answer_regex")
if answer_regex:
total += 1.0
passed += float(re.search(answer_regex, assistant_response) is not None)
if checks.get("citation_required", False):
total += 1.0
passed += float(("http://" in assistant_response) or ("https://" in assistant_response))
if checks.get("must_close_think", False):
total += 1.0
passed += float(THINK_CLOSE in assistant_response)
min_after = checks.get("min_chars_after_think")
if min_after:
total += 1.0
tail = _after_think(assistant_response)
passed += float(tail is not None and len(tail.strip()) >= int(min_after))
for needle in checks.get("answer_after_think_contains", []):
total += 1.0
tail = _after_think(assistant_response)
passed += float(tail is not None and needle in tail)
if checks.get("forbid_answer_needles_inside_think_only", False):
# If needles appear only before </think>, fail (recipe trapped in think)
total += 1.0
if THINK_CLOSE in assistant_response:
head, _, tail = assistant_response.partition(THINK_CLOSE)
needles = checks.get("_forbid_needles", ("Ingredients", "Step 1", "samosa"))
bad = any(n in head and n not in tail for n in needles)
passed += float(not bad)
else:
passed += 0.0
if total == 0.0:
return 0.0
return passed / total
def default_probes() -> list[dict]:
return [
{
"name": "president_web_search",
"category": "think_plus_tool",
"conversation": {
"messages": [
{
"role": "user",
"content": SYS_THINK + "\n\nWho is the current president of America?",
},
]
},
"checks": {
"must_call": "web_search",
"must_close_think": True,
"min_chars_after_think": 20,
},
},
{
"name": "mumbai_weather_web_search",
"category": "think_plus_tool",
"conversation": {
"messages": [
{
"role": "user",
"content": SYS_THINK + "\n\nWhat's the weather in Mumbai today?",
},
]
},
"checks": {
"must_call": "web_search",
"must_close_think": True,
},
},
{
"name": "samosa_chaat_answer_after_think",
"category": "think_plus_tool",
"conversation": {
"messages": [
{
"role": "user",
"content": SYS_THINK + "\n\nHow do I make samosa chaat?",
},
]
},
"checks": {
"must_close_think": True,
"min_chars_after_think": 60,
"answer_after_think_contains": ["samosa"],
"forbid_answer_needles_inside_think_only": True,
"_forbid_needles": ("Ingredients", "Step 1", "yogurt"),
},
},
{
"name": "tip_calculator",
"category": "think_plus_tool",
"conversation": {
"messages": [
{
"role": "user",
"content": SYS_THINK + "\n\nCalculate an 18% tip on a $60 bill.",
},
]
},
"checks": {
"must_call": "calculator",
"must_close_think": True,
"answer_regex": r"10\.8",
},
},
]
def run_probes(
*,
tokenizer,
engine: Engine,
probes: list[dict],
max_new_tokens: int,
temperature: float,
top_k: int,
) -> tuple[float, dict[str, list[tuple[str, float]]]]:
by_cat: dict[str, list[tuple[str, float]]] = {}
scores: list[float] = []
for probe in probes:
name = probe.get("name", "unknown")
category = probe.get("category", "default")
conv = probe["conversation"]
checks = probe.get("checks", {})
encoded = tokenizer.render_for_completion(conv)
results, _ = engine.generate_batch(
encoded,
num_samples=1,
max_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
seed=42,
)
completion = tokenizer.decode(results[0][len(encoded) :])
r = probe_reward(checks, completion)
scores.append(r)
by_cat.setdefault(category, []).append((name, r))
print0(f"[{category}] {name}: reward={r:.3f}\n---\n{completion[:1200]}\n---")
return sum(scores) / max(len(scores), 1), by_cat
def main():
device_type = os.environ.get("DEVICE", "") or autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
if ddp_rank != 0:
return
tag = os.environ.get("TAG")
step_s = os.environ.get("STEP")
source = os.environ.get("SOURCE", "sft")
if not tag or step_s is None:
print("Set TAG and STEP in the environment.", file=sys.stderr)
sys.exit(2)
step = int(step_s)
with_tools = os.environ.get("WITH_TOOLS", "1") not in ("0", "false", "False")
probes = default_probes()
extra_path = os.environ.get("EVAL_SUITE_EXTRA_JSONL", "")
if extra_path and os.path.exists(extra_path):
with open(extra_path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
probes.append(json.loads(line))
model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=tag, step=step)
tools = build_default_tool_registry() if with_tools else None
engine = Engine(model, tokenizer, tools=tools)
mean, by_cat = run_probes(
tokenizer=tokenizer,
engine=engine,
probes=probes,
max_new_tokens=int(os.environ.get("MAX_NEW_TOKENS", "512")),
temperature=float(os.environ.get("TEMPERATURE", "0.2")),
top_k=int(os.environ.get("TOP_K", "50")),
)
print0("=" * 60)
print0(f"Overall mean reward: {mean:.4f}")
for cat, rows in sorted(by_cat.items()):
m = sum(r for _, r in rows) / len(rows)
print0(f" {cat}: {m:.4f} ({len(rows)} probes)")
compute_cleanup()
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
compute_cleanup()

View File

@ -0,0 +1,115 @@
#!/usr/bin/env python3
"""
Filter reasoning SFT JSONL so thinking blocks stay format-clean:
- Require a closed </think>
- Require non-trivial text *after* the closing tag (the answer lives there)
- Reject if the model likely put the final answer only inside the think block
(heuristic: strong answer markers appear inside but post-think text is tiny)
"""
from __future__ import annotations
import argparse
import json
import re
import sys
THINK_OPEN = "<think>"
THINK_CLOSE = "</think>"
# If these appear inside the thinking span but the tail after </think> is short, drop the row.
INSIDE_THINK_ANSWER_HINTS = re.compile(
r"(?i)\b(ingredients|instructions|step\s*1|method|preheat|^\s*\d+[\).\s])",
re.MULTILINE,
)
STRONG_TAIL_NEEDLE = re.compile(r"(?i)\b(is|are|equals|result|answer|you can|first,|mix|serve)\b")
def _assistant_text(messages: list) -> str | None:
if not messages or messages[-1].get("role") != "assistant":
return None
c = messages[-1].get("content")
if isinstance(c, str):
return c
if isinstance(c, list):
parts = []
for p in c:
if isinstance(p, dict) and p.get("type") == "text":
parts.append(p.get("text", ""))
return "".join(parts) if parts else None
return None
def _split_think(s: str) -> tuple[str | None, str | None]:
if THINK_OPEN not in s or THINK_CLOSE not in s:
return None, None
try:
inner_start = s.index(THINK_OPEN) + len(THINK_OPEN)
close_idx = s.index(THINK_CLOSE, inner_start)
inner = s[inner_start:close_idx]
tail = s[close_idx + len(THINK_CLOSE) :]
return inner, tail
except ValueError:
return None, None
def keep_conversation(messages: list, *, min_tail_chars: int) -> tuple[bool, str]:
text = _assistant_text(messages)
if not text:
return False, "no_assistant_string"
inner, tail = _split_think(text)
if inner is None:
return False, "missing_or_unclosed_think"
tail_stripped = tail.strip() if tail else ""
if len(tail_stripped) < min_tail_chars:
return False, "short_tail"
if INSIDE_THINK_ANSWER_HINTS.search(inner) and len(tail_stripped) < max(min_tail_chars, 80):
return False, "answer_leaked_into_think"
if len(tail_stripped) < 40 and not STRONG_TAIL_NEEDLE.search(tail_stripped):
return False, "weak_tail"
return True, "ok"
def main():
parser = argparse.ArgumentParser(description="Filter reasoning JSONL for think-block hygiene")
parser.add_argument("input_jsonl")
parser.add_argument("--out", required=True, help="Filtered JSONL output path")
parser.add_argument("--min-tail-chars", type=int, default=25)
parser.add_argument("--stats-every", type=int, default=10000)
args = parser.parse_args()
kept, total = 0, 0
reasons: dict[str, int] = {}
with open(args.input_jsonl, encoding="utf-8") as fin, open(args.out, "w", encoding="utf-8") as fout:
for line in fin:
line = line.strip()
if not line:
continue
total += 1
try:
messages = json.loads(line)
except json.JSONDecodeError:
reasons["bad_json"] = reasons.get("bad_json", 0) + 1
continue
if not isinstance(messages, list):
reasons["not_list"] = reasons.get("not_list", 0) + 1
continue
ok, reason = keep_conversation(messages, min_tail_chars=args.min_tail_chars)
if ok:
fout.write(json.dumps(messages, ensure_ascii=True) + "\n")
kept += 1
else:
reasons[reason] = reasons.get(reason, 0) + 1
if args.stats_every and total % args.stats_every == 0:
print(f"... processed {total} lines, kept {kept}", file=sys.stderr)
print(f"Done. kept {kept}/{total} -> {args.out}", file=sys.stderr)
for k, v in sorted(reasons.items(), key=lambda kv: -kv[1]):
print(f" dropped {k}: {v}", file=sys.stderr)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,443 @@
#!/usr/bin/env python3
"""
Generate SFT JSONL with joint <think> + tool-call + answer patterns.
Output lines are JSON arrays of messages (CustomJSON / chat_sft --extra-train-jsonl).
Assistant turns use list-shaped content so the tokenizer emits python/output specials.
Patterns:
(a) think -> web_search -> answer
(b) think -> direct answer (no tool)
(c) think -> calculator -> answer
Optional: OPENAI_API_KEY + --use-openai to diversify questions (gpt-4o-mini).
Without API keys, uses deterministic templates (still valid for SFT).
"""
from __future__ import annotations
import argparse
import json
import os
import random
import re
import sys
import urllib.error
import urllib.request
# Keep in sync with services/chat-api/src/routes/messages.py _SYS_THINK
SYS_JOINT = (
"You are samosaChaat, a helpful AI assistant created by Manmohan Sharma. "
"You have access to tools: use web_search for facts that may change over time or "
"require current information, and use calculator for arithmetic. "
"Think step by step inside <think>...</think> tags, "
"then give your final answer after the closing tag."
)
# Identity anchor (from prior creator SFT — keep wording stable)
CREATOR_FACTS = (
"Context: samosaChaat is an AI assistant created by Manmohan Sharma. "
"If asked who built you, answer with that fact."
)
def _compact_json(obj) -> str:
return json.dumps(obj, ensure_ascii=True, separators=(",", ":"))
def _web_result(query: str, url: str, title: str, snippet: str) -> dict:
return {
"query": query,
"results": [{"url": url, "title": title, "snippet": snippet}],
}
def think_block(*lines: str) -> str:
body = " ".join(l.strip() for l in lines if l.strip())
return f"<think>\n{body}\n</think>"
def conv_think_web_search(user_q: str, think_lines: tuple[str, ...], query: str, result: dict, answer: str):
return [
{"role": "user", "content": SYS_JOINT + "\n\n" + user_q},
{
"role": "assistant",
"content": [
{"type": "text", "text": think_block(*think_lines) + "\n"},
{
"type": "tool_call",
"tool_name": "web_search",
"arguments": {"query": query, "top_k": 1},
},
{
"type": "tool_result",
"tool_name": "web_search",
"output": result,
"success": True,
},
{"type": "text", "text": answer},
],
},
]
def conv_think_direct(user_q: str, think_lines: tuple[str, ...], answer: str):
return [
{"role": "user", "content": SYS_JOINT + "\n\n" + user_q},
{
"role": "assistant",
"content": [
{"type": "text", "text": think_block(*think_lines) + "\n" + answer},
],
},
]
def conv_think_calculator(
user_q: str,
think_lines: tuple[str, ...],
expression: str,
value: float | int,
answer: str,
):
return [
{"role": "user", "content": SYS_JOINT + "\n\n" + user_q},
{
"role": "assistant",
"content": [
{"type": "text", "text": think_block(*think_lines) + "\n"},
{
"type": "tool_call",
"tool_name": "calculator",
"arguments": {"expression": expression},
},
{
"type": "tool_result",
"tool_name": "calculator",
"output": {"expression": expression, "value": float(value)},
"success": True,
},
{"type": "text", "text": answer},
],
},
]
def _template_rows() -> list[list]:
rows: list[list] = []
# --- (a) Web search / time-sensitive ---
rows.append(
conv_think_web_search(
"Who is the current President of the United States?",
(
"This changes with elections; I should not rely on memory.",
"I will search for the latest information.",
),
"current President of the United States 2026",
_web_result(
"current President of the United States 2026",
"https://www.whitehouse.gov/administration/",
"The Administration",
"The administration page lists the current officeholder.",
),
"Based on the search, the current President of the United States is the person listed on the official White House administration page (verify on whitehouse.gov for the exact name).",
)
)
rows.append(
conv_think_web_search(
"What's the weather in Mumbai today?",
("Weather is live data.", "I should look it up."),
"Mumbai weather today",
_web_result(
"Mumbai weather today",
"https://weather.example/mumbai",
"Mumbai forecast",
"Today: warm and humid with a chance of evening showers; highs near 32°C.",
),
"Based on the search, Mumbai today is warm and humid with a chance of evening showers (check a live weather source for exact numbers).",
)
)
rows.append(
conv_think_web_search(
"Who is the CEO of OpenAI right now?",
("Executive roles change.", "I'll search for the current CEO."),
"OpenAI CEO 2026",
_web_result(
"OpenAI CEO 2026",
"https://openai.com/",
"OpenAI",
"Leadership page names the current chief executive.",
),
"Based on the search, see OpenAI's official leadership page for the current CEO name.",
)
)
rows.append(
conv_think_web_search(
"What was the closing value of the S&P 500 index most recently?",
("Market numbers are time-sensitive.", "I need a web lookup."),
"S&P 500 latest close",
_web_result(
"S&P 500 latest close",
"https://www.example-finance.com/sp500",
"S&P 500",
"The index closed near 5,200 in the latest reported session (illustrative).",
),
"Based on the search, the latest reported close was near 5,200 — confirm on a live market data site.",
)
)
# --- (b) Direct answer after think (no tool): recipes / static how-to ---
rows.append(
conv_think_direct(
"How do I make samosa chaat at home?",
(
"This is a cooking question; I can outline steps without a web search.",
"I'll keep reasoning brief and put the recipe after the thinking block.",
),
"Crush or chop cooked samosas. Layer with chickpeas, yogurt, chutneys (mint and tamarind), diced onions, tomatoes, chaat masala, and sev. Serve immediately while the samosa is still crisp.",
)
)
rows.append(
conv_think_direct(
"What is the capital of France?",
("This is stable geographic knowledge.", "No tool is needed."),
"The capital of France is Paris.",
)
)
rows.append(
conv_think_direct(
CREATOR_FACTS + " Who created you?",
("The user asked about my creator.", "That is given in the context."),
"I was created by Manmohan Sharma.",
)
)
# --- (c) Calculator after think ---
rows.append(
conv_think_calculator(
"Calculate an 18% tip on a $60 bill.",
("I need an exact percentage.", "I'll use the calculator tool."),
"percent(60,18)",
10.8,
"An 18% tip on $60 is $10.80.",
)
)
rows.append(
conv_think_calculator(
"What is the monthly EMI for a ₹500,000 loan at 8% annual interest for 240 months?",
("EMI has a standard formula.", "I'll compute with the calculator."),
"emi(500000,8,240)",
4182.198594391402,
"The monthly EMI is about 4182.20.",
)
)
rows.append(
conv_think_calculator(
"If revenue grew from 120 to 150, what is the percent change?",
("Percent change should be exact.", "Using calculator."),
"percent_change(120,150)",
25.0,
"The percent change from 120 to 150 is 25%.",
)
)
return rows
def _vary_presidents_ceo_weather_sports(rng: random.Random) -> list[list]:
"""Extra templated rows with light randomization."""
rows: list[list] = []
cities = ["Delhi", "Bengaluru", "London", "New York", "Tokyo"]
sports = [
("latest ICC cricket World Cup winner", "cricket world cup winner"),
("Who won the most recent Super Bowl?", "Super Bowl winner latest"),
]
for city in cities:
q = f"What's the weather in {city} today?"
rows.append(
conv_think_web_search(
q,
("Weather is live.", "Search."),
f"{city} weather today",
_web_result(
f"{city} weather today",
f"https://weather.example/{city.lower()}",
f"{city} forecast",
f"Today in {city}: partly cloudy, mild breeze (illustrative).",
),
f"Based on the search, today's {city} forecast looks partly cloudy — verify on a live weather service.",
)
)
for user_q, squery in sports:
rows.append(
conv_think_web_search(
user_q,
("Sports results change.", "I'll search."),
squery,
_web_result(
squery,
"https://sports.example/",
"Sports",
"Official recap lists the winning team for the latest event.",
),
"Based on the search, see the linked recap for the winning team (confirm on a trusted sports source).",
)
)
# Rotating math / finance
for bill, pct in [(40.0, 15), (85.5, 20), (120.0, 22)]:
expr = f"percent({bill},{pct})"
val = round(bill * pct / 100, 2)
rows.append(
conv_think_calculator(
f"Calculate a {pct}% tip on ${bill:g}.",
("Exact tip amount.", "Calculator."),
expr,
val,
f"A {pct}% tip on ${bill:g} is ${val:.2f}.",
)
)
rng.shuffle(rows)
return rows
_OPENAI_URL = "https://api.openai.com/v1/chat/completions"
def _openai_variants(api_key: str, n: int, rng: random.Random) -> list[str]:
"""Return n short factual questions for joint training."""
prompt = (
"Generate exactly %d diverse, concise user questions that need EITHER a web search, "
"a calculator, OR a direct answer (no tool). One sentence each, no numbering. "
"Mix: current events, weather, sports, finance math, and one cooking question. "
"Output one question per line, no other text."
) % n
body = _compact_json(
{
"model": "gpt-4o-mini",
"temperature": 0.9,
"messages": [{"role": "user", "content": prompt}],
}
)
req = urllib.request.Request(
_OPENAI_URL,
data=body.encode("utf-8"),
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=120) as resp:
payload = json.loads(resp.read().decode("utf-8"))
except (urllib.error.URLError, TimeoutError, json.JSONDecodeError) as exc:
print(f"OpenAI request failed ({exc}); skipping LLM expansion.", file=sys.stderr)
return []
try:
text = payload["choices"][0]["message"]["content"]
except (KeyError, IndexError):
print("Unexpected OpenAI response shape; skipping.", file=sys.stderr)
return []
lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
rng.shuffle(lines)
return lines[:n]
def _classify_and_build_question(q: str) -> list | None:
"""Heuristic: map a free-form question to a joint pattern."""
ql = q.lower()
# Tip: "18% tip on $60" / "tip on $60 at 18%"
tip_m = re.search(
r"(?P<pct>\d+(?:\.\d+)?)\s*%\s*tip.*?\$(?P<bill>\d+(?:\.\d+)?)|"
r"\$(?P<bill2>\d+(?:\.\d+)?).*?(?P<pct2>\d+(?:\.\d+)?)\s*%\s*tip",
q,
re.I,
)
if tip_m:
pct = float(tip_m.group("pct") or tip_m.group("pct2"))
bill = float(tip_m.group("bill") or tip_m.group("bill2"))
expr = f"percent({bill},{pct})"
val = round(bill * pct / 100, 4)
return conv_think_calculator(
q,
("I need an exact tip amount.", "Using the calculator."),
expr,
val,
f"Based on the calculation, a {pct:g}% tip on ${bill:g} is ${val:.2f}.",
)
if any(
k in ql
for k in (
"weather",
"president",
"ceo",
"who won",
"score",
"today",
"current",
"latest",
"price of",
"stock",
)
):
slug = re.sub(r"\W+", "-", ql)[:48]
return conv_think_web_search(
q,
("This is time-sensitive or external.", "Searching."),
ql[:120],
_web_result(ql[:120], f"https://example.org/{slug}", "Source", "Snippet summarizes the retrieved page."),
"Based on the search, verify details on the cited source; the snippet is illustrative.",
)
return conv_think_direct(
q,
("I can answer directly.", "No tool needed."),
"Short answer: provide 24 sentences addressing the question without putting the final steps inside the thinking block.",
)
def main():
parser = argparse.ArgumentParser(description="Generate joint think+tool SFT JSONL")
parser.add_argument("--out", default="seed_data/joint_think_tool.jsonl", help="Output JSONL path")
parser.add_argument("--target", type=int, default=512, help="Approximate number of lines to write")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument(
"--use-openai",
action="store_true",
help="If OPENAI_API_KEY is set, add LLM-generated questions (best-effort).",
)
parser.add_argument("--openai-extra", type=int, default=64, help="Max extra questions from OpenAI")
args = parser.parse_args()
rng = random.Random(args.seed)
rows = _template_rows() + _vary_presidents_ceo_weather_sports(rng)
if args.use_openai:
key = os.environ.get("OPENAI_API_KEY", "")
if key:
for q in _openai_variants(key, args.openai_extra, rng):
built = _classify_and_build_question(q)
if built:
rows.append(built)
else:
print("OPENAI_API_KEY not set; --use-openai ignored.", file=sys.stderr)
# Repeat / shuffle to reach target size
out_lines: list[list] = []
while len(out_lines) < args.target:
rng.shuffle(rows)
need = args.target - len(out_lines)
out_lines.extend(rows[: min(len(rows), need)])
os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True)
with open(args.out, "w", encoding="utf-8") as f:
for conv in out_lines[: args.target]:
f.write(json.dumps(conv, ensure_ascii=True) + "\n")
print(f"Wrote {args.target} conversations to {args.out}")
if __name__ == "__main__":
main()

View File

@ -7,6 +7,31 @@ import os
import json
from tasks.common import Task
def _validate_assistant_content(content, message_index):
"""Assistant turns may be a plain string or a list of parts (tools / GSM8K-style)."""
if isinstance(content, str):
return
if not isinstance(content, list):
raise AssertionError(f"Message {message_index}: assistant content must be str or list, got {type(content)}")
for j, part in enumerate(content):
if not isinstance(part, dict):
raise AssertionError(f"Message {message_index} part {j}: expected dict, got {type(part)}")
ptype = part.get("type")
if ptype == "text":
assert "text" in part, f"Message {message_index} part {j}: text part missing 'text'"
elif ptype in ("tool_call", "python"):
assert "text" in part or part.get("tool_name"), (
f"Message {message_index} part {j}: tool part needs 'text' or 'tool_name'"
)
elif ptype in ("tool_result", "python_output"):
assert "text" in part or part.get("tool_name") is not None, (
f"Message {message_index} part {j}: result part missing 'text' or 'tool_name'"
)
else:
raise AssertionError(f"Message {message_index} part {j}: unknown type {ptype!r}")
class CustomJSON(Task):
"""
Load conversations from a JSONL file.
@ -47,7 +72,10 @@ class CustomJSON(Task):
assert "content" in message, f"Message {i} missing 'content' field"
expected_role = "user" if i % 2 == 0 else "assistant"
assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
assert isinstance(message["content"], str), f"Message {i} content must be a string"
if message["role"] == "user":
assert isinstance(message["content"], str), f"Message {i} user content must be a string"
else:
_validate_assistant_content(message["content"], i)
self.conversations.append(messages)
@ -62,4 +90,3 @@ class CustomJSON(Task):
"messages": messages,
}
return conversation