diff --git a/synth-data-pipeline/src/synth_data_pipeline/__init__.py b/synth-data-pipeline/src/synth_data_pipeline/__init__.py index 55a66f2..33e8f50 100644 --- a/synth-data-pipeline/src/synth_data_pipeline/__init__.py +++ b/synth-data-pipeline/src/synth_data_pipeline/__init__.py @@ -2,7 +2,89 @@ import textprompts +# Core models +from .models import ( + QAPair, + QAPairBatch, + QAValidation, + ValidatedQAPair, + Message, + Conversation, + ConversationMetadata, + JudgmentScore, + JudgedConversation, + EmbeddedConversation, + UniqueConversation, + NanoChatMessage, + NanoChatConversation, +) + +# Configuration +from .config import ( + APIConfig, + FilePaths, + PipelineParams, + Persona, + SystemPromptTemplate, + PATHS, + FULL_PARAMS, + STAGE_CONFIGS, +) + +# Utilities +from .utils import ( + load_jsonl, + save_jsonl, + parse_markdown_chunks, + process_with_concurrency, + calculate_overall_score, + print_sample, + print_statistics, +) + +# Agents (as submodule) +from . import agents + __version__ = "0.1.0" # Set strict metadata requirement for all prompts globally -textprompts.set_metadata('strict') +textprompts.set_metadata("strict") + +__all__ = [ + # Version + "__version__", + # Models + "QAPair", + "QAPairBatch", + "QAValidation", + "ValidatedQAPair", + "Message", + "Conversation", + "ConversationMetadata", + "JudgmentScore", + "JudgedConversation", + "EmbeddedConversation", + "UniqueConversation", + "NanoChatMessage", + "NanoChatConversation", + # Config classes + "APIConfig", + "FilePaths", + "PipelineParams", + "Persona", + "SystemPromptTemplate", + # Config instances + "PATHS", + "FULL_PARAMS", + "STAGE_CONFIGS", + # Utils + "load_jsonl", + "save_jsonl", + "parse_markdown_chunks", + "process_with_concurrency", + "calculate_overall_score", + "print_sample", + "print_statistics", + # Submodules + "agents", +] diff --git a/synth-data-pipeline/src/synth_data_pipeline/agents/conversation_judge.py b/synth-data-pipeline/src/synth_data_pipeline/agents/conversation_judge.py index 98e916f..c03b8bf 100644 --- a/synth-data-pipeline/src/synth_data_pipeline/agents/conversation_judge.py +++ b/synth-data-pipeline/src/synth_data_pipeline/agents/conversation_judge.py @@ -5,7 +5,9 @@ from src.synth_data_pipeline.config import APIConfig from src.synth_data_pipeline.models import JudgmentScore PROMPT_NAME = "conversation_judge" -SYSTEM_PROMPT = "You are an expert evaluator of training data quality for language models." +SYSTEM_PROMPT = ( + "You are an expert evaluator of training data quality for language models." +) def build_agent(config: APIConfig, *, api_key: str | None = None): diff --git a/synth-data-pipeline/src/synth_data_pipeline/config.py b/synth-data-pipeline/src/synth_data_pipeline/config.py index dafc134..64e03da 100644 --- a/synth-data-pipeline/src/synth_data_pipeline/config.py +++ b/synth-data-pipeline/src/synth_data_pipeline/config.py @@ -18,7 +18,7 @@ DIFFICULTY_LEVELS = ["basic", "intermediate", "advanced"] DIFFICULTY_DESCRIPTIONS = { "basic": "Simple factual questions requiring basic recall", "intermediate": "Questions requiring understanding and reasoning", - "advanced": "Complex technical or multi-faceted questions" + "advanced": "Complex technical or multi-faceted questions", } @@ -31,7 +31,7 @@ CONVERSATION_STYLES = ["formal", "casual", "technical"] STYLE_DESCRIPTIONS = { "formal": "Professional language, complete sentences, no slang", "casual": "Friendly tone, can use contractions, conversational", - "technical": "Uses technical terminology, assumes some expertise" + "technical": "Uses technical terminology, assumes some expertise", } @@ -39,9 +39,11 @@ STYLE_DESCRIPTIONS = { # User Personas # ============================================================================ + @dataclass class Persona: """Definition of a user persona.""" + name: str description: str typical_questions: List[str] = field(default_factory=list) @@ -56,9 +58,9 @@ PERSONAS = { "API integration details", "Technical specifications", "SDK usage", - "Error handling" + "Error handling", ], - formality="technical" + formality="technical", ), "product_manager": Persona( name="product_manager", @@ -67,9 +69,9 @@ PERSONAS = { "Feature comparisons", "Roadmap questions", "Use cases", - "ROI analysis" + "ROI analysis", ], - formality="formal" + formality="formal", ), "cs_agent": Persona( name="cs_agent", @@ -78,9 +80,9 @@ PERSONAS = { "Setup instructions", "Troubleshooting", "Configuration options", - "Best practices" + "Best practices", ], - formality="neutral" + formality="neutral", ), "executive": Persona( name="executive", @@ -89,9 +91,9 @@ PERSONAS = { "Business value", "Competitive advantages", "Pricing strategy", - "Scalability" + "Scalability", ], - formality="formal" + formality="formal", ), "operations": Persona( name="operations", @@ -100,9 +102,9 @@ PERSONAS = { "Integration capabilities", "Workflow automation", "Performance metrics", - "SLA guarantees" + "SLA guarantees", ], - formality="neutral" + formality="neutral", ), "finance": Persona( name="finance", @@ -111,9 +113,9 @@ PERSONAS = { "Tax compliance", "Financial reporting", "Audit trails", - "Cost structure" + "Cost structure", ], - formality="formal" + formality="formal", ), } @@ -122,9 +124,11 @@ PERSONAS = { # System Prompt Templates # ============================================================================ + @dataclass class SystemPromptTemplate: """Definition of a system prompt template.""" + name: str description: str template: str @@ -138,35 +142,35 @@ SYSTEM_PROMPT_TEMPLATES = { description="Helpful and friendly assistant", template="You are a helpful AI assistant with expertise in SWAP Commerce's e-commerce platform and services. You provide accurate, friendly, and detailed answers to questions about SWAP Commerce's products, features, integrations, and pricing.", verbosity="detailed", - use_case="general" + use_case="general", ), "concise": SystemPromptTemplate( name="concise", description="Brief and to-the-point responses", template="You are a SWAP Commerce expert providing clear, concise answers. Focus on key information without unnecessary detail.", verbosity="concise", - use_case="quick_reference" + use_case="quick_reference", ), "technical": SystemPromptTemplate( name="technical", description="Technical expert for developers", template="You are a technical expert on SWAP Commerce's platform. You provide detailed technical information about APIs, integrations, implementation, and system architecture. You assume the user has technical knowledge.", verbosity="detailed", - use_case="developer" + use_case="developer", ), "detailed": SystemPromptTemplate( name="detailed", description="Comprehensive explanations", template="You are a comprehensive SWAP Commerce expert who provides thorough, well-explained answers with examples, context, and relevant details. You ensure users fully understand the topic.", verbosity="detailed", - use_case="learning" + use_case="learning", ), "sales": SystemPromptTemplate( name="sales", description="Sales and solutions focused", template="You are a SWAP Commerce solutions consultant helping potential customers understand how SWAP Commerce can solve their e-commerce challenges. You're knowledgeable about features, benefits, and competitive advantages.", verbosity="balanced", - use_case="sales" + use_case="sales", ), } @@ -211,7 +215,7 @@ DEFAULT_TURN_DISTRIBUTION = { TURN_DISTRIBUTIONS = { "default": DEFAULT_TURN_DISTRIBUTION, "short": {1: 0.6, 2: 0.3, 3: 0.1}, # Mostly short conversations - "long": {2: 0.2, 3: 0.4, 4: 0.4}, # Longer conversations + "long": {2: 0.2, 3: 0.4, 4: 0.4}, # Longer conversations "balanced": {1: 0.25, 2: 0.25, 3: 0.25, 4: 0.25}, # Equal distribution } @@ -224,10 +228,10 @@ USER_EMOTIONS = ["professional", "happy", "frustrated", "impatient", "confused"] USER_EMOTION_DISTRIBUTION = { "professional": 0.50, # Most common - neutral, business-like - "happy": 0.15, # Positive, enthusiastic - "frustrated": 0.15, # Having issues, needs help - "impatient": 0.10, # Wants quick answers - "confused": 0.10, # Unclear about something + "happy": 0.15, # Positive, enthusiastic + "frustrated": 0.15, # Having issues, needs help + "impatient": 0.10, # Wants quick answers + "confused": 0.10, # Unclear about something } EMOTION_DESCRIPTIONS = { @@ -235,7 +239,7 @@ EMOTION_DESCRIPTIONS = { "happy": "Positive, enthusiastic. May express excitement about features or capabilities.", "frustrated": "Experiencing issues or challenges. May express mild annoyance or urgency.", "impatient": "Wants quick, direct answers. Brief messages, may skip pleasantries.", - "confused": "Unclear about concepts or features. May ask for clarification or examples." + "confused": "Unclear about concepts or features. May ask for clarification or examples.", } @@ -246,15 +250,15 @@ EMOTION_DESCRIPTIONS = { INPUT_MODALITIES = ["standard", "typed_on_phone", "voice_dictated"] INPUT_MODALITY_DISTRIBUTION = { - "standard": 0.70, # Normal typing on computer - "typed_on_phone": 0.20, # Mobile typing - autocorrect errors, brevity - "voice_dictated": 0.10, # Voice-to-text - filler words, natural speech + "standard": 0.70, # Normal typing on computer + "typed_on_phone": 0.20, # Mobile typing - autocorrect errors, brevity + "voice_dictated": 0.10, # Voice-to-text - filler words, natural speech } MODALITY_DESCRIPTIONS = { "standard": "Standard computer typing. Clean text, proper formatting.", "typed_on_phone": "Mobile device typing. May have autocorrect errors, abbreviations, shorter messages.", - "voice_dictated": "Voice-to-text transcription. May include 'um', 'uh', natural speech patterns, occasional transcription errors." + "voice_dictated": "Voice-to-text transcription. May include 'um', 'uh', natural speech patterns, occasional transcription errors.", } @@ -265,15 +269,15 @@ MODALITY_DESCRIPTIONS = { TEXT_VARIATIONS = ["standard", "all_lowercase", "no_punctuation"] TEXT_VARIATION_DISTRIBUTION = { - "standard": 0.80, # Normal capitalization and punctuation - "all_lowercase": 0.15, # all lowercase (casual/mobile) - "no_punctuation": 0.05, # missing punctuation (rushed/mobile) + "standard": 0.80, # Normal capitalization and punctuation + "all_lowercase": 0.15, # all lowercase (casual/mobile) + "no_punctuation": 0.05, # missing punctuation (rushed/mobile) } VARIATION_DESCRIPTIONS = { "standard": "Standard capitalization and punctuation.", "all_lowercase": "All lowercase letters (common in casual or mobile communication).", - "no_punctuation": "Missing or minimal punctuation (rushed typing or informal style)." + "no_punctuation": "Missing or minimal punctuation (rushed typing or informal style).", } @@ -293,9 +297,11 @@ QUALITY_WEIGHTS = { # API Configuration # ============================================================================ + @dataclass class APIConfig: """Configuration for API calls.""" + model: str = "gemini-2.5-flash-lite" max_concurrent: int = 10 temperature: float = 0.9 # Higher for generation, lower for judging @@ -337,9 +343,11 @@ STAGE_CONFIGS = { # File Paths # ============================================================================ + @dataclass class FilePaths: """Standard file paths for the pipeline.""" + data_dir: str = "data" prompts_dir: str = "prompts" output_dir: str = "output" @@ -374,6 +382,7 @@ PATHS = FilePaths() # Pipeline Parameters # ============================================================================ + @dataclass class PipelineParams: """Parameters for the full pipeline run.""" @@ -389,10 +398,18 @@ class PipelineParams: # Stage 3: Conversation Generation num_conversations: int = 2000 conversations_per_qa: int = 10 - turn_distribution: Dict[int, float] = field(default_factory=lambda: DEFAULT_TURN_DISTRIBUTION) - emotion_distribution: Dict[str, float] = field(default_factory=lambda: USER_EMOTION_DISTRIBUTION) - modality_distribution: Dict[str, float] = field(default_factory=lambda: INPUT_MODALITY_DISTRIBUTION) - variation_distribution: Dict[str, float] = field(default_factory=lambda: TEXT_VARIATION_DISTRIBUTION) + turn_distribution: Dict[int, float] = field( + default_factory=lambda: DEFAULT_TURN_DISTRIBUTION + ) + emotion_distribution: Dict[str, float] = field( + default_factory=lambda: USER_EMOTION_DISTRIBUTION + ) + modality_distribution: Dict[str, float] = field( + default_factory=lambda: INPUT_MODALITY_DISTRIBUTION + ) + variation_distribution: Dict[str, float] = field( + default_factory=lambda: TEXT_VARIATION_DISTRIBUTION + ) # Stage 4: Judging min_quality_score: float = 5.0 # Minimum acceptable score diff --git a/synth-data-pipeline/src/synth_data_pipeline/embedding_utils.py b/synth-data-pipeline/src/synth_data_pipeline/embedding_utils.py index 6d59db2..5305479 100644 --- a/synth-data-pipeline/src/synth_data_pipeline/embedding_utils.py +++ b/synth-data-pipeline/src/synth_data_pipeline/embedding_utils.py @@ -20,7 +20,7 @@ async def batch_embed( model: str = "text-embedding-3-small", dimensions: int = 1024, batch_size: int = 100, - max_concurrent: int = 20 + max_concurrent: int = 20, ) -> List[np.ndarray]: """ Generate embeddings for a list of texts using OpenAI API. @@ -39,7 +39,7 @@ async def batch_embed( logfire.info(f"Embedding {len(texts)} texts in batches of {batch_size}") # Split into batches - batches = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)] + batches = [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)] # Semaphore for rate limiting semaphore = asyncio.Semaphore(max_concurrent) @@ -49,9 +49,7 @@ async def batch_embed( async with semaphore: try: response = await client.embeddings.create( - model=model, - input=batch, - dimensions=dimensions + model=model, input=batch, dimensions=dimensions ) return [item.embedding for item in response.data] except Exception as e: @@ -112,7 +110,7 @@ def compute_similarity(emb1: np.ndarray, emb2: np.ndarray) -> float: def greedy_deduplicate( embeddings: List[np.ndarray], scores: List[float], - similarity_threshold: float = 0.95 + similarity_threshold: float = 0.95, ) -> List[int]: """ Greedy deduplication: keep highest-scoring items, remove similar duplicates. @@ -170,8 +168,8 @@ def conversation_to_text(messages: List[dict], max_chars: int = 24000) -> str: """ parts = [] for msg in messages: - role = msg.get('role', 'unknown').upper() - content = msg.get('content', '') + role = msg.get("role", "unknown").upper() + content = msg.get("content", "") parts.append(f"{role}: {content}") full_text = "\n\n".join(parts) diff --git a/synth-data-pipeline/src/synth_data_pipeline/models.py b/synth-data-pipeline/src/synth_data_pipeline/models.py index 58c7ff4..ee8d13d 100644 --- a/synth-data-pipeline/src/synth_data_pipeline/models.py +++ b/synth-data-pipeline/src/synth_data_pipeline/models.py @@ -17,26 +17,18 @@ class QAPair(BaseModel): question: str = Field( description="A natural question that could be asked about this topic" ) - answer: str = Field( - description="The accurate answer grounded in the source text" - ) + answer: str = Field(description="The accurate answer grounded in the source text") source_text: str = Field( description="The specific text chunk this Q&A was generated from" ) - context_before: str = Field( - default="", - description="Preceding lines for context" - ) - context_after: str = Field( - default="", - description="Following lines for context" - ) + context_before: str = Field(default="", description="Preceding lines for context") + context_after: str = Field(default="", description="Following lines for context") difficulty: Literal["basic", "intermediate", "advanced"] = Field( description="The difficulty level of this question" ) categories: list[str] = Field( default_factory=list, - description="Topic categories (e.g., 'pricing', 'features', 'integrations')" + description="Topic categories (e.g., 'pricing', 'features', 'integrations')", ) @@ -65,23 +57,15 @@ class QAValidation(BaseModel): sensible_answer: bool = Field( description="Is the answer appropriate and sensible for the question?" ) - passed: bool = Field( - description="Overall pass (all three bools must be True)" - ) - feedback: str = Field( - description="Brief explanation of validation result" - ) + passed: bool = Field(description="Overall pass (all three bools must be True)") + feedback: str = Field(description="Brief explanation of validation result") class ValidatedQAPair(BaseModel): """A Q&A pair with its validation result.""" - qa_pair: QAPair = Field( - description="The Q&A pair being validated" - ) - validation: QAValidation = Field( - description="The validation result" - ) + qa_pair: QAPair = Field(description="The Q&A pair being validated") + validation: QAValidation = Field(description="The validation result") # ============================================================================ @@ -95,9 +79,7 @@ class Message(BaseModel): role: Literal["system", "user", "assistant"] = Field( description="The role of the message sender" ) - content: str = Field( - description="The message content" - ) + content: str = Field(description="The message content") class ConversationMetadata(BaseModel): @@ -112,28 +94,24 @@ class ConversationMetadata(BaseModel): user_persona: str = Field( description="The persona/role of the user (e.g., 'developer', 'business owner')" ) - user_emotion: Literal["professional", "happy", "frustrated", "impatient", "confused"] = Field( - default="professional", - description="The emotional state of the user" - ) + user_emotion: Literal[ + "professional", "happy", "frustrated", "impatient", "confused" + ] = Field(default="professional", description="The emotional state of the user") input_modality: Literal["standard", "typed_on_phone", "voice_dictated"] = Field( - default="standard", - description="How the user is inputting their messages" + default="standard", description="How the user is inputting their messages" ) text_variation: Literal["standard", "all_lowercase", "no_punctuation"] = Field( default="standard", - description="Text formatting variation applied to user messages" + description="Text formatting variation applied to user messages", ) source_qa_ids: list[int] = Field( default_factory=list, - description="Indices of Q&A pairs used to generate this conversation" - ) - difficulty: str = Field( - description="Overall difficulty level" + description="Indices of Q&A pairs used to generate this conversation", ) + difficulty: str = Field(description="Overall difficulty level") categories: list[str] = Field( default_factory=list, - description="Topic categories covered in this conversation" + description="Topic categories covered in this conversation", ) @@ -148,7 +126,7 @@ class Conversation(BaseModel): ) source_qa_pairs: list[QAPair] = Field( default_factory=list, - description="The Q&A pairs used to generate this conversation (for fact-checking)" + description="The Q&A pairs used to generate this conversation (for fact-checking)", ) @@ -175,24 +153,17 @@ class JudgmentScore(BaseModel): overall_pass: bool = Field( description="TRUE only if ALL four criteria above are TRUE" ) - feedback: str = Field( - description="Brief explanation of judgment (1-2 sentences)" - ) + feedback: str = Field(description="Brief explanation of judgment (1-2 sentences)") issues: list[str] = Field( - default_factory=list, - description="Specific problems found (if any)" + default_factory=list, description="Specific problems found (if any)" ) class JudgedConversation(BaseModel): """A conversation with its quality judgment.""" - conversation: Conversation = Field( - description="The conversation being judged" - ) - judgment: JudgmentScore = Field( - description="The quality judgment scores" - ) + conversation: Conversation = Field(description="The conversation being judged") + judgment: JudgmentScore = Field(description="The quality judgment scores") # ============================================================================ @@ -203,18 +174,12 @@ class JudgedConversation(BaseModel): class EmbeddedConversation(BaseModel): """A judged conversation with its embedding.""" - conversation: Conversation = Field( - description="The conversation" - ) - judgment: JudgmentScore = Field( - description="The quality judgment" - ) + conversation: Conversation = Field(description="The conversation") + judgment: JudgmentScore = Field(description="The quality judgment") embedding: list[float] = Field( description="Conversation embedding (1024 dimensions)" ) - text_preview: str = Field( - description="First 200 characters for debugging" - ) + text_preview: str = Field(description="First 200 characters for debugging") # ============================================================================ @@ -225,12 +190,8 @@ class EmbeddedConversation(BaseModel): class UniqueConversation(BaseModel): """A conversation marked as unique after deduplication.""" - conversation: Conversation = Field( - description="The conversation" - ) - judgment: JudgmentScore = Field( - description="The quality judgment" - ) + conversation: Conversation = Field(description="The conversation") + judgment: JudgmentScore = Field(description="The quality judgment") # Note: embedding removed to save space after dedup diff --git a/synth-data-pipeline/src/synth_data_pipeline/sampling.py b/synth-data-pipeline/src/synth_data_pipeline/sampling.py index 9eebef9..7acd0d0 100644 --- a/synth-data-pipeline/src/synth_data_pipeline/sampling.py +++ b/synth-data-pipeline/src/synth_data_pipeline/sampling.py @@ -11,11 +11,8 @@ from .config import ( SYSTEM_PROMPT_TEMPLATES, CONVERSATION_STYLES, DEFAULT_TURN_DISTRIBUTION, - USER_EMOTIONS, USER_EMOTION_DISTRIBUTION, - INPUT_MODALITIES, INPUT_MODALITY_DISTRIBUTION, - TEXT_VARIATIONS, TEXT_VARIATION_DISTRIBUTION, Persona, SystemPromptTemplate, @@ -137,10 +134,7 @@ def sample_system_prompt_by_use_case(use_case: str) -> SystemPromptTemplate: Returns: A SystemPromptTemplate with matching use case """ - matching = [ - s for s in SYSTEM_PROMPT_TEMPLATES.values() - if s.use_case == use_case - ] + matching = [s for s in SYSTEM_PROMPT_TEMPLATES.values() if s.use_case == use_case] return random.choice(matching) if matching else sample_system_prompt() @@ -172,7 +166,9 @@ def sample_balanced_config( # Sample persona if prefer_technical: - persona = PERSONAS.get("developer") if random.random() < 0.5 else sample_persona() + persona = ( + PERSONAS.get("developer") if random.random() < 0.5 else sample_persona() + ) else: persona = sample_persona() @@ -193,7 +189,9 @@ def sample_balanced_config( } -def load_system_prompts_from_files(prompts_dir: str = "src/synth_data_pipeline/prompts/system_prompts") -> Dict[str, str]: +def load_system_prompts_from_files( + prompts_dir: str = "src/synth_data_pipeline/prompts/system_prompts", +) -> Dict[str, str]: """ Load system prompt templates from text files. @@ -208,13 +206,15 @@ def load_system_prompts_from_files(prompts_dir: str = "src/synth_data_pipeline/p for prompt_file in prompts_path.glob("*.txt"): name = prompt_file.stem - with open(prompt_file, 'r', encoding='utf-8') as f: + with open(prompt_file, "r", encoding="utf-8") as f: system_prompts[name] = f.read().strip() return system_prompts -def load_personas_from_files(personas_dir: str = "src/synth_data_pipeline/prompts/personas") -> Dict[str, str]: +def load_personas_from_files( + personas_dir: str = "src/synth_data_pipeline/prompts/personas", +) -> Dict[str, str]: """ Load persona descriptions from text files. @@ -229,7 +229,7 @@ def load_personas_from_files(personas_dir: str = "src/synth_data_pipeline/prompt for persona_file in personas_path.glob("*.txt"): name = persona_file.stem - with open(persona_file, 'r', encoding='utf-8') as f: + with open(persona_file, "r", encoding="utf-8") as f: personas[name] = f.read().strip() return personas @@ -272,10 +272,8 @@ def sample_multiple_configs( # Sampling Strategies # ============================================================================ -def stratified_sample_configs( - n: int, - ensure_coverage: bool = True -) -> List[Dict]: + +def stratified_sample_configs(n: int, ensure_coverage: bool = True) -> List[Dict]: """ Sample configurations with stratified sampling to ensure diversity. @@ -292,15 +290,17 @@ def stratified_sample_configs( # First, ensure we have at least one of each persona-style combination for persona in PERSONAS.values(): for style in CONVERSATION_STYLES: - configs.append({ - "num_turns": sample_num_turns(), - "style": style, - "persona": persona, - "system_prompt": sample_system_prompt(), - "user_emotion": sample_emotion(), - "input_modality": sample_input_modality(), - "text_variation": sample_text_variation(), - }) + configs.append( + { + "num_turns": sample_num_turns(), + "style": style, + "persona": persona, + "system_prompt": sample_system_prompt(), + "user_emotion": sample_emotion(), + "input_modality": sample_input_modality(), + "text_variation": sample_text_variation(), + } + ) # Fill remaining with random samples remaining = n - len(configs) diff --git a/synth-data-pipeline/src/synth_data_pipeline/utils.py b/synth-data-pipeline/src/synth_data_pipeline/utils.py index 15c15c9..dd60b1f 100644 --- a/synth-data-pipeline/src/synth_data_pipeline/utils.py +++ b/synth-data-pipeline/src/synth_data_pipeline/utils.py @@ -10,15 +10,15 @@ from typing import List, TypeVar, Callable, Awaitable import logfire from tqdm.asyncio import tqdm_asyncio -T = TypeVar('T') -R = TypeVar('R') +T = TypeVar("T") +R = TypeVar("R") async def process_with_concurrency( items: List[T], process_fn: Callable[[T], Awaitable[R]], max_concurrent: int = 10, - desc: str = "Processing" + desc: str = "Processing", ) -> List[R]: """ Process items concurrently with a semaphore to limit concurrency. @@ -61,12 +61,12 @@ def save_jsonl(items: List, output_path: str | Path): output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) - with open(output_path, 'w', encoding='utf-8') as f: + with open(output_path, "w", encoding="utf-8") as f: for item in items: - if hasattr(item, 'model_dump_json'): - f.write(item.model_dump_json() + '\n') + if hasattr(item, "model_dump_json"): + f.write(item.model_dump_json() + "\n") else: - f.write(json.dumps(item) + '\n') + f.write(json.dumps(item) + "\n") logfire.info(f"Saved {len(items)} items to {output_path}") @@ -83,7 +83,7 @@ def load_jsonl(file_path: str | Path, model_class=None) -> List: List of items """ items = [] - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: for line in f: if model_class: items.append(model_class.model_validate_json(line)) @@ -92,10 +92,7 @@ def load_jsonl(file_path: str | Path, model_class=None) -> List: return items -def parse_markdown_chunks( - file_path: str | Path, - context_lines: int = 3 -) -> List[dict]: +def parse_markdown_chunks(file_path: str | Path, context_lines: int = 3) -> List[dict]: """ Parse markdown file and create chunks with context. @@ -106,7 +103,7 @@ def parse_markdown_chunks( Returns: List of dicts with 'source_text', 'context_before', 'context_after' """ - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: lines = f.readlines() chunks = [] @@ -116,28 +113,30 @@ def parse_markdown_chunks( line = lines[i].strip() # Skip empty lines and metadata - if not line or line.startswith('---') or line.startswith('**As of:**'): + if not line or line.startswith("---") or line.startswith("**As of:**"): i += 1 continue # Process bullets and significant lines - if line.startswith('*') or line.startswith('#') or len(line) > 50: + if line.startswith("*") or line.startswith("#") or len(line) > 50: # Get context before context_start = max(0, i - context_lines) - context_before = ''.join(lines[context_start:i]).strip() + context_before = "".join(lines[context_start:i]).strip() # Get the main text (current line) source_text = line # Get context after context_end = min(len(lines), i + context_lines + 1) - context_after = ''.join(lines[i + 1:context_end]).strip() + context_after = "".join(lines[i + 1 : context_end]).strip() - chunks.append({ - 'source_text': source_text, - 'context_before': context_before, - 'context_after': context_after, - }) + chunks.append( + { + "source_text": source_text, + "context_before": context_before, + "context_after": context_after, + } + ) i += 1 @@ -149,7 +148,7 @@ def calculate_overall_score( naturalness: float, relevance: float, diversity: float, - weights: dict = None + weights: dict = None, ) -> float: """ Calculate overall quality score from individual metrics. @@ -166,13 +165,14 @@ def calculate_overall_score( """ if weights is None: from .config import QUALITY_WEIGHTS + weights = QUALITY_WEIGHTS overall = ( - factual_accuracy * weights.get("factual_accuracy", 0.35) + - naturalness * weights.get("naturalness", 0.25) + - relevance * weights.get("relevance", 0.25) + - diversity * weights.get("diversity", 0.15) + factual_accuracy * weights.get("factual_accuracy", 0.35) + + naturalness * weights.get("naturalness", 0.25) + + relevance * weights.get("relevance", 0.25) + + diversity * weights.get("diversity", 0.15) ) return round(overall, 2) @@ -186,11 +186,11 @@ def print_sample(item, title: str = "SAMPLE"): item: Item to print (conversation, Q&A, etc.) title: Title for the sample section """ - print("\n" + "="*80) + print("\n" + "=" * 80) print(title) - print("="*80) + print("=" * 80) - if hasattr(item, 'model_dump'): + if hasattr(item, "model_dump"): # Pydantic model print(json.dumps(item.model_dump(), indent=2)) elif isinstance(item, dict): @@ -198,7 +198,7 @@ def print_sample(item, title: str = "SAMPLE"): else: print(item) - print("="*80 + "\n") + print("=" * 80 + "\n") def print_statistics(scores: List[float], metric_name: str = "Score"):