nanochat/synth-data-pipeline/2_validate_qa.py
2025-10-26 19:24:22 +00:00

144 lines
4.3 KiB
Python

"""
Stage 2: Validate Q&A pairs for quality and accuracy.
This script:
1. Loads Q&A pairs from Stage 1
2. Uses Gemini 2.5 Flash Lite to validate each pair
3. Filters out pairs that fail validation
4. Saves validated pairs to output
"""
import asyncio
import logfire
from dotenv import load_dotenv
from src.synth_data_pipeline.agents import qa_validator
from src.synth_data_pipeline.models import QAPair, QAValidation, ValidatedQAPair
from src.synth_data_pipeline.config import PATHS, STAGE_CONFIGS, FULL_PARAMS
from src.synth_data_pipeline.utils import (
load_jsonl,
save_jsonl,
process_with_concurrency,
)
# Load environment variables
load_dotenv()
# Configure logging
logfire.configure(scrubbing=False)
logfire.instrument_pydantic_ai()
# Get configuration for this stage
config = STAGE_CONFIGS["stage2_qa_validation"]
# Load validation agent definition
validation_prompt_template = qa_validator.get_prompt_template()
validation_agent = qa_validator.build_agent(config)
async def validate_qa_pair(qa_pair: QAPair) -> ValidatedQAPair:
"""
Validate a Q&A pair using Gemini.
Args:
qa_pair: QAPair object to validate
Returns:
ValidatedQAPair with validation result
"""
# Format the prompt
prompt_text = validation_prompt_template.prompt.format(
question=qa_pair.question,
answer=qa_pair.answer,
source_text=qa_pair.source_text,
context_before=qa_pair.context_before,
context_after=qa_pair.context_after,
)
# Validate using the agent
result = await validation_agent.run(prompt_text)
validation = result.output
return ValidatedQAPair(
qa_pair=qa_pair,
validation=validation
)
async def main(
input_file: str = None,
output_file: str = None,
max_concurrent: int = None
):
"""
Main function to validate Q&A pairs.
Args:
input_file: Path to input JSONL file (default from config)
output_file: Path to output JSONL file (default from config)
max_concurrent: Maximum concurrent API calls (default from config)
"""
# Use defaults from config if not specified
input_file = input_file or PATHS.stage1_qa_pairs
output_file = output_file or PATHS.stage2_qa_validated
max_concurrent = max_concurrent or config.max_concurrent
logfire.info("Starting Q&A validation", input_file=input_file)
# Load Q&A pairs
qa_pairs = load_jsonl(input_file, model_class=QAPair)
logfire.info(f"Loaded {len(qa_pairs)} Q&A pairs")
# Validate all pairs
with logfire.span("validate_qa_pairs"):
validated_pairs = await process_with_concurrency(
qa_pairs,
validate_qa_pair,
max_concurrent=max_concurrent,
desc="Validating Q&A pairs"
)
# Count passed/failed
passed = [vp for vp in validated_pairs if vp.validation.passed]
failed = [vp for vp in validated_pairs if not vp.validation.passed]
logfire.info(
f"Validation complete: {len(passed)} passed, {len(failed)} failed "
f"({100 * len(failed) / len(validated_pairs):.1f}% rejection rate)"
)
# Save all validated pairs (with validation results)
save_jsonl(validated_pairs, output_file)
# Also save just the passed Q&A pairs (without validation metadata) for next stage
passed_qa_pairs = [vp.qa_pair for vp in passed]
if output_file == PATHS.stage2_qa_validated:
passed_output = PATHS.stage2_qa_validated_passed
else:
passed_output = output_file.replace('.jsonl', '_passed.jsonl')
save_jsonl(passed_qa_pairs, passed_output)
logfire.info(f"Saved {len(passed_qa_pairs)} passed Q&A pairs to {passed_output}")
# Print sample
if validated_pairs:
print("\n" + "="*80)
print("VALIDATION SAMPLE:")
print("="*80)
sample = validated_pairs[0]
print(f"Question: {sample.qa_pair.question}")
print(f"Answer: {sample.qa_pair.answer[:100]}...")
print(f"\nValidation:")
print(f" uses_source_fact: {sample.validation.uses_source_fact}")
print(f" realistic_question: {sample.validation.realistic_question}")
print(f" sensible_answer: {sample.validation.sensible_answer}")
print(f" PASSED: {sample.validation.passed}")
print(f" Feedback: {sample.validation.feedback}")
print("="*80 + "\n")
if __name__ == "__main__":
asyncio.run(main())