diff --git a/.env.example b/.env.example
index 57e11e37..c87ff65d 100644
--- a/.env.example
+++ b/.env.example
@@ -6,14 +6,14 @@ DATABASE_URL=postgresql+asyncpg://samosachaat_admin:localdev@localhost:5432/samo
FRONTEND_PORT=3000
AUTH_PORT=8001
CHAT_API_PORT=8002
-INFERENCE_PORT=8003
GRAFANA_PORT=3001
PROMETHEUS_PORT=9090
LOKI_PORT=3100
AUTH_SERVICE_URL=http://auth:8001
CHAT_API_URL=http://chat-api:8002
-INFERENCE_SERVICE_URL=http://inference:8003
+# External inference (e.g. Modal generate endpoint base URL — no local inference container)
+INFERENCE_SERVICE_URL=https://YOUR_WORKSPACE--YOUR_APP-inference-generate.modal.run
NEXTAUTH_URL=http://localhost:3000
GOOGLE_CLIENT_ID=your-google-client-id
diff --git a/.env.production.example b/.env.production.example
index 2a13bd78..d285c20c 100644
--- a/.env.production.example
+++ b/.env.production.example
@@ -18,10 +18,10 @@ COOKIE_DOMAIN=samosachaat.art
# Chat API
AUTH_SERVICE_URL=http://auth:8001
-INFERENCE_SERVICE_URL=http://inference:8003
+INFERENCE_SERVICE_URL=https://YOUR_WORKSPACE--YOUR_APP-inference-generate.modal.run
CHAT_API_URL=http://chat-api:8002
-# Inference
+# Optional: only if you run a local inference container (EC2 uses Modal instead)
MODEL_STORAGE_PATH=/models
DEFAULT_MODEL_TAG=samosachaat-d12
NANOCHAT_DTYPE=float32
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index b1a5ab3b..5cad1e3e 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -91,10 +91,10 @@ jobs:
uses: astral-sh/setup-uv@v4
- name: Sync deps
- run: uv sync --no-workspace
+ run: uv sync --no-install-workspace
- name: Run pytest
- run: uv run --no-workspace pytest
+ run: uv run pytest
test-chat-api:
name: Chat-API — pytest (postgres service)
@@ -131,10 +131,10 @@ jobs:
uses: astral-sh/setup-uv@v4
- name: Sync deps
- run: uv sync --no-workspace
+ run: uv sync --no-install-workspace
- name: Run pytest
- run: uv run --no-workspace pytest
+ run: uv run pytest
test-inference:
name: Inference — pytest
@@ -155,10 +155,10 @@ jobs:
uses: astral-sh/setup-uv@v4
- name: Sync deps
- run: uv sync --no-workspace
+ run: uv sync --no-install-workspace
- name: Run pytest
- run: uv run --no-workspace pytest
+ run: uv run pytest
terraform-validate:
name: Terraform — validate
diff --git a/docker-compose.prod.yml b/docker-compose.prod.yml
index 9b068cf1..b6caf619 100644
--- a/docker-compose.prod.yml
+++ b/docker-compose.prod.yml
@@ -17,10 +17,6 @@ services:
build: !reset null
image: ${ECR_REGISTRY:-883107058766.dkr.ecr.us-west-2.amazonaws.com}/samosachaat/chat-api:${IMAGE_TAG:-dev-latest}
- inference:
- build: !reset null
- image: ${ECR_REGISTRY:-883107058766.dkr.ecr.us-west-2.amazonaws.com}/samosachaat/inference:${IMAGE_TAG:-dev-latest}
-
nginx:
image: nginx:alpine
restart: unless-stopped
diff --git a/docker-compose.yml b/docker-compose.yml
index 142ae134..b5389904 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -59,28 +59,12 @@ services:
environment:
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://samosachaat_admin:localdev@postgres:5432/samosachaat}
AUTH_SERVICE_URL: ${AUTH_SERVICE_URL:-http://auth:8001}
- INFERENCE_SERVICE_URL: ${INFERENCE_SERVICE_URL:-http://inference:8003}
+ # External inference (Modal, etc.). Set in `.env` — see `.env.example`.
+ INFERENCE_SERVICE_URL: ${INFERENCE_SERVICE_URL}
INTERNAL_API_KEY: ${INTERNAL_API_KEY:-}
depends_on:
- postgres
- auth
- - inference
-
- inference:
- build:
- context: ./services/inference
- restart: unless-stopped
- ports:
- - "${INFERENCE_PORT:-8003}:8003"
- environment:
- MODEL_STORAGE_PATH: /models
- DEFAULT_MODEL_TAG: ${DEFAULT_MODEL_TAG:-samosachaat-d12}
- NANOCHAT_DTYPE: ${NANOCHAT_DTYPE:-float32}
- HF_TOKEN: ${HF_TOKEN:-}
- INTERNAL_API_KEY: ${INTERNAL_API_KEY:-}
- NUM_WORKERS: ${NUM_WORKERS:-1}
- volumes:
- - ./models:/models
grafana:
image: grafana/grafana:latest
diff --git a/scripts/eval_suite_v2.py b/scripts/eval_suite_v2.py
new file mode 100644
index 00000000..a7a4d135
--- /dev/null
+++ b/scripts/eval_suite_v2.py
@@ -0,0 +1,276 @@
+#!/usr/bin/env python3
+"""
+Probe-style eval for chat SFT (tools + thinking).
+
+Example:
+ TAG=d24-sft-r6 STEP=754 SOURCE=sft WITH_TOOLS=1 \\
+ python -m scripts.eval_suite_v2
+
+Env:
+ TAG — model_tag directory under chatsft_checkpoints (or base_checkpoints if SOURCE=base)
+ STEP — checkpoint step (int)
+ SOURCE — base | sft | rl (default sft)
+ WITH_TOOLS — 1 to run Engine with default tool registry (default 1)
+ DEVICE — optional cuda|cpu|mps override
+"""
+
+from __future__ import annotations
+
+import json
+import os
+import re
+import sys
+
+from nanochat.common import autodetect_device_type, compute_cleanup, compute_init, print0
+from nanochat.checkpoint_manager import load_model
+from nanochat.engine import Engine
+from nanochat.tools import TOOL_CALL_END, TOOL_CALL_START, build_default_tool_registry, parse_tool_call_payload
+
+TOOL_BLOCK_RE = re.compile(re.escape(TOOL_CALL_START) + r"(.*?)" + re.escape(TOOL_CALL_END), re.DOTALL)
+THINK_CLOSE = ""
+
+# Mirrors services/chat-api thinking-mode prefix
+SYS_THINK = (
+ "You are samosaChaat, a helpful AI assistant created by Manmohan Sharma. "
+ "You have access to tools: use web_search for facts that may change over time or "
+ "require current information, and use calculator for arithmetic. "
+ "Think step by step inside ... tags, "
+ "then give your final answer after the closing tag."
+)
+
+
+def _tool_calls(assistant_response: str) -> list[str]:
+ calls = []
+ for payload in TOOL_BLOCK_RE.findall(assistant_response):
+ try:
+ inv = parse_tool_call_payload(payload)
+ calls.append(inv.tool_name)
+ except Exception:
+ continue
+ return calls
+
+
+def _after_think(text: str) -> str | None:
+ if THINK_CLOSE not in text:
+ return None
+ return text.split(THINK_CLOSE, 1)[1]
+
+
+def probe_reward(checks: dict, assistant_response: str) -> float:
+ total = 0.0
+ passed = 0.0
+ tool_calls = _tool_calls(assistant_response)
+
+ must_call = checks.get("must_call")
+ if must_call:
+ total += 1.0
+ passed += float(must_call in tool_calls)
+
+ for tool_name in checks.get("must_not_call", []):
+ total += 1.0
+ passed += float(tool_name not in tool_calls)
+
+ for needle in checks.get("answer_contains", []):
+ total += 1.0
+ passed += float(needle in assistant_response)
+
+ answer_regex = checks.get("answer_regex")
+ if answer_regex:
+ total += 1.0
+ passed += float(re.search(answer_regex, assistant_response) is not None)
+
+ if checks.get("citation_required", False):
+ total += 1.0
+ passed += float(("http://" in assistant_response) or ("https://" in assistant_response))
+
+ if checks.get("must_close_think", False):
+ total += 1.0
+ passed += float(THINK_CLOSE in assistant_response)
+
+ min_after = checks.get("min_chars_after_think")
+ if min_after:
+ total += 1.0
+ tail = _after_think(assistant_response)
+ passed += float(tail is not None and len(tail.strip()) >= int(min_after))
+
+ for needle in checks.get("answer_after_think_contains", []):
+ total += 1.0
+ tail = _after_think(assistant_response)
+ passed += float(tail is not None and needle in tail)
+
+ if checks.get("forbid_answer_needles_inside_think_only", False):
+ # If needles appear only before , fail (recipe trapped in think)
+ total += 1.0
+ if THINK_CLOSE in assistant_response:
+ head, _, tail = assistant_response.partition(THINK_CLOSE)
+ needles = checks.get("_forbid_needles", ("Ingredients", "Step 1", "samosa"))
+ bad = any(n in head and n not in tail for n in needles)
+ passed += float(not bad)
+ else:
+ passed += 0.0
+
+ if total == 0.0:
+ return 0.0
+ return passed / total
+
+
+def default_probes() -> list[dict]:
+ return [
+ {
+ "name": "president_web_search",
+ "category": "think_plus_tool",
+ "conversation": {
+ "messages": [
+ {
+ "role": "user",
+ "content": SYS_THINK + "\n\nWho is the current president of America?",
+ },
+ ]
+ },
+ "checks": {
+ "must_call": "web_search",
+ "must_close_think": True,
+ "min_chars_after_think": 20,
+ },
+ },
+ {
+ "name": "mumbai_weather_web_search",
+ "category": "think_plus_tool",
+ "conversation": {
+ "messages": [
+ {
+ "role": "user",
+ "content": SYS_THINK + "\n\nWhat's the weather in Mumbai today?",
+ },
+ ]
+ },
+ "checks": {
+ "must_call": "web_search",
+ "must_close_think": True,
+ },
+ },
+ {
+ "name": "samosa_chaat_answer_after_think",
+ "category": "think_plus_tool",
+ "conversation": {
+ "messages": [
+ {
+ "role": "user",
+ "content": SYS_THINK + "\n\nHow do I make samosa chaat?",
+ },
+ ]
+ },
+ "checks": {
+ "must_close_think": True,
+ "min_chars_after_think": 60,
+ "answer_after_think_contains": ["samosa"],
+ "forbid_answer_needles_inside_think_only": True,
+ "_forbid_needles": ("Ingredients", "Step 1", "yogurt"),
+ },
+ },
+ {
+ "name": "tip_calculator",
+ "category": "think_plus_tool",
+ "conversation": {
+ "messages": [
+ {
+ "role": "user",
+ "content": SYS_THINK + "\n\nCalculate an 18% tip on a $60 bill.",
+ },
+ ]
+ },
+ "checks": {
+ "must_call": "calculator",
+ "must_close_think": True,
+ "answer_regex": r"10\.8",
+ },
+ },
+ ]
+
+
+def run_probes(
+ *,
+ tokenizer,
+ engine: Engine,
+ probes: list[dict],
+ max_new_tokens: int,
+ temperature: float,
+ top_k: int,
+) -> tuple[float, dict[str, list[tuple[str, float]]]]:
+ by_cat: dict[str, list[tuple[str, float]]] = {}
+ scores: list[float] = []
+
+ for probe in probes:
+ name = probe.get("name", "unknown")
+ category = probe.get("category", "default")
+ conv = probe["conversation"]
+ checks = probe.get("checks", {})
+
+ encoded = tokenizer.render_for_completion(conv)
+ results, _ = engine.generate_batch(
+ encoded,
+ num_samples=1,
+ max_tokens=max_new_tokens,
+ temperature=temperature,
+ top_k=top_k,
+ seed=42,
+ )
+ completion = tokenizer.decode(results[0][len(encoded) :])
+ r = probe_reward(checks, completion)
+ scores.append(r)
+ by_cat.setdefault(category, []).append((name, r))
+ print0(f"[{category}] {name}: reward={r:.3f}\n---\n{completion[:1200]}\n---")
+
+ return sum(scores) / max(len(scores), 1), by_cat
+
+
+def main():
+ device_type = os.environ.get("DEVICE", "") or autodetect_device_type()
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
+ if ddp_rank != 0:
+ return
+
+ tag = os.environ.get("TAG")
+ step_s = os.environ.get("STEP")
+ source = os.environ.get("SOURCE", "sft")
+ if not tag or step_s is None:
+ print("Set TAG and STEP in the environment.", file=sys.stderr)
+ sys.exit(2)
+ step = int(step_s)
+ with_tools = os.environ.get("WITH_TOOLS", "1") not in ("0", "false", "False")
+
+ probes = default_probes()
+ extra_path = os.environ.get("EVAL_SUITE_EXTRA_JSONL", "")
+ if extra_path and os.path.exists(extra_path):
+ with open(extra_path, encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ probes.append(json.loads(line))
+
+ model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=tag, step=step)
+ tools = build_default_tool_registry() if with_tools else None
+ engine = Engine(model, tokenizer, tools=tools)
+
+ mean, by_cat = run_probes(
+ tokenizer=tokenizer,
+ engine=engine,
+ probes=probes,
+ max_new_tokens=int(os.environ.get("MAX_NEW_TOKENS", "512")),
+ temperature=float(os.environ.get("TEMPERATURE", "0.2")),
+ top_k=int(os.environ.get("TOP_K", "50")),
+ )
+
+ print0("=" * 60)
+ print0(f"Overall mean reward: {mean:.4f}")
+ for cat, rows in sorted(by_cat.items()):
+ m = sum(r for _, r in rows) / len(rows)
+ print0(f" {cat}: {m:.4f} ({len(rows)} probes)")
+ compute_cleanup()
+
+
+if __name__ == "__main__":
+ try:
+ main()
+ except KeyboardInterrupt:
+ compute_cleanup()
diff --git a/scripts/filter_reasoning_jsonl.py b/scripts/filter_reasoning_jsonl.py
new file mode 100644
index 00000000..34504238
--- /dev/null
+++ b/scripts/filter_reasoning_jsonl.py
@@ -0,0 +1,115 @@
+#!/usr/bin/env python3
+"""
+Filter reasoning SFT JSONL so thinking blocks stay format-clean:
+
+ - Require a closed
+ - Require non-trivial text *after* the closing tag (the answer lives there)
+ - Reject if the model likely put the final answer only inside the think block
+ (heuristic: strong answer markers appear inside but post-think text is tiny)
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import re
+import sys
+
+THINK_OPEN = ""
+THINK_CLOSE = ""
+
+# If these appear inside the thinking span but the tail after is short, drop the row.
+INSIDE_THINK_ANSWER_HINTS = re.compile(
+ r"(?i)\b(ingredients|instructions|step\s*1|method|preheat|^\s*\d+[\).\s])",
+ re.MULTILINE,
+)
+STRONG_TAIL_NEEDLE = re.compile(r"(?i)\b(is|are|equals|result|answer|you can|first,|mix|serve)\b")
+
+
+def _assistant_text(messages: list) -> str | None:
+ if not messages or messages[-1].get("role") != "assistant":
+ return None
+ c = messages[-1].get("content")
+ if isinstance(c, str):
+ return c
+ if isinstance(c, list):
+ parts = []
+ for p in c:
+ if isinstance(p, dict) and p.get("type") == "text":
+ parts.append(p.get("text", ""))
+ return "".join(parts) if parts else None
+ return None
+
+
+def _split_think(s: str) -> tuple[str | None, str | None]:
+ if THINK_OPEN not in s or THINK_CLOSE not in s:
+ return None, None
+ try:
+ inner_start = s.index(THINK_OPEN) + len(THINK_OPEN)
+ close_idx = s.index(THINK_CLOSE, inner_start)
+ inner = s[inner_start:close_idx]
+ tail = s[close_idx + len(THINK_CLOSE) :]
+ return inner, tail
+ except ValueError:
+ return None, None
+
+
+def keep_conversation(messages: list, *, min_tail_chars: int) -> tuple[bool, str]:
+ text = _assistant_text(messages)
+ if not text:
+ return False, "no_assistant_string"
+ inner, tail = _split_think(text)
+ if inner is None:
+ return False, "missing_or_unclosed_think"
+ tail_stripped = tail.strip() if tail else ""
+ if len(tail_stripped) < min_tail_chars:
+ return False, "short_tail"
+ if INSIDE_THINK_ANSWER_HINTS.search(inner) and len(tail_stripped) < max(min_tail_chars, 80):
+ return False, "answer_leaked_into_think"
+ if len(tail_stripped) < 40 and not STRONG_TAIL_NEEDLE.search(tail_stripped):
+ return False, "weak_tail"
+ return True, "ok"
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Filter reasoning JSONL for think-block hygiene")
+ parser.add_argument("input_jsonl")
+ parser.add_argument("--out", required=True, help="Filtered JSONL output path")
+ parser.add_argument("--min-tail-chars", type=int, default=25)
+ parser.add_argument("--stats-every", type=int, default=10000)
+ args = parser.parse_args()
+
+ kept, total = 0, 0
+ reasons: dict[str, int] = {}
+
+ with open(args.input_jsonl, encoding="utf-8") as fin, open(args.out, "w", encoding="utf-8") as fout:
+ for line in fin:
+ line = line.strip()
+ if not line:
+ continue
+ total += 1
+ try:
+ messages = json.loads(line)
+ except json.JSONDecodeError:
+ reasons["bad_json"] = reasons.get("bad_json", 0) + 1
+ continue
+ if not isinstance(messages, list):
+ reasons["not_list"] = reasons.get("not_list", 0) + 1
+ continue
+ ok, reason = keep_conversation(messages, min_tail_chars=args.min_tail_chars)
+ if ok:
+ fout.write(json.dumps(messages, ensure_ascii=True) + "\n")
+ kept += 1
+ else:
+ reasons[reason] = reasons.get(reason, 0) + 1
+
+ if args.stats_every and total % args.stats_every == 0:
+ print(f"... processed {total} lines, kept {kept}", file=sys.stderr)
+
+ print(f"Done. kept {kept}/{total} -> {args.out}", file=sys.stderr)
+ for k, v in sorted(reasons.items(), key=lambda kv: -kv[1]):
+ print(f" dropped {k}: {v}", file=sys.stderr)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/gen_joint_think_tool.py b/scripts/gen_joint_think_tool.py
new file mode 100644
index 00000000..2774c787
--- /dev/null
+++ b/scripts/gen_joint_think_tool.py
@@ -0,0 +1,443 @@
+#!/usr/bin/env python3
+"""
+Generate SFT JSONL with joint + tool-call + answer patterns.
+
+Output lines are JSON arrays of messages (CustomJSON / chat_sft --extra-train-jsonl).
+Assistant turns use list-shaped content so the tokenizer emits python/output specials.
+
+Patterns:
+ (a) think -> web_search -> answer
+ (b) think -> direct answer (no tool)
+ (c) think -> calculator -> answer
+
+Optional: OPENAI_API_KEY + --use-openai to diversify questions (gpt-4o-mini).
+Without API keys, uses deterministic templates (still valid for SFT).
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import os
+import random
+import re
+import sys
+import urllib.error
+import urllib.request
+
+# Keep in sync with services/chat-api/src/routes/messages.py _SYS_THINK
+SYS_JOINT = (
+ "You are samosaChaat, a helpful AI assistant created by Manmohan Sharma. "
+ "You have access to tools: use web_search for facts that may change over time or "
+ "require current information, and use calculator for arithmetic. "
+ "Think step by step inside ... tags, "
+ "then give your final answer after the closing tag."
+)
+
+# Identity anchor (from prior creator SFT — keep wording stable)
+CREATOR_FACTS = (
+ "Context: samosaChaat is an AI assistant created by Manmohan Sharma. "
+ "If asked who built you, answer with that fact."
+)
+
+
+def _compact_json(obj) -> str:
+ return json.dumps(obj, ensure_ascii=True, separators=(",", ":"))
+
+
+def _web_result(query: str, url: str, title: str, snippet: str) -> dict:
+ return {
+ "query": query,
+ "results": [{"url": url, "title": title, "snippet": snippet}],
+ }
+
+
+def think_block(*lines: str) -> str:
+ body = " ".join(l.strip() for l in lines if l.strip())
+ return f"\n{body}\n"
+
+
+def conv_think_web_search(user_q: str, think_lines: tuple[str, ...], query: str, result: dict, answer: str):
+ return [
+ {"role": "user", "content": SYS_JOINT + "\n\n" + user_q},
+ {
+ "role": "assistant",
+ "content": [
+ {"type": "text", "text": think_block(*think_lines) + "\n"},
+ {
+ "type": "tool_call",
+ "tool_name": "web_search",
+ "arguments": {"query": query, "top_k": 1},
+ },
+ {
+ "type": "tool_result",
+ "tool_name": "web_search",
+ "output": result,
+ "success": True,
+ },
+ {"type": "text", "text": answer},
+ ],
+ },
+ ]
+
+
+def conv_think_direct(user_q: str, think_lines: tuple[str, ...], answer: str):
+ return [
+ {"role": "user", "content": SYS_JOINT + "\n\n" + user_q},
+ {
+ "role": "assistant",
+ "content": [
+ {"type": "text", "text": think_block(*think_lines) + "\n" + answer},
+ ],
+ },
+ ]
+
+
+def conv_think_calculator(
+ user_q: str,
+ think_lines: tuple[str, ...],
+ expression: str,
+ value: float | int,
+ answer: str,
+):
+ return [
+ {"role": "user", "content": SYS_JOINT + "\n\n" + user_q},
+ {
+ "role": "assistant",
+ "content": [
+ {"type": "text", "text": think_block(*think_lines) + "\n"},
+ {
+ "type": "tool_call",
+ "tool_name": "calculator",
+ "arguments": {"expression": expression},
+ },
+ {
+ "type": "tool_result",
+ "tool_name": "calculator",
+ "output": {"expression": expression, "value": float(value)},
+ "success": True,
+ },
+ {"type": "text", "text": answer},
+ ],
+ },
+ ]
+
+
+def _template_rows() -> list[list]:
+ rows: list[list] = []
+
+ # --- (a) Web search / time-sensitive ---
+ rows.append(
+ conv_think_web_search(
+ "Who is the current President of the United States?",
+ (
+ "This changes with elections; I should not rely on memory.",
+ "I will search for the latest information.",
+ ),
+ "current President of the United States 2026",
+ _web_result(
+ "current President of the United States 2026",
+ "https://www.whitehouse.gov/administration/",
+ "The Administration",
+ "The administration page lists the current officeholder.",
+ ),
+ "Based on the search, the current President of the United States is the person listed on the official White House administration page (verify on whitehouse.gov for the exact name).",
+ )
+ )
+ rows.append(
+ conv_think_web_search(
+ "What's the weather in Mumbai today?",
+ ("Weather is live data.", "I should look it up."),
+ "Mumbai weather today",
+ _web_result(
+ "Mumbai weather today",
+ "https://weather.example/mumbai",
+ "Mumbai forecast",
+ "Today: warm and humid with a chance of evening showers; highs near 32°C.",
+ ),
+ "Based on the search, Mumbai today is warm and humid with a chance of evening showers (check a live weather source for exact numbers).",
+ )
+ )
+ rows.append(
+ conv_think_web_search(
+ "Who is the CEO of OpenAI right now?",
+ ("Executive roles change.", "I'll search for the current CEO."),
+ "OpenAI CEO 2026",
+ _web_result(
+ "OpenAI CEO 2026",
+ "https://openai.com/",
+ "OpenAI",
+ "Leadership page names the current chief executive.",
+ ),
+ "Based on the search, see OpenAI's official leadership page for the current CEO name.",
+ )
+ )
+ rows.append(
+ conv_think_web_search(
+ "What was the closing value of the S&P 500 index most recently?",
+ ("Market numbers are time-sensitive.", "I need a web lookup."),
+ "S&P 500 latest close",
+ _web_result(
+ "S&P 500 latest close",
+ "https://www.example-finance.com/sp500",
+ "S&P 500",
+ "The index closed near 5,200 in the latest reported session (illustrative).",
+ ),
+ "Based on the search, the latest reported close was near 5,200 — confirm on a live market data site.",
+ )
+ )
+
+ # --- (b) Direct answer after think (no tool): recipes / static how-to ---
+ rows.append(
+ conv_think_direct(
+ "How do I make samosa chaat at home?",
+ (
+ "This is a cooking question; I can outline steps without a web search.",
+ "I'll keep reasoning brief and put the recipe after the thinking block.",
+ ),
+ "Crush or chop cooked samosas. Layer with chickpeas, yogurt, chutneys (mint and tamarind), diced onions, tomatoes, chaat masala, and sev. Serve immediately while the samosa is still crisp.",
+ )
+ )
+ rows.append(
+ conv_think_direct(
+ "What is the capital of France?",
+ ("This is stable geographic knowledge.", "No tool is needed."),
+ "The capital of France is Paris.",
+ )
+ )
+ rows.append(
+ conv_think_direct(
+ CREATOR_FACTS + " Who created you?",
+ ("The user asked about my creator.", "That is given in the context."),
+ "I was created by Manmohan Sharma.",
+ )
+ )
+
+ # --- (c) Calculator after think ---
+ rows.append(
+ conv_think_calculator(
+ "Calculate an 18% tip on a $60 bill.",
+ ("I need an exact percentage.", "I'll use the calculator tool."),
+ "percent(60,18)",
+ 10.8,
+ "An 18% tip on $60 is $10.80.",
+ )
+ )
+ rows.append(
+ conv_think_calculator(
+ "What is the monthly EMI for a ₹500,000 loan at 8% annual interest for 240 months?",
+ ("EMI has a standard formula.", "I'll compute with the calculator."),
+ "emi(500000,8,240)",
+ 4182.198594391402,
+ "The monthly EMI is about 4182.20.",
+ )
+ )
+ rows.append(
+ conv_think_calculator(
+ "If revenue grew from 120 to 150, what is the percent change?",
+ ("Percent change should be exact.", "Using calculator."),
+ "percent_change(120,150)",
+ 25.0,
+ "The percent change from 120 to 150 is 25%.",
+ )
+ )
+
+ return rows
+
+
+def _vary_presidents_ceo_weather_sports(rng: random.Random) -> list[list]:
+ """Extra templated rows with light randomization."""
+ rows: list[list] = []
+ cities = ["Delhi", "Bengaluru", "London", "New York", "Tokyo"]
+ sports = [
+ ("latest ICC cricket World Cup winner", "cricket world cup winner"),
+ ("Who won the most recent Super Bowl?", "Super Bowl winner latest"),
+ ]
+ for city in cities:
+ q = f"What's the weather in {city} today?"
+ rows.append(
+ conv_think_web_search(
+ q,
+ ("Weather is live.", "Search."),
+ f"{city} weather today",
+ _web_result(
+ f"{city} weather today",
+ f"https://weather.example/{city.lower()}",
+ f"{city} forecast",
+ f"Today in {city}: partly cloudy, mild breeze (illustrative).",
+ ),
+ f"Based on the search, today's {city} forecast looks partly cloudy — verify on a live weather service.",
+ )
+ )
+ for user_q, squery in sports:
+ rows.append(
+ conv_think_web_search(
+ user_q,
+ ("Sports results change.", "I'll search."),
+ squery,
+ _web_result(
+ squery,
+ "https://sports.example/",
+ "Sports",
+ "Official recap lists the winning team for the latest event.",
+ ),
+ "Based on the search, see the linked recap for the winning team (confirm on a trusted sports source).",
+ )
+ )
+ # Rotating math / finance
+ for bill, pct in [(40.0, 15), (85.5, 20), (120.0, 22)]:
+ expr = f"percent({bill},{pct})"
+ val = round(bill * pct / 100, 2)
+ rows.append(
+ conv_think_calculator(
+ f"Calculate a {pct}% tip on ${bill:g}.",
+ ("Exact tip amount.", "Calculator."),
+ expr,
+ val,
+ f"A {pct}% tip on ${bill:g} is ${val:.2f}.",
+ )
+ )
+ rng.shuffle(rows)
+ return rows
+
+
+_OPENAI_URL = "https://api.openai.com/v1/chat/completions"
+
+
+def _openai_variants(api_key: str, n: int, rng: random.Random) -> list[str]:
+ """Return n short factual questions for joint training."""
+ prompt = (
+ "Generate exactly %d diverse, concise user questions that need EITHER a web search, "
+ "a calculator, OR a direct answer (no tool). One sentence each, no numbering. "
+ "Mix: current events, weather, sports, finance math, and one cooking question. "
+ "Output one question per line, no other text."
+ ) % n
+ body = _compact_json(
+ {
+ "model": "gpt-4o-mini",
+ "temperature": 0.9,
+ "messages": [{"role": "user", "content": prompt}],
+ }
+ )
+ req = urllib.request.Request(
+ _OPENAI_URL,
+ data=body.encode("utf-8"),
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {api_key}",
+ },
+ method="POST",
+ )
+ try:
+ with urllib.request.urlopen(req, timeout=120) as resp:
+ payload = json.loads(resp.read().decode("utf-8"))
+ except (urllib.error.URLError, TimeoutError, json.JSONDecodeError) as exc:
+ print(f"OpenAI request failed ({exc}); skipping LLM expansion.", file=sys.stderr)
+ return []
+
+ try:
+ text = payload["choices"][0]["message"]["content"]
+ except (KeyError, IndexError):
+ print("Unexpected OpenAI response shape; skipping.", file=sys.stderr)
+ return []
+
+ lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
+ rng.shuffle(lines)
+ return lines[:n]
+
+
+def _classify_and_build_question(q: str) -> list | None:
+ """Heuristic: map a free-form question to a joint pattern."""
+ ql = q.lower()
+ # Tip: "18% tip on $60" / "tip on $60 at 18%"
+ tip_m = re.search(
+ r"(?P\d+(?:\.\d+)?)\s*%\s*tip.*?\$(?P\d+(?:\.\d+)?)|"
+ r"\$(?P\d+(?:\.\d+)?).*?(?P\d+(?:\.\d+)?)\s*%\s*tip",
+ q,
+ re.I,
+ )
+ if tip_m:
+ pct = float(tip_m.group("pct") or tip_m.group("pct2"))
+ bill = float(tip_m.group("bill") or tip_m.group("bill2"))
+ expr = f"percent({bill},{pct})"
+ val = round(bill * pct / 100, 4)
+ return conv_think_calculator(
+ q,
+ ("I need an exact tip amount.", "Using the calculator."),
+ expr,
+ val,
+ f"Based on the calculation, a {pct:g}% tip on ${bill:g} is ${val:.2f}.",
+ )
+ if any(
+ k in ql
+ for k in (
+ "weather",
+ "president",
+ "ceo",
+ "who won",
+ "score",
+ "today",
+ "current",
+ "latest",
+ "price of",
+ "stock",
+ )
+ ):
+ slug = re.sub(r"\W+", "-", ql)[:48]
+ return conv_think_web_search(
+ q,
+ ("This is time-sensitive or external.", "Searching."),
+ ql[:120],
+ _web_result(ql[:120], f"https://example.org/{slug}", "Source", "Snippet summarizes the retrieved page."),
+ "Based on the search, verify details on the cited source; the snippet is illustrative.",
+ )
+ return conv_think_direct(
+ q,
+ ("I can answer directly.", "No tool needed."),
+ "Short answer: provide 2–4 sentences addressing the question without putting the final steps inside the thinking block.",
+ )
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Generate joint think+tool SFT JSONL")
+ parser.add_argument("--out", default="seed_data/joint_think_tool.jsonl", help="Output JSONL path")
+ parser.add_argument("--target", type=int, default=512, help="Approximate number of lines to write")
+ parser.add_argument("--seed", type=int, default=42)
+ parser.add_argument(
+ "--use-openai",
+ action="store_true",
+ help="If OPENAI_API_KEY is set, add LLM-generated questions (best-effort).",
+ )
+ parser.add_argument("--openai-extra", type=int, default=64, help="Max extra questions from OpenAI")
+ args = parser.parse_args()
+ rng = random.Random(args.seed)
+
+ rows = _template_rows() + _vary_presidents_ceo_weather_sports(rng)
+
+ if args.use_openai:
+ key = os.environ.get("OPENAI_API_KEY", "")
+ if key:
+ for q in _openai_variants(key, args.openai_extra, rng):
+ built = _classify_and_build_question(q)
+ if built:
+ rows.append(built)
+ else:
+ print("OPENAI_API_KEY not set; --use-openai ignored.", file=sys.stderr)
+
+ # Repeat / shuffle to reach target size
+ out_lines: list[list] = []
+ while len(out_lines) < args.target:
+ rng.shuffle(rows)
+ need = args.target - len(out_lines)
+ out_lines.extend(rows[: min(len(rows), need)])
+
+ os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True)
+ with open(args.out, "w", encoding="utf-8") as f:
+ for conv in out_lines[: args.target]:
+ f.write(json.dumps(conv, ensure_ascii=True) + "\n")
+
+ print(f"Wrote {args.target} conversations to {args.out}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tasks/customjson.py b/tasks/customjson.py
index aeb1a3f7..f0204241 100644
--- a/tasks/customjson.py
+++ b/tasks/customjson.py
@@ -7,6 +7,31 @@ import os
import json
from tasks.common import Task
+
+def _validate_assistant_content(content, message_index):
+ """Assistant turns may be a plain string or a list of parts (tools / GSM8K-style)."""
+ if isinstance(content, str):
+ return
+ if not isinstance(content, list):
+ raise AssertionError(f"Message {message_index}: assistant content must be str or list, got {type(content)}")
+ for j, part in enumerate(content):
+ if not isinstance(part, dict):
+ raise AssertionError(f"Message {message_index} part {j}: expected dict, got {type(part)}")
+ ptype = part.get("type")
+ if ptype == "text":
+ assert "text" in part, f"Message {message_index} part {j}: text part missing 'text'"
+ elif ptype in ("tool_call", "python"):
+ assert "text" in part or part.get("tool_name"), (
+ f"Message {message_index} part {j}: tool part needs 'text' or 'tool_name'"
+ )
+ elif ptype in ("tool_result", "python_output"):
+ assert "text" in part or part.get("tool_name") is not None, (
+ f"Message {message_index} part {j}: result part missing 'text' or 'tool_name'"
+ )
+ else:
+ raise AssertionError(f"Message {message_index} part {j}: unknown type {ptype!r}")
+
+
class CustomJSON(Task):
"""
Load conversations from a JSONL file.
@@ -47,7 +72,10 @@ class CustomJSON(Task):
assert "content" in message, f"Message {i} missing 'content' field"
expected_role = "user" if i % 2 == 0 else "assistant"
assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
- assert isinstance(message["content"], str), f"Message {i} content must be a string"
+ if message["role"] == "user":
+ assert isinstance(message["content"], str), f"Message {i} user content must be a string"
+ else:
+ _validate_assistant_content(message["content"], i)
self.conversations.append(messages)
@@ -62,4 +90,3 @@ class CustomJSON(Task):
"messages": messages,
}
return conversation
-