mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-09 00:18:47 +00:00
feat(sft): add r7 think+tool prep scripts and compose cleanup
- allow assistant list-shaped content in CustomJSON for joint think+tool JSONL - add gen_joint_think_tool, filter_reasoning_jsonl, eval_suite_v2 (think_plus_tool probes) - fix CI: uv sync --no-install-workspace; uv run pytest - remove unused local inference service from compose; document Modal URL in env examples Made-with: Cursor
This commit is contained in:
parent
38cb7f7596
commit
f642cb2eb6
|
|
@ -6,14 +6,14 @@ DATABASE_URL=postgresql+asyncpg://samosachaat_admin:localdev@localhost:5432/samo
|
||||||
FRONTEND_PORT=3000
|
FRONTEND_PORT=3000
|
||||||
AUTH_PORT=8001
|
AUTH_PORT=8001
|
||||||
CHAT_API_PORT=8002
|
CHAT_API_PORT=8002
|
||||||
INFERENCE_PORT=8003
|
|
||||||
GRAFANA_PORT=3001
|
GRAFANA_PORT=3001
|
||||||
PROMETHEUS_PORT=9090
|
PROMETHEUS_PORT=9090
|
||||||
LOKI_PORT=3100
|
LOKI_PORT=3100
|
||||||
|
|
||||||
AUTH_SERVICE_URL=http://auth:8001
|
AUTH_SERVICE_URL=http://auth:8001
|
||||||
CHAT_API_URL=http://chat-api:8002
|
CHAT_API_URL=http://chat-api:8002
|
||||||
INFERENCE_SERVICE_URL=http://inference:8003
|
# External inference (e.g. Modal generate endpoint base URL — no local inference container)
|
||||||
|
INFERENCE_SERVICE_URL=https://YOUR_WORKSPACE--YOUR_APP-inference-generate.modal.run
|
||||||
NEXTAUTH_URL=http://localhost:3000
|
NEXTAUTH_URL=http://localhost:3000
|
||||||
|
|
||||||
GOOGLE_CLIENT_ID=your-google-client-id
|
GOOGLE_CLIENT_ID=your-google-client-id
|
||||||
|
|
|
||||||
|
|
@ -18,10 +18,10 @@ COOKIE_DOMAIN=samosachaat.art
|
||||||
|
|
||||||
# Chat API
|
# Chat API
|
||||||
AUTH_SERVICE_URL=http://auth:8001
|
AUTH_SERVICE_URL=http://auth:8001
|
||||||
INFERENCE_SERVICE_URL=http://inference:8003
|
INFERENCE_SERVICE_URL=https://YOUR_WORKSPACE--YOUR_APP-inference-generate.modal.run
|
||||||
CHAT_API_URL=http://chat-api:8002
|
CHAT_API_URL=http://chat-api:8002
|
||||||
|
|
||||||
# Inference
|
# Optional: only if you run a local inference container (EC2 uses Modal instead)
|
||||||
MODEL_STORAGE_PATH=/models
|
MODEL_STORAGE_PATH=/models
|
||||||
DEFAULT_MODEL_TAG=samosachaat-d12
|
DEFAULT_MODEL_TAG=samosachaat-d12
|
||||||
NANOCHAT_DTYPE=float32
|
NANOCHAT_DTYPE=float32
|
||||||
|
|
|
||||||
12
.github/workflows/ci.yml
vendored
12
.github/workflows/ci.yml
vendored
|
|
@ -91,10 +91,10 @@ jobs:
|
||||||
uses: astral-sh/setup-uv@v4
|
uses: astral-sh/setup-uv@v4
|
||||||
|
|
||||||
- name: Sync deps
|
- name: Sync deps
|
||||||
run: uv sync --no-workspace
|
run: uv sync --no-install-workspace
|
||||||
|
|
||||||
- name: Run pytest
|
- name: Run pytest
|
||||||
run: uv run --no-workspace pytest
|
run: uv run pytest
|
||||||
|
|
||||||
test-chat-api:
|
test-chat-api:
|
||||||
name: Chat-API — pytest (postgres service)
|
name: Chat-API — pytest (postgres service)
|
||||||
|
|
@ -131,10 +131,10 @@ jobs:
|
||||||
uses: astral-sh/setup-uv@v4
|
uses: astral-sh/setup-uv@v4
|
||||||
|
|
||||||
- name: Sync deps
|
- name: Sync deps
|
||||||
run: uv sync --no-workspace
|
run: uv sync --no-install-workspace
|
||||||
|
|
||||||
- name: Run pytest
|
- name: Run pytest
|
||||||
run: uv run --no-workspace pytest
|
run: uv run pytest
|
||||||
|
|
||||||
test-inference:
|
test-inference:
|
||||||
name: Inference — pytest
|
name: Inference — pytest
|
||||||
|
|
@ -155,10 +155,10 @@ jobs:
|
||||||
uses: astral-sh/setup-uv@v4
|
uses: astral-sh/setup-uv@v4
|
||||||
|
|
||||||
- name: Sync deps
|
- name: Sync deps
|
||||||
run: uv sync --no-workspace
|
run: uv sync --no-install-workspace
|
||||||
|
|
||||||
- name: Run pytest
|
- name: Run pytest
|
||||||
run: uv run --no-workspace pytest
|
run: uv run pytest
|
||||||
|
|
||||||
terraform-validate:
|
terraform-validate:
|
||||||
name: Terraform — validate
|
name: Terraform — validate
|
||||||
|
|
|
||||||
|
|
@ -17,10 +17,6 @@ services:
|
||||||
build: !reset null
|
build: !reset null
|
||||||
image: ${ECR_REGISTRY:-883107058766.dkr.ecr.us-west-2.amazonaws.com}/samosachaat/chat-api:${IMAGE_TAG:-dev-latest}
|
image: ${ECR_REGISTRY:-883107058766.dkr.ecr.us-west-2.amazonaws.com}/samosachaat/chat-api:${IMAGE_TAG:-dev-latest}
|
||||||
|
|
||||||
inference:
|
|
||||||
build: !reset null
|
|
||||||
image: ${ECR_REGISTRY:-883107058766.dkr.ecr.us-west-2.amazonaws.com}/samosachaat/inference:${IMAGE_TAG:-dev-latest}
|
|
||||||
|
|
||||||
nginx:
|
nginx:
|
||||||
image: nginx:alpine
|
image: nginx:alpine
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
|
||||||
|
|
@ -59,28 +59,12 @@ services:
|
||||||
environment:
|
environment:
|
||||||
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://samosachaat_admin:localdev@postgres:5432/samosachaat}
|
DATABASE_URL: ${DATABASE_URL:-postgresql+asyncpg://samosachaat_admin:localdev@postgres:5432/samosachaat}
|
||||||
AUTH_SERVICE_URL: ${AUTH_SERVICE_URL:-http://auth:8001}
|
AUTH_SERVICE_URL: ${AUTH_SERVICE_URL:-http://auth:8001}
|
||||||
INFERENCE_SERVICE_URL: ${INFERENCE_SERVICE_URL:-http://inference:8003}
|
# External inference (Modal, etc.). Set in `.env` — see `.env.example`.
|
||||||
|
INFERENCE_SERVICE_URL: ${INFERENCE_SERVICE_URL}
|
||||||
INTERNAL_API_KEY: ${INTERNAL_API_KEY:-}
|
INTERNAL_API_KEY: ${INTERNAL_API_KEY:-}
|
||||||
depends_on:
|
depends_on:
|
||||||
- postgres
|
- postgres
|
||||||
- auth
|
- auth
|
||||||
- inference
|
|
||||||
|
|
||||||
inference:
|
|
||||||
build:
|
|
||||||
context: ./services/inference
|
|
||||||
restart: unless-stopped
|
|
||||||
ports:
|
|
||||||
- "${INFERENCE_PORT:-8003}:8003"
|
|
||||||
environment:
|
|
||||||
MODEL_STORAGE_PATH: /models
|
|
||||||
DEFAULT_MODEL_TAG: ${DEFAULT_MODEL_TAG:-samosachaat-d12}
|
|
||||||
NANOCHAT_DTYPE: ${NANOCHAT_DTYPE:-float32}
|
|
||||||
HF_TOKEN: ${HF_TOKEN:-}
|
|
||||||
INTERNAL_API_KEY: ${INTERNAL_API_KEY:-}
|
|
||||||
NUM_WORKERS: ${NUM_WORKERS:-1}
|
|
||||||
volumes:
|
|
||||||
- ./models:/models
|
|
||||||
|
|
||||||
grafana:
|
grafana:
|
||||||
image: grafana/grafana:latest
|
image: grafana/grafana:latest
|
||||||
|
|
|
||||||
276
scripts/eval_suite_v2.py
Normal file
276
scripts/eval_suite_v2.py
Normal file
|
|
@ -0,0 +1,276 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Probe-style eval for chat SFT (tools + thinking).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
TAG=d24-sft-r6 STEP=754 SOURCE=sft WITH_TOOLS=1 \\
|
||||||
|
python -m scripts.eval_suite_v2
|
||||||
|
|
||||||
|
Env:
|
||||||
|
TAG — model_tag directory under chatsft_checkpoints (or base_checkpoints if SOURCE=base)
|
||||||
|
STEP — checkpoint step (int)
|
||||||
|
SOURCE — base | sft | rl (default sft)
|
||||||
|
WITH_TOOLS — 1 to run Engine with default tool registry (default 1)
|
||||||
|
DEVICE — optional cuda|cpu|mps override
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from nanochat.common import autodetect_device_type, compute_cleanup, compute_init, print0
|
||||||
|
from nanochat.checkpoint_manager import load_model
|
||||||
|
from nanochat.engine import Engine
|
||||||
|
from nanochat.tools import TOOL_CALL_END, TOOL_CALL_START, build_default_tool_registry, parse_tool_call_payload
|
||||||
|
|
||||||
|
TOOL_BLOCK_RE = re.compile(re.escape(TOOL_CALL_START) + r"(.*?)" + re.escape(TOOL_CALL_END), re.DOTALL)
|
||||||
|
THINK_CLOSE = "</think>"
|
||||||
|
|
||||||
|
# Mirrors services/chat-api thinking-mode prefix
|
||||||
|
SYS_THINK = (
|
||||||
|
"You are samosaChaat, a helpful AI assistant created by Manmohan Sharma. "
|
||||||
|
"You have access to tools: use web_search for facts that may change over time or "
|
||||||
|
"require current information, and use calculator for arithmetic. "
|
||||||
|
"Think step by step inside <think>...</think> tags, "
|
||||||
|
"then give your final answer after the closing tag."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_calls(assistant_response: str) -> list[str]:
|
||||||
|
calls = []
|
||||||
|
for payload in TOOL_BLOCK_RE.findall(assistant_response):
|
||||||
|
try:
|
||||||
|
inv = parse_tool_call_payload(payload)
|
||||||
|
calls.append(inv.tool_name)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return calls
|
||||||
|
|
||||||
|
|
||||||
|
def _after_think(text: str) -> str | None:
|
||||||
|
if THINK_CLOSE not in text:
|
||||||
|
return None
|
||||||
|
return text.split(THINK_CLOSE, 1)[1]
|
||||||
|
|
||||||
|
|
||||||
|
def probe_reward(checks: dict, assistant_response: str) -> float:
|
||||||
|
total = 0.0
|
||||||
|
passed = 0.0
|
||||||
|
tool_calls = _tool_calls(assistant_response)
|
||||||
|
|
||||||
|
must_call = checks.get("must_call")
|
||||||
|
if must_call:
|
||||||
|
total += 1.0
|
||||||
|
passed += float(must_call in tool_calls)
|
||||||
|
|
||||||
|
for tool_name in checks.get("must_not_call", []):
|
||||||
|
total += 1.0
|
||||||
|
passed += float(tool_name not in tool_calls)
|
||||||
|
|
||||||
|
for needle in checks.get("answer_contains", []):
|
||||||
|
total += 1.0
|
||||||
|
passed += float(needle in assistant_response)
|
||||||
|
|
||||||
|
answer_regex = checks.get("answer_regex")
|
||||||
|
if answer_regex:
|
||||||
|
total += 1.0
|
||||||
|
passed += float(re.search(answer_regex, assistant_response) is not None)
|
||||||
|
|
||||||
|
if checks.get("citation_required", False):
|
||||||
|
total += 1.0
|
||||||
|
passed += float(("http://" in assistant_response) or ("https://" in assistant_response))
|
||||||
|
|
||||||
|
if checks.get("must_close_think", False):
|
||||||
|
total += 1.0
|
||||||
|
passed += float(THINK_CLOSE in assistant_response)
|
||||||
|
|
||||||
|
min_after = checks.get("min_chars_after_think")
|
||||||
|
if min_after:
|
||||||
|
total += 1.0
|
||||||
|
tail = _after_think(assistant_response)
|
||||||
|
passed += float(tail is not None and len(tail.strip()) >= int(min_after))
|
||||||
|
|
||||||
|
for needle in checks.get("answer_after_think_contains", []):
|
||||||
|
total += 1.0
|
||||||
|
tail = _after_think(assistant_response)
|
||||||
|
passed += float(tail is not None and needle in tail)
|
||||||
|
|
||||||
|
if checks.get("forbid_answer_needles_inside_think_only", False):
|
||||||
|
# If needles appear only before </think>, fail (recipe trapped in think)
|
||||||
|
total += 1.0
|
||||||
|
if THINK_CLOSE in assistant_response:
|
||||||
|
head, _, tail = assistant_response.partition(THINK_CLOSE)
|
||||||
|
needles = checks.get("_forbid_needles", ("Ingredients", "Step 1", "samosa"))
|
||||||
|
bad = any(n in head and n not in tail for n in needles)
|
||||||
|
passed += float(not bad)
|
||||||
|
else:
|
||||||
|
passed += 0.0
|
||||||
|
|
||||||
|
if total == 0.0:
|
||||||
|
return 0.0
|
||||||
|
return passed / total
|
||||||
|
|
||||||
|
|
||||||
|
def default_probes() -> list[dict]:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "president_web_search",
|
||||||
|
"category": "think_plus_tool",
|
||||||
|
"conversation": {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": SYS_THINK + "\n\nWho is the current president of America?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"checks": {
|
||||||
|
"must_call": "web_search",
|
||||||
|
"must_close_think": True,
|
||||||
|
"min_chars_after_think": 20,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "mumbai_weather_web_search",
|
||||||
|
"category": "think_plus_tool",
|
||||||
|
"conversation": {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": SYS_THINK + "\n\nWhat's the weather in Mumbai today?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"checks": {
|
||||||
|
"must_call": "web_search",
|
||||||
|
"must_close_think": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "samosa_chaat_answer_after_think",
|
||||||
|
"category": "think_plus_tool",
|
||||||
|
"conversation": {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": SYS_THINK + "\n\nHow do I make samosa chaat?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"checks": {
|
||||||
|
"must_close_think": True,
|
||||||
|
"min_chars_after_think": 60,
|
||||||
|
"answer_after_think_contains": ["samosa"],
|
||||||
|
"forbid_answer_needles_inside_think_only": True,
|
||||||
|
"_forbid_needles": ("Ingredients", "Step 1", "yogurt"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "tip_calculator",
|
||||||
|
"category": "think_plus_tool",
|
||||||
|
"conversation": {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": SYS_THINK + "\n\nCalculate an 18% tip on a $60 bill.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"checks": {
|
||||||
|
"must_call": "calculator",
|
||||||
|
"must_close_think": True,
|
||||||
|
"answer_regex": r"10\.8",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def run_probes(
|
||||||
|
*,
|
||||||
|
tokenizer,
|
||||||
|
engine: Engine,
|
||||||
|
probes: list[dict],
|
||||||
|
max_new_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
top_k: int,
|
||||||
|
) -> tuple[float, dict[str, list[tuple[str, float]]]]:
|
||||||
|
by_cat: dict[str, list[tuple[str, float]]] = {}
|
||||||
|
scores: list[float] = []
|
||||||
|
|
||||||
|
for probe in probes:
|
||||||
|
name = probe.get("name", "unknown")
|
||||||
|
category = probe.get("category", "default")
|
||||||
|
conv = probe["conversation"]
|
||||||
|
checks = probe.get("checks", {})
|
||||||
|
|
||||||
|
encoded = tokenizer.render_for_completion(conv)
|
||||||
|
results, _ = engine.generate_batch(
|
||||||
|
encoded,
|
||||||
|
num_samples=1,
|
||||||
|
max_tokens=max_new_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
completion = tokenizer.decode(results[0][len(encoded) :])
|
||||||
|
r = probe_reward(checks, completion)
|
||||||
|
scores.append(r)
|
||||||
|
by_cat.setdefault(category, []).append((name, r))
|
||||||
|
print0(f"[{category}] {name}: reward={r:.3f}\n---\n{completion[:1200]}\n---")
|
||||||
|
|
||||||
|
return sum(scores) / max(len(scores), 1), by_cat
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
device_type = os.environ.get("DEVICE", "") or autodetect_device_type()
|
||||||
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
|
if ddp_rank != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
tag = os.environ.get("TAG")
|
||||||
|
step_s = os.environ.get("STEP")
|
||||||
|
source = os.environ.get("SOURCE", "sft")
|
||||||
|
if not tag or step_s is None:
|
||||||
|
print("Set TAG and STEP in the environment.", file=sys.stderr)
|
||||||
|
sys.exit(2)
|
||||||
|
step = int(step_s)
|
||||||
|
with_tools = os.environ.get("WITH_TOOLS", "1") not in ("0", "false", "False")
|
||||||
|
|
||||||
|
probes = default_probes()
|
||||||
|
extra_path = os.environ.get("EVAL_SUITE_EXTRA_JSONL", "")
|
||||||
|
if extra_path and os.path.exists(extra_path):
|
||||||
|
with open(extra_path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
probes.append(json.loads(line))
|
||||||
|
|
||||||
|
model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=tag, step=step)
|
||||||
|
tools = build_default_tool_registry() if with_tools else None
|
||||||
|
engine = Engine(model, tokenizer, tools=tools)
|
||||||
|
|
||||||
|
mean, by_cat = run_probes(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
engine=engine,
|
||||||
|
probes=probes,
|
||||||
|
max_new_tokens=int(os.environ.get("MAX_NEW_TOKENS", "512")),
|
||||||
|
temperature=float(os.environ.get("TEMPERATURE", "0.2")),
|
||||||
|
top_k=int(os.environ.get("TOP_K", "50")),
|
||||||
|
)
|
||||||
|
|
||||||
|
print0("=" * 60)
|
||||||
|
print0(f"Overall mean reward: {mean:.4f}")
|
||||||
|
for cat, rows in sorted(by_cat.items()):
|
||||||
|
m = sum(r for _, r in rows) / len(rows)
|
||||||
|
print0(f" {cat}: {m:.4f} ({len(rows)} probes)")
|
||||||
|
compute_cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
main()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
compute_cleanup()
|
||||||
115
scripts/filter_reasoning_jsonl.py
Normal file
115
scripts/filter_reasoning_jsonl.py
Normal file
|
|
@ -0,0 +1,115 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Filter reasoning SFT JSONL so thinking blocks stay format-clean:
|
||||||
|
|
||||||
|
- Require a closed </think>
|
||||||
|
- Require non-trivial text *after* the closing tag (the answer lives there)
|
||||||
|
- Reject if the model likely put the final answer only inside the think block
|
||||||
|
(heuristic: strong answer markers appear inside but post-think text is tiny)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
THINK_OPEN = "<think>"
|
||||||
|
THINK_CLOSE = "</think>"
|
||||||
|
|
||||||
|
# If these appear inside the thinking span but the tail after </think> is short, drop the row.
|
||||||
|
INSIDE_THINK_ANSWER_HINTS = re.compile(
|
||||||
|
r"(?i)\b(ingredients|instructions|step\s*1|method|preheat|^\s*\d+[\).\s])",
|
||||||
|
re.MULTILINE,
|
||||||
|
)
|
||||||
|
STRONG_TAIL_NEEDLE = re.compile(r"(?i)\b(is|are|equals|result|answer|you can|first,|mix|serve)\b")
|
||||||
|
|
||||||
|
|
||||||
|
def _assistant_text(messages: list) -> str | None:
|
||||||
|
if not messages or messages[-1].get("role") != "assistant":
|
||||||
|
return None
|
||||||
|
c = messages[-1].get("content")
|
||||||
|
if isinstance(c, str):
|
||||||
|
return c
|
||||||
|
if isinstance(c, list):
|
||||||
|
parts = []
|
||||||
|
for p in c:
|
||||||
|
if isinstance(p, dict) and p.get("type") == "text":
|
||||||
|
parts.append(p.get("text", ""))
|
||||||
|
return "".join(parts) if parts else None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _split_think(s: str) -> tuple[str | None, str | None]:
|
||||||
|
if THINK_OPEN not in s or THINK_CLOSE not in s:
|
||||||
|
return None, None
|
||||||
|
try:
|
||||||
|
inner_start = s.index(THINK_OPEN) + len(THINK_OPEN)
|
||||||
|
close_idx = s.index(THINK_CLOSE, inner_start)
|
||||||
|
inner = s[inner_start:close_idx]
|
||||||
|
tail = s[close_idx + len(THINK_CLOSE) :]
|
||||||
|
return inner, tail
|
||||||
|
except ValueError:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def keep_conversation(messages: list, *, min_tail_chars: int) -> tuple[bool, str]:
|
||||||
|
text = _assistant_text(messages)
|
||||||
|
if not text:
|
||||||
|
return False, "no_assistant_string"
|
||||||
|
inner, tail = _split_think(text)
|
||||||
|
if inner is None:
|
||||||
|
return False, "missing_or_unclosed_think"
|
||||||
|
tail_stripped = tail.strip() if tail else ""
|
||||||
|
if len(tail_stripped) < min_tail_chars:
|
||||||
|
return False, "short_tail"
|
||||||
|
if INSIDE_THINK_ANSWER_HINTS.search(inner) and len(tail_stripped) < max(min_tail_chars, 80):
|
||||||
|
return False, "answer_leaked_into_think"
|
||||||
|
if len(tail_stripped) < 40 and not STRONG_TAIL_NEEDLE.search(tail_stripped):
|
||||||
|
return False, "weak_tail"
|
||||||
|
return True, "ok"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Filter reasoning JSONL for think-block hygiene")
|
||||||
|
parser.add_argument("input_jsonl")
|
||||||
|
parser.add_argument("--out", required=True, help="Filtered JSONL output path")
|
||||||
|
parser.add_argument("--min-tail-chars", type=int, default=25)
|
||||||
|
parser.add_argument("--stats-every", type=int, default=10000)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
kept, total = 0, 0
|
||||||
|
reasons: dict[str, int] = {}
|
||||||
|
|
||||||
|
with open(args.input_jsonl, encoding="utf-8") as fin, open(args.out, "w", encoding="utf-8") as fout:
|
||||||
|
for line in fin:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
total += 1
|
||||||
|
try:
|
||||||
|
messages = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
reasons["bad_json"] = reasons.get("bad_json", 0) + 1
|
||||||
|
continue
|
||||||
|
if not isinstance(messages, list):
|
||||||
|
reasons["not_list"] = reasons.get("not_list", 0) + 1
|
||||||
|
continue
|
||||||
|
ok, reason = keep_conversation(messages, min_tail_chars=args.min_tail_chars)
|
||||||
|
if ok:
|
||||||
|
fout.write(json.dumps(messages, ensure_ascii=True) + "\n")
|
||||||
|
kept += 1
|
||||||
|
else:
|
||||||
|
reasons[reason] = reasons.get(reason, 0) + 1
|
||||||
|
|
||||||
|
if args.stats_every and total % args.stats_every == 0:
|
||||||
|
print(f"... processed {total} lines, kept {kept}", file=sys.stderr)
|
||||||
|
|
||||||
|
print(f"Done. kept {kept}/{total} -> {args.out}", file=sys.stderr)
|
||||||
|
for k, v in sorted(reasons.items(), key=lambda kv: -kv[1]):
|
||||||
|
print(f" dropped {k}: {v}", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
443
scripts/gen_joint_think_tool.py
Normal file
443
scripts/gen_joint_think_tool.py
Normal file
|
|
@ -0,0 +1,443 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Generate SFT JSONL with joint <think> + tool-call + answer patterns.
|
||||||
|
|
||||||
|
Output lines are JSON arrays of messages (CustomJSON / chat_sft --extra-train-jsonl).
|
||||||
|
Assistant turns use list-shaped content so the tokenizer emits python/output specials.
|
||||||
|
|
||||||
|
Patterns:
|
||||||
|
(a) think -> web_search -> answer
|
||||||
|
(b) think -> direct answer (no tool)
|
||||||
|
(c) think -> calculator -> answer
|
||||||
|
|
||||||
|
Optional: OPENAI_API_KEY + --use-openai to diversify questions (gpt-4o-mini).
|
||||||
|
Without API keys, uses deterministic templates (still valid for SFT).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import urllib.error
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
# Keep in sync with services/chat-api/src/routes/messages.py _SYS_THINK
|
||||||
|
SYS_JOINT = (
|
||||||
|
"You are samosaChaat, a helpful AI assistant created by Manmohan Sharma. "
|
||||||
|
"You have access to tools: use web_search for facts that may change over time or "
|
||||||
|
"require current information, and use calculator for arithmetic. "
|
||||||
|
"Think step by step inside <think>...</think> tags, "
|
||||||
|
"then give your final answer after the closing tag."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Identity anchor (from prior creator SFT — keep wording stable)
|
||||||
|
CREATOR_FACTS = (
|
||||||
|
"Context: samosaChaat is an AI assistant created by Manmohan Sharma. "
|
||||||
|
"If asked who built you, answer with that fact."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _compact_json(obj) -> str:
|
||||||
|
return json.dumps(obj, ensure_ascii=True, separators=(",", ":"))
|
||||||
|
|
||||||
|
|
||||||
|
def _web_result(query: str, url: str, title: str, snippet: str) -> dict:
|
||||||
|
return {
|
||||||
|
"query": query,
|
||||||
|
"results": [{"url": url, "title": title, "snippet": snippet}],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def think_block(*lines: str) -> str:
|
||||||
|
body = " ".join(l.strip() for l in lines if l.strip())
|
||||||
|
return f"<think>\n{body}\n</think>"
|
||||||
|
|
||||||
|
|
||||||
|
def conv_think_web_search(user_q: str, think_lines: tuple[str, ...], query: str, result: dict, answer: str):
|
||||||
|
return [
|
||||||
|
{"role": "user", "content": SYS_JOINT + "\n\n" + user_q},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": think_block(*think_lines) + "\n"},
|
||||||
|
{
|
||||||
|
"type": "tool_call",
|
||||||
|
"tool_name": "web_search",
|
||||||
|
"arguments": {"query": query, "top_k": 1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_name": "web_search",
|
||||||
|
"output": result,
|
||||||
|
"success": True,
|
||||||
|
},
|
||||||
|
{"type": "text", "text": answer},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def conv_think_direct(user_q: str, think_lines: tuple[str, ...], answer: str):
|
||||||
|
return [
|
||||||
|
{"role": "user", "content": SYS_JOINT + "\n\n" + user_q},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": think_block(*think_lines) + "\n" + answer},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def conv_think_calculator(
|
||||||
|
user_q: str,
|
||||||
|
think_lines: tuple[str, ...],
|
||||||
|
expression: str,
|
||||||
|
value: float | int,
|
||||||
|
answer: str,
|
||||||
|
):
|
||||||
|
return [
|
||||||
|
{"role": "user", "content": SYS_JOINT + "\n\n" + user_q},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": think_block(*think_lines) + "\n"},
|
||||||
|
{
|
||||||
|
"type": "tool_call",
|
||||||
|
"tool_name": "calculator",
|
||||||
|
"arguments": {"expression": expression},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_name": "calculator",
|
||||||
|
"output": {"expression": expression, "value": float(value)},
|
||||||
|
"success": True,
|
||||||
|
},
|
||||||
|
{"type": "text", "text": answer},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _template_rows() -> list[list]:
|
||||||
|
rows: list[list] = []
|
||||||
|
|
||||||
|
# --- (a) Web search / time-sensitive ---
|
||||||
|
rows.append(
|
||||||
|
conv_think_web_search(
|
||||||
|
"Who is the current President of the United States?",
|
||||||
|
(
|
||||||
|
"This changes with elections; I should not rely on memory.",
|
||||||
|
"I will search for the latest information.",
|
||||||
|
),
|
||||||
|
"current President of the United States 2026",
|
||||||
|
_web_result(
|
||||||
|
"current President of the United States 2026",
|
||||||
|
"https://www.whitehouse.gov/administration/",
|
||||||
|
"The Administration",
|
||||||
|
"The administration page lists the current officeholder.",
|
||||||
|
),
|
||||||
|
"Based on the search, the current President of the United States is the person listed on the official White House administration page (verify on whitehouse.gov for the exact name).",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rows.append(
|
||||||
|
conv_think_web_search(
|
||||||
|
"What's the weather in Mumbai today?",
|
||||||
|
("Weather is live data.", "I should look it up."),
|
||||||
|
"Mumbai weather today",
|
||||||
|
_web_result(
|
||||||
|
"Mumbai weather today",
|
||||||
|
"https://weather.example/mumbai",
|
||||||
|
"Mumbai forecast",
|
||||||
|
"Today: warm and humid with a chance of evening showers; highs near 32°C.",
|
||||||
|
),
|
||||||
|
"Based on the search, Mumbai today is warm and humid with a chance of evening showers (check a live weather source for exact numbers).",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rows.append(
|
||||||
|
conv_think_web_search(
|
||||||
|
"Who is the CEO of OpenAI right now?",
|
||||||
|
("Executive roles change.", "I'll search for the current CEO."),
|
||||||
|
"OpenAI CEO 2026",
|
||||||
|
_web_result(
|
||||||
|
"OpenAI CEO 2026",
|
||||||
|
"https://openai.com/",
|
||||||
|
"OpenAI",
|
||||||
|
"Leadership page names the current chief executive.",
|
||||||
|
),
|
||||||
|
"Based on the search, see OpenAI's official leadership page for the current CEO name.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rows.append(
|
||||||
|
conv_think_web_search(
|
||||||
|
"What was the closing value of the S&P 500 index most recently?",
|
||||||
|
("Market numbers are time-sensitive.", "I need a web lookup."),
|
||||||
|
"S&P 500 latest close",
|
||||||
|
_web_result(
|
||||||
|
"S&P 500 latest close",
|
||||||
|
"https://www.example-finance.com/sp500",
|
||||||
|
"S&P 500",
|
||||||
|
"The index closed near 5,200 in the latest reported session (illustrative).",
|
||||||
|
),
|
||||||
|
"Based on the search, the latest reported close was near 5,200 — confirm on a live market data site.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- (b) Direct answer after think (no tool): recipes / static how-to ---
|
||||||
|
rows.append(
|
||||||
|
conv_think_direct(
|
||||||
|
"How do I make samosa chaat at home?",
|
||||||
|
(
|
||||||
|
"This is a cooking question; I can outline steps without a web search.",
|
||||||
|
"I'll keep reasoning brief and put the recipe after the thinking block.",
|
||||||
|
),
|
||||||
|
"Crush or chop cooked samosas. Layer with chickpeas, yogurt, chutneys (mint and tamarind), diced onions, tomatoes, chaat masala, and sev. Serve immediately while the samosa is still crisp.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rows.append(
|
||||||
|
conv_think_direct(
|
||||||
|
"What is the capital of France?",
|
||||||
|
("This is stable geographic knowledge.", "No tool is needed."),
|
||||||
|
"The capital of France is Paris.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rows.append(
|
||||||
|
conv_think_direct(
|
||||||
|
CREATOR_FACTS + " Who created you?",
|
||||||
|
("The user asked about my creator.", "That is given in the context."),
|
||||||
|
"I was created by Manmohan Sharma.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- (c) Calculator after think ---
|
||||||
|
rows.append(
|
||||||
|
conv_think_calculator(
|
||||||
|
"Calculate an 18% tip on a $60 bill.",
|
||||||
|
("I need an exact percentage.", "I'll use the calculator tool."),
|
||||||
|
"percent(60,18)",
|
||||||
|
10.8,
|
||||||
|
"An 18% tip on $60 is $10.80.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rows.append(
|
||||||
|
conv_think_calculator(
|
||||||
|
"What is the monthly EMI for a ₹500,000 loan at 8% annual interest for 240 months?",
|
||||||
|
("EMI has a standard formula.", "I'll compute with the calculator."),
|
||||||
|
"emi(500000,8,240)",
|
||||||
|
4182.198594391402,
|
||||||
|
"The monthly EMI is about 4182.20.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rows.append(
|
||||||
|
conv_think_calculator(
|
||||||
|
"If revenue grew from 120 to 150, what is the percent change?",
|
||||||
|
("Percent change should be exact.", "Using calculator."),
|
||||||
|
"percent_change(120,150)",
|
||||||
|
25.0,
|
||||||
|
"The percent change from 120 to 150 is 25%.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def _vary_presidents_ceo_weather_sports(rng: random.Random) -> list[list]:
|
||||||
|
"""Extra templated rows with light randomization."""
|
||||||
|
rows: list[list] = []
|
||||||
|
cities = ["Delhi", "Bengaluru", "London", "New York", "Tokyo"]
|
||||||
|
sports = [
|
||||||
|
("latest ICC cricket World Cup winner", "cricket world cup winner"),
|
||||||
|
("Who won the most recent Super Bowl?", "Super Bowl winner latest"),
|
||||||
|
]
|
||||||
|
for city in cities:
|
||||||
|
q = f"What's the weather in {city} today?"
|
||||||
|
rows.append(
|
||||||
|
conv_think_web_search(
|
||||||
|
q,
|
||||||
|
("Weather is live.", "Search."),
|
||||||
|
f"{city} weather today",
|
||||||
|
_web_result(
|
||||||
|
f"{city} weather today",
|
||||||
|
f"https://weather.example/{city.lower()}",
|
||||||
|
f"{city} forecast",
|
||||||
|
f"Today in {city}: partly cloudy, mild breeze (illustrative).",
|
||||||
|
),
|
||||||
|
f"Based on the search, today's {city} forecast looks partly cloudy — verify on a live weather service.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for user_q, squery in sports:
|
||||||
|
rows.append(
|
||||||
|
conv_think_web_search(
|
||||||
|
user_q,
|
||||||
|
("Sports results change.", "I'll search."),
|
||||||
|
squery,
|
||||||
|
_web_result(
|
||||||
|
squery,
|
||||||
|
"https://sports.example/",
|
||||||
|
"Sports",
|
||||||
|
"Official recap lists the winning team for the latest event.",
|
||||||
|
),
|
||||||
|
"Based on the search, see the linked recap for the winning team (confirm on a trusted sports source).",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Rotating math / finance
|
||||||
|
for bill, pct in [(40.0, 15), (85.5, 20), (120.0, 22)]:
|
||||||
|
expr = f"percent({bill},{pct})"
|
||||||
|
val = round(bill * pct / 100, 2)
|
||||||
|
rows.append(
|
||||||
|
conv_think_calculator(
|
||||||
|
f"Calculate a {pct}% tip on ${bill:g}.",
|
||||||
|
("Exact tip amount.", "Calculator."),
|
||||||
|
expr,
|
||||||
|
val,
|
||||||
|
f"A {pct}% tip on ${bill:g} is ${val:.2f}.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rng.shuffle(rows)
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
_OPENAI_URL = "https://api.openai.com/v1/chat/completions"
|
||||||
|
|
||||||
|
|
||||||
|
def _openai_variants(api_key: str, n: int, rng: random.Random) -> list[str]:
|
||||||
|
"""Return n short factual questions for joint training."""
|
||||||
|
prompt = (
|
||||||
|
"Generate exactly %d diverse, concise user questions that need EITHER a web search, "
|
||||||
|
"a calculator, OR a direct answer (no tool). One sentence each, no numbering. "
|
||||||
|
"Mix: current events, weather, sports, finance math, and one cooking question. "
|
||||||
|
"Output one question per line, no other text."
|
||||||
|
) % n
|
||||||
|
body = _compact_json(
|
||||||
|
{
|
||||||
|
"model": "gpt-4o-mini",
|
||||||
|
"temperature": 0.9,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
req = urllib.request.Request(
|
||||||
|
_OPENAI_URL,
|
||||||
|
data=body.encode("utf-8"),
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
},
|
||||||
|
method="POST",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(req, timeout=120) as resp:
|
||||||
|
payload = json.loads(resp.read().decode("utf-8"))
|
||||||
|
except (urllib.error.URLError, TimeoutError, json.JSONDecodeError) as exc:
|
||||||
|
print(f"OpenAI request failed ({exc}); skipping LLM expansion.", file=sys.stderr)
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
text = payload["choices"][0]["message"]["content"]
|
||||||
|
except (KeyError, IndexError):
|
||||||
|
print("Unexpected OpenAI response shape; skipping.", file=sys.stderr)
|
||||||
|
return []
|
||||||
|
|
||||||
|
lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
|
||||||
|
rng.shuffle(lines)
|
||||||
|
return lines[:n]
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_and_build_question(q: str) -> list | None:
|
||||||
|
"""Heuristic: map a free-form question to a joint pattern."""
|
||||||
|
ql = q.lower()
|
||||||
|
# Tip: "18% tip on $60" / "tip on $60 at 18%"
|
||||||
|
tip_m = re.search(
|
||||||
|
r"(?P<pct>\d+(?:\.\d+)?)\s*%\s*tip.*?\$(?P<bill>\d+(?:\.\d+)?)|"
|
||||||
|
r"\$(?P<bill2>\d+(?:\.\d+)?).*?(?P<pct2>\d+(?:\.\d+)?)\s*%\s*tip",
|
||||||
|
q,
|
||||||
|
re.I,
|
||||||
|
)
|
||||||
|
if tip_m:
|
||||||
|
pct = float(tip_m.group("pct") or tip_m.group("pct2"))
|
||||||
|
bill = float(tip_m.group("bill") or tip_m.group("bill2"))
|
||||||
|
expr = f"percent({bill},{pct})"
|
||||||
|
val = round(bill * pct / 100, 4)
|
||||||
|
return conv_think_calculator(
|
||||||
|
q,
|
||||||
|
("I need an exact tip amount.", "Using the calculator."),
|
||||||
|
expr,
|
||||||
|
val,
|
||||||
|
f"Based on the calculation, a {pct:g}% tip on ${bill:g} is ${val:.2f}.",
|
||||||
|
)
|
||||||
|
if any(
|
||||||
|
k in ql
|
||||||
|
for k in (
|
||||||
|
"weather",
|
||||||
|
"president",
|
||||||
|
"ceo",
|
||||||
|
"who won",
|
||||||
|
"score",
|
||||||
|
"today",
|
||||||
|
"current",
|
||||||
|
"latest",
|
||||||
|
"price of",
|
||||||
|
"stock",
|
||||||
|
)
|
||||||
|
):
|
||||||
|
slug = re.sub(r"\W+", "-", ql)[:48]
|
||||||
|
return conv_think_web_search(
|
||||||
|
q,
|
||||||
|
("This is time-sensitive or external.", "Searching."),
|
||||||
|
ql[:120],
|
||||||
|
_web_result(ql[:120], f"https://example.org/{slug}", "Source", "Snippet summarizes the retrieved page."),
|
||||||
|
"Based on the search, verify details on the cited source; the snippet is illustrative.",
|
||||||
|
)
|
||||||
|
return conv_think_direct(
|
||||||
|
q,
|
||||||
|
("I can answer directly.", "No tool needed."),
|
||||||
|
"Short answer: provide 2–4 sentences addressing the question without putting the final steps inside the thinking block.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Generate joint think+tool SFT JSONL")
|
||||||
|
parser.add_argument("--out", default="seed_data/joint_think_tool.jsonl", help="Output JSONL path")
|
||||||
|
parser.add_argument("--target", type=int, default=512, help="Approximate number of lines to write")
|
||||||
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-openai",
|
||||||
|
action="store_true",
|
||||||
|
help="If OPENAI_API_KEY is set, add LLM-generated questions (best-effort).",
|
||||||
|
)
|
||||||
|
parser.add_argument("--openai-extra", type=int, default=64, help="Max extra questions from OpenAI")
|
||||||
|
args = parser.parse_args()
|
||||||
|
rng = random.Random(args.seed)
|
||||||
|
|
||||||
|
rows = _template_rows() + _vary_presidents_ceo_weather_sports(rng)
|
||||||
|
|
||||||
|
if args.use_openai:
|
||||||
|
key = os.environ.get("OPENAI_API_KEY", "")
|
||||||
|
if key:
|
||||||
|
for q in _openai_variants(key, args.openai_extra, rng):
|
||||||
|
built = _classify_and_build_question(q)
|
||||||
|
if built:
|
||||||
|
rows.append(built)
|
||||||
|
else:
|
||||||
|
print("OPENAI_API_KEY not set; --use-openai ignored.", file=sys.stderr)
|
||||||
|
|
||||||
|
# Repeat / shuffle to reach target size
|
||||||
|
out_lines: list[list] = []
|
||||||
|
while len(out_lines) < args.target:
|
||||||
|
rng.shuffle(rows)
|
||||||
|
need = args.target - len(out_lines)
|
||||||
|
out_lines.extend(rows[: min(len(rows), need)])
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True)
|
||||||
|
with open(args.out, "w", encoding="utf-8") as f:
|
||||||
|
for conv in out_lines[: args.target]:
|
||||||
|
f.write(json.dumps(conv, ensure_ascii=True) + "\n")
|
||||||
|
|
||||||
|
print(f"Wrote {args.target} conversations to {args.out}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -7,6 +7,31 @@ import os
|
||||||
import json
|
import json
|
||||||
from tasks.common import Task
|
from tasks.common import Task
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_assistant_content(content, message_index):
|
||||||
|
"""Assistant turns may be a plain string or a list of parts (tools / GSM8K-style)."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return
|
||||||
|
if not isinstance(content, list):
|
||||||
|
raise AssertionError(f"Message {message_index}: assistant content must be str or list, got {type(content)}")
|
||||||
|
for j, part in enumerate(content):
|
||||||
|
if not isinstance(part, dict):
|
||||||
|
raise AssertionError(f"Message {message_index} part {j}: expected dict, got {type(part)}")
|
||||||
|
ptype = part.get("type")
|
||||||
|
if ptype == "text":
|
||||||
|
assert "text" in part, f"Message {message_index} part {j}: text part missing 'text'"
|
||||||
|
elif ptype in ("tool_call", "python"):
|
||||||
|
assert "text" in part or part.get("tool_name"), (
|
||||||
|
f"Message {message_index} part {j}: tool part needs 'text' or 'tool_name'"
|
||||||
|
)
|
||||||
|
elif ptype in ("tool_result", "python_output"):
|
||||||
|
assert "text" in part or part.get("tool_name") is not None, (
|
||||||
|
f"Message {message_index} part {j}: result part missing 'text' or 'tool_name'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise AssertionError(f"Message {message_index} part {j}: unknown type {ptype!r}")
|
||||||
|
|
||||||
|
|
||||||
class CustomJSON(Task):
|
class CustomJSON(Task):
|
||||||
"""
|
"""
|
||||||
Load conversations from a JSONL file.
|
Load conversations from a JSONL file.
|
||||||
|
|
@ -47,7 +72,10 @@ class CustomJSON(Task):
|
||||||
assert "content" in message, f"Message {i} missing 'content' field"
|
assert "content" in message, f"Message {i} missing 'content' field"
|
||||||
expected_role = "user" if i % 2 == 0 else "assistant"
|
expected_role = "user" if i % 2 == 0 else "assistant"
|
||||||
assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
|
assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||||
assert isinstance(message["content"], str), f"Message {i} content must be a string"
|
if message["role"] == "user":
|
||||||
|
assert isinstance(message["content"], str), f"Message {i} user content must be a string"
|
||||||
|
else:
|
||||||
|
_validate_assistant_content(message["content"], i)
|
||||||
|
|
||||||
self.conversations.append(messages)
|
self.conversations.append(messages)
|
||||||
|
|
||||||
|
|
@ -62,4 +90,3 @@ class CustomJSON(Task):
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
}
|
}
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user