nanochat/synth-data-pipeline/6_deduplicate.py
2025-10-26 19:24:22 +00:00

120 lines
3.7 KiB
Python

"""
Stage 6: Deduplicate conversations based on embedding similarity.
This script:
1. Loads embedded conversations from Stage 5
2. L2-normalizes embeddings
3. Computes pairwise cosine similarity
4. Removes duplicates above similarity threshold
5. Saves unique conversations
"""
import asyncio
import numpy as np
import logfire
from src.synth_data_pipeline.models import EmbeddedConversation, UniqueConversation
from src.synth_data_pipeline.config import PATHS, FULL_PARAMS
from src.synth_data_pipeline.utils import load_jsonl, save_jsonl
from src.synth_data_pipeline.embedding_utils import (
l2_normalize,
greedy_deduplicate,
)
# Configure logging
logfire.configure(scrubbing=False)
async def main(
input_file: str = None,
output_file: str = None,
similarity_threshold: float = None
):
"""
Main function to deduplicate conversations.
Args:
input_file: Path to input JSONL file (default from config)
output_file: Path to output JSONL file (default from config)
similarity_threshold: Similarity threshold for deduplication (default from config)
"""
# Use defaults from config if not specified
input_file = input_file or PATHS.stage5_conversations_embedded
output_file = output_file or PATHS.stage6_conversations_unique
similarity_threshold = similarity_threshold or FULL_PARAMS.dedup_similarity_threshold
logfire.info(
"Starting deduplication",
input_file=input_file,
similarity_threshold=similarity_threshold
)
# Load embedded conversations
embedded_convs = load_jsonl(input_file, model_class=EmbeddedConversation)
logfire.info(f"Loaded {len(embedded_convs)} embedded conversations")
# Extract embeddings and scores
embeddings = [np.array(ec.embedding, dtype=np.float32) for ec in embedded_convs]
scores = [ec.judgment.overall_score for ec in embedded_convs]
# L2-normalize embeddings for cosine similarity
with logfire.span("normalize_embeddings"):
normalized_embeddings = l2_normalize(embeddings)
logfire.info("Normalized embeddings for cosine similarity")
# Deduplicate
with logfire.span("deduplicate"):
kept_indices = greedy_deduplicate(
normalized_embeddings,
scores,
similarity_threshold=similarity_threshold
)
# Create unique conversations (without embeddings to save space)
unique_convs = []
for idx in kept_indices:
ec = embedded_convs[idx]
unique_conv = UniqueConversation(
conversation=ec.conversation,
judgment=ec.judgment
)
unique_convs.append(unique_conv)
# Save results
save_jsonl(unique_convs, output_file)
# Statistics
total = len(embedded_convs)
kept = len(unique_convs)
removed = total - kept
removal_rate = 100 * removed / total if total > 0 else 0
logfire.info(
f"Deduplication complete: {kept} kept, {removed} removed ({removal_rate:.1f}%)"
)
# Print statistics
print("\n" + "="*80)
print("DEDUPLICATION STATISTICS:")
print("="*80)
print(f"Total conversations: {total}")
print(f"Unique conversations: {kept}")
print(f"Duplicates removed: {removed} ({removal_rate:.1f}%)")
print(f"Similarity threshold: {similarity_threshold}")
print("="*80 + "\n")
# Print score statistics for unique conversations
unique_scores = [uc.judgment.overall_score for uc in unique_convs]
if unique_scores:
print("Unique conversation scores:")
print(f" Average: {np.mean(unique_scores):.2f}")
print(f" Min: {np.min(unique_scores):.2f}")
print(f" Max: {np.max(unique_scores):.2f}")
print("="*80 + "\n")
if __name__ == "__main__":
asyncio.run(main())