mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
144 lines
4.3 KiB
Python
144 lines
4.3 KiB
Python
"""
|
|
Stage 2: Validate Q&A pairs for quality and accuracy.
|
|
|
|
This script:
|
|
1. Loads Q&A pairs from Stage 1
|
|
2. Uses Gemini 2.5 Flash Lite to validate each pair
|
|
3. Filters out pairs that fail validation
|
|
4. Saves validated pairs to output
|
|
"""
|
|
|
|
import asyncio
|
|
|
|
import logfire
|
|
from dotenv import load_dotenv
|
|
|
|
from src.synth_data_pipeline.agents import qa_validator
|
|
|
|
from src.synth_data_pipeline.models import QAPair, QAValidation, ValidatedQAPair
|
|
from src.synth_data_pipeline.config import PATHS, STAGE_CONFIGS, FULL_PARAMS
|
|
from src.synth_data_pipeline.utils import (
|
|
load_jsonl,
|
|
save_jsonl,
|
|
process_with_concurrency,
|
|
)
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# Configure logging
|
|
logfire.configure(scrubbing=False)
|
|
logfire.instrument_pydantic_ai()
|
|
|
|
# Get configuration for this stage
|
|
config = STAGE_CONFIGS["stage2_qa_validation"]
|
|
|
|
# Load validation agent definition
|
|
validation_prompt_template = qa_validator.get_prompt_template()
|
|
validation_agent = qa_validator.build_agent(config)
|
|
|
|
|
|
async def validate_qa_pair(qa_pair: QAPair) -> ValidatedQAPair:
|
|
"""
|
|
Validate a Q&A pair using Gemini.
|
|
|
|
Args:
|
|
qa_pair: QAPair object to validate
|
|
|
|
Returns:
|
|
ValidatedQAPair with validation result
|
|
"""
|
|
# Format the prompt
|
|
prompt_text = validation_prompt_template.prompt.format(
|
|
question=qa_pair.question,
|
|
answer=qa_pair.answer,
|
|
source_text=qa_pair.source_text,
|
|
context_before=qa_pair.context_before,
|
|
context_after=qa_pair.context_after,
|
|
)
|
|
|
|
# Validate using the agent
|
|
result = await validation_agent.run(prompt_text)
|
|
validation = result.output
|
|
|
|
return ValidatedQAPair(
|
|
qa_pair=qa_pair,
|
|
validation=validation
|
|
)
|
|
|
|
|
|
async def main(
|
|
input_file: str = None,
|
|
output_file: str = None,
|
|
max_concurrent: int = None
|
|
):
|
|
"""
|
|
Main function to validate Q&A pairs.
|
|
|
|
Args:
|
|
input_file: Path to input JSONL file (default from config)
|
|
output_file: Path to output JSONL file (default from config)
|
|
max_concurrent: Maximum concurrent API calls (default from config)
|
|
"""
|
|
# Use defaults from config if not specified
|
|
input_file = input_file or PATHS.stage1_qa_pairs
|
|
output_file = output_file or PATHS.stage2_qa_validated
|
|
max_concurrent = max_concurrent or config.max_concurrent
|
|
|
|
logfire.info("Starting Q&A validation", input_file=input_file)
|
|
|
|
# Load Q&A pairs
|
|
qa_pairs = load_jsonl(input_file, model_class=QAPair)
|
|
logfire.info(f"Loaded {len(qa_pairs)} Q&A pairs")
|
|
|
|
# Validate all pairs
|
|
with logfire.span("validate_qa_pairs"):
|
|
validated_pairs = await process_with_concurrency(
|
|
qa_pairs,
|
|
validate_qa_pair,
|
|
max_concurrent=max_concurrent,
|
|
desc="Validating Q&A pairs"
|
|
)
|
|
|
|
# Count passed/failed
|
|
passed = [vp for vp in validated_pairs if vp.validation.passed]
|
|
failed = [vp for vp in validated_pairs if not vp.validation.passed]
|
|
|
|
logfire.info(
|
|
f"Validation complete: {len(passed)} passed, {len(failed)} failed "
|
|
f"({100 * len(failed) / len(validated_pairs):.1f}% rejection rate)"
|
|
)
|
|
|
|
# Save all validated pairs (with validation results)
|
|
save_jsonl(validated_pairs, output_file)
|
|
|
|
# Also save just the passed Q&A pairs (without validation metadata) for next stage
|
|
passed_qa_pairs = [vp.qa_pair for vp in passed]
|
|
if output_file == PATHS.stage2_qa_validated:
|
|
passed_output = PATHS.stage2_qa_validated_passed
|
|
else:
|
|
passed_output = output_file.replace('.jsonl', '_passed.jsonl')
|
|
|
|
save_jsonl(passed_qa_pairs, passed_output)
|
|
logfire.info(f"Saved {len(passed_qa_pairs)} passed Q&A pairs to {passed_output}")
|
|
|
|
# Print sample
|
|
if validated_pairs:
|
|
print("\n" + "="*80)
|
|
print("VALIDATION SAMPLE:")
|
|
print("="*80)
|
|
sample = validated_pairs[0]
|
|
print(f"Question: {sample.qa_pair.question}")
|
|
print(f"Answer: {sample.qa_pair.answer[:100]}...")
|
|
print(f"\nValidation:")
|
|
print(f" uses_source_fact: {sample.validation.uses_source_fact}")
|
|
print(f" realistic_question: {sample.validation.realistic_question}")
|
|
print(f" sensible_answer: {sample.validation.sensible_answer}")
|
|
print(f" PASSED: {sample.validation.passed}")
|
|
print(f" Feedback: {sample.validation.feedback}")
|
|
print("="*80 + "\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|