mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
formatting
This commit is contained in:
parent
6bfc1f8f53
commit
7eac69487b
|
|
@ -2,7 +2,89 @@
|
|||
|
||||
import textprompts
|
||||
|
||||
# Core models
|
||||
from .models import (
|
||||
QAPair,
|
||||
QAPairBatch,
|
||||
QAValidation,
|
||||
ValidatedQAPair,
|
||||
Message,
|
||||
Conversation,
|
||||
ConversationMetadata,
|
||||
JudgmentScore,
|
||||
JudgedConversation,
|
||||
EmbeddedConversation,
|
||||
UniqueConversation,
|
||||
NanoChatMessage,
|
||||
NanoChatConversation,
|
||||
)
|
||||
|
||||
# Configuration
|
||||
from .config import (
|
||||
APIConfig,
|
||||
FilePaths,
|
||||
PipelineParams,
|
||||
Persona,
|
||||
SystemPromptTemplate,
|
||||
PATHS,
|
||||
FULL_PARAMS,
|
||||
STAGE_CONFIGS,
|
||||
)
|
||||
|
||||
# Utilities
|
||||
from .utils import (
|
||||
load_jsonl,
|
||||
save_jsonl,
|
||||
parse_markdown_chunks,
|
||||
process_with_concurrency,
|
||||
calculate_overall_score,
|
||||
print_sample,
|
||||
print_statistics,
|
||||
)
|
||||
|
||||
# Agents (as submodule)
|
||||
from . import agents
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
# Set strict metadata requirement for all prompts globally
|
||||
textprompts.set_metadata('strict')
|
||||
textprompts.set_metadata("strict")
|
||||
|
||||
__all__ = [
|
||||
# Version
|
||||
"__version__",
|
||||
# Models
|
||||
"QAPair",
|
||||
"QAPairBatch",
|
||||
"QAValidation",
|
||||
"ValidatedQAPair",
|
||||
"Message",
|
||||
"Conversation",
|
||||
"ConversationMetadata",
|
||||
"JudgmentScore",
|
||||
"JudgedConversation",
|
||||
"EmbeddedConversation",
|
||||
"UniqueConversation",
|
||||
"NanoChatMessage",
|
||||
"NanoChatConversation",
|
||||
# Config classes
|
||||
"APIConfig",
|
||||
"FilePaths",
|
||||
"PipelineParams",
|
||||
"Persona",
|
||||
"SystemPromptTemplate",
|
||||
# Config instances
|
||||
"PATHS",
|
||||
"FULL_PARAMS",
|
||||
"STAGE_CONFIGS",
|
||||
# Utils
|
||||
"load_jsonl",
|
||||
"save_jsonl",
|
||||
"parse_markdown_chunks",
|
||||
"process_with_concurrency",
|
||||
"calculate_overall_score",
|
||||
"print_sample",
|
||||
"print_statistics",
|
||||
# Submodules
|
||||
"agents",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@ from src.synth_data_pipeline.config import APIConfig
|
|||
from src.synth_data_pipeline.models import JudgmentScore
|
||||
|
||||
PROMPT_NAME = "conversation_judge"
|
||||
SYSTEM_PROMPT = "You are an expert evaluator of training data quality for language models."
|
||||
SYSTEM_PROMPT = (
|
||||
"You are an expert evaluator of training data quality for language models."
|
||||
)
|
||||
|
||||
|
||||
def build_agent(config: APIConfig, *, api_key: str | None = None):
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ DIFFICULTY_LEVELS = ["basic", "intermediate", "advanced"]
|
|||
DIFFICULTY_DESCRIPTIONS = {
|
||||
"basic": "Simple factual questions requiring basic recall",
|
||||
"intermediate": "Questions requiring understanding and reasoning",
|
||||
"advanced": "Complex technical or multi-faceted questions"
|
||||
"advanced": "Complex technical or multi-faceted questions",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ CONVERSATION_STYLES = ["formal", "casual", "technical"]
|
|||
STYLE_DESCRIPTIONS = {
|
||||
"formal": "Professional language, complete sentences, no slang",
|
||||
"casual": "Friendly tone, can use contractions, conversational",
|
||||
"technical": "Uses technical terminology, assumes some expertise"
|
||||
"technical": "Uses technical terminology, assumes some expertise",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -39,9 +39,11 @@ STYLE_DESCRIPTIONS = {
|
|||
# User Personas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class Persona:
|
||||
"""Definition of a user persona."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
typical_questions: List[str] = field(default_factory=list)
|
||||
|
|
@ -56,9 +58,9 @@ PERSONAS = {
|
|||
"API integration details",
|
||||
"Technical specifications",
|
||||
"SDK usage",
|
||||
"Error handling"
|
||||
"Error handling",
|
||||
],
|
||||
formality="technical"
|
||||
formality="technical",
|
||||
),
|
||||
"product_manager": Persona(
|
||||
name="product_manager",
|
||||
|
|
@ -67,9 +69,9 @@ PERSONAS = {
|
|||
"Feature comparisons",
|
||||
"Roadmap questions",
|
||||
"Use cases",
|
||||
"ROI analysis"
|
||||
"ROI analysis",
|
||||
],
|
||||
formality="formal"
|
||||
formality="formal",
|
||||
),
|
||||
"cs_agent": Persona(
|
||||
name="cs_agent",
|
||||
|
|
@ -78,9 +80,9 @@ PERSONAS = {
|
|||
"Setup instructions",
|
||||
"Troubleshooting",
|
||||
"Configuration options",
|
||||
"Best practices"
|
||||
"Best practices",
|
||||
],
|
||||
formality="neutral"
|
||||
formality="neutral",
|
||||
),
|
||||
"executive": Persona(
|
||||
name="executive",
|
||||
|
|
@ -89,9 +91,9 @@ PERSONAS = {
|
|||
"Business value",
|
||||
"Competitive advantages",
|
||||
"Pricing strategy",
|
||||
"Scalability"
|
||||
"Scalability",
|
||||
],
|
||||
formality="formal"
|
||||
formality="formal",
|
||||
),
|
||||
"operations": Persona(
|
||||
name="operations",
|
||||
|
|
@ -100,9 +102,9 @@ PERSONAS = {
|
|||
"Integration capabilities",
|
||||
"Workflow automation",
|
||||
"Performance metrics",
|
||||
"SLA guarantees"
|
||||
"SLA guarantees",
|
||||
],
|
||||
formality="neutral"
|
||||
formality="neutral",
|
||||
),
|
||||
"finance": Persona(
|
||||
name="finance",
|
||||
|
|
@ -111,9 +113,9 @@ PERSONAS = {
|
|||
"Tax compliance",
|
||||
"Financial reporting",
|
||||
"Audit trails",
|
||||
"Cost structure"
|
||||
"Cost structure",
|
||||
],
|
||||
formality="formal"
|
||||
formality="formal",
|
||||
),
|
||||
}
|
||||
|
||||
|
|
@ -122,9 +124,11 @@ PERSONAS = {
|
|||
# System Prompt Templates
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemPromptTemplate:
|
||||
"""Definition of a system prompt template."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
template: str
|
||||
|
|
@ -138,35 +142,35 @@ SYSTEM_PROMPT_TEMPLATES = {
|
|||
description="Helpful and friendly assistant",
|
||||
template="You are a helpful AI assistant with expertise in SWAP Commerce's e-commerce platform and services. You provide accurate, friendly, and detailed answers to questions about SWAP Commerce's products, features, integrations, and pricing.",
|
||||
verbosity="detailed",
|
||||
use_case="general"
|
||||
use_case="general",
|
||||
),
|
||||
"concise": SystemPromptTemplate(
|
||||
name="concise",
|
||||
description="Brief and to-the-point responses",
|
||||
template="You are a SWAP Commerce expert providing clear, concise answers. Focus on key information without unnecessary detail.",
|
||||
verbosity="concise",
|
||||
use_case="quick_reference"
|
||||
use_case="quick_reference",
|
||||
),
|
||||
"technical": SystemPromptTemplate(
|
||||
name="technical",
|
||||
description="Technical expert for developers",
|
||||
template="You are a technical expert on SWAP Commerce's platform. You provide detailed technical information about APIs, integrations, implementation, and system architecture. You assume the user has technical knowledge.",
|
||||
verbosity="detailed",
|
||||
use_case="developer"
|
||||
use_case="developer",
|
||||
),
|
||||
"detailed": SystemPromptTemplate(
|
||||
name="detailed",
|
||||
description="Comprehensive explanations",
|
||||
template="You are a comprehensive SWAP Commerce expert who provides thorough, well-explained answers with examples, context, and relevant details. You ensure users fully understand the topic.",
|
||||
verbosity="detailed",
|
||||
use_case="learning"
|
||||
use_case="learning",
|
||||
),
|
||||
"sales": SystemPromptTemplate(
|
||||
name="sales",
|
||||
description="Sales and solutions focused",
|
||||
template="You are a SWAP Commerce solutions consultant helping potential customers understand how SWAP Commerce can solve their e-commerce challenges. You're knowledgeable about features, benefits, and competitive advantages.",
|
||||
verbosity="balanced",
|
||||
use_case="sales"
|
||||
use_case="sales",
|
||||
),
|
||||
}
|
||||
|
||||
|
|
@ -211,7 +215,7 @@ DEFAULT_TURN_DISTRIBUTION = {
|
|||
TURN_DISTRIBUTIONS = {
|
||||
"default": DEFAULT_TURN_DISTRIBUTION,
|
||||
"short": {1: 0.6, 2: 0.3, 3: 0.1}, # Mostly short conversations
|
||||
"long": {2: 0.2, 3: 0.4, 4: 0.4}, # Longer conversations
|
||||
"long": {2: 0.2, 3: 0.4, 4: 0.4}, # Longer conversations
|
||||
"balanced": {1: 0.25, 2: 0.25, 3: 0.25, 4: 0.25}, # Equal distribution
|
||||
}
|
||||
|
||||
|
|
@ -224,10 +228,10 @@ USER_EMOTIONS = ["professional", "happy", "frustrated", "impatient", "confused"]
|
|||
|
||||
USER_EMOTION_DISTRIBUTION = {
|
||||
"professional": 0.50, # Most common - neutral, business-like
|
||||
"happy": 0.15, # Positive, enthusiastic
|
||||
"frustrated": 0.15, # Having issues, needs help
|
||||
"impatient": 0.10, # Wants quick answers
|
||||
"confused": 0.10, # Unclear about something
|
||||
"happy": 0.15, # Positive, enthusiastic
|
||||
"frustrated": 0.15, # Having issues, needs help
|
||||
"impatient": 0.10, # Wants quick answers
|
||||
"confused": 0.10, # Unclear about something
|
||||
}
|
||||
|
||||
EMOTION_DESCRIPTIONS = {
|
||||
|
|
@ -235,7 +239,7 @@ EMOTION_DESCRIPTIONS = {
|
|||
"happy": "Positive, enthusiastic. May express excitement about features or capabilities.",
|
||||
"frustrated": "Experiencing issues or challenges. May express mild annoyance or urgency.",
|
||||
"impatient": "Wants quick, direct answers. Brief messages, may skip pleasantries.",
|
||||
"confused": "Unclear about concepts or features. May ask for clarification or examples."
|
||||
"confused": "Unclear about concepts or features. May ask for clarification or examples.",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -246,15 +250,15 @@ EMOTION_DESCRIPTIONS = {
|
|||
INPUT_MODALITIES = ["standard", "typed_on_phone", "voice_dictated"]
|
||||
|
||||
INPUT_MODALITY_DISTRIBUTION = {
|
||||
"standard": 0.70, # Normal typing on computer
|
||||
"typed_on_phone": 0.20, # Mobile typing - autocorrect errors, brevity
|
||||
"voice_dictated": 0.10, # Voice-to-text - filler words, natural speech
|
||||
"standard": 0.70, # Normal typing on computer
|
||||
"typed_on_phone": 0.20, # Mobile typing - autocorrect errors, brevity
|
||||
"voice_dictated": 0.10, # Voice-to-text - filler words, natural speech
|
||||
}
|
||||
|
||||
MODALITY_DESCRIPTIONS = {
|
||||
"standard": "Standard computer typing. Clean text, proper formatting.",
|
||||
"typed_on_phone": "Mobile device typing. May have autocorrect errors, abbreviations, shorter messages.",
|
||||
"voice_dictated": "Voice-to-text transcription. May include 'um', 'uh', natural speech patterns, occasional transcription errors."
|
||||
"voice_dictated": "Voice-to-text transcription. May include 'um', 'uh', natural speech patterns, occasional transcription errors.",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -265,15 +269,15 @@ MODALITY_DESCRIPTIONS = {
|
|||
TEXT_VARIATIONS = ["standard", "all_lowercase", "no_punctuation"]
|
||||
|
||||
TEXT_VARIATION_DISTRIBUTION = {
|
||||
"standard": 0.80, # Normal capitalization and punctuation
|
||||
"all_lowercase": 0.15, # all lowercase (casual/mobile)
|
||||
"no_punctuation": 0.05, # missing punctuation (rushed/mobile)
|
||||
"standard": 0.80, # Normal capitalization and punctuation
|
||||
"all_lowercase": 0.15, # all lowercase (casual/mobile)
|
||||
"no_punctuation": 0.05, # missing punctuation (rushed/mobile)
|
||||
}
|
||||
|
||||
VARIATION_DESCRIPTIONS = {
|
||||
"standard": "Standard capitalization and punctuation.",
|
||||
"all_lowercase": "All lowercase letters (common in casual or mobile communication).",
|
||||
"no_punctuation": "Missing or minimal punctuation (rushed typing or informal style)."
|
||||
"no_punctuation": "Missing or minimal punctuation (rushed typing or informal style).",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -293,9 +297,11 @@ QUALITY_WEIGHTS = {
|
|||
# API Configuration
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIConfig:
|
||||
"""Configuration for API calls."""
|
||||
|
||||
model: str = "gemini-2.5-flash-lite"
|
||||
max_concurrent: int = 10
|
||||
temperature: float = 0.9 # Higher for generation, lower for judging
|
||||
|
|
@ -337,9 +343,11 @@ STAGE_CONFIGS = {
|
|||
# File Paths
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilePaths:
|
||||
"""Standard file paths for the pipeline."""
|
||||
|
||||
data_dir: str = "data"
|
||||
prompts_dir: str = "prompts"
|
||||
output_dir: str = "output"
|
||||
|
|
@ -374,6 +382,7 @@ PATHS = FilePaths()
|
|||
# Pipeline Parameters
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineParams:
|
||||
"""Parameters for the full pipeline run."""
|
||||
|
|
@ -389,10 +398,18 @@ class PipelineParams:
|
|||
# Stage 3: Conversation Generation
|
||||
num_conversations: int = 2000
|
||||
conversations_per_qa: int = 10
|
||||
turn_distribution: Dict[int, float] = field(default_factory=lambda: DEFAULT_TURN_DISTRIBUTION)
|
||||
emotion_distribution: Dict[str, float] = field(default_factory=lambda: USER_EMOTION_DISTRIBUTION)
|
||||
modality_distribution: Dict[str, float] = field(default_factory=lambda: INPUT_MODALITY_DISTRIBUTION)
|
||||
variation_distribution: Dict[str, float] = field(default_factory=lambda: TEXT_VARIATION_DISTRIBUTION)
|
||||
turn_distribution: Dict[int, float] = field(
|
||||
default_factory=lambda: DEFAULT_TURN_DISTRIBUTION
|
||||
)
|
||||
emotion_distribution: Dict[str, float] = field(
|
||||
default_factory=lambda: USER_EMOTION_DISTRIBUTION
|
||||
)
|
||||
modality_distribution: Dict[str, float] = field(
|
||||
default_factory=lambda: INPUT_MODALITY_DISTRIBUTION
|
||||
)
|
||||
variation_distribution: Dict[str, float] = field(
|
||||
default_factory=lambda: TEXT_VARIATION_DISTRIBUTION
|
||||
)
|
||||
|
||||
# Stage 4: Judging
|
||||
min_quality_score: float = 5.0 # Minimum acceptable score
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ async def batch_embed(
|
|||
model: str = "text-embedding-3-small",
|
||||
dimensions: int = 1024,
|
||||
batch_size: int = 100,
|
||||
max_concurrent: int = 20
|
||||
max_concurrent: int = 20,
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Generate embeddings for a list of texts using OpenAI API.
|
||||
|
|
@ -39,7 +39,7 @@ async def batch_embed(
|
|||
logfire.info(f"Embedding {len(texts)} texts in batches of {batch_size}")
|
||||
|
||||
# Split into batches
|
||||
batches = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]
|
||||
batches = [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]
|
||||
|
||||
# Semaphore for rate limiting
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
|
@ -49,9 +49,7 @@ async def batch_embed(
|
|||
async with semaphore:
|
||||
try:
|
||||
response = await client.embeddings.create(
|
||||
model=model,
|
||||
input=batch,
|
||||
dimensions=dimensions
|
||||
model=model, input=batch, dimensions=dimensions
|
||||
)
|
||||
return [item.embedding for item in response.data]
|
||||
except Exception as e:
|
||||
|
|
@ -112,7 +110,7 @@ def compute_similarity(emb1: np.ndarray, emb2: np.ndarray) -> float:
|
|||
def greedy_deduplicate(
|
||||
embeddings: List[np.ndarray],
|
||||
scores: List[float],
|
||||
similarity_threshold: float = 0.95
|
||||
similarity_threshold: float = 0.95,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Greedy deduplication: keep highest-scoring items, remove similar duplicates.
|
||||
|
|
@ -170,8 +168,8 @@ def conversation_to_text(messages: List[dict], max_chars: int = 24000) -> str:
|
|||
"""
|
||||
parts = []
|
||||
for msg in messages:
|
||||
role = msg.get('role', 'unknown').upper()
|
||||
content = msg.get('content', '')
|
||||
role = msg.get("role", "unknown").upper()
|
||||
content = msg.get("content", "")
|
||||
parts.append(f"{role}: {content}")
|
||||
|
||||
full_text = "\n\n".join(parts)
|
||||
|
|
|
|||
|
|
@ -17,26 +17,18 @@ class QAPair(BaseModel):
|
|||
question: str = Field(
|
||||
description="A natural question that could be asked about this topic"
|
||||
)
|
||||
answer: str = Field(
|
||||
description="The accurate answer grounded in the source text"
|
||||
)
|
||||
answer: str = Field(description="The accurate answer grounded in the source text")
|
||||
source_text: str = Field(
|
||||
description="The specific text chunk this Q&A was generated from"
|
||||
)
|
||||
context_before: str = Field(
|
||||
default="",
|
||||
description="Preceding lines for context"
|
||||
)
|
||||
context_after: str = Field(
|
||||
default="",
|
||||
description="Following lines for context"
|
||||
)
|
||||
context_before: str = Field(default="", description="Preceding lines for context")
|
||||
context_after: str = Field(default="", description="Following lines for context")
|
||||
difficulty: Literal["basic", "intermediate", "advanced"] = Field(
|
||||
description="The difficulty level of this question"
|
||||
)
|
||||
categories: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Topic categories (e.g., 'pricing', 'features', 'integrations')"
|
||||
description="Topic categories (e.g., 'pricing', 'features', 'integrations')",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -65,23 +57,15 @@ class QAValidation(BaseModel):
|
|||
sensible_answer: bool = Field(
|
||||
description="Is the answer appropriate and sensible for the question?"
|
||||
)
|
||||
passed: bool = Field(
|
||||
description="Overall pass (all three bools must be True)"
|
||||
)
|
||||
feedback: str = Field(
|
||||
description="Brief explanation of validation result"
|
||||
)
|
||||
passed: bool = Field(description="Overall pass (all three bools must be True)")
|
||||
feedback: str = Field(description="Brief explanation of validation result")
|
||||
|
||||
|
||||
class ValidatedQAPair(BaseModel):
|
||||
"""A Q&A pair with its validation result."""
|
||||
|
||||
qa_pair: QAPair = Field(
|
||||
description="The Q&A pair being validated"
|
||||
)
|
||||
validation: QAValidation = Field(
|
||||
description="The validation result"
|
||||
)
|
||||
qa_pair: QAPair = Field(description="The Q&A pair being validated")
|
||||
validation: QAValidation = Field(description="The validation result")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
|
@ -95,9 +79,7 @@ class Message(BaseModel):
|
|||
role: Literal["system", "user", "assistant"] = Field(
|
||||
description="The role of the message sender"
|
||||
)
|
||||
content: str = Field(
|
||||
description="The message content"
|
||||
)
|
||||
content: str = Field(description="The message content")
|
||||
|
||||
|
||||
class ConversationMetadata(BaseModel):
|
||||
|
|
@ -112,28 +94,24 @@ class ConversationMetadata(BaseModel):
|
|||
user_persona: str = Field(
|
||||
description="The persona/role of the user (e.g., 'developer', 'business owner')"
|
||||
)
|
||||
user_emotion: Literal["professional", "happy", "frustrated", "impatient", "confused"] = Field(
|
||||
default="professional",
|
||||
description="The emotional state of the user"
|
||||
)
|
||||
user_emotion: Literal[
|
||||
"professional", "happy", "frustrated", "impatient", "confused"
|
||||
] = Field(default="professional", description="The emotional state of the user")
|
||||
input_modality: Literal["standard", "typed_on_phone", "voice_dictated"] = Field(
|
||||
default="standard",
|
||||
description="How the user is inputting their messages"
|
||||
default="standard", description="How the user is inputting their messages"
|
||||
)
|
||||
text_variation: Literal["standard", "all_lowercase", "no_punctuation"] = Field(
|
||||
default="standard",
|
||||
description="Text formatting variation applied to user messages"
|
||||
description="Text formatting variation applied to user messages",
|
||||
)
|
||||
source_qa_ids: list[int] = Field(
|
||||
default_factory=list,
|
||||
description="Indices of Q&A pairs used to generate this conversation"
|
||||
)
|
||||
difficulty: str = Field(
|
||||
description="Overall difficulty level"
|
||||
description="Indices of Q&A pairs used to generate this conversation",
|
||||
)
|
||||
difficulty: str = Field(description="Overall difficulty level")
|
||||
categories: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Topic categories covered in this conversation"
|
||||
description="Topic categories covered in this conversation",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -148,7 +126,7 @@ class Conversation(BaseModel):
|
|||
)
|
||||
source_qa_pairs: list[QAPair] = Field(
|
||||
default_factory=list,
|
||||
description="The Q&A pairs used to generate this conversation (for fact-checking)"
|
||||
description="The Q&A pairs used to generate this conversation (for fact-checking)",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -175,24 +153,17 @@ class JudgmentScore(BaseModel):
|
|||
overall_pass: bool = Field(
|
||||
description="TRUE only if ALL four criteria above are TRUE"
|
||||
)
|
||||
feedback: str = Field(
|
||||
description="Brief explanation of judgment (1-2 sentences)"
|
||||
)
|
||||
feedback: str = Field(description="Brief explanation of judgment (1-2 sentences)")
|
||||
issues: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Specific problems found (if any)"
|
||||
default_factory=list, description="Specific problems found (if any)"
|
||||
)
|
||||
|
||||
|
||||
class JudgedConversation(BaseModel):
|
||||
"""A conversation with its quality judgment."""
|
||||
|
||||
conversation: Conversation = Field(
|
||||
description="The conversation being judged"
|
||||
)
|
||||
judgment: JudgmentScore = Field(
|
||||
description="The quality judgment scores"
|
||||
)
|
||||
conversation: Conversation = Field(description="The conversation being judged")
|
||||
judgment: JudgmentScore = Field(description="The quality judgment scores")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
|
@ -203,18 +174,12 @@ class JudgedConversation(BaseModel):
|
|||
class EmbeddedConversation(BaseModel):
|
||||
"""A judged conversation with its embedding."""
|
||||
|
||||
conversation: Conversation = Field(
|
||||
description="The conversation"
|
||||
)
|
||||
judgment: JudgmentScore = Field(
|
||||
description="The quality judgment"
|
||||
)
|
||||
conversation: Conversation = Field(description="The conversation")
|
||||
judgment: JudgmentScore = Field(description="The quality judgment")
|
||||
embedding: list[float] = Field(
|
||||
description="Conversation embedding (1024 dimensions)"
|
||||
)
|
||||
text_preview: str = Field(
|
||||
description="First 200 characters for debugging"
|
||||
)
|
||||
text_preview: str = Field(description="First 200 characters for debugging")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
|
@ -225,12 +190,8 @@ class EmbeddedConversation(BaseModel):
|
|||
class UniqueConversation(BaseModel):
|
||||
"""A conversation marked as unique after deduplication."""
|
||||
|
||||
conversation: Conversation = Field(
|
||||
description="The conversation"
|
||||
)
|
||||
judgment: JudgmentScore = Field(
|
||||
description="The quality judgment"
|
||||
)
|
||||
conversation: Conversation = Field(description="The conversation")
|
||||
judgment: JudgmentScore = Field(description="The quality judgment")
|
||||
# Note: embedding removed to save space after dedup
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,11 +11,8 @@ from .config import (
|
|||
SYSTEM_PROMPT_TEMPLATES,
|
||||
CONVERSATION_STYLES,
|
||||
DEFAULT_TURN_DISTRIBUTION,
|
||||
USER_EMOTIONS,
|
||||
USER_EMOTION_DISTRIBUTION,
|
||||
INPUT_MODALITIES,
|
||||
INPUT_MODALITY_DISTRIBUTION,
|
||||
TEXT_VARIATIONS,
|
||||
TEXT_VARIATION_DISTRIBUTION,
|
||||
Persona,
|
||||
SystemPromptTemplate,
|
||||
|
|
@ -137,10 +134,7 @@ def sample_system_prompt_by_use_case(use_case: str) -> SystemPromptTemplate:
|
|||
Returns:
|
||||
A SystemPromptTemplate with matching use case
|
||||
"""
|
||||
matching = [
|
||||
s for s in SYSTEM_PROMPT_TEMPLATES.values()
|
||||
if s.use_case == use_case
|
||||
]
|
||||
matching = [s for s in SYSTEM_PROMPT_TEMPLATES.values() if s.use_case == use_case]
|
||||
return random.choice(matching) if matching else sample_system_prompt()
|
||||
|
||||
|
||||
|
|
@ -172,7 +166,9 @@ def sample_balanced_config(
|
|||
|
||||
# Sample persona
|
||||
if prefer_technical:
|
||||
persona = PERSONAS.get("developer") if random.random() < 0.5 else sample_persona()
|
||||
persona = (
|
||||
PERSONAS.get("developer") if random.random() < 0.5 else sample_persona()
|
||||
)
|
||||
else:
|
||||
persona = sample_persona()
|
||||
|
||||
|
|
@ -193,7 +189,9 @@ def sample_balanced_config(
|
|||
}
|
||||
|
||||
|
||||
def load_system_prompts_from_files(prompts_dir: str = "src/synth_data_pipeline/prompts/system_prompts") -> Dict[str, str]:
|
||||
def load_system_prompts_from_files(
|
||||
prompts_dir: str = "src/synth_data_pipeline/prompts/system_prompts",
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Load system prompt templates from text files.
|
||||
|
||||
|
|
@ -208,13 +206,15 @@ def load_system_prompts_from_files(prompts_dir: str = "src/synth_data_pipeline/p
|
|||
|
||||
for prompt_file in prompts_path.glob("*.txt"):
|
||||
name = prompt_file.stem
|
||||
with open(prompt_file, 'r', encoding='utf-8') as f:
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
system_prompts[name] = f.read().strip()
|
||||
|
||||
return system_prompts
|
||||
|
||||
|
||||
def load_personas_from_files(personas_dir: str = "src/synth_data_pipeline/prompts/personas") -> Dict[str, str]:
|
||||
def load_personas_from_files(
|
||||
personas_dir: str = "src/synth_data_pipeline/prompts/personas",
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Load persona descriptions from text files.
|
||||
|
||||
|
|
@ -229,7 +229,7 @@ def load_personas_from_files(personas_dir: str = "src/synth_data_pipeline/prompt
|
|||
|
||||
for persona_file in personas_path.glob("*.txt"):
|
||||
name = persona_file.stem
|
||||
with open(persona_file, 'r', encoding='utf-8') as f:
|
||||
with open(persona_file, "r", encoding="utf-8") as f:
|
||||
personas[name] = f.read().strip()
|
||||
|
||||
return personas
|
||||
|
|
@ -272,10 +272,8 @@ def sample_multiple_configs(
|
|||
# Sampling Strategies
|
||||
# ============================================================================
|
||||
|
||||
def stratified_sample_configs(
|
||||
n: int,
|
||||
ensure_coverage: bool = True
|
||||
) -> List[Dict]:
|
||||
|
||||
def stratified_sample_configs(n: int, ensure_coverage: bool = True) -> List[Dict]:
|
||||
"""
|
||||
Sample configurations with stratified sampling to ensure diversity.
|
||||
|
||||
|
|
@ -292,15 +290,17 @@ def stratified_sample_configs(
|
|||
# First, ensure we have at least one of each persona-style combination
|
||||
for persona in PERSONAS.values():
|
||||
for style in CONVERSATION_STYLES:
|
||||
configs.append({
|
||||
"num_turns": sample_num_turns(),
|
||||
"style": style,
|
||||
"persona": persona,
|
||||
"system_prompt": sample_system_prompt(),
|
||||
"user_emotion": sample_emotion(),
|
||||
"input_modality": sample_input_modality(),
|
||||
"text_variation": sample_text_variation(),
|
||||
})
|
||||
configs.append(
|
||||
{
|
||||
"num_turns": sample_num_turns(),
|
||||
"style": style,
|
||||
"persona": persona,
|
||||
"system_prompt": sample_system_prompt(),
|
||||
"user_emotion": sample_emotion(),
|
||||
"input_modality": sample_input_modality(),
|
||||
"text_variation": sample_text_variation(),
|
||||
}
|
||||
)
|
||||
|
||||
# Fill remaining with random samples
|
||||
remaining = n - len(configs)
|
||||
|
|
|
|||
|
|
@ -10,15 +10,15 @@ from typing import List, TypeVar, Callable, Awaitable
|
|||
import logfire
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
T = TypeVar('T')
|
||||
R = TypeVar('R')
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
async def process_with_concurrency(
|
||||
items: List[T],
|
||||
process_fn: Callable[[T], Awaitable[R]],
|
||||
max_concurrent: int = 10,
|
||||
desc: str = "Processing"
|
||||
desc: str = "Processing",
|
||||
) -> List[R]:
|
||||
"""
|
||||
Process items concurrently with a semaphore to limit concurrency.
|
||||
|
|
@ -61,12 +61,12 @@ def save_jsonl(items: List, output_path: str | Path):
|
|||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for item in items:
|
||||
if hasattr(item, 'model_dump_json'):
|
||||
f.write(item.model_dump_json() + '\n')
|
||||
if hasattr(item, "model_dump_json"):
|
||||
f.write(item.model_dump_json() + "\n")
|
||||
else:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
f.write(json.dumps(item) + "\n")
|
||||
|
||||
logfire.info(f"Saved {len(items)} items to {output_path}")
|
||||
|
||||
|
|
@ -83,7 +83,7 @@ def load_jsonl(file_path: str | Path, model_class=None) -> List:
|
|||
List of items
|
||||
"""
|
||||
items = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if model_class:
|
||||
items.append(model_class.model_validate_json(line))
|
||||
|
|
@ -92,10 +92,7 @@ def load_jsonl(file_path: str | Path, model_class=None) -> List:
|
|||
return items
|
||||
|
||||
|
||||
def parse_markdown_chunks(
|
||||
file_path: str | Path,
|
||||
context_lines: int = 3
|
||||
) -> List[dict]:
|
||||
def parse_markdown_chunks(file_path: str | Path, context_lines: int = 3) -> List[dict]:
|
||||
"""
|
||||
Parse markdown file and create chunks with context.
|
||||
|
||||
|
|
@ -106,7 +103,7 @@ def parse_markdown_chunks(
|
|||
Returns:
|
||||
List of dicts with 'source_text', 'context_before', 'context_after'
|
||||
"""
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
chunks = []
|
||||
|
|
@ -116,28 +113,30 @@ def parse_markdown_chunks(
|
|||
line = lines[i].strip()
|
||||
|
||||
# Skip empty lines and metadata
|
||||
if not line or line.startswith('---') or line.startswith('**As of:**'):
|
||||
if not line or line.startswith("---") or line.startswith("**As of:**"):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Process bullets and significant lines
|
||||
if line.startswith('*') or line.startswith('#') or len(line) > 50:
|
||||
if line.startswith("*") or line.startswith("#") or len(line) > 50:
|
||||
# Get context before
|
||||
context_start = max(0, i - context_lines)
|
||||
context_before = ''.join(lines[context_start:i]).strip()
|
||||
context_before = "".join(lines[context_start:i]).strip()
|
||||
|
||||
# Get the main text (current line)
|
||||
source_text = line
|
||||
|
||||
# Get context after
|
||||
context_end = min(len(lines), i + context_lines + 1)
|
||||
context_after = ''.join(lines[i + 1:context_end]).strip()
|
||||
context_after = "".join(lines[i + 1 : context_end]).strip()
|
||||
|
||||
chunks.append({
|
||||
'source_text': source_text,
|
||||
'context_before': context_before,
|
||||
'context_after': context_after,
|
||||
})
|
||||
chunks.append(
|
||||
{
|
||||
"source_text": source_text,
|
||||
"context_before": context_before,
|
||||
"context_after": context_after,
|
||||
}
|
||||
)
|
||||
|
||||
i += 1
|
||||
|
||||
|
|
@ -149,7 +148,7 @@ def calculate_overall_score(
|
|||
naturalness: float,
|
||||
relevance: float,
|
||||
diversity: float,
|
||||
weights: dict = None
|
||||
weights: dict = None,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate overall quality score from individual metrics.
|
||||
|
|
@ -166,13 +165,14 @@ def calculate_overall_score(
|
|||
"""
|
||||
if weights is None:
|
||||
from .config import QUALITY_WEIGHTS
|
||||
|
||||
weights = QUALITY_WEIGHTS
|
||||
|
||||
overall = (
|
||||
factual_accuracy * weights.get("factual_accuracy", 0.35) +
|
||||
naturalness * weights.get("naturalness", 0.25) +
|
||||
relevance * weights.get("relevance", 0.25) +
|
||||
diversity * weights.get("diversity", 0.15)
|
||||
factual_accuracy * weights.get("factual_accuracy", 0.35)
|
||||
+ naturalness * weights.get("naturalness", 0.25)
|
||||
+ relevance * weights.get("relevance", 0.25)
|
||||
+ diversity * weights.get("diversity", 0.15)
|
||||
)
|
||||
|
||||
return round(overall, 2)
|
||||
|
|
@ -186,11 +186,11 @@ def print_sample(item, title: str = "SAMPLE"):
|
|||
item: Item to print (conversation, Q&A, etc.)
|
||||
title: Title for the sample section
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("\n" + "=" * 80)
|
||||
print(title)
|
||||
print("="*80)
|
||||
print("=" * 80)
|
||||
|
||||
if hasattr(item, 'model_dump'):
|
||||
if hasattr(item, "model_dump"):
|
||||
# Pydantic model
|
||||
print(json.dumps(item.model_dump(), indent=2))
|
||||
elif isinstance(item, dict):
|
||||
|
|
@ -198,7 +198,7 @@ def print_sample(item, title: str = "SAMPLE"):
|
|||
else:
|
||||
print(item)
|
||||
|
||||
print("="*80 + "\n")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
|
||||
def print_statistics(scores: List[float], metric_name: str = "Score"):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user