mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
205 lines
6.5 KiB
Python
205 lines
6.5 KiB
Python
"""
|
|
Stage 3: Generate conversations from validated Q&A pairs.
|
|
|
|
This script:
|
|
1. Loads validated Q&A pairs from output/qa_pairs_validated_passed.jsonl (or the provided path)
|
|
2. Samples different conversation configurations
|
|
3. Uses Gemini to generate natural conversations
|
|
4. Saves results to output/conversations_raw.jsonl
|
|
"""
|
|
|
|
import asyncio
|
|
import random
|
|
|
|
import logfire
|
|
from dotenv import load_dotenv
|
|
|
|
from src.synth_data_pipeline.agents import conversation_generator
|
|
|
|
from src.synth_data_pipeline.models import QAPair, Conversation
|
|
from src.synth_data_pipeline.config import (
|
|
PATHS,
|
|
STAGE_CONFIGS,
|
|
FULL_PARAMS,
|
|
)
|
|
from src.synth_data_pipeline.sampling import (
|
|
stratified_sample_configs,
|
|
load_system_prompts_from_files,
|
|
set_random_seed,
|
|
)
|
|
from src.synth_data_pipeline.utils import (
|
|
load_jsonl,
|
|
save_jsonl,
|
|
process_with_concurrency,
|
|
)
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# Configure logging
|
|
logfire.configure(scrubbing=False)
|
|
logfire.instrument_pydantic_ai()
|
|
|
|
# Get configuration for this stage
|
|
config = STAGE_CONFIGS["stage3_conversation_generation"]
|
|
|
|
# Load conversation generation agent definition
|
|
conv_prompt_template = conversation_generator.get_prompt_template()
|
|
conv_agent = conversation_generator.build_agent(config)
|
|
|
|
|
|
async def generate_conversation(
|
|
qa_pairs: list[QAPair],
|
|
config_dict: dict
|
|
) -> Conversation:
|
|
"""
|
|
Generate a conversation from a configuration.
|
|
|
|
Args:
|
|
qa_pairs: All available Q&A pairs
|
|
config_dict: Configuration dict with num_turns, style, persona, system_prompt
|
|
|
|
Returns:
|
|
Conversation object
|
|
"""
|
|
# Sample Q&A pairs for this conversation
|
|
num_turns = config_dict["num_turns"]
|
|
|
|
# Filter Q&A by topic relevance to system prompt (pragmatic matching)
|
|
system_prompt_name = config_dict["system_prompt"].name if hasattr(config_dict["system_prompt"], "name") else "helpful"
|
|
|
|
# Simple persona matching: sales/solutions prompts shouldn't get governance/company info Q&A
|
|
if "sales" in system_prompt_name.lower() or "solutions" in system_prompt_name.lower():
|
|
# Prefer product/feature Q&A, avoid company registration/director questions
|
|
avoid_keywords = ["appointed", "director", "registered", "incorporation", "companies house"]
|
|
filtered_qa = [qa for qa in qa_pairs if not any(kw in qa.question.lower() for kw in avoid_keywords)]
|
|
qa_pool = filtered_qa if filtered_qa else qa_pairs
|
|
else:
|
|
qa_pool = qa_pairs
|
|
|
|
sampled_qa = random.sample(qa_pool, min(num_turns, len(qa_pool)))
|
|
|
|
# Format Q&A pairs for the prompt
|
|
qa_text = "\n\n".join([
|
|
f"Q: {qa.question}\nA: {qa.answer}\nCategories: {', '.join(qa.categories)}"
|
|
for qa in sampled_qa
|
|
])
|
|
|
|
# Get system prompt text
|
|
system_prompt = config_dict["system_prompt"].template
|
|
|
|
# Format the prompt
|
|
prompt_text = conv_prompt_template.prompt.format(
|
|
num_turns=num_turns,
|
|
style=config_dict["style"],
|
|
user_persona=config_dict["persona"].description,
|
|
system_prompt=system_prompt,
|
|
qa_pairs=qa_text,
|
|
)
|
|
|
|
# Generate conversation using the agent
|
|
result = await conv_agent.run(prompt_text)
|
|
conversation = result.output
|
|
|
|
# Add source Q&A pairs to conversation for fact-checking
|
|
conversation.source_qa_pairs = sampled_qa
|
|
|
|
return conversation
|
|
|
|
|
|
async def main(
|
|
qa_file: str = None,
|
|
output_file: str = None,
|
|
num_conversations: int = None,
|
|
max_concurrent: int = None,
|
|
):
|
|
"""
|
|
Main function to generate conversations from Q&A pairs.
|
|
|
|
Args:
|
|
qa_file: Path to Q&A pairs JSONL file (default from config)
|
|
output_file: Path to output JSONL file (default from config)
|
|
num_conversations: Number of conversations to generate (default from config)
|
|
max_concurrent: Maximum concurrent API calls (default from config)
|
|
"""
|
|
# Use defaults from config if not specified
|
|
qa_file = qa_file or PATHS.stage2_qa_validated_passed
|
|
output_file = output_file or PATHS.stage3_conversations_raw
|
|
max_concurrent = max_concurrent or config.max_concurrent
|
|
|
|
# Set random seed if specified
|
|
if FULL_PARAMS.random_seed is not None:
|
|
set_random_seed(FULL_PARAMS.random_seed)
|
|
|
|
logfire.info("Starting conversation generation", qa_file=qa_file)
|
|
|
|
# Load Q&A pairs
|
|
qa_pairs = load_jsonl(qa_file, model_class=QAPair)
|
|
logfire.info(f"Loaded {len(qa_pairs)} Q&A pairs")
|
|
|
|
# Determine how many conversations to generate based on available QA pairs
|
|
requested_conversations = num_conversations or FULL_PARAMS.num_conversations
|
|
auto_cap = 0
|
|
if FULL_PARAMS.conversations_per_qa:
|
|
auto_cap = len(qa_pairs) * FULL_PARAMS.conversations_per_qa
|
|
|
|
target_conversations = requested_conversations
|
|
if auto_cap:
|
|
target_conversations = min(requested_conversations, auto_cap)
|
|
|
|
if target_conversations == 0:
|
|
logfire.warning("No conversations generated because no Q&A pairs are available.")
|
|
return
|
|
|
|
logfire.info(
|
|
"Conversation target determined",
|
|
requested=requested_conversations,
|
|
auto_cap=auto_cap,
|
|
final=target_conversations,
|
|
)
|
|
|
|
# Load system prompt templates from files (for runtime flexibility)
|
|
system_prompts_from_files = load_system_prompts_from_files()
|
|
logfire.info(f"Loaded {len(system_prompts_from_files)} system prompts from files")
|
|
|
|
# Sample conversation configurations
|
|
configs = stratified_sample_configs(
|
|
target_conversations,
|
|
ensure_coverage=True
|
|
)
|
|
logfire.info(f"Sampled {len(configs)} conversation configurations")
|
|
|
|
# Generate conversations
|
|
with logfire.span("generate_conversations"):
|
|
# Create a closure that includes qa_pairs
|
|
async def generate_fn(config_dict):
|
|
return await generate_conversation(qa_pairs, config_dict)
|
|
|
|
conversations = await process_with_concurrency(
|
|
configs,
|
|
generate_fn,
|
|
max_concurrent=max_concurrent,
|
|
desc="Generating conversations"
|
|
)
|
|
|
|
logfire.info(f"Generated {len(conversations)} conversations")
|
|
|
|
# Save results
|
|
save_jsonl(conversations, output_file)
|
|
|
|
# Print sample for inspection
|
|
if conversations:
|
|
print("\n" + "="*80)
|
|
print("SAMPLE CONVERSATION:")
|
|
print("="*80)
|
|
sample = conversations[0]
|
|
for msg in sample.messages:
|
|
print(f"{msg.role.upper()}: {msg.content[:200]}...")
|
|
print()
|
|
print(f"Metadata: {sample.metadata}")
|
|
print("="*80 + "\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|