diff --git a/.env.example b/.env.example index 57e11e37..c87ff65d 100644 --- a/.env.example +++ b/.env.example @@ -6,14 +6,14 @@ DATABASE_URL=postgresql+asyncpg://samosachaat_admin:localdev@localhost:5432/samo FRONTEND_PORT=3000 AUTH_PORT=8001 CHAT_API_PORT=8002 -INFERENCE_PORT=8003 GRAFANA_PORT=3001 PROMETHEUS_PORT=9090 LOKI_PORT=3100 AUTH_SERVICE_URL=http://auth:8001 CHAT_API_URL=http://chat-api:8002 -INFERENCE_SERVICE_URL=http://inference:8003 +# External inference (e.g. Modal generate endpoint base URL — no local inference container) +INFERENCE_SERVICE_URL=https://YOUR_WORKSPACE--YOUR_APP-inference-generate.modal.run NEXTAUTH_URL=http://localhost:3000 GOOGLE_CLIENT_ID=your-google-client-id diff --git a/.env.production.example b/.env.production.example index 2a13bd78..d285c20c 100644 --- a/.env.production.example +++ b/.env.production.example @@ -18,10 +18,10 @@ COOKIE_DOMAIN=samosachaat.art # Chat API AUTH_SERVICE_URL=http://auth:8001 -INFERENCE_SERVICE_URL=http://inference:8003 +INFERENCE_SERVICE_URL=https://YOUR_WORKSPACE--YOUR_APP-inference-generate.modal.run CHAT_API_URL=http://chat-api:8002 -# Inference +# Optional: only if you run a local inference container (EC2 uses Modal instead) MODEL_STORAGE_PATH=/models DEFAULT_MODEL_TAG=samosachaat-d12 NANOCHAT_DTYPE=float32 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b1a5ab3b..5cad1e3e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -91,10 +91,10 @@ jobs: uses: astral-sh/setup-uv@v4 - name: Sync deps - run: uv sync --no-workspace + run: uv sync --no-install-workspace - name: Run pytest - run: uv run --no-workspace pytest + run: uv run pytest test-chat-api: name: Chat-API — pytest (postgres service) @@ -131,10 +131,10 @@ jobs: uses: astral-sh/setup-uv@v4 - name: Sync deps - run: uv sync --no-workspace + run: uv sync --no-install-workspace - name: Run pytest - run: uv run --no-workspace pytest + run: uv run pytest test-inference: name: Inference — pytest @@ -155,10 +155,10 @@ jobs: uses: astral-sh/setup-uv@v4 - name: Sync deps - run: uv sync --no-workspace + run: uv sync --no-install-workspace - name: Run pytest - run: uv run --no-workspace pytest + run: uv run pytest terraform-validate: name: Terraform — validate diff --git a/docker-compose.prod.yml b/docker-compose.prod.yml index 9b068cf1..b6caf619 100644 --- a/docker-compose.prod.yml +++ b/docker-compose.prod.yml @@ -17,10 +17,6 @@ services: build: !reset null image: ${ECR_REGISTRY:-883107058766.dkr.ecr.us-west-2.amazonaws.com}/samosachaat/chat-api:${IMAGE_TAG:-dev-latest} - inference: - build: !reset null - image: ${ECR_REGISTRY:-883107058766.dkr.ecr.us-west-2.amazonaws.com}/samosachaat/inference:${IMAGE_TAG:-dev-latest} - nginx: image: nginx:alpine restart: unless-stopped diff --git a/docker-compose.yml b/docker-compose.yml index 142ae134..b5389904 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -59,28 +59,12 @@ services: environment: DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://samosachaat_admin:localdev@postgres:5432/samosachaat} AUTH_SERVICE_URL: ${AUTH_SERVICE_URL:-http://auth:8001} - INFERENCE_SERVICE_URL: ${INFERENCE_SERVICE_URL:-http://inference:8003} + # External inference (Modal, etc.). Set in `.env` — see `.env.example`. + INFERENCE_SERVICE_URL: ${INFERENCE_SERVICE_URL} INTERNAL_API_KEY: ${INTERNAL_API_KEY:-} depends_on: - postgres - auth - - inference - - inference: - build: - context: ./services/inference - restart: unless-stopped - ports: - - "${INFERENCE_PORT:-8003}:8003" - environment: - MODEL_STORAGE_PATH: /models - DEFAULT_MODEL_TAG: ${DEFAULT_MODEL_TAG:-samosachaat-d12} - NANOCHAT_DTYPE: ${NANOCHAT_DTYPE:-float32} - HF_TOKEN: ${HF_TOKEN:-} - INTERNAL_API_KEY: ${INTERNAL_API_KEY:-} - NUM_WORKERS: ${NUM_WORKERS:-1} - volumes: - - ./models:/models grafana: image: grafana/grafana:latest diff --git a/scripts/eval_suite_v2.py b/scripts/eval_suite_v2.py new file mode 100644 index 00000000..a7a4d135 --- /dev/null +++ b/scripts/eval_suite_v2.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 +""" +Probe-style eval for chat SFT (tools + thinking). + +Example: + TAG=d24-sft-r6 STEP=754 SOURCE=sft WITH_TOOLS=1 \\ + python -m scripts.eval_suite_v2 + +Env: + TAG — model_tag directory under chatsft_checkpoints (or base_checkpoints if SOURCE=base) + STEP — checkpoint step (int) + SOURCE — base | sft | rl (default sft) + WITH_TOOLS — 1 to run Engine with default tool registry (default 1) + DEVICE — optional cuda|cpu|mps override +""" + +from __future__ import annotations + +import json +import os +import re +import sys + +from nanochat.common import autodetect_device_type, compute_cleanup, compute_init, print0 +from nanochat.checkpoint_manager import load_model +from nanochat.engine import Engine +from nanochat.tools import TOOL_CALL_END, TOOL_CALL_START, build_default_tool_registry, parse_tool_call_payload + +TOOL_BLOCK_RE = re.compile(re.escape(TOOL_CALL_START) + r"(.*?)" + re.escape(TOOL_CALL_END), re.DOTALL) +THINK_CLOSE = "" + +# Mirrors services/chat-api thinking-mode prefix +SYS_THINK = ( + "You are samosaChaat, a helpful AI assistant created by Manmohan Sharma. " + "You have access to tools: use web_search for facts that may change over time or " + "require current information, and use calculator for arithmetic. " + "Think step by step inside ... tags, " + "then give your final answer after the closing tag." +) + + +def _tool_calls(assistant_response: str) -> list[str]: + calls = [] + for payload in TOOL_BLOCK_RE.findall(assistant_response): + try: + inv = parse_tool_call_payload(payload) + calls.append(inv.tool_name) + except Exception: + continue + return calls + + +def _after_think(text: str) -> str | None: + if THINK_CLOSE not in text: + return None + return text.split(THINK_CLOSE, 1)[1] + + +def probe_reward(checks: dict, assistant_response: str) -> float: + total = 0.0 + passed = 0.0 + tool_calls = _tool_calls(assistant_response) + + must_call = checks.get("must_call") + if must_call: + total += 1.0 + passed += float(must_call in tool_calls) + + for tool_name in checks.get("must_not_call", []): + total += 1.0 + passed += float(tool_name not in tool_calls) + + for needle in checks.get("answer_contains", []): + total += 1.0 + passed += float(needle in assistant_response) + + answer_regex = checks.get("answer_regex") + if answer_regex: + total += 1.0 + passed += float(re.search(answer_regex, assistant_response) is not None) + + if checks.get("citation_required", False): + total += 1.0 + passed += float(("http://" in assistant_response) or ("https://" in assistant_response)) + + if checks.get("must_close_think", False): + total += 1.0 + passed += float(THINK_CLOSE in assistant_response) + + min_after = checks.get("min_chars_after_think") + if min_after: + total += 1.0 + tail = _after_think(assistant_response) + passed += float(tail is not None and len(tail.strip()) >= int(min_after)) + + for needle in checks.get("answer_after_think_contains", []): + total += 1.0 + tail = _after_think(assistant_response) + passed += float(tail is not None and needle in tail) + + if checks.get("forbid_answer_needles_inside_think_only", False): + # If needles appear only before , fail (recipe trapped in think) + total += 1.0 + if THINK_CLOSE in assistant_response: + head, _, tail = assistant_response.partition(THINK_CLOSE) + needles = checks.get("_forbid_needles", ("Ingredients", "Step 1", "samosa")) + bad = any(n in head and n not in tail for n in needles) + passed += float(not bad) + else: + passed += 0.0 + + if total == 0.0: + return 0.0 + return passed / total + + +def default_probes() -> list[dict]: + return [ + { + "name": "president_web_search", + "category": "think_plus_tool", + "conversation": { + "messages": [ + { + "role": "user", + "content": SYS_THINK + "\n\nWho is the current president of America?", + }, + ] + }, + "checks": { + "must_call": "web_search", + "must_close_think": True, + "min_chars_after_think": 20, + }, + }, + { + "name": "mumbai_weather_web_search", + "category": "think_plus_tool", + "conversation": { + "messages": [ + { + "role": "user", + "content": SYS_THINK + "\n\nWhat's the weather in Mumbai today?", + }, + ] + }, + "checks": { + "must_call": "web_search", + "must_close_think": True, + }, + }, + { + "name": "samosa_chaat_answer_after_think", + "category": "think_plus_tool", + "conversation": { + "messages": [ + { + "role": "user", + "content": SYS_THINK + "\n\nHow do I make samosa chaat?", + }, + ] + }, + "checks": { + "must_close_think": True, + "min_chars_after_think": 60, + "answer_after_think_contains": ["samosa"], + "forbid_answer_needles_inside_think_only": True, + "_forbid_needles": ("Ingredients", "Step 1", "yogurt"), + }, + }, + { + "name": "tip_calculator", + "category": "think_plus_tool", + "conversation": { + "messages": [ + { + "role": "user", + "content": SYS_THINK + "\n\nCalculate an 18% tip on a $60 bill.", + }, + ] + }, + "checks": { + "must_call": "calculator", + "must_close_think": True, + "answer_regex": r"10\.8", + }, + }, + ] + + +def run_probes( + *, + tokenizer, + engine: Engine, + probes: list[dict], + max_new_tokens: int, + temperature: float, + top_k: int, +) -> tuple[float, dict[str, list[tuple[str, float]]]]: + by_cat: dict[str, list[tuple[str, float]]] = {} + scores: list[float] = [] + + for probe in probes: + name = probe.get("name", "unknown") + category = probe.get("category", "default") + conv = probe["conversation"] + checks = probe.get("checks", {}) + + encoded = tokenizer.render_for_completion(conv) + results, _ = engine.generate_batch( + encoded, + num_samples=1, + max_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + seed=42, + ) + completion = tokenizer.decode(results[0][len(encoded) :]) + r = probe_reward(checks, completion) + scores.append(r) + by_cat.setdefault(category, []).append((name, r)) + print0(f"[{category}] {name}: reward={r:.3f}\n---\n{completion[:1200]}\n---") + + return sum(scores) / max(len(scores), 1), by_cat + + +def main(): + device_type = os.environ.get("DEVICE", "") or autodetect_device_type() + ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) + if ddp_rank != 0: + return + + tag = os.environ.get("TAG") + step_s = os.environ.get("STEP") + source = os.environ.get("SOURCE", "sft") + if not tag or step_s is None: + print("Set TAG and STEP in the environment.", file=sys.stderr) + sys.exit(2) + step = int(step_s) + with_tools = os.environ.get("WITH_TOOLS", "1") not in ("0", "false", "False") + + probes = default_probes() + extra_path = os.environ.get("EVAL_SUITE_EXTRA_JSONL", "") + if extra_path and os.path.exists(extra_path): + with open(extra_path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + probes.append(json.loads(line)) + + model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=tag, step=step) + tools = build_default_tool_registry() if with_tools else None + engine = Engine(model, tokenizer, tools=tools) + + mean, by_cat = run_probes( + tokenizer=tokenizer, + engine=engine, + probes=probes, + max_new_tokens=int(os.environ.get("MAX_NEW_TOKENS", "512")), + temperature=float(os.environ.get("TEMPERATURE", "0.2")), + top_k=int(os.environ.get("TOP_K", "50")), + ) + + print0("=" * 60) + print0(f"Overall mean reward: {mean:.4f}") + for cat, rows in sorted(by_cat.items()): + m = sum(r for _, r in rows) / len(rows) + print0(f" {cat}: {m:.4f} ({len(rows)} probes)") + compute_cleanup() + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + compute_cleanup() diff --git a/scripts/filter_reasoning_jsonl.py b/scripts/filter_reasoning_jsonl.py new file mode 100644 index 00000000..34504238 --- /dev/null +++ b/scripts/filter_reasoning_jsonl.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +""" +Filter reasoning SFT JSONL so thinking blocks stay format-clean: + + - Require a closed + - Require non-trivial text *after* the closing tag (the answer lives there) + - Reject if the model likely put the final answer only inside the think block + (heuristic: strong answer markers appear inside but post-think text is tiny) +""" + +from __future__ import annotations + +import argparse +import json +import re +import sys + +THINK_OPEN = "" +THINK_CLOSE = "" + +# If these appear inside the thinking span but the tail after is short, drop the row. +INSIDE_THINK_ANSWER_HINTS = re.compile( + r"(?i)\b(ingredients|instructions|step\s*1|method|preheat|^\s*\d+[\).\s])", + re.MULTILINE, +) +STRONG_TAIL_NEEDLE = re.compile(r"(?i)\b(is|are|equals|result|answer|you can|first,|mix|serve)\b") + + +def _assistant_text(messages: list) -> str | None: + if not messages or messages[-1].get("role") != "assistant": + return None + c = messages[-1].get("content") + if isinstance(c, str): + return c + if isinstance(c, list): + parts = [] + for p in c: + if isinstance(p, dict) and p.get("type") == "text": + parts.append(p.get("text", "")) + return "".join(parts) if parts else None + return None + + +def _split_think(s: str) -> tuple[str | None, str | None]: + if THINK_OPEN not in s or THINK_CLOSE not in s: + return None, None + try: + inner_start = s.index(THINK_OPEN) + len(THINK_OPEN) + close_idx = s.index(THINK_CLOSE, inner_start) + inner = s[inner_start:close_idx] + tail = s[close_idx + len(THINK_CLOSE) :] + return inner, tail + except ValueError: + return None, None + + +def keep_conversation(messages: list, *, min_tail_chars: int) -> tuple[bool, str]: + text = _assistant_text(messages) + if not text: + return False, "no_assistant_string" + inner, tail = _split_think(text) + if inner is None: + return False, "missing_or_unclosed_think" + tail_stripped = tail.strip() if tail else "" + if len(tail_stripped) < min_tail_chars: + return False, "short_tail" + if INSIDE_THINK_ANSWER_HINTS.search(inner) and len(tail_stripped) < max(min_tail_chars, 80): + return False, "answer_leaked_into_think" + if len(tail_stripped) < 40 and not STRONG_TAIL_NEEDLE.search(tail_stripped): + return False, "weak_tail" + return True, "ok" + + +def main(): + parser = argparse.ArgumentParser(description="Filter reasoning JSONL for think-block hygiene") + parser.add_argument("input_jsonl") + parser.add_argument("--out", required=True, help="Filtered JSONL output path") + parser.add_argument("--min-tail-chars", type=int, default=25) + parser.add_argument("--stats-every", type=int, default=10000) + args = parser.parse_args() + + kept, total = 0, 0 + reasons: dict[str, int] = {} + + with open(args.input_jsonl, encoding="utf-8") as fin, open(args.out, "w", encoding="utf-8") as fout: + for line in fin: + line = line.strip() + if not line: + continue + total += 1 + try: + messages = json.loads(line) + except json.JSONDecodeError: + reasons["bad_json"] = reasons.get("bad_json", 0) + 1 + continue + if not isinstance(messages, list): + reasons["not_list"] = reasons.get("not_list", 0) + 1 + continue + ok, reason = keep_conversation(messages, min_tail_chars=args.min_tail_chars) + if ok: + fout.write(json.dumps(messages, ensure_ascii=True) + "\n") + kept += 1 + else: + reasons[reason] = reasons.get(reason, 0) + 1 + + if args.stats_every and total % args.stats_every == 0: + print(f"... processed {total} lines, kept {kept}", file=sys.stderr) + + print(f"Done. kept {kept}/{total} -> {args.out}", file=sys.stderr) + for k, v in sorted(reasons.items(), key=lambda kv: -kv[1]): + print(f" dropped {k}: {v}", file=sys.stderr) + + +if __name__ == "__main__": + main() diff --git a/scripts/gen_joint_think_tool.py b/scripts/gen_joint_think_tool.py new file mode 100644 index 00000000..2774c787 --- /dev/null +++ b/scripts/gen_joint_think_tool.py @@ -0,0 +1,443 @@ +#!/usr/bin/env python3 +""" +Generate SFT JSONL with joint + tool-call + answer patterns. + +Output lines are JSON arrays of messages (CustomJSON / chat_sft --extra-train-jsonl). +Assistant turns use list-shaped content so the tokenizer emits python/output specials. + +Patterns: + (a) think -> web_search -> answer + (b) think -> direct answer (no tool) + (c) think -> calculator -> answer + +Optional: OPENAI_API_KEY + --use-openai to diversify questions (gpt-4o-mini). +Without API keys, uses deterministic templates (still valid for SFT). +""" + +from __future__ import annotations + +import argparse +import json +import os +import random +import re +import sys +import urllib.error +import urllib.request + +# Keep in sync with services/chat-api/src/routes/messages.py _SYS_THINK +SYS_JOINT = ( + "You are samosaChaat, a helpful AI assistant created by Manmohan Sharma. " + "You have access to tools: use web_search for facts that may change over time or " + "require current information, and use calculator for arithmetic. " + "Think step by step inside ... tags, " + "then give your final answer after the closing tag." +) + +# Identity anchor (from prior creator SFT — keep wording stable) +CREATOR_FACTS = ( + "Context: samosaChaat is an AI assistant created by Manmohan Sharma. " + "If asked who built you, answer with that fact." +) + + +def _compact_json(obj) -> str: + return json.dumps(obj, ensure_ascii=True, separators=(",", ":")) + + +def _web_result(query: str, url: str, title: str, snippet: str) -> dict: + return { + "query": query, + "results": [{"url": url, "title": title, "snippet": snippet}], + } + + +def think_block(*lines: str) -> str: + body = " ".join(l.strip() for l in lines if l.strip()) + return f"\n{body}\n" + + +def conv_think_web_search(user_q: str, think_lines: tuple[str, ...], query: str, result: dict, answer: str): + return [ + {"role": "user", "content": SYS_JOINT + "\n\n" + user_q}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": think_block(*think_lines) + "\n"}, + { + "type": "tool_call", + "tool_name": "web_search", + "arguments": {"query": query, "top_k": 1}, + }, + { + "type": "tool_result", + "tool_name": "web_search", + "output": result, + "success": True, + }, + {"type": "text", "text": answer}, + ], + }, + ] + + +def conv_think_direct(user_q: str, think_lines: tuple[str, ...], answer: str): + return [ + {"role": "user", "content": SYS_JOINT + "\n\n" + user_q}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": think_block(*think_lines) + "\n" + answer}, + ], + }, + ] + + +def conv_think_calculator( + user_q: str, + think_lines: tuple[str, ...], + expression: str, + value: float | int, + answer: str, +): + return [ + {"role": "user", "content": SYS_JOINT + "\n\n" + user_q}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": think_block(*think_lines) + "\n"}, + { + "type": "tool_call", + "tool_name": "calculator", + "arguments": {"expression": expression}, + }, + { + "type": "tool_result", + "tool_name": "calculator", + "output": {"expression": expression, "value": float(value)}, + "success": True, + }, + {"type": "text", "text": answer}, + ], + }, + ] + + +def _template_rows() -> list[list]: + rows: list[list] = [] + + # --- (a) Web search / time-sensitive --- + rows.append( + conv_think_web_search( + "Who is the current President of the United States?", + ( + "This changes with elections; I should not rely on memory.", + "I will search for the latest information.", + ), + "current President of the United States 2026", + _web_result( + "current President of the United States 2026", + "https://www.whitehouse.gov/administration/", + "The Administration", + "The administration page lists the current officeholder.", + ), + "Based on the search, the current President of the United States is the person listed on the official White House administration page (verify on whitehouse.gov for the exact name).", + ) + ) + rows.append( + conv_think_web_search( + "What's the weather in Mumbai today?", + ("Weather is live data.", "I should look it up."), + "Mumbai weather today", + _web_result( + "Mumbai weather today", + "https://weather.example/mumbai", + "Mumbai forecast", + "Today: warm and humid with a chance of evening showers; highs near 32°C.", + ), + "Based on the search, Mumbai today is warm and humid with a chance of evening showers (check a live weather source for exact numbers).", + ) + ) + rows.append( + conv_think_web_search( + "Who is the CEO of OpenAI right now?", + ("Executive roles change.", "I'll search for the current CEO."), + "OpenAI CEO 2026", + _web_result( + "OpenAI CEO 2026", + "https://openai.com/", + "OpenAI", + "Leadership page names the current chief executive.", + ), + "Based on the search, see OpenAI's official leadership page for the current CEO name.", + ) + ) + rows.append( + conv_think_web_search( + "What was the closing value of the S&P 500 index most recently?", + ("Market numbers are time-sensitive.", "I need a web lookup."), + "S&P 500 latest close", + _web_result( + "S&P 500 latest close", + "https://www.example-finance.com/sp500", + "S&P 500", + "The index closed near 5,200 in the latest reported session (illustrative).", + ), + "Based on the search, the latest reported close was near 5,200 — confirm on a live market data site.", + ) + ) + + # --- (b) Direct answer after think (no tool): recipes / static how-to --- + rows.append( + conv_think_direct( + "How do I make samosa chaat at home?", + ( + "This is a cooking question; I can outline steps without a web search.", + "I'll keep reasoning brief and put the recipe after the thinking block.", + ), + "Crush or chop cooked samosas. Layer with chickpeas, yogurt, chutneys (mint and tamarind), diced onions, tomatoes, chaat masala, and sev. Serve immediately while the samosa is still crisp.", + ) + ) + rows.append( + conv_think_direct( + "What is the capital of France?", + ("This is stable geographic knowledge.", "No tool is needed."), + "The capital of France is Paris.", + ) + ) + rows.append( + conv_think_direct( + CREATOR_FACTS + " Who created you?", + ("The user asked about my creator.", "That is given in the context."), + "I was created by Manmohan Sharma.", + ) + ) + + # --- (c) Calculator after think --- + rows.append( + conv_think_calculator( + "Calculate an 18% tip on a $60 bill.", + ("I need an exact percentage.", "I'll use the calculator tool."), + "percent(60,18)", + 10.8, + "An 18% tip on $60 is $10.80.", + ) + ) + rows.append( + conv_think_calculator( + "What is the monthly EMI for a ₹500,000 loan at 8% annual interest for 240 months?", + ("EMI has a standard formula.", "I'll compute with the calculator."), + "emi(500000,8,240)", + 4182.198594391402, + "The monthly EMI is about 4182.20.", + ) + ) + rows.append( + conv_think_calculator( + "If revenue grew from 120 to 150, what is the percent change?", + ("Percent change should be exact.", "Using calculator."), + "percent_change(120,150)", + 25.0, + "The percent change from 120 to 150 is 25%.", + ) + ) + + return rows + + +def _vary_presidents_ceo_weather_sports(rng: random.Random) -> list[list]: + """Extra templated rows with light randomization.""" + rows: list[list] = [] + cities = ["Delhi", "Bengaluru", "London", "New York", "Tokyo"] + sports = [ + ("latest ICC cricket World Cup winner", "cricket world cup winner"), + ("Who won the most recent Super Bowl?", "Super Bowl winner latest"), + ] + for city in cities: + q = f"What's the weather in {city} today?" + rows.append( + conv_think_web_search( + q, + ("Weather is live.", "Search."), + f"{city} weather today", + _web_result( + f"{city} weather today", + f"https://weather.example/{city.lower()}", + f"{city} forecast", + f"Today in {city}: partly cloudy, mild breeze (illustrative).", + ), + f"Based on the search, today's {city} forecast looks partly cloudy — verify on a live weather service.", + ) + ) + for user_q, squery in sports: + rows.append( + conv_think_web_search( + user_q, + ("Sports results change.", "I'll search."), + squery, + _web_result( + squery, + "https://sports.example/", + "Sports", + "Official recap lists the winning team for the latest event.", + ), + "Based on the search, see the linked recap for the winning team (confirm on a trusted sports source).", + ) + ) + # Rotating math / finance + for bill, pct in [(40.0, 15), (85.5, 20), (120.0, 22)]: + expr = f"percent({bill},{pct})" + val = round(bill * pct / 100, 2) + rows.append( + conv_think_calculator( + f"Calculate a {pct}% tip on ${bill:g}.", + ("Exact tip amount.", "Calculator."), + expr, + val, + f"A {pct}% tip on ${bill:g} is ${val:.2f}.", + ) + ) + rng.shuffle(rows) + return rows + + +_OPENAI_URL = "https://api.openai.com/v1/chat/completions" + + +def _openai_variants(api_key: str, n: int, rng: random.Random) -> list[str]: + """Return n short factual questions for joint training.""" + prompt = ( + "Generate exactly %d diverse, concise user questions that need EITHER a web search, " + "a calculator, OR a direct answer (no tool). One sentence each, no numbering. " + "Mix: current events, weather, sports, finance math, and one cooking question. " + "Output one question per line, no other text." + ) % n + body = _compact_json( + { + "model": "gpt-4o-mini", + "temperature": 0.9, + "messages": [{"role": "user", "content": prompt}], + } + ) + req = urllib.request.Request( + _OPENAI_URL, + data=body.encode("utf-8"), + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + }, + method="POST", + ) + try: + with urllib.request.urlopen(req, timeout=120) as resp: + payload = json.loads(resp.read().decode("utf-8")) + except (urllib.error.URLError, TimeoutError, json.JSONDecodeError) as exc: + print(f"OpenAI request failed ({exc}); skipping LLM expansion.", file=sys.stderr) + return [] + + try: + text = payload["choices"][0]["message"]["content"] + except (KeyError, IndexError): + print("Unexpected OpenAI response shape; skipping.", file=sys.stderr) + return [] + + lines = [ln.strip() for ln in text.splitlines() if ln.strip()] + rng.shuffle(lines) + return lines[:n] + + +def _classify_and_build_question(q: str) -> list | None: + """Heuristic: map a free-form question to a joint pattern.""" + ql = q.lower() + # Tip: "18% tip on $60" / "tip on $60 at 18%" + tip_m = re.search( + r"(?P\d+(?:\.\d+)?)\s*%\s*tip.*?\$(?P\d+(?:\.\d+)?)|" + r"\$(?P\d+(?:\.\d+)?).*?(?P\d+(?:\.\d+)?)\s*%\s*tip", + q, + re.I, + ) + if tip_m: + pct = float(tip_m.group("pct") or tip_m.group("pct2")) + bill = float(tip_m.group("bill") or tip_m.group("bill2")) + expr = f"percent({bill},{pct})" + val = round(bill * pct / 100, 4) + return conv_think_calculator( + q, + ("I need an exact tip amount.", "Using the calculator."), + expr, + val, + f"Based on the calculation, a {pct:g}% tip on ${bill:g} is ${val:.2f}.", + ) + if any( + k in ql + for k in ( + "weather", + "president", + "ceo", + "who won", + "score", + "today", + "current", + "latest", + "price of", + "stock", + ) + ): + slug = re.sub(r"\W+", "-", ql)[:48] + return conv_think_web_search( + q, + ("This is time-sensitive or external.", "Searching."), + ql[:120], + _web_result(ql[:120], f"https://example.org/{slug}", "Source", "Snippet summarizes the retrieved page."), + "Based on the search, verify details on the cited source; the snippet is illustrative.", + ) + return conv_think_direct( + q, + ("I can answer directly.", "No tool needed."), + "Short answer: provide 2–4 sentences addressing the question without putting the final steps inside the thinking block.", + ) + + +def main(): + parser = argparse.ArgumentParser(description="Generate joint think+tool SFT JSONL") + parser.add_argument("--out", default="seed_data/joint_think_tool.jsonl", help="Output JSONL path") + parser.add_argument("--target", type=int, default=512, help="Approximate number of lines to write") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument( + "--use-openai", + action="store_true", + help="If OPENAI_API_KEY is set, add LLM-generated questions (best-effort).", + ) + parser.add_argument("--openai-extra", type=int, default=64, help="Max extra questions from OpenAI") + args = parser.parse_args() + rng = random.Random(args.seed) + + rows = _template_rows() + _vary_presidents_ceo_weather_sports(rng) + + if args.use_openai: + key = os.environ.get("OPENAI_API_KEY", "") + if key: + for q in _openai_variants(key, args.openai_extra, rng): + built = _classify_and_build_question(q) + if built: + rows.append(built) + else: + print("OPENAI_API_KEY not set; --use-openai ignored.", file=sys.stderr) + + # Repeat / shuffle to reach target size + out_lines: list[list] = [] + while len(out_lines) < args.target: + rng.shuffle(rows) + need = args.target - len(out_lines) + out_lines.extend(rows[: min(len(rows), need)]) + + os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True) + with open(args.out, "w", encoding="utf-8") as f: + for conv in out_lines[: args.target]: + f.write(json.dumps(conv, ensure_ascii=True) + "\n") + + print(f"Wrote {args.target} conversations to {args.out}") + + +if __name__ == "__main__": + main() diff --git a/tasks/customjson.py b/tasks/customjson.py index aeb1a3f7..f0204241 100644 --- a/tasks/customjson.py +++ b/tasks/customjson.py @@ -7,6 +7,31 @@ import os import json from tasks.common import Task + +def _validate_assistant_content(content, message_index): + """Assistant turns may be a plain string or a list of parts (tools / GSM8K-style).""" + if isinstance(content, str): + return + if not isinstance(content, list): + raise AssertionError(f"Message {message_index}: assistant content must be str or list, got {type(content)}") + for j, part in enumerate(content): + if not isinstance(part, dict): + raise AssertionError(f"Message {message_index} part {j}: expected dict, got {type(part)}") + ptype = part.get("type") + if ptype == "text": + assert "text" in part, f"Message {message_index} part {j}: text part missing 'text'" + elif ptype in ("tool_call", "python"): + assert "text" in part or part.get("tool_name"), ( + f"Message {message_index} part {j}: tool part needs 'text' or 'tool_name'" + ) + elif ptype in ("tool_result", "python_output"): + assert "text" in part or part.get("tool_name") is not None, ( + f"Message {message_index} part {j}: result part missing 'text' or 'tool_name'" + ) + else: + raise AssertionError(f"Message {message_index} part {j}: unknown type {ptype!r}") + + class CustomJSON(Task): """ Load conversations from a JSONL file. @@ -47,7 +72,10 @@ class CustomJSON(Task): assert "content" in message, f"Message {i} missing 'content' field" expected_role = "user" if i % 2 == 0 else "assistant" assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}" - assert isinstance(message["content"], str), f"Message {i} content must be a string" + if message["role"] == "user": + assert isinstance(message["content"], str), f"Message {i} user content must be a string" + else: + _validate_assistant_content(message["content"], i) self.conversations.append(messages) @@ -62,4 +90,3 @@ class CustomJSON(Task): "messages": messages, } return conversation -