nanochat/synth-data-pipeline/4_judge_and_save.py
2025-10-26 19:24:22 +00:00

260 lines
8.0 KiB
Python

"""
Stage 4: Judge conversations and save top candidates.
This script:
1. Loads raw conversations from output/conversations_raw.jsonl
2. Uses Gemini to judge quality of each conversation
3. Ranks by quality score
4. Saves all judged conversations and top 1000 in NanoChat format
"""
import asyncio
import logfire
from dotenv import load_dotenv
from src.synth_data_pipeline.agents import conversation_judge
from src.synth_data_pipeline.models import (
Conversation,
JudgedConversation,
JudgmentScore,
NanoChatConversation,
NanoChatMessage,
)
from src.synth_data_pipeline.config import (
PATHS,
STAGE_CONFIGS,
FULL_PARAMS,
)
from src.synth_data_pipeline.utils import (
load_jsonl,
save_jsonl,
process_with_concurrency,
print_statistics,
)
# Load environment variables
load_dotenv()
# Configure logging
logfire.configure(scrubbing=False)
logfire.instrument_pydantic_ai()
# Get configuration for this stage
config = STAGE_CONFIGS["stage4_judging"]
# Load judging agent definition
judge_prompt_template = conversation_judge.get_prompt_template()
judge_agent = conversation_judge.build_agent(config)
async def judge_conversation(conversation: Conversation) -> JudgedConversation:
"""
Judge the quality of a conversation.
Args:
conversation: Conversation object to judge
Returns:
JudgedConversation with quality scores
"""
# Format conversation for judging
conv_text = "\n\n".join([
f"{msg.role.upper()}: {msg.content}"
for msg in conversation.messages
])
# Format source Q&A pairs for fact-checking
source_qa_text = "\n\n".join([
f"Q: {qa.question}\nA: {qa.answer}"
for qa in conversation.source_qa_pairs
])
# Format the prompt
prompt_text = judge_prompt_template.prompt.format(
conversation=conv_text,
source_qa=source_qa_text if source_qa_text else "No source Q&A available"
)
# Judge using the agent
result = await judge_agent.run(prompt_text)
judgment = result.output
return JudgedConversation(
conversation=conversation,
judgment=judgment
)
def conversation_to_nanochat(conversation: Conversation) -> NanoChatConversation:
"""
Convert a Conversation to NanoChat format.
Args:
conversation: Conversation object
Returns:
NanoChatConversation (just messages array)
"""
messages = [
NanoChatMessage(role=msg.role, content=msg.content)
for msg in conversation.messages
]
return NanoChatConversation(messages=messages)
def save_top_conversations_nanochat(
judged_conversations: list[JudgedConversation],
output_path: str,
top_k: int = 1000,
min_score: float = None
):
"""
Save top K conversations in NanoChat format.
Args:
judged_conversations: List of judged conversations
output_path: Path to output JSONL file
top_k: Number of top conversations to save
min_score: Minimum score threshold (optional)
"""
# Filter to only passing conversations
passing_conversations = [
jc for jc in judged_conversations
if jc.judgment.overall_pass
]
# Sort by number of criteria passed (for ordering within passing conversations)
def count_passes(jc):
return sum([
jc.judgment.factually_accurate,
jc.judgment.natural_conversation,
jc.judgment.on_topic,
jc.judgment.adds_value
])
sorted_conversations = sorted(
passing_conversations,
key=count_passes,
reverse=True
)
# Note: min_score parameter is ignored with bool-only system
# Take top K
top_conversations = sorted_conversations[:top_k]
# Convert to NanoChat format and save
nanochat_convs = [
conversation_to_nanochat(jc.conversation)
for jc in top_conversations
]
save_jsonl(nanochat_convs, output_path)
# Log statistics
print(f"\nTop {len(top_conversations)} passing conversations selected")
print(f" All passed: factually_accurate AND natural AND on_topic AND adds_value")
def print_quality_statistics(judged_conversations: list[JudgedConversation]):
"""Print quality statistics for all judged conversations."""
if not judged_conversations:
return
total = len(judged_conversations)
passing = sum(1 for jc in judged_conversations if jc.judgment.overall_pass)
factual_pass = sum(1 for jc in judged_conversations if jc.judgment.factually_accurate)
natural_pass = sum(1 for jc in judged_conversations if jc.judgment.natural_conversation)
ontopic_pass = sum(1 for jc in judged_conversations if jc.judgment.on_topic)
value_pass = sum(1 for jc in judged_conversations if jc.judgment.adds_value)
print("\n" + "="*80)
print("QUALITY STATISTICS (All Conversations)")
print("="*80)
print(f"Total conversations judged: {total}")
print(f"Overall PASS (all 4 criteria): {passing} ({passing/total*100:.1f}%)")
print(f"\nIndividual criteria:")
print(f" Factually accurate : {factual_pass}/{total} ({factual_pass/total*100:.1f}%)")
print(f" Natural conversation: {natural_pass}/{total} ({natural_pass/total*100:.1f}%)")
print(f" On topic : {ontopic_pass}/{total} ({ontopic_pass/total*100:.1f}%)")
print(f" Adds value : {value_pass}/{total} ({value_pass/total*100:.1f}%)")
print("="*80 + "\n")
async def main(
input_file: str = None,
judged_output: str = None,
nanochat_output: str = None,
max_concurrent: int = None,
top_k: int = None,
min_score: float = None
):
"""
Main function to judge conversations and save top K.
Args:
input_file: Path to raw conversations JSONL file (default from config)
judged_output: Path to save all judged conversations (default from config)
nanochat_output: Path to save top K in NanoChat format (default from config)
max_concurrent: Maximum concurrent API calls (default from config)
top_k: Number of top conversations to save (default from config)
min_score: Minimum quality score threshold (default from config)
"""
# Use defaults from config if not specified
input_file = input_file or PATHS.stage3_conversations_raw
judged_output = judged_output or PATHS.stage4_conversations_judged
nanochat_output = nanochat_output or PATHS.stage7_conversations_final
max_concurrent = max_concurrent or config.max_concurrent
top_k = top_k or FULL_PARAMS.top_k
min_score = min_score or FULL_PARAMS.min_quality_score
logfire.info("Starting conversation judging", input_file=input_file)
# Load conversations
conversations = load_jsonl(input_file, model_class=Conversation)
logfire.info(f"Loaded {len(conversations)} conversations")
# Judge conversations
with logfire.span("judge_conversations"):
judged_conversations = await process_with_concurrency(
conversations,
judge_conversation,
max_concurrent=max_concurrent,
desc="Judging conversations"
)
logfire.info(f"Judged {len(judged_conversations)} conversations")
# Save all judged conversations
save_jsonl(judged_conversations, judged_output)
# Print statistics
print_quality_statistics(judged_conversations)
# Save top K in NanoChat format
save_top_conversations_nanochat(
judged_conversations,
nanochat_output,
top_k,
min_score
)
# Print sample of a passing conversation
passing_convs = [jc for jc in judged_conversations if jc.judgment.overall_pass]
if passing_convs:
print("\n" + "="*80)
print("SAMPLE PASSING CONVERSATION:")
print("="*80)
sample = passing_convs[0]
print(f"Overall: PASS (all 4 criteria met)")
print(f"Feedback: {sample.judgment.feedback}")
print("\nConversation:")
for msg in sample.conversation.messages:
print(f"\n{msg.role.upper()}: {msg.content[:200]}...")
print("="*80 + "\n")
if __name__ == "__main__":
asyncio.run(main())