formatting

This commit is contained in:
svilupp 2025-10-26 19:28:04 +00:00
parent 6bfc1f8f53
commit 7eac69487b
7 changed files with 229 additions and 169 deletions

View File

@ -2,7 +2,89 @@
import textprompts
# Core models
from .models import (
QAPair,
QAPairBatch,
QAValidation,
ValidatedQAPair,
Message,
Conversation,
ConversationMetadata,
JudgmentScore,
JudgedConversation,
EmbeddedConversation,
UniqueConversation,
NanoChatMessage,
NanoChatConversation,
)
# Configuration
from .config import (
APIConfig,
FilePaths,
PipelineParams,
Persona,
SystemPromptTemplate,
PATHS,
FULL_PARAMS,
STAGE_CONFIGS,
)
# Utilities
from .utils import (
load_jsonl,
save_jsonl,
parse_markdown_chunks,
process_with_concurrency,
calculate_overall_score,
print_sample,
print_statistics,
)
# Agents (as submodule)
from . import agents
__version__ = "0.1.0"
# Set strict metadata requirement for all prompts globally
textprompts.set_metadata('strict')
textprompts.set_metadata("strict")
__all__ = [
# Version
"__version__",
# Models
"QAPair",
"QAPairBatch",
"QAValidation",
"ValidatedQAPair",
"Message",
"Conversation",
"ConversationMetadata",
"JudgmentScore",
"JudgedConversation",
"EmbeddedConversation",
"UniqueConversation",
"NanoChatMessage",
"NanoChatConversation",
# Config classes
"APIConfig",
"FilePaths",
"PipelineParams",
"Persona",
"SystemPromptTemplate",
# Config instances
"PATHS",
"FULL_PARAMS",
"STAGE_CONFIGS",
# Utils
"load_jsonl",
"save_jsonl",
"parse_markdown_chunks",
"process_with_concurrency",
"calculate_overall_score",
"print_sample",
"print_statistics",
# Submodules
"agents",
]

View File

@ -5,7 +5,9 @@ from src.synth_data_pipeline.config import APIConfig
from src.synth_data_pipeline.models import JudgmentScore
PROMPT_NAME = "conversation_judge"
SYSTEM_PROMPT = "You are an expert evaluator of training data quality for language models."
SYSTEM_PROMPT = (
"You are an expert evaluator of training data quality for language models."
)
def build_agent(config: APIConfig, *, api_key: str | None = None):

View File

@ -18,7 +18,7 @@ DIFFICULTY_LEVELS = ["basic", "intermediate", "advanced"]
DIFFICULTY_DESCRIPTIONS = {
"basic": "Simple factual questions requiring basic recall",
"intermediate": "Questions requiring understanding and reasoning",
"advanced": "Complex technical or multi-faceted questions"
"advanced": "Complex technical or multi-faceted questions",
}
@ -31,7 +31,7 @@ CONVERSATION_STYLES = ["formal", "casual", "technical"]
STYLE_DESCRIPTIONS = {
"formal": "Professional language, complete sentences, no slang",
"casual": "Friendly tone, can use contractions, conversational",
"technical": "Uses technical terminology, assumes some expertise"
"technical": "Uses technical terminology, assumes some expertise",
}
@ -39,9 +39,11 @@ STYLE_DESCRIPTIONS = {
# User Personas
# ============================================================================
@dataclass
class Persona:
"""Definition of a user persona."""
name: str
description: str
typical_questions: List[str] = field(default_factory=list)
@ -56,9 +58,9 @@ PERSONAS = {
"API integration details",
"Technical specifications",
"SDK usage",
"Error handling"
"Error handling",
],
formality="technical"
formality="technical",
),
"product_manager": Persona(
name="product_manager",
@ -67,9 +69,9 @@ PERSONAS = {
"Feature comparisons",
"Roadmap questions",
"Use cases",
"ROI analysis"
"ROI analysis",
],
formality="formal"
formality="formal",
),
"cs_agent": Persona(
name="cs_agent",
@ -78,9 +80,9 @@ PERSONAS = {
"Setup instructions",
"Troubleshooting",
"Configuration options",
"Best practices"
"Best practices",
],
formality="neutral"
formality="neutral",
),
"executive": Persona(
name="executive",
@ -89,9 +91,9 @@ PERSONAS = {
"Business value",
"Competitive advantages",
"Pricing strategy",
"Scalability"
"Scalability",
],
formality="formal"
formality="formal",
),
"operations": Persona(
name="operations",
@ -100,9 +102,9 @@ PERSONAS = {
"Integration capabilities",
"Workflow automation",
"Performance metrics",
"SLA guarantees"
"SLA guarantees",
],
formality="neutral"
formality="neutral",
),
"finance": Persona(
name="finance",
@ -111,9 +113,9 @@ PERSONAS = {
"Tax compliance",
"Financial reporting",
"Audit trails",
"Cost structure"
"Cost structure",
],
formality="formal"
formality="formal",
),
}
@ -122,9 +124,11 @@ PERSONAS = {
# System Prompt Templates
# ============================================================================
@dataclass
class SystemPromptTemplate:
"""Definition of a system prompt template."""
name: str
description: str
template: str
@ -138,35 +142,35 @@ SYSTEM_PROMPT_TEMPLATES = {
description="Helpful and friendly assistant",
template="You are a helpful AI assistant with expertise in SWAP Commerce's e-commerce platform and services. You provide accurate, friendly, and detailed answers to questions about SWAP Commerce's products, features, integrations, and pricing.",
verbosity="detailed",
use_case="general"
use_case="general",
),
"concise": SystemPromptTemplate(
name="concise",
description="Brief and to-the-point responses",
template="You are a SWAP Commerce expert providing clear, concise answers. Focus on key information without unnecessary detail.",
verbosity="concise",
use_case="quick_reference"
use_case="quick_reference",
),
"technical": SystemPromptTemplate(
name="technical",
description="Technical expert for developers",
template="You are a technical expert on SWAP Commerce's platform. You provide detailed technical information about APIs, integrations, implementation, and system architecture. You assume the user has technical knowledge.",
verbosity="detailed",
use_case="developer"
use_case="developer",
),
"detailed": SystemPromptTemplate(
name="detailed",
description="Comprehensive explanations",
template="You are a comprehensive SWAP Commerce expert who provides thorough, well-explained answers with examples, context, and relevant details. You ensure users fully understand the topic.",
verbosity="detailed",
use_case="learning"
use_case="learning",
),
"sales": SystemPromptTemplate(
name="sales",
description="Sales and solutions focused",
template="You are a SWAP Commerce solutions consultant helping potential customers understand how SWAP Commerce can solve their e-commerce challenges. You're knowledgeable about features, benefits, and competitive advantages.",
verbosity="balanced",
use_case="sales"
use_case="sales",
),
}
@ -211,7 +215,7 @@ DEFAULT_TURN_DISTRIBUTION = {
TURN_DISTRIBUTIONS = {
"default": DEFAULT_TURN_DISTRIBUTION,
"short": {1: 0.6, 2: 0.3, 3: 0.1}, # Mostly short conversations
"long": {2: 0.2, 3: 0.4, 4: 0.4}, # Longer conversations
"long": {2: 0.2, 3: 0.4, 4: 0.4}, # Longer conversations
"balanced": {1: 0.25, 2: 0.25, 3: 0.25, 4: 0.25}, # Equal distribution
}
@ -224,10 +228,10 @@ USER_EMOTIONS = ["professional", "happy", "frustrated", "impatient", "confused"]
USER_EMOTION_DISTRIBUTION = {
"professional": 0.50, # Most common - neutral, business-like
"happy": 0.15, # Positive, enthusiastic
"frustrated": 0.15, # Having issues, needs help
"impatient": 0.10, # Wants quick answers
"confused": 0.10, # Unclear about something
"happy": 0.15, # Positive, enthusiastic
"frustrated": 0.15, # Having issues, needs help
"impatient": 0.10, # Wants quick answers
"confused": 0.10, # Unclear about something
}
EMOTION_DESCRIPTIONS = {
@ -235,7 +239,7 @@ EMOTION_DESCRIPTIONS = {
"happy": "Positive, enthusiastic. May express excitement about features or capabilities.",
"frustrated": "Experiencing issues or challenges. May express mild annoyance or urgency.",
"impatient": "Wants quick, direct answers. Brief messages, may skip pleasantries.",
"confused": "Unclear about concepts or features. May ask for clarification or examples."
"confused": "Unclear about concepts or features. May ask for clarification or examples.",
}
@ -246,15 +250,15 @@ EMOTION_DESCRIPTIONS = {
INPUT_MODALITIES = ["standard", "typed_on_phone", "voice_dictated"]
INPUT_MODALITY_DISTRIBUTION = {
"standard": 0.70, # Normal typing on computer
"typed_on_phone": 0.20, # Mobile typing - autocorrect errors, brevity
"voice_dictated": 0.10, # Voice-to-text - filler words, natural speech
"standard": 0.70, # Normal typing on computer
"typed_on_phone": 0.20, # Mobile typing - autocorrect errors, brevity
"voice_dictated": 0.10, # Voice-to-text - filler words, natural speech
}
MODALITY_DESCRIPTIONS = {
"standard": "Standard computer typing. Clean text, proper formatting.",
"typed_on_phone": "Mobile device typing. May have autocorrect errors, abbreviations, shorter messages.",
"voice_dictated": "Voice-to-text transcription. May include 'um', 'uh', natural speech patterns, occasional transcription errors."
"voice_dictated": "Voice-to-text transcription. May include 'um', 'uh', natural speech patterns, occasional transcription errors.",
}
@ -265,15 +269,15 @@ MODALITY_DESCRIPTIONS = {
TEXT_VARIATIONS = ["standard", "all_lowercase", "no_punctuation"]
TEXT_VARIATION_DISTRIBUTION = {
"standard": 0.80, # Normal capitalization and punctuation
"all_lowercase": 0.15, # all lowercase (casual/mobile)
"no_punctuation": 0.05, # missing punctuation (rushed/mobile)
"standard": 0.80, # Normal capitalization and punctuation
"all_lowercase": 0.15, # all lowercase (casual/mobile)
"no_punctuation": 0.05, # missing punctuation (rushed/mobile)
}
VARIATION_DESCRIPTIONS = {
"standard": "Standard capitalization and punctuation.",
"all_lowercase": "All lowercase letters (common in casual or mobile communication).",
"no_punctuation": "Missing or minimal punctuation (rushed typing or informal style)."
"no_punctuation": "Missing or minimal punctuation (rushed typing or informal style).",
}
@ -293,9 +297,11 @@ QUALITY_WEIGHTS = {
# API Configuration
# ============================================================================
@dataclass
class APIConfig:
"""Configuration for API calls."""
model: str = "gemini-2.5-flash-lite"
max_concurrent: int = 10
temperature: float = 0.9 # Higher for generation, lower for judging
@ -337,9 +343,11 @@ STAGE_CONFIGS = {
# File Paths
# ============================================================================
@dataclass
class FilePaths:
"""Standard file paths for the pipeline."""
data_dir: str = "data"
prompts_dir: str = "prompts"
output_dir: str = "output"
@ -374,6 +382,7 @@ PATHS = FilePaths()
# Pipeline Parameters
# ============================================================================
@dataclass
class PipelineParams:
"""Parameters for the full pipeline run."""
@ -389,10 +398,18 @@ class PipelineParams:
# Stage 3: Conversation Generation
num_conversations: int = 2000
conversations_per_qa: int = 10
turn_distribution: Dict[int, float] = field(default_factory=lambda: DEFAULT_TURN_DISTRIBUTION)
emotion_distribution: Dict[str, float] = field(default_factory=lambda: USER_EMOTION_DISTRIBUTION)
modality_distribution: Dict[str, float] = field(default_factory=lambda: INPUT_MODALITY_DISTRIBUTION)
variation_distribution: Dict[str, float] = field(default_factory=lambda: TEXT_VARIATION_DISTRIBUTION)
turn_distribution: Dict[int, float] = field(
default_factory=lambda: DEFAULT_TURN_DISTRIBUTION
)
emotion_distribution: Dict[str, float] = field(
default_factory=lambda: USER_EMOTION_DISTRIBUTION
)
modality_distribution: Dict[str, float] = field(
default_factory=lambda: INPUT_MODALITY_DISTRIBUTION
)
variation_distribution: Dict[str, float] = field(
default_factory=lambda: TEXT_VARIATION_DISTRIBUTION
)
# Stage 4: Judging
min_quality_score: float = 5.0 # Minimum acceptable score

View File

@ -20,7 +20,7 @@ async def batch_embed(
model: str = "text-embedding-3-small",
dimensions: int = 1024,
batch_size: int = 100,
max_concurrent: int = 20
max_concurrent: int = 20,
) -> List[np.ndarray]:
"""
Generate embeddings for a list of texts using OpenAI API.
@ -39,7 +39,7 @@ async def batch_embed(
logfire.info(f"Embedding {len(texts)} texts in batches of {batch_size}")
# Split into batches
batches = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]
batches = [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]
# Semaphore for rate limiting
semaphore = asyncio.Semaphore(max_concurrent)
@ -49,9 +49,7 @@ async def batch_embed(
async with semaphore:
try:
response = await client.embeddings.create(
model=model,
input=batch,
dimensions=dimensions
model=model, input=batch, dimensions=dimensions
)
return [item.embedding for item in response.data]
except Exception as e:
@ -112,7 +110,7 @@ def compute_similarity(emb1: np.ndarray, emb2: np.ndarray) -> float:
def greedy_deduplicate(
embeddings: List[np.ndarray],
scores: List[float],
similarity_threshold: float = 0.95
similarity_threshold: float = 0.95,
) -> List[int]:
"""
Greedy deduplication: keep highest-scoring items, remove similar duplicates.
@ -170,8 +168,8 @@ def conversation_to_text(messages: List[dict], max_chars: int = 24000) -> str:
"""
parts = []
for msg in messages:
role = msg.get('role', 'unknown').upper()
content = msg.get('content', '')
role = msg.get("role", "unknown").upper()
content = msg.get("content", "")
parts.append(f"{role}: {content}")
full_text = "\n\n".join(parts)

View File

@ -17,26 +17,18 @@ class QAPair(BaseModel):
question: str = Field(
description="A natural question that could be asked about this topic"
)
answer: str = Field(
description="The accurate answer grounded in the source text"
)
answer: str = Field(description="The accurate answer grounded in the source text")
source_text: str = Field(
description="The specific text chunk this Q&A was generated from"
)
context_before: str = Field(
default="",
description="Preceding lines for context"
)
context_after: str = Field(
default="",
description="Following lines for context"
)
context_before: str = Field(default="", description="Preceding lines for context")
context_after: str = Field(default="", description="Following lines for context")
difficulty: Literal["basic", "intermediate", "advanced"] = Field(
description="The difficulty level of this question"
)
categories: list[str] = Field(
default_factory=list,
description="Topic categories (e.g., 'pricing', 'features', 'integrations')"
description="Topic categories (e.g., 'pricing', 'features', 'integrations')",
)
@ -65,23 +57,15 @@ class QAValidation(BaseModel):
sensible_answer: bool = Field(
description="Is the answer appropriate and sensible for the question?"
)
passed: bool = Field(
description="Overall pass (all three bools must be True)"
)
feedback: str = Field(
description="Brief explanation of validation result"
)
passed: bool = Field(description="Overall pass (all three bools must be True)")
feedback: str = Field(description="Brief explanation of validation result")
class ValidatedQAPair(BaseModel):
"""A Q&A pair with its validation result."""
qa_pair: QAPair = Field(
description="The Q&A pair being validated"
)
validation: QAValidation = Field(
description="The validation result"
)
qa_pair: QAPair = Field(description="The Q&A pair being validated")
validation: QAValidation = Field(description="The validation result")
# ============================================================================
@ -95,9 +79,7 @@ class Message(BaseModel):
role: Literal["system", "user", "assistant"] = Field(
description="The role of the message sender"
)
content: str = Field(
description="The message content"
)
content: str = Field(description="The message content")
class ConversationMetadata(BaseModel):
@ -112,28 +94,24 @@ class ConversationMetadata(BaseModel):
user_persona: str = Field(
description="The persona/role of the user (e.g., 'developer', 'business owner')"
)
user_emotion: Literal["professional", "happy", "frustrated", "impatient", "confused"] = Field(
default="professional",
description="The emotional state of the user"
)
user_emotion: Literal[
"professional", "happy", "frustrated", "impatient", "confused"
] = Field(default="professional", description="The emotional state of the user")
input_modality: Literal["standard", "typed_on_phone", "voice_dictated"] = Field(
default="standard",
description="How the user is inputting their messages"
default="standard", description="How the user is inputting their messages"
)
text_variation: Literal["standard", "all_lowercase", "no_punctuation"] = Field(
default="standard",
description="Text formatting variation applied to user messages"
description="Text formatting variation applied to user messages",
)
source_qa_ids: list[int] = Field(
default_factory=list,
description="Indices of Q&A pairs used to generate this conversation"
)
difficulty: str = Field(
description="Overall difficulty level"
description="Indices of Q&A pairs used to generate this conversation",
)
difficulty: str = Field(description="Overall difficulty level")
categories: list[str] = Field(
default_factory=list,
description="Topic categories covered in this conversation"
description="Topic categories covered in this conversation",
)
@ -148,7 +126,7 @@ class Conversation(BaseModel):
)
source_qa_pairs: list[QAPair] = Field(
default_factory=list,
description="The Q&A pairs used to generate this conversation (for fact-checking)"
description="The Q&A pairs used to generate this conversation (for fact-checking)",
)
@ -175,24 +153,17 @@ class JudgmentScore(BaseModel):
overall_pass: bool = Field(
description="TRUE only if ALL four criteria above are TRUE"
)
feedback: str = Field(
description="Brief explanation of judgment (1-2 sentences)"
)
feedback: str = Field(description="Brief explanation of judgment (1-2 sentences)")
issues: list[str] = Field(
default_factory=list,
description="Specific problems found (if any)"
default_factory=list, description="Specific problems found (if any)"
)
class JudgedConversation(BaseModel):
"""A conversation with its quality judgment."""
conversation: Conversation = Field(
description="The conversation being judged"
)
judgment: JudgmentScore = Field(
description="The quality judgment scores"
)
conversation: Conversation = Field(description="The conversation being judged")
judgment: JudgmentScore = Field(description="The quality judgment scores")
# ============================================================================
@ -203,18 +174,12 @@ class JudgedConversation(BaseModel):
class EmbeddedConversation(BaseModel):
"""A judged conversation with its embedding."""
conversation: Conversation = Field(
description="The conversation"
)
judgment: JudgmentScore = Field(
description="The quality judgment"
)
conversation: Conversation = Field(description="The conversation")
judgment: JudgmentScore = Field(description="The quality judgment")
embedding: list[float] = Field(
description="Conversation embedding (1024 dimensions)"
)
text_preview: str = Field(
description="First 200 characters for debugging"
)
text_preview: str = Field(description="First 200 characters for debugging")
# ============================================================================
@ -225,12 +190,8 @@ class EmbeddedConversation(BaseModel):
class UniqueConversation(BaseModel):
"""A conversation marked as unique after deduplication."""
conversation: Conversation = Field(
description="The conversation"
)
judgment: JudgmentScore = Field(
description="The quality judgment"
)
conversation: Conversation = Field(description="The conversation")
judgment: JudgmentScore = Field(description="The quality judgment")
# Note: embedding removed to save space after dedup

View File

@ -11,11 +11,8 @@ from .config import (
SYSTEM_PROMPT_TEMPLATES,
CONVERSATION_STYLES,
DEFAULT_TURN_DISTRIBUTION,
USER_EMOTIONS,
USER_EMOTION_DISTRIBUTION,
INPUT_MODALITIES,
INPUT_MODALITY_DISTRIBUTION,
TEXT_VARIATIONS,
TEXT_VARIATION_DISTRIBUTION,
Persona,
SystemPromptTemplate,
@ -137,10 +134,7 @@ def sample_system_prompt_by_use_case(use_case: str) -> SystemPromptTemplate:
Returns:
A SystemPromptTemplate with matching use case
"""
matching = [
s for s in SYSTEM_PROMPT_TEMPLATES.values()
if s.use_case == use_case
]
matching = [s for s in SYSTEM_PROMPT_TEMPLATES.values() if s.use_case == use_case]
return random.choice(matching) if matching else sample_system_prompt()
@ -172,7 +166,9 @@ def sample_balanced_config(
# Sample persona
if prefer_technical:
persona = PERSONAS.get("developer") if random.random() < 0.5 else sample_persona()
persona = (
PERSONAS.get("developer") if random.random() < 0.5 else sample_persona()
)
else:
persona = sample_persona()
@ -193,7 +189,9 @@ def sample_balanced_config(
}
def load_system_prompts_from_files(prompts_dir: str = "src/synth_data_pipeline/prompts/system_prompts") -> Dict[str, str]:
def load_system_prompts_from_files(
prompts_dir: str = "src/synth_data_pipeline/prompts/system_prompts",
) -> Dict[str, str]:
"""
Load system prompt templates from text files.
@ -208,13 +206,15 @@ def load_system_prompts_from_files(prompts_dir: str = "src/synth_data_pipeline/p
for prompt_file in prompts_path.glob("*.txt"):
name = prompt_file.stem
with open(prompt_file, 'r', encoding='utf-8') as f:
with open(prompt_file, "r", encoding="utf-8") as f:
system_prompts[name] = f.read().strip()
return system_prompts
def load_personas_from_files(personas_dir: str = "src/synth_data_pipeline/prompts/personas") -> Dict[str, str]:
def load_personas_from_files(
personas_dir: str = "src/synth_data_pipeline/prompts/personas",
) -> Dict[str, str]:
"""
Load persona descriptions from text files.
@ -229,7 +229,7 @@ def load_personas_from_files(personas_dir: str = "src/synth_data_pipeline/prompt
for persona_file in personas_path.glob("*.txt"):
name = persona_file.stem
with open(persona_file, 'r', encoding='utf-8') as f:
with open(persona_file, "r", encoding="utf-8") as f:
personas[name] = f.read().strip()
return personas
@ -272,10 +272,8 @@ def sample_multiple_configs(
# Sampling Strategies
# ============================================================================
def stratified_sample_configs(
n: int,
ensure_coverage: bool = True
) -> List[Dict]:
def stratified_sample_configs(n: int, ensure_coverage: bool = True) -> List[Dict]:
"""
Sample configurations with stratified sampling to ensure diversity.
@ -292,15 +290,17 @@ def stratified_sample_configs(
# First, ensure we have at least one of each persona-style combination
for persona in PERSONAS.values():
for style in CONVERSATION_STYLES:
configs.append({
"num_turns": sample_num_turns(),
"style": style,
"persona": persona,
"system_prompt": sample_system_prompt(),
"user_emotion": sample_emotion(),
"input_modality": sample_input_modality(),
"text_variation": sample_text_variation(),
})
configs.append(
{
"num_turns": sample_num_turns(),
"style": style,
"persona": persona,
"system_prompt": sample_system_prompt(),
"user_emotion": sample_emotion(),
"input_modality": sample_input_modality(),
"text_variation": sample_text_variation(),
}
)
# Fill remaining with random samples
remaining = n - len(configs)

View File

@ -10,15 +10,15 @@ from typing import List, TypeVar, Callable, Awaitable
import logfire
from tqdm.asyncio import tqdm_asyncio
T = TypeVar('T')
R = TypeVar('R')
T = TypeVar("T")
R = TypeVar("R")
async def process_with_concurrency(
items: List[T],
process_fn: Callable[[T], Awaitable[R]],
max_concurrent: int = 10,
desc: str = "Processing"
desc: str = "Processing",
) -> List[R]:
"""
Process items concurrently with a semaphore to limit concurrency.
@ -61,12 +61,12 @@ def save_jsonl(items: List, output_path: str | Path):
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
with open(output_path, "w", encoding="utf-8") as f:
for item in items:
if hasattr(item, 'model_dump_json'):
f.write(item.model_dump_json() + '\n')
if hasattr(item, "model_dump_json"):
f.write(item.model_dump_json() + "\n")
else:
f.write(json.dumps(item) + '\n')
f.write(json.dumps(item) + "\n")
logfire.info(f"Saved {len(items)} items to {output_path}")
@ -83,7 +83,7 @@ def load_jsonl(file_path: str | Path, model_class=None) -> List:
List of items
"""
items = []
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
if model_class:
items.append(model_class.model_validate_json(line))
@ -92,10 +92,7 @@ def load_jsonl(file_path: str | Path, model_class=None) -> List:
return items
def parse_markdown_chunks(
file_path: str | Path,
context_lines: int = 3
) -> List[dict]:
def parse_markdown_chunks(file_path: str | Path, context_lines: int = 3) -> List[dict]:
"""
Parse markdown file and create chunks with context.
@ -106,7 +103,7 @@ def parse_markdown_chunks(
Returns:
List of dicts with 'source_text', 'context_before', 'context_after'
"""
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, "r", encoding="utf-8") as f:
lines = f.readlines()
chunks = []
@ -116,28 +113,30 @@ def parse_markdown_chunks(
line = lines[i].strip()
# Skip empty lines and metadata
if not line or line.startswith('---') or line.startswith('**As of:**'):
if not line or line.startswith("---") or line.startswith("**As of:**"):
i += 1
continue
# Process bullets and significant lines
if line.startswith('*') or line.startswith('#') or len(line) > 50:
if line.startswith("*") or line.startswith("#") or len(line) > 50:
# Get context before
context_start = max(0, i - context_lines)
context_before = ''.join(lines[context_start:i]).strip()
context_before = "".join(lines[context_start:i]).strip()
# Get the main text (current line)
source_text = line
# Get context after
context_end = min(len(lines), i + context_lines + 1)
context_after = ''.join(lines[i + 1:context_end]).strip()
context_after = "".join(lines[i + 1 : context_end]).strip()
chunks.append({
'source_text': source_text,
'context_before': context_before,
'context_after': context_after,
})
chunks.append(
{
"source_text": source_text,
"context_before": context_before,
"context_after": context_after,
}
)
i += 1
@ -149,7 +148,7 @@ def calculate_overall_score(
naturalness: float,
relevance: float,
diversity: float,
weights: dict = None
weights: dict = None,
) -> float:
"""
Calculate overall quality score from individual metrics.
@ -166,13 +165,14 @@ def calculate_overall_score(
"""
if weights is None:
from .config import QUALITY_WEIGHTS
weights = QUALITY_WEIGHTS
overall = (
factual_accuracy * weights.get("factual_accuracy", 0.35) +
naturalness * weights.get("naturalness", 0.25) +
relevance * weights.get("relevance", 0.25) +
diversity * weights.get("diversity", 0.15)
factual_accuracy * weights.get("factual_accuracy", 0.35)
+ naturalness * weights.get("naturalness", 0.25)
+ relevance * weights.get("relevance", 0.25)
+ diversity * weights.get("diversity", 0.15)
)
return round(overall, 2)
@ -186,11 +186,11 @@ def print_sample(item, title: str = "SAMPLE"):
item: Item to print (conversation, Q&A, etc.)
title: Title for the sample section
"""
print("\n" + "="*80)
print("\n" + "=" * 80)
print(title)
print("="*80)
print("=" * 80)
if hasattr(item, 'model_dump'):
if hasattr(item, "model_dump"):
# Pydantic model
print(json.dumps(item.model_dump(), indent=2))
elif isinstance(item, dict):
@ -198,7 +198,7 @@ def print_sample(item, title: str = "SAMPLE"):
else:
print(item)
print("="*80 + "\n")
print("=" * 80 + "\n")
def print_statistics(scores: List[float], metric_name: str = "Score"):