mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
300 lines
10 KiB
Python
300 lines
10 KiB
Python
"""
|
|
Trial run script to validate the pipeline with a small dataset.
|
|
|
|
This script:
|
|
1. Runs all 3 stages with limited data
|
|
2. Validates prompts and logic
|
|
3. Prints samples for manual inspection
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import logfire
|
|
|
|
# Import the main functions from each stage
|
|
from src.synth_data_pipeline.models import QAPair, Conversation, JudgedConversation
|
|
|
|
# We'll use the actual script functions
|
|
import sys
|
|
sys.path.append(str(Path(__file__).parent))
|
|
|
|
|
|
TRIAL_QA_CHUNK_LIMIT = 10
|
|
TRIAL_FINAL_CONVERSATIONS = "output/trial_conversations_final.jsonl"
|
|
|
|
|
|
async def trial_extract_qa():
|
|
"""Trial run of Q&A extraction with 10 chunks."""
|
|
print("\n" + "="*80)
|
|
print("STAGE 1: Q&A EXTRACTION (Trial with 10 chunks)")
|
|
print("="*80)
|
|
|
|
from importlib import import_module
|
|
stage1 = import_module('1_extract_qa')
|
|
|
|
# Run with trial parameters (uses STAGE_CONFIGS for max_concurrent)
|
|
await stage1.main(
|
|
input_file="data/swap_facts.md",
|
|
output_file="output/trial_qa_pairs.jsonl",
|
|
limit=TRIAL_QA_CHUNK_LIMIT,
|
|
)
|
|
|
|
# Load and show results
|
|
qa_pairs = []
|
|
with open("output/trial_qa_pairs.jsonl", 'r') as f:
|
|
for line in f:
|
|
qa_pairs.append(QAPair.model_validate_json(line))
|
|
|
|
print(f"\n✓ Generated {len(qa_pairs)} Q&A pairs")
|
|
|
|
# Show first 3
|
|
for i, qa in enumerate(qa_pairs[:3], 1):
|
|
print(f"\n--- Q&A Pair {i} ---")
|
|
print(f"Q: {qa.question}")
|
|
print(f"A: {qa.answer[:150]}...")
|
|
print(f"Difficulty: {qa.difficulty}")
|
|
print(f"Categories: {', '.join(qa.categories)}")
|
|
|
|
return len(qa_pairs)
|
|
|
|
|
|
async def trial_validate_qa(num_qa_pairs: int):
|
|
"""Trial run of Q&A validation."""
|
|
print("\n" + "="*80)
|
|
print("STAGE 2: Q&A VALIDATION (Trial with production configs)")
|
|
print("="*80)
|
|
|
|
from importlib import import_module
|
|
stage2 = import_module('2_validate_qa')
|
|
|
|
await stage2.main(
|
|
input_file="output/trial_qa_pairs.jsonl",
|
|
output_file="output/trial_qa_validated.jsonl",
|
|
)
|
|
|
|
passed_pairs = []
|
|
with open("output/trial_qa_validated_passed.jsonl", 'r') as f:
|
|
for line in f:
|
|
passed_pairs.append(QAPair.model_validate_json(line))
|
|
|
|
print(f"\n✓ Validated {num_qa_pairs} Q&A pairs")
|
|
print(f"✓ {len(passed_pairs)} passed validation")
|
|
|
|
for i, qa in enumerate(passed_pairs[:3], 1):
|
|
print(f"\n--- Passed Q&A {i} ---")
|
|
print(f"Q: {qa.question}")
|
|
print(f"A: {qa.answer[:150]}...")
|
|
|
|
return len(passed_pairs)
|
|
|
|
|
|
async def trial_generate_conversations(num_valid_pairs: int):
|
|
"""Trial run of conversation generation with 20 conversations."""
|
|
print("\n" + "="*80)
|
|
print("STAGE 3: CONVERSATION GENERATION (Trial with production configs)")
|
|
print("="*80)
|
|
|
|
from importlib import import_module
|
|
stage3 = import_module('3_generate_conversations')
|
|
|
|
# Run with trial parameters (uses STAGE_CONFIGS for max_concurrent)
|
|
await stage3.main(
|
|
qa_file="output/trial_qa_validated_passed.jsonl",
|
|
output_file="output/trial_conversations_raw.jsonl",
|
|
)
|
|
|
|
# Load and show results
|
|
conversations = []
|
|
with open("output/trial_conversations_raw.jsonl", 'r') as f:
|
|
for line in f:
|
|
conversations.append(Conversation.model_validate_json(line))
|
|
|
|
print(f"\n✓ Valid Q&A pairs available: {num_valid_pairs}")
|
|
print(f"✓ Generated {len(conversations)} conversations")
|
|
|
|
# Show first 2
|
|
for i, conv in enumerate(conversations[:2], 1):
|
|
print(f"\n--- Conversation {i} ---")
|
|
print(f"Style: {conv.metadata.style}")
|
|
print(f"Persona: {conv.metadata.user_persona}")
|
|
print(f"Turns: {conv.metadata.num_turns}")
|
|
print("\nMessages:")
|
|
for msg in conv.messages:
|
|
print(f" {msg.role.upper()}: {msg.content[:100]}...")
|
|
|
|
return len(conversations)
|
|
|
|
|
|
async def trial_judge_conversations(num_conversations: int):
|
|
"""Trial run of judging all conversations."""
|
|
print("\n" + "="*80)
|
|
print("STAGE 4: JUDGING & SELECTION (Trial with all conversations)")
|
|
print("="*80)
|
|
|
|
from importlib import import_module
|
|
stage3 = import_module('4_judge_and_save')
|
|
|
|
# Judge all and save top K (uses STAGE_CONFIGS for max_concurrent)
|
|
await stage3.main(
|
|
input_file="output/trial_conversations_raw.jsonl",
|
|
judged_output="output/trial_conversations_judged.jsonl",
|
|
nanochat_output=TRIAL_FINAL_CONVERSATIONS,
|
|
)
|
|
|
|
# Load and show results
|
|
judged = []
|
|
with open("output/trial_conversations_judged.jsonl", 'r') as f:
|
|
for line in f:
|
|
judged.append(JudgedConversation.model_validate_json(line))
|
|
|
|
print(f"\n✓ Judged {len(judged)} conversations")
|
|
|
|
# Show pass/fail statistics (bool-based system)
|
|
total = len(judged)
|
|
passing = sum(1 for jc in judged if jc.judgment.overall_pass)
|
|
factual_pass = sum(1 for jc in judged if jc.judgment.factually_accurate)
|
|
natural_pass = sum(1 for jc in judged if jc.judgment.natural_conversation)
|
|
ontopic_pass = sum(1 for jc in judged if jc.judgment.on_topic)
|
|
value_pass = sum(1 for jc in judged if jc.judgment.adds_value)
|
|
|
|
print(f"\nQuality statistics:")
|
|
print(f" Overall PASS (all 4 criteria): {passing}/{total} ({passing/total*100:.1f}%)")
|
|
print(f"\nIndividual criteria:")
|
|
print(f" Factually accurate : {factual_pass}/{total} ({factual_pass/total*100:.1f}%)")
|
|
print(f" Natural conversation: {natural_pass}/{total} ({natural_pass/total*100:.1f}%)")
|
|
print(f" On topic : {ontopic_pass}/{total} ({ontopic_pass/total*100:.1f}%)")
|
|
print(f" Adds value : {value_pass}/{total} ({value_pass/total*100:.1f}%)")
|
|
|
|
# Show sample passing and failing conversations
|
|
passing_convs = [jc for jc in judged if jc.judgment.overall_pass]
|
|
failing_convs = [jc for jc in judged if not jc.judgment.overall_pass]
|
|
|
|
if passing_convs:
|
|
sample = passing_convs[0]
|
|
print(f"\n--- Sample PASSING Conversation ---")
|
|
print(f"Feedback: {sample.judgment.feedback}")
|
|
|
|
if failing_convs:
|
|
sample = failing_convs[0]
|
|
print(f"\n--- Sample FAILING Conversation ---")
|
|
print(f"Failed criteria: ", end="")
|
|
failed = []
|
|
if not sample.judgment.factually_accurate: failed.append("factual")
|
|
if not sample.judgment.natural_conversation: failed.append("natural")
|
|
if not sample.judgment.on_topic: failed.append("on-topic")
|
|
if not sample.judgment.adds_value: failed.append("adds-value")
|
|
print(", ".join(failed))
|
|
print(f"Feedback: {sample.judgment.feedback}")
|
|
if sample.judgment.issues:
|
|
print(f"Issues: {', '.join(sample.judgment.issues)}")
|
|
|
|
return len(judged)
|
|
|
|
|
|
def validate_output_formats():
|
|
"""Validate that output files match expected formats."""
|
|
print("\n" + "="*80)
|
|
print("VALIDATION: Checking output formats")
|
|
print("="*80)
|
|
|
|
checks = {
|
|
"Q&A pairs JSONL": "output/trial_qa_pairs.jsonl",
|
|
"Raw conversations JSONL": "output/trial_conversations_raw.jsonl",
|
|
"Judged conversations JSONL": "output/trial_conversations_judged.jsonl",
|
|
"NanoChat format JSONL": TRIAL_FINAL_CONVERSATIONS,
|
|
}
|
|
|
|
all_valid = True
|
|
|
|
for name, path in checks.items():
|
|
if not Path(path).exists():
|
|
print(f"✗ {name}: FILE NOT FOUND")
|
|
all_valid = False
|
|
continue
|
|
|
|
try:
|
|
with open(path, 'r') as f:
|
|
lines = f.readlines()
|
|
if not lines:
|
|
print(f"✗ {name}: EMPTY FILE")
|
|
all_valid = False
|
|
continue
|
|
|
|
# Try parsing first line as JSON
|
|
json.loads(lines[0])
|
|
print(f"✓ {name}: Valid ({len(lines)} entries)")
|
|
|
|
except Exception as e:
|
|
print(f"✗ {name}: {e}")
|
|
all_valid = False
|
|
|
|
return all_valid
|
|
|
|
|
|
async def main():
|
|
"""Run the complete trial pipeline."""
|
|
print("="*80)
|
|
print("SYNTHETIC DATA PIPELINE - TRIAL RUN")
|
|
print("="*80)
|
|
print("\nThis will:")
|
|
print(f" 1. Extract up to {TRIAL_QA_CHUNK_LIMIT} chunks worth of Q&A pairs")
|
|
print(" 2. Validate those Q&A pairs with the production agent")
|
|
print(" 3. Generate conversations using the same production configuration")
|
|
print(" 4. Judge all conversations and select the configured top K")
|
|
print()
|
|
|
|
# Configure logfire without sending to cloud (for trial runs)
|
|
logfire.configure(send_to_logfire=False)
|
|
|
|
try:
|
|
# Stage 1: Extract Q&A
|
|
num_qa = await trial_extract_qa()
|
|
|
|
# Stage 2: Validate Q&A
|
|
num_valid = await trial_validate_qa(num_qa)
|
|
|
|
# Stage 3: Generate conversations
|
|
num_conv = await trial_generate_conversations(num_valid)
|
|
|
|
# Stage 4: Judge and select
|
|
num_judged = await trial_judge_conversations(num_conv)
|
|
|
|
# Validate formats
|
|
all_valid = validate_output_formats()
|
|
|
|
# Final summary
|
|
print("\n" + "="*80)
|
|
print("TRIAL RUN COMPLETE")
|
|
print("="*80)
|
|
print(f"✓ Q&A pairs extracted: {num_qa}")
|
|
print(f"✓ Q&A pairs passed validation: {num_valid}")
|
|
print(f"✓ Conversations generated: {num_conv}")
|
|
print(f"✓ Conversations judged: {num_judged}")
|
|
print(f"✓ Output formats valid: {'YES' if all_valid else 'NO'}")
|
|
print()
|
|
|
|
if all_valid:
|
|
print("🎉 Trial run successful! You can now run the full pipeline.")
|
|
print()
|
|
print("Next steps:")
|
|
print(" 1. Review the trial outputs in output/ directory")
|
|
print(" 2. Adjust prompts if needed")
|
|
print(" 3. Run full pipeline:")
|
|
print(" - uv run 1_extract_qa.py")
|
|
print(" - uv run 2_validate_qa.py")
|
|
print(" - uv run 3_generate_conversations.py")
|
|
print(" - uv run 4_judge_and_save.py")
|
|
else:
|
|
print("⚠️ Some validations failed. Please review the errors above.")
|
|
|
|
except Exception as e:
|
|
print(f"\n❌ Trial run failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|