mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
update synth-data-pipe
This commit is contained in:
parent
0197c1cd3c
commit
6bfc1f8f53
89
MODAL_DEPLOY.md
Normal file
89
MODAL_DEPLOY.md
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
# Modal Deployment Guide
|
||||
|
||||
This guide explains how to deploy and run nanochat on Modal's serverless infrastructure.
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. **Install and authenticate with Modal:**
|
||||
```bash
|
||||
uv run modal setup
|
||||
```
|
||||
|
||||
2. **Choose your workflow:**
|
||||
|
||||
## Files Overview
|
||||
|
||||
### `modal_smoke.py` - Quick Image Setup & Testing
|
||||
|
||||
This script sets up the Modal image and runs quick smoke tests to verify everything works. Use this to:
|
||||
- Build the Modal image with all dependencies
|
||||
- Test that the environment is correctly configured
|
||||
- Quickly validate your setup before running longer training jobs
|
||||
|
||||
```bash
|
||||
uv run modal run modal_smoke.py
|
||||
```
|
||||
|
||||
### `modal_speedrun.py` - Full Training Pipeline
|
||||
|
||||
This script runs the complete nanochat training pipeline on Modal. You can:
|
||||
- Run the entire pipeline (base training → midtraining → SFT)
|
||||
- Run only SFT if you already have a base model
|
||||
- Train on 8xB200 GPUs (completes in under 2 hours)
|
||||
|
||||
```bash
|
||||
# Full training pipeline
|
||||
uv run modal run modal_speedrun.py
|
||||
|
||||
# SFT only (requires existing base model)
|
||||
uv run modal run modal_speedrun.py --mode sft
|
||||
```
|
||||
|
||||
### `modal_serve.py` - Interactive Chat Interface
|
||||
|
||||
This script deploys the web UI so you can chat with your trained model.
|
||||
|
||||
```bash
|
||||
# Development mode (stays running while terminal is open)
|
||||
uv run modal serve modal_serve.py
|
||||
|
||||
# Production deployment (runs independently)
|
||||
uv run modal deploy modal_serve.py
|
||||
```
|
||||
|
||||
Modal will print a URL - visit it in your browser to chat with your model!
|
||||
|
||||
## Model Loading Behavior
|
||||
|
||||
**Important:** Under the hood, the `load_model()` function automatically picks:
|
||||
- The **highest model variant** available
|
||||
- The **checkpoint with the most steps**
|
||||
|
||||
To override this behavior, explicitly specify the model tag and step in the `sys.argv` configuration of the Modal script you're using.
|
||||
|
||||
## Volume Structure
|
||||
|
||||
All scripts use the same `nanochat-data` volume with this structure:
|
||||
|
||||
```
|
||||
/data/.cache/nanochat/
|
||||
├── chatsft_checkpoints/ # SFT checkpoints
|
||||
├── mid_checkpoints/ # Midtraining checkpoints
|
||||
├── base_checkpoints/ # Base model checkpoints
|
||||
└── tokenizer/ # Trained tokenizer
|
||||
```
|
||||
|
||||
## Monitoring & Debugging
|
||||
|
||||
```bash
|
||||
# View logs
|
||||
modal app logs nanochat-serve
|
||||
|
||||
# Check volume contents
|
||||
modal volume ls nanochat-data /.cache/nanochat
|
||||
|
||||
# Download checkpoints
|
||||
modal volume get nanochat-data /.cache/nanochat/chatsft_checkpoints ./checkpoints
|
||||
```
|
||||
|
||||
Visit the Modal dashboard at https://modal.com/apps for more details.
|
||||
10
README.md
10
README.md
|
|
@ -4,6 +4,16 @@
|
|||
|
||||
> The best ChatGPT that $100 can buy.
|
||||
|
||||
## Fork Highlights
|
||||
|
||||
This is a fork of [Andrej Karpathy's incredible nanochat project](https://github.com/karpathy/nanochat). We've added:
|
||||
|
||||
- **Modal deployment support**: Scripts to deploy and run training on Modal with 8xB200 GPUs (completing the full training pipeline in under 2 hours). See [MODAL_DEPLOY.md](MODAL_DEPLOY.md) for details.
|
||||
- **Synthetic data pipeline**: A flexible pipeline for generating synthetic training data to experiment with the SFT (Supervised Fine-Tuning) stage. See `synth-data-pipeline/` directory.
|
||||
- **Auto-generated documentation**: Quick overview of the system architecture and components in [DOCS.md](DOCS.md).
|
||||
|
||||
---
|
||||
|
||||
This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs.
|
||||
|
||||
## Talk to it
|
||||
|
|
|
|||
168
d32_setup/modal_d32_serve.py
Normal file
168
d32_setup/modal_d32_serve.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
"""
|
||||
Serve the d32 model from HuggingFace using Modal.
|
||||
|
||||
This script is specifically designed to work with the uploaded d32 model
|
||||
which has its own tokenizer separate from your trained d20 model.
|
||||
|
||||
Usage:
|
||||
modal serve modal_d32_serve.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import modal
|
||||
|
||||
APP_NAME = "nanochat-d32-serve"
|
||||
VOLUME_NAME = "nanochat-data"
|
||||
|
||||
app = modal.App(APP_NAME)
|
||||
vol = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
|
||||
|
||||
# Get the local directory path
|
||||
LOCAL_DIR = Path(__file__).parent
|
||||
|
||||
# Build image with nanochat code
|
||||
image = (
|
||||
modal.Image.debian_slim(python_version="3.11")
|
||||
.apt_install("curl", "build-essential", "pkg-config", "unzip")
|
||||
.add_local_dir("dev", "/nanochat/dev", copy=True)
|
||||
.add_local_dir("nanochat", "/nanochat/nanochat", copy=True)
|
||||
.add_local_dir("rustbpe", "/nanochat/rustbpe", copy=True)
|
||||
.add_local_dir("scripts", "/nanochat/scripts", copy=True)
|
||||
.add_local_dir("tasks", "/nanochat/tasks", copy=True)
|
||||
.add_local_dir("tests", "/nanochat/tests", copy=True)
|
||||
.add_local_file("pyproject.toml", "/nanochat/pyproject.toml", copy=True)
|
||||
.add_local_file(".python-version", "/nanochat/.python-version", copy=True)
|
||||
.add_local_file("README.md", "/nanochat/README.md", copy=True)
|
||||
.add_local_file("LICENSE", "/nanochat/LICENSE", copy=True)
|
||||
.workdir("/nanochat")
|
||||
.run_commands(
|
||||
"curl -LsSf https://astral.sh/uv/install.sh | sh",
|
||||
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable",
|
||||
)
|
||||
.env({"PATH": "/root/.cargo/bin:/root/.local/bin:$PATH"})
|
||||
.uv_sync(extras=["gpu"])
|
||||
.run_commands(
|
||||
"uv run maturin develop --release --manifest-path rustbpe/Cargo.toml",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@app.function(
|
||||
image=image,
|
||||
volumes={"/data": vol},
|
||||
gpu="H100:1", # Single H100 for serving
|
||||
timeout=60 * 60,
|
||||
container_idle_timeout=300,
|
||||
)
|
||||
def chat_d32(prompt: str, temperature: float = 0.6, top_k: int = 50) -> str:
|
||||
"""
|
||||
Chat with the d32 model.
|
||||
|
||||
Args:
|
||||
prompt: User prompt/question
|
||||
temperature: Sampling temperature (default: 0.6)
|
||||
top_k: Top-k sampling parameter (default: 50)
|
||||
|
||||
Returns:
|
||||
Model's response as a string
|
||||
"""
|
||||
import sys
|
||||
import torch
|
||||
from contextlib import nullcontext
|
||||
|
||||
# Add nanochat to path
|
||||
sys.path.insert(0, '/nanochat')
|
||||
|
||||
# Import after adding to path
|
||||
from nanochat.common import get_base_dir, autodetect_device_type, compute_init
|
||||
from nanochat.checkpoint_manager import build_model
|
||||
from nanochat.tokenizer import RustBPETokenizer
|
||||
from nanochat.engine import Engine
|
||||
|
||||
# Setup environment to point to d32's tokenizer
|
||||
DATA = Path("/data")
|
||||
BASE_DIR = DATA / ".cache" / "nanochat"
|
||||
|
||||
# CRITICAL: Override the base dir so it uses our volume
|
||||
os.environ["NANOCHAT_BASE_DIR"] = str(BASE_DIR)
|
||||
|
||||
# Setup device
|
||||
device_type = autodetect_device_type()
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# Load the d32 model
|
||||
checkpoint_dir = BASE_DIR / "chatsft_checkpoints" / "d32"
|
||||
step = 650 # The uploaded checkpoint is at step 650
|
||||
|
||||
print(f"Loading d32 model from {checkpoint_dir} at step {step}")
|
||||
model, _, meta = build_model(str(checkpoint_dir), step, device, phase="eval")
|
||||
|
||||
# Load the d32-specific tokenizer
|
||||
tokenizer_dir = BASE_DIR / "tokenizer_d32"
|
||||
print(f"Loading d32 tokenizer from {tokenizer_dir}")
|
||||
tokenizer = RustBPETokenizer.from_directory(str(tokenizer_dir))
|
||||
|
||||
# Verify vocab size matches
|
||||
assert tokenizer.get_vocab_size() == model.config.vocab_size, \
|
||||
f"Tokenizer vocab size {tokenizer.get_vocab_size()} != model vocab size {model.config.vocab_size}"
|
||||
|
||||
# Create engine
|
||||
engine = Engine(model, tokenizer)
|
||||
|
||||
# Special tokens
|
||||
bos = tokenizer.get_bos_token_id()
|
||||
user_start = tokenizer.encode_special("<|user_start|>")
|
||||
user_end = tokenizer.encode_special("<|user_end|>")
|
||||
assistant_start = tokenizer.encode_special("<|assistant_start|>")
|
||||
assistant_end = tokenizer.encode_special("<|assistant_end|>")
|
||||
|
||||
# Build conversation tokens
|
||||
conversation_tokens = [bos]
|
||||
conversation_tokens.append(user_start)
|
||||
conversation_tokens.extend(tokenizer.encode(prompt))
|
||||
conversation_tokens.append(user_end)
|
||||
conversation_tokens.append(assistant_start)
|
||||
|
||||
# Generate response
|
||||
conversation_tokens = torch.tensor(conversation_tokens, dtype=torch.long, device=device).unsqueeze(0)
|
||||
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
generated, _ = engine.generate(
|
||||
conversation_tokens,
|
||||
max_new_tokens=2048,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
stop_tokens=[assistant_end],
|
||||
)
|
||||
|
||||
# Decode response (skip the prompt)
|
||||
response_tokens = generated[0, conversation_tokens.size(1):].tolist()
|
||||
response = tokenizer.decode(response_tokens)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main(prompt: str = "What is the capital of France?"):
|
||||
"""
|
||||
Test the d32 model with a prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to the model
|
||||
|
||||
Examples:
|
||||
modal run modal_d32_serve.py
|
||||
modal run modal_d32_serve.py --prompt "Explain quantum computing"
|
||||
"""
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🤖 NanoChat d32 Model")
|
||||
print(f"{'='*80}\n")
|
||||
print(f"Prompt: {prompt}\n")
|
||||
|
||||
response = chat_d32.remote(prompt)
|
||||
|
||||
print(f"Response: {response}")
|
||||
print(f"\n{'='*80}\n")
|
||||
156
d32_setup/modal_d32_setup.py
Normal file
156
d32_setup/modal_d32_setup.py
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
"""
|
||||
Setup script to organize the uploaded d32 model from HuggingFace.
|
||||
|
||||
This script moves the d32 model files into the proper nanochat directory structure
|
||||
so it can be used alongside your trained d20 model.
|
||||
|
||||
The d32 model comes with its own tokenizer, so we need to set up a separate
|
||||
environment variable or parameter to switch between d20 and d32.
|
||||
|
||||
Directory structure after setup:
|
||||
.cache/nanochat/
|
||||
├── chatsft_checkpoints/
|
||||
│ ├── d20/ # Your trained model
|
||||
│ │ ├── model_*.pt
|
||||
│ │ └── meta_*.json
|
||||
│ └── d32/ # Karpathy's d32 model
|
||||
│ ├── model_000650.pt
|
||||
│ └── meta_000650.json
|
||||
└── tokenizer_d32/ # d32's tokenizer (separate from your d20 tokenizer)
|
||||
├── tokenizer.pkl
|
||||
└── token_bytes.pt
|
||||
"""
|
||||
|
||||
import modal
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
APP_NAME = "nanochat-d32-setup"
|
||||
VOLUME_NAME = "nanochat-data"
|
||||
|
||||
app = modal.App(APP_NAME)
|
||||
vol = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
|
||||
|
||||
image = modal.Image.debian_slim(python_version="3.11")
|
||||
|
||||
@app.function(
|
||||
image=image,
|
||||
volumes={"/data": vol},
|
||||
timeout=60 * 10, # 10 minutes
|
||||
max_inputs=1,
|
||||
)
|
||||
def setup_d32_model():
|
||||
"""
|
||||
Organize the uploaded d32 model files into the proper nanochat directory structure.
|
||||
|
||||
This moves:
|
||||
- model_000650.pt -> chatsft_checkpoints/d32/model_000650.pt
|
||||
- meta_000650.json -> chatsft_checkpoints/d32/meta_000650.json
|
||||
- tokenizer.pkl -> tokenizer_d32/tokenizer.pkl
|
||||
- token_bytes.pt -> tokenizer_d32/token_bytes.pt
|
||||
"""
|
||||
DATA = Path("/data")
|
||||
BASE_DIR = DATA / ".cache" / "nanochat"
|
||||
|
||||
# Source: where the d32 files were uploaded
|
||||
UPLOADED_DIR = BASE_DIR / "chatsft_uploaded"
|
||||
|
||||
# Destinations
|
||||
D32_CHECKPOINT_DIR = BASE_DIR / "chatsft_checkpoints" / "d32"
|
||||
D32_TOKENIZER_DIR = BASE_DIR / "tokenizer_d32"
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🔧 Setting up d32 model in nanochat directory structure")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
# Check if uploaded files exist
|
||||
if not UPLOADED_DIR.exists():
|
||||
print(f"❌ Error: Upload directory not found: {UPLOADED_DIR}")
|
||||
print(f" Please run modal_d32_upload.py first!")
|
||||
return
|
||||
|
||||
uploaded_files = list(UPLOADED_DIR.iterdir())
|
||||
print(f"📦 Found {len(uploaded_files)} files in {UPLOADED_DIR}:")
|
||||
for f in uploaded_files:
|
||||
print(f" - {f.name}")
|
||||
|
||||
# Create destination directories
|
||||
D32_CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
D32_TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Move model checkpoint and metadata
|
||||
print(f"\n📁 Setting up model checkpoint in {D32_CHECKPOINT_DIR}...")
|
||||
model_file = UPLOADED_DIR / "model_000650.pt"
|
||||
meta_file = UPLOADED_DIR / "meta_000650.json"
|
||||
|
||||
if model_file.exists():
|
||||
dest = D32_CHECKPOINT_DIR / "model_000650.pt"
|
||||
print(f" Copying {model_file.name} -> {dest}")
|
||||
shutil.copy2(model_file, dest)
|
||||
else:
|
||||
print(f" ⚠️ Warning: {model_file.name} not found")
|
||||
|
||||
if meta_file.exists():
|
||||
dest = D32_CHECKPOINT_DIR / "meta_000650.json"
|
||||
print(f" Copying {meta_file.name} -> {dest}")
|
||||
shutil.copy2(meta_file, dest)
|
||||
else:
|
||||
print(f" ⚠️ Warning: {meta_file.name} not found")
|
||||
|
||||
# Move tokenizer files
|
||||
print(f"\n🔤 Setting up d32 tokenizer in {D32_TOKENIZER_DIR}...")
|
||||
tokenizer_file = UPLOADED_DIR / "tokenizer.pkl"
|
||||
token_bytes_file = UPLOADED_DIR / "token_bytes.pt"
|
||||
|
||||
if tokenizer_file.exists():
|
||||
dest = D32_TOKENIZER_DIR / "tokenizer.pkl"
|
||||
print(f" Copying {tokenizer_file.name} -> {dest}")
|
||||
shutil.copy2(tokenizer_file, dest)
|
||||
else:
|
||||
print(f" ⚠️ Warning: {tokenizer_file.name} not found")
|
||||
|
||||
if token_bytes_file.exists():
|
||||
dest = D32_TOKENIZER_DIR / "token_bytes.pt"
|
||||
print(f" Copying {token_bytes_file.name} -> {dest}")
|
||||
shutil.copy2(token_bytes_file, dest)
|
||||
else:
|
||||
print(f" ⚠️ Warning: {token_bytes_file.name} not found")
|
||||
|
||||
# Commit changes to volume
|
||||
vol.commit()
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"✅ d32 model setup complete!")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
print("\n📋 Directory structure:")
|
||||
print(f"\nYour d20 model:")
|
||||
print(f" Model: .cache/nanochat/chatsft_checkpoints/d20/")
|
||||
print(f" Tokenizer: .cache/nanochat/tokenizer/")
|
||||
|
||||
print(f"\nKarpathy's d32 model:")
|
||||
print(f" Model: .cache/nanochat/chatsft_checkpoints/d32/")
|
||||
print(f" Tokenizer: .cache/nanochat/tokenizer_d32/")
|
||||
|
||||
print("\n⚠️ IMPORTANT:")
|
||||
print(" Each model MUST use its own tokenizer!")
|
||||
print(" You cannot mix d20 model with d32 tokenizer or vice versa.")
|
||||
print("\n To use d32, you'll need to modify scripts to:")
|
||||
print(" 1. Set model_tag='d32' to load the d32 checkpoint")
|
||||
print(" 2. Modify get_tokenizer() to load from tokenizer_d32/")
|
||||
print(" OR use a separate script that handles d32 specifically.")
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
"""
|
||||
Organize uploaded d32 model files into proper nanochat directory structure.
|
||||
|
||||
Usage:
|
||||
modal run modal_d32_setup.py
|
||||
|
||||
This should be run AFTER modal_d32_upload.py completes.
|
||||
"""
|
||||
print("\n🚀 Setting up d32 model directory structure...")
|
||||
setup_d32_model.remote()
|
||||
print("\n✅ Setup complete!")
|
||||
104
d32_setup/modal_d32_upload.py
Normal file
104
d32_setup/modal_d32_upload.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
import modal
|
||||
from pathlib import Path
|
||||
|
||||
APP_NAME = "nanochat-d32-upload"
|
||||
VOLUME_NAME = "nanochat-data"
|
||||
|
||||
app = modal.App(APP_NAME)
|
||||
vol = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
|
||||
|
||||
# Image with huggingface-hub for downloading files
|
||||
image = (
|
||||
modal.Image.debian_slim(python_version="3.11")
|
||||
.pip_install("huggingface-hub")
|
||||
)
|
||||
|
||||
@app.function(
|
||||
image=image,
|
||||
volumes={"/data": vol},
|
||||
timeout=60 * 60 * 2, # 2 hours for downloading ~7.25GB
|
||||
max_inputs=1,
|
||||
)
|
||||
def upload_d32_checkpoint():
|
||||
"""
|
||||
Download nanochat-d32 model files from Hugging Face and sync them into Modal volume.
|
||||
|
||||
Downloads from: https://huggingface.co/karpathy/nanochat-d32/tree/main
|
||||
Target path in volume: .cache/nanochat/chatsft_uploaded
|
||||
|
||||
Files to download:
|
||||
- model_000650.pt (7.25 GB) - PyTorch model checkpoint
|
||||
- meta_000650.json (263 B) - Metadata for checkpoint
|
||||
- tokenizer.pkl (846 kB) - Tokenizer pickle file
|
||||
- token_bytes.pt (264 kB) - Token bytes
|
||||
- README.md (526 B) - Documentation
|
||||
"""
|
||||
from huggingface_hub import hf_hub_download
|
||||
import shutil
|
||||
|
||||
DATA = Path("/data")
|
||||
TARGET_DIR = DATA / ".cache" / "nanochat" / "chatsft_uploaded"
|
||||
TARGET_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
REPO_ID = "karpathy/nanochat-d32"
|
||||
FILES_TO_DOWNLOAD = [
|
||||
"model_000650.pt", # 7.25 GB - main model checkpoint
|
||||
"meta_000650.json", # 263 B - metadata
|
||||
"tokenizer.pkl", # 846 kB - tokenizer
|
||||
"token_bytes.pt", # 264 kB - token bytes
|
||||
"README.md", # 526 B - documentation
|
||||
]
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"📥 Downloading nanochat-d32 files from Hugging Face")
|
||||
print(f" Repository: {REPO_ID}")
|
||||
print(f" Target: {TARGET_DIR}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
for filename in FILES_TO_DOWNLOAD:
|
||||
print(f"\n📦 Downloading {filename}...")
|
||||
|
||||
# Download file to HF cache
|
||||
cached_path = hf_hub_download(
|
||||
repo_id=REPO_ID,
|
||||
filename=filename,
|
||||
repo_type="model",
|
||||
)
|
||||
|
||||
# Copy to our target directory in the volume
|
||||
target_path = TARGET_DIR / filename
|
||||
print(f" Copying to {target_path}...")
|
||||
shutil.copy2(cached_path, target_path)
|
||||
print(f" ✓ {filename} synced")
|
||||
|
||||
# Commit changes to volume
|
||||
vol.commit()
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"✅ All files downloaded and synced to volume!")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
# List final directory contents
|
||||
print("\n📋 Final directory contents:")
|
||||
for file in sorted(TARGET_DIR.iterdir()):
|
||||
size_mb = file.stat().st_size / (1024 * 1024)
|
||||
print(f" {file.name:25s} ({size_mb:8.2f} MB)")
|
||||
|
||||
print(f"\nFiles available at: /data/.cache/nanochat/chatsft_uploaded")
|
||||
|
||||
|
||||
# Local entrypoint
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
"""
|
||||
Download nanochat-d32 checkpoint from Hugging Face to Modal volume.
|
||||
|
||||
Usage:
|
||||
modal run modal_d32_upload.py
|
||||
|
||||
This will download ~7.25GB of model files and sync them to:
|
||||
{VOLUME_NAME}/.cache/nanochat/chatsft_uploaded/
|
||||
"""
|
||||
print("\n🚀 Starting nanochat-d32 upload to Modal volume...")
|
||||
upload_d32_checkpoint.remote()
|
||||
print("\n✅ Upload complete! Files are now in the Modal volume.")
|
||||
119
modal_serve.py
Normal file
119
modal_serve.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
"""
|
||||
Modal deployment for nanochat - serves the existing chat_web.py FastAPI app on H100.
|
||||
|
||||
Usage:
|
||||
modal deploy modal_serve.py
|
||||
|
||||
This will:
|
||||
1. Build a container image with PyTorch, FastAPI, and the nanochat module
|
||||
2. Load the best available checkpoint (from sft by default)
|
||||
3. Serve the chat UI and API endpoints from scripts/chat_web.py
|
||||
|
||||
The web UI will be available at the URL printed by Modal after deployment.
|
||||
|
||||
Note: Before deploying, upload your model checkpoints to the volume.
|
||||
"""
|
||||
|
||||
import modal
|
||||
from pathlib import Path
|
||||
|
||||
APP_NAME = "nanochat-serve"
|
||||
VOLUME_NAME = "nanochat-data" # Reuse the same volume as modal_speedrun.py
|
||||
|
||||
app = modal.App(APP_NAME)
|
||||
|
||||
# Reuse volume from modal_speedrun (or create if missing)
|
||||
vol = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
|
||||
|
||||
# Get the local directory path
|
||||
LOCAL_DIR = Path(__file__).parent
|
||||
|
||||
# Build Modal image with identical environment to modal_speedrun.py
|
||||
# This ensures consistency between training and serving
|
||||
image = (
|
||||
modal.Image.debian_slim(python_version="3.11")
|
||||
.apt_install("curl", "build-essential", "pkg-config", "unzip")
|
||||
.add_local_dir("dev", "/nanochat/dev", copy=True)
|
||||
.add_local_dir("nanochat", "/nanochat/nanochat", copy=True)
|
||||
.add_local_dir("rustbpe", "/nanochat/rustbpe", copy=True)
|
||||
.add_local_dir("scripts", "/nanochat/scripts", copy=True)
|
||||
.add_local_dir("tasks", "/nanochat/tasks", copy=True)
|
||||
.add_local_dir("tests", "/nanochat/tests", copy=True)
|
||||
.add_local_file("pyproject.toml", "/nanochat/pyproject.toml", copy=True)
|
||||
.add_local_file(".python-version", "/nanochat/.python-version", copy=True)
|
||||
.add_local_file("README.md", "/nanochat/README.md", copy=True)
|
||||
.add_local_file("LICENSE", "/nanochat/LICENSE", copy=True)
|
||||
.workdir("/nanochat")
|
||||
.run_commands(
|
||||
# Install uv (Python package manager)
|
||||
"curl -LsSf https://astral.sh/uv/install.sh | sh",
|
||||
# Install Rust and set default toolchain
|
||||
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable",
|
||||
)
|
||||
.env({"PATH": "/root/.cargo/bin:/root/.local/bin:$PATH"})
|
||||
.uv_sync(extras=["gpu"])
|
||||
.run_commands(
|
||||
# Build the Rust tokenizer (the slow part)
|
||||
"uv run maturin develop --release --manifest-path rustbpe/Cargo.toml",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@app.function(
|
||||
image=image,
|
||||
gpu="H100",
|
||||
volumes={"/data": vol},
|
||||
timeout=3600, # 1 hour timeout
|
||||
scaledown_window=300, # Keep alive for 5 min after last request
|
||||
)
|
||||
@modal.asgi_app()
|
||||
def fastapi_app():
|
||||
"""
|
||||
Import and return the FastAPI app from chat_web.py.
|
||||
|
||||
This reuses all the existing logic: endpoints, streaming, validation, etc.
|
||||
The only difference is we run on Modal infrastructure with H100 GPU.
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Set base directory to where checkpoints are mounted (same as modal_speedrun)
|
||||
BASE_DIR = "/data/.cache/nanochat"
|
||||
os.environ['NANOCHAT_BASE_DIR'] = BASE_DIR
|
||||
|
||||
# Mock the command-line arguments that chat_web.py expects
|
||||
sys.argv = [
|
||||
'chat_web.py',
|
||||
'--num-gpus', '1', # Single GPU (Modal handles scaling)
|
||||
'--source', 'sft', # Load from sft checkpoints
|
||||
'--temperature', '0.8', # Default temperature
|
||||
'--top-k', '50', # Default top-k
|
||||
'--max-tokens', '512', # Default max tokens
|
||||
'--device-type', 'cuda', # Use CUDA
|
||||
'--dtype', 'bfloat16', # Use bfloat16 for efficiency
|
||||
]
|
||||
|
||||
# Import the FastAPI app from chat_web
|
||||
# This will trigger model loading via the lifespan context manager
|
||||
from scripts.chat_web import app
|
||||
|
||||
print(f"✅ NanoChat server initialized!")
|
||||
print(f" Checkpoint directory: {BASE_DIR}")
|
||||
print(f" GPU: H100 x 1")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# Convenience local entrypoint for testing
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
"""
|
||||
Deploy the nanochat serving endpoint.
|
||||
|
||||
This is just a convenience wrapper. You can also run:
|
||||
modal deploy modal_serve.py
|
||||
"""
|
||||
print("Deploying nanochat serving endpoint...")
|
||||
print(f"Using volume: {VOLUME_NAME}")
|
||||
print(f"GPU: H100 x 1")
|
||||
print("\nThe app will be available at the URL printed by Modal.")
|
||||
|
|
@ -59,9 +59,9 @@ name = "pytorch-cpu"
|
|||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
3
synth-data-pipeline/.gitignore
vendored
Normal file
3
synth-data-pipeline/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
data/**
|
||||
output/**
|
||||
.env
|
||||
1
synth-data-pipeline/.python-version
Normal file
1
synth-data-pipeline/.python-version
Normal file
|
|
@ -0,0 +1 @@
|
|||
3.13
|
||||
146
synth-data-pipeline/1_extract_qa.py
Normal file
146
synth-data-pipeline/1_extract_qa.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
"""
|
||||
Stage 1: Extract Q&A pairs from SWAP Commerce documentation.
|
||||
|
||||
This script:
|
||||
1. Parses swap_facts.md and chunks by bullet points/sections
|
||||
2. Uses Gemini 2.5 Flash to generate Q&A pairs with context
|
||||
3. Saves results to output/qa_pairs.jsonl
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import logfire
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from src.synth_data_pipeline.agents import qa_extractor
|
||||
|
||||
from src.synth_data_pipeline.models import QAPair, QAPairBatch
|
||||
from src.synth_data_pipeline.config import PATHS, STAGE_CONFIGS, FULL_PARAMS
|
||||
from src.synth_data_pipeline.utils import (
|
||||
parse_markdown_chunks,
|
||||
process_with_concurrency,
|
||||
save_jsonl,
|
||||
)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Configure logging
|
||||
logfire.configure(scrubbing=False)
|
||||
logfire.instrument_pydantic_ai()
|
||||
|
||||
# Get configuration for this stage
|
||||
config = STAGE_CONFIGS["stage1_qa_extraction"]
|
||||
|
||||
# Load Q&A extraction agent definition
|
||||
qa_prompt_template = qa_extractor.get_prompt_template()
|
||||
qa_agent = qa_extractor.build_agent(config)
|
||||
|
||||
|
||||
async def generate_qa_batch(chunk: dict) -> list[QAPair]:
|
||||
"""
|
||||
Generate 3 Q&A pairs from a text chunk using Gemini.
|
||||
|
||||
Args:
|
||||
chunk: Dict with source_text, context_before, context_after
|
||||
|
||||
Returns:
|
||||
List of 3 QAPair objects
|
||||
"""
|
||||
# Format the prompt
|
||||
prompt_text = qa_prompt_template.prompt.format(
|
||||
source_text=chunk['source_text'],
|
||||
context_before=chunk['context_before'],
|
||||
context_after=chunk['context_after'],
|
||||
)
|
||||
|
||||
# Generate batch of 3 Q&A pairs using the agent
|
||||
result = await qa_agent.run(prompt_text)
|
||||
batch = result.output
|
||||
|
||||
# Return the list of QAPair objects
|
||||
return batch.qa_pairs
|
||||
|
||||
|
||||
async def main(
|
||||
input_file: str = None,
|
||||
output_file: str = None,
|
||||
max_concurrent: int = None,
|
||||
limit: int = None
|
||||
):
|
||||
"""
|
||||
Main function to extract Q&A pairs from documentation.
|
||||
|
||||
Args:
|
||||
input_file: Path to input markdown file (default from config)
|
||||
output_file: Path to output JSONL file (default from config)
|
||||
max_concurrent: Maximum concurrent API calls (default from config)
|
||||
limit: Limit number of chunks to process (None = no limit)
|
||||
"""
|
||||
# Use defaults from config if not specified
|
||||
input_file = input_file or PATHS.source_facts
|
||||
output_file = output_file or PATHS.stage1_qa_pairs
|
||||
max_concurrent = max_concurrent or config.max_concurrent
|
||||
limit = limit or FULL_PARAMS.qa_chunk_limit
|
||||
|
||||
logfire.info("Starting Q&A extraction", input_file=input_file)
|
||||
|
||||
# Parse the markdown file into chunks
|
||||
chunks = parse_markdown_chunks(input_file, FULL_PARAMS.qa_chunk_context_lines)
|
||||
logfire.info(f"Parsed {len(chunks)} chunks from {input_file}")
|
||||
|
||||
# Limit chunks if specified (for testing)
|
||||
if limit:
|
||||
chunks = chunks[:limit]
|
||||
logfire.info(f"Limited to {limit} chunks")
|
||||
|
||||
# Generate Q&A pairs (3 per chunk)
|
||||
with logfire.span("generate_qa_batches"):
|
||||
qa_batches = await process_with_concurrency(
|
||||
chunks,
|
||||
generate_qa_batch,
|
||||
max_concurrent=max_concurrent,
|
||||
desc="Generating Q&A batches"
|
||||
)
|
||||
|
||||
# Flatten the batches into individual QA pairs
|
||||
qa_pairs = []
|
||||
for batch in qa_batches:
|
||||
qa_pairs.extend(batch)
|
||||
|
||||
logfire.info(f"Generated {len(qa_pairs)} Q&A pairs from {len(chunks)} chunks ({len(qa_pairs)/len(chunks):.1f} per chunk)")
|
||||
|
||||
# Filter out Q&A pairs with future dates to avoid hallucination issues
|
||||
from datetime import datetime
|
||||
today = datetime.now()
|
||||
filtered_qa = []
|
||||
for qa in qa_pairs:
|
||||
# Simple check: if answer or question mentions a date in 2025 or later, skip it
|
||||
# This is a pragmatic filter - could be made more sophisticated
|
||||
text = qa.answer + " " + qa.question
|
||||
has_future_date = any(year in text for year in ["2025", "2026", "2027", "2028"])
|
||||
if not has_future_date:
|
||||
filtered_qa.append(qa)
|
||||
|
||||
if len(filtered_qa) < len(qa_pairs):
|
||||
logfire.info(f"Filtered out {len(qa_pairs) - len(filtered_qa)} Q&A pairs with future dates")
|
||||
qa_pairs = filtered_qa
|
||||
|
||||
# Save results
|
||||
save_jsonl(qa_pairs, output_file)
|
||||
|
||||
# Print sample for inspection
|
||||
if qa_pairs:
|
||||
print("\n" + "="*80)
|
||||
print("SAMPLE Q&A PAIR:")
|
||||
print("="*80)
|
||||
sample = qa_pairs[0]
|
||||
print(f"Question: {sample.question}")
|
||||
print(f"Answer: {sample.answer}")
|
||||
print(f"Difficulty: {sample.difficulty}")
|
||||
print(f"Categories: {', '.join(sample.categories)}")
|
||||
print("="*80 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
143
synth-data-pipeline/2_validate_qa.py
Normal file
143
synth-data-pipeline/2_validate_qa.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
"""
|
||||
Stage 2: Validate Q&A pairs for quality and accuracy.
|
||||
|
||||
This script:
|
||||
1. Loads Q&A pairs from Stage 1
|
||||
2. Uses Gemini 2.5 Flash Lite to validate each pair
|
||||
3. Filters out pairs that fail validation
|
||||
4. Saves validated pairs to output
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import logfire
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from src.synth_data_pipeline.agents import qa_validator
|
||||
|
||||
from src.synth_data_pipeline.models import QAPair, QAValidation, ValidatedQAPair
|
||||
from src.synth_data_pipeline.config import PATHS, STAGE_CONFIGS, FULL_PARAMS
|
||||
from src.synth_data_pipeline.utils import (
|
||||
load_jsonl,
|
||||
save_jsonl,
|
||||
process_with_concurrency,
|
||||
)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Configure logging
|
||||
logfire.configure(scrubbing=False)
|
||||
logfire.instrument_pydantic_ai()
|
||||
|
||||
# Get configuration for this stage
|
||||
config = STAGE_CONFIGS["stage2_qa_validation"]
|
||||
|
||||
# Load validation agent definition
|
||||
validation_prompt_template = qa_validator.get_prompt_template()
|
||||
validation_agent = qa_validator.build_agent(config)
|
||||
|
||||
|
||||
async def validate_qa_pair(qa_pair: QAPair) -> ValidatedQAPair:
|
||||
"""
|
||||
Validate a Q&A pair using Gemini.
|
||||
|
||||
Args:
|
||||
qa_pair: QAPair object to validate
|
||||
|
||||
Returns:
|
||||
ValidatedQAPair with validation result
|
||||
"""
|
||||
# Format the prompt
|
||||
prompt_text = validation_prompt_template.prompt.format(
|
||||
question=qa_pair.question,
|
||||
answer=qa_pair.answer,
|
||||
source_text=qa_pair.source_text,
|
||||
context_before=qa_pair.context_before,
|
||||
context_after=qa_pair.context_after,
|
||||
)
|
||||
|
||||
# Validate using the agent
|
||||
result = await validation_agent.run(prompt_text)
|
||||
validation = result.output
|
||||
|
||||
return ValidatedQAPair(
|
||||
qa_pair=qa_pair,
|
||||
validation=validation
|
||||
)
|
||||
|
||||
|
||||
async def main(
|
||||
input_file: str = None,
|
||||
output_file: str = None,
|
||||
max_concurrent: int = None
|
||||
):
|
||||
"""
|
||||
Main function to validate Q&A pairs.
|
||||
|
||||
Args:
|
||||
input_file: Path to input JSONL file (default from config)
|
||||
output_file: Path to output JSONL file (default from config)
|
||||
max_concurrent: Maximum concurrent API calls (default from config)
|
||||
"""
|
||||
# Use defaults from config if not specified
|
||||
input_file = input_file or PATHS.stage1_qa_pairs
|
||||
output_file = output_file or PATHS.stage2_qa_validated
|
||||
max_concurrent = max_concurrent or config.max_concurrent
|
||||
|
||||
logfire.info("Starting Q&A validation", input_file=input_file)
|
||||
|
||||
# Load Q&A pairs
|
||||
qa_pairs = load_jsonl(input_file, model_class=QAPair)
|
||||
logfire.info(f"Loaded {len(qa_pairs)} Q&A pairs")
|
||||
|
||||
# Validate all pairs
|
||||
with logfire.span("validate_qa_pairs"):
|
||||
validated_pairs = await process_with_concurrency(
|
||||
qa_pairs,
|
||||
validate_qa_pair,
|
||||
max_concurrent=max_concurrent,
|
||||
desc="Validating Q&A pairs"
|
||||
)
|
||||
|
||||
# Count passed/failed
|
||||
passed = [vp for vp in validated_pairs if vp.validation.passed]
|
||||
failed = [vp for vp in validated_pairs if not vp.validation.passed]
|
||||
|
||||
logfire.info(
|
||||
f"Validation complete: {len(passed)} passed, {len(failed)} failed "
|
||||
f"({100 * len(failed) / len(validated_pairs):.1f}% rejection rate)"
|
||||
)
|
||||
|
||||
# Save all validated pairs (with validation results)
|
||||
save_jsonl(validated_pairs, output_file)
|
||||
|
||||
# Also save just the passed Q&A pairs (without validation metadata) for next stage
|
||||
passed_qa_pairs = [vp.qa_pair for vp in passed]
|
||||
if output_file == PATHS.stage2_qa_validated:
|
||||
passed_output = PATHS.stage2_qa_validated_passed
|
||||
else:
|
||||
passed_output = output_file.replace('.jsonl', '_passed.jsonl')
|
||||
|
||||
save_jsonl(passed_qa_pairs, passed_output)
|
||||
logfire.info(f"Saved {len(passed_qa_pairs)} passed Q&A pairs to {passed_output}")
|
||||
|
||||
# Print sample
|
||||
if validated_pairs:
|
||||
print("\n" + "="*80)
|
||||
print("VALIDATION SAMPLE:")
|
||||
print("="*80)
|
||||
sample = validated_pairs[0]
|
||||
print(f"Question: {sample.qa_pair.question}")
|
||||
print(f"Answer: {sample.qa_pair.answer[:100]}...")
|
||||
print(f"\nValidation:")
|
||||
print(f" uses_source_fact: {sample.validation.uses_source_fact}")
|
||||
print(f" realistic_question: {sample.validation.realistic_question}")
|
||||
print(f" sensible_answer: {sample.validation.sensible_answer}")
|
||||
print(f" PASSED: {sample.validation.passed}")
|
||||
print(f" Feedback: {sample.validation.feedback}")
|
||||
print("="*80 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
204
synth-data-pipeline/3_generate_conversations.py
Normal file
204
synth-data-pipeline/3_generate_conversations.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
"""
|
||||
Stage 3: Generate conversations from validated Q&A pairs.
|
||||
|
||||
This script:
|
||||
1. Loads validated Q&A pairs from output/qa_pairs_validated_passed.jsonl (or the provided path)
|
||||
2. Samples different conversation configurations
|
||||
3. Uses Gemini to generate natural conversations
|
||||
4. Saves results to output/conversations_raw.jsonl
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
import logfire
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from src.synth_data_pipeline.agents import conversation_generator
|
||||
|
||||
from src.synth_data_pipeline.models import QAPair, Conversation
|
||||
from src.synth_data_pipeline.config import (
|
||||
PATHS,
|
||||
STAGE_CONFIGS,
|
||||
FULL_PARAMS,
|
||||
)
|
||||
from src.synth_data_pipeline.sampling import (
|
||||
stratified_sample_configs,
|
||||
load_system_prompts_from_files,
|
||||
set_random_seed,
|
||||
)
|
||||
from src.synth_data_pipeline.utils import (
|
||||
load_jsonl,
|
||||
save_jsonl,
|
||||
process_with_concurrency,
|
||||
)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Configure logging
|
||||
logfire.configure(scrubbing=False)
|
||||
logfire.instrument_pydantic_ai()
|
||||
|
||||
# Get configuration for this stage
|
||||
config = STAGE_CONFIGS["stage3_conversation_generation"]
|
||||
|
||||
# Load conversation generation agent definition
|
||||
conv_prompt_template = conversation_generator.get_prompt_template()
|
||||
conv_agent = conversation_generator.build_agent(config)
|
||||
|
||||
|
||||
async def generate_conversation(
|
||||
qa_pairs: list[QAPair],
|
||||
config_dict: dict
|
||||
) -> Conversation:
|
||||
"""
|
||||
Generate a conversation from a configuration.
|
||||
|
||||
Args:
|
||||
qa_pairs: All available Q&A pairs
|
||||
config_dict: Configuration dict with num_turns, style, persona, system_prompt
|
||||
|
||||
Returns:
|
||||
Conversation object
|
||||
"""
|
||||
# Sample Q&A pairs for this conversation
|
||||
num_turns = config_dict["num_turns"]
|
||||
|
||||
# Filter Q&A by topic relevance to system prompt (pragmatic matching)
|
||||
system_prompt_name = config_dict["system_prompt"].name if hasattr(config_dict["system_prompt"], "name") else "helpful"
|
||||
|
||||
# Simple persona matching: sales/solutions prompts shouldn't get governance/company info Q&A
|
||||
if "sales" in system_prompt_name.lower() or "solutions" in system_prompt_name.lower():
|
||||
# Prefer product/feature Q&A, avoid company registration/director questions
|
||||
avoid_keywords = ["appointed", "director", "registered", "incorporation", "companies house"]
|
||||
filtered_qa = [qa for qa in qa_pairs if not any(kw in qa.question.lower() for kw in avoid_keywords)]
|
||||
qa_pool = filtered_qa if filtered_qa else qa_pairs
|
||||
else:
|
||||
qa_pool = qa_pairs
|
||||
|
||||
sampled_qa = random.sample(qa_pool, min(num_turns, len(qa_pool)))
|
||||
|
||||
# Format Q&A pairs for the prompt
|
||||
qa_text = "\n\n".join([
|
||||
f"Q: {qa.question}\nA: {qa.answer}\nCategories: {', '.join(qa.categories)}"
|
||||
for qa in sampled_qa
|
||||
])
|
||||
|
||||
# Get system prompt text
|
||||
system_prompt = config_dict["system_prompt"].template
|
||||
|
||||
# Format the prompt
|
||||
prompt_text = conv_prompt_template.prompt.format(
|
||||
num_turns=num_turns,
|
||||
style=config_dict["style"],
|
||||
user_persona=config_dict["persona"].description,
|
||||
system_prompt=system_prompt,
|
||||
qa_pairs=qa_text,
|
||||
)
|
||||
|
||||
# Generate conversation using the agent
|
||||
result = await conv_agent.run(prompt_text)
|
||||
conversation = result.output
|
||||
|
||||
# Add source Q&A pairs to conversation for fact-checking
|
||||
conversation.source_qa_pairs = sampled_qa
|
||||
|
||||
return conversation
|
||||
|
||||
|
||||
async def main(
|
||||
qa_file: str = None,
|
||||
output_file: str = None,
|
||||
num_conversations: int = None,
|
||||
max_concurrent: int = None,
|
||||
):
|
||||
"""
|
||||
Main function to generate conversations from Q&A pairs.
|
||||
|
||||
Args:
|
||||
qa_file: Path to Q&A pairs JSONL file (default from config)
|
||||
output_file: Path to output JSONL file (default from config)
|
||||
num_conversations: Number of conversations to generate (default from config)
|
||||
max_concurrent: Maximum concurrent API calls (default from config)
|
||||
"""
|
||||
# Use defaults from config if not specified
|
||||
qa_file = qa_file or PATHS.stage2_qa_validated_passed
|
||||
output_file = output_file or PATHS.stage3_conversations_raw
|
||||
max_concurrent = max_concurrent or config.max_concurrent
|
||||
|
||||
# Set random seed if specified
|
||||
if FULL_PARAMS.random_seed is not None:
|
||||
set_random_seed(FULL_PARAMS.random_seed)
|
||||
|
||||
logfire.info("Starting conversation generation", qa_file=qa_file)
|
||||
|
||||
# Load Q&A pairs
|
||||
qa_pairs = load_jsonl(qa_file, model_class=QAPair)
|
||||
logfire.info(f"Loaded {len(qa_pairs)} Q&A pairs")
|
||||
|
||||
# Determine how many conversations to generate based on available QA pairs
|
||||
requested_conversations = num_conversations or FULL_PARAMS.num_conversations
|
||||
auto_cap = 0
|
||||
if FULL_PARAMS.conversations_per_qa:
|
||||
auto_cap = len(qa_pairs) * FULL_PARAMS.conversations_per_qa
|
||||
|
||||
target_conversations = requested_conversations
|
||||
if auto_cap:
|
||||
target_conversations = min(requested_conversations, auto_cap)
|
||||
|
||||
if target_conversations == 0:
|
||||
logfire.warning("No conversations generated because no Q&A pairs are available.")
|
||||
return
|
||||
|
||||
logfire.info(
|
||||
"Conversation target determined",
|
||||
requested=requested_conversations,
|
||||
auto_cap=auto_cap,
|
||||
final=target_conversations,
|
||||
)
|
||||
|
||||
# Load system prompt templates from files (for runtime flexibility)
|
||||
system_prompts_from_files = load_system_prompts_from_files()
|
||||
logfire.info(f"Loaded {len(system_prompts_from_files)} system prompts from files")
|
||||
|
||||
# Sample conversation configurations
|
||||
configs = stratified_sample_configs(
|
||||
target_conversations,
|
||||
ensure_coverage=True
|
||||
)
|
||||
logfire.info(f"Sampled {len(configs)} conversation configurations")
|
||||
|
||||
# Generate conversations
|
||||
with logfire.span("generate_conversations"):
|
||||
# Create a closure that includes qa_pairs
|
||||
async def generate_fn(config_dict):
|
||||
return await generate_conversation(qa_pairs, config_dict)
|
||||
|
||||
conversations = await process_with_concurrency(
|
||||
configs,
|
||||
generate_fn,
|
||||
max_concurrent=max_concurrent,
|
||||
desc="Generating conversations"
|
||||
)
|
||||
|
||||
logfire.info(f"Generated {len(conversations)} conversations")
|
||||
|
||||
# Save results
|
||||
save_jsonl(conversations, output_file)
|
||||
|
||||
# Print sample for inspection
|
||||
if conversations:
|
||||
print("\n" + "="*80)
|
||||
print("SAMPLE CONVERSATION:")
|
||||
print("="*80)
|
||||
sample = conversations[0]
|
||||
for msg in sample.messages:
|
||||
print(f"{msg.role.upper()}: {msg.content[:200]}...")
|
||||
print()
|
||||
print(f"Metadata: {sample.metadata}")
|
||||
print("="*80 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
259
synth-data-pipeline/4_judge_and_save.py
Normal file
259
synth-data-pipeline/4_judge_and_save.py
Normal file
|
|
@ -0,0 +1,259 @@
|
|||
"""
|
||||
Stage 4: Judge conversations and save top candidates.
|
||||
|
||||
This script:
|
||||
1. Loads raw conversations from output/conversations_raw.jsonl
|
||||
2. Uses Gemini to judge quality of each conversation
|
||||
3. Ranks by quality score
|
||||
4. Saves all judged conversations and top 1000 in NanoChat format
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import logfire
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from src.synth_data_pipeline.agents import conversation_judge
|
||||
|
||||
from src.synth_data_pipeline.models import (
|
||||
Conversation,
|
||||
JudgedConversation,
|
||||
JudgmentScore,
|
||||
NanoChatConversation,
|
||||
NanoChatMessage,
|
||||
)
|
||||
from src.synth_data_pipeline.config import (
|
||||
PATHS,
|
||||
STAGE_CONFIGS,
|
||||
FULL_PARAMS,
|
||||
)
|
||||
from src.synth_data_pipeline.utils import (
|
||||
load_jsonl,
|
||||
save_jsonl,
|
||||
process_with_concurrency,
|
||||
print_statistics,
|
||||
)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Configure logging
|
||||
logfire.configure(scrubbing=False)
|
||||
logfire.instrument_pydantic_ai()
|
||||
|
||||
# Get configuration for this stage
|
||||
config = STAGE_CONFIGS["stage4_judging"]
|
||||
|
||||
# Load judging agent definition
|
||||
judge_prompt_template = conversation_judge.get_prompt_template()
|
||||
judge_agent = conversation_judge.build_agent(config)
|
||||
|
||||
|
||||
async def judge_conversation(conversation: Conversation) -> JudgedConversation:
|
||||
"""
|
||||
Judge the quality of a conversation.
|
||||
|
||||
Args:
|
||||
conversation: Conversation object to judge
|
||||
|
||||
Returns:
|
||||
JudgedConversation with quality scores
|
||||
"""
|
||||
# Format conversation for judging
|
||||
conv_text = "\n\n".join([
|
||||
f"{msg.role.upper()}: {msg.content}"
|
||||
for msg in conversation.messages
|
||||
])
|
||||
|
||||
# Format source Q&A pairs for fact-checking
|
||||
source_qa_text = "\n\n".join([
|
||||
f"Q: {qa.question}\nA: {qa.answer}"
|
||||
for qa in conversation.source_qa_pairs
|
||||
])
|
||||
|
||||
# Format the prompt
|
||||
prompt_text = judge_prompt_template.prompt.format(
|
||||
conversation=conv_text,
|
||||
source_qa=source_qa_text if source_qa_text else "No source Q&A available"
|
||||
)
|
||||
|
||||
# Judge using the agent
|
||||
result = await judge_agent.run(prompt_text)
|
||||
judgment = result.output
|
||||
|
||||
return JudgedConversation(
|
||||
conversation=conversation,
|
||||
judgment=judgment
|
||||
)
|
||||
|
||||
|
||||
def conversation_to_nanochat(conversation: Conversation) -> NanoChatConversation:
|
||||
"""
|
||||
Convert a Conversation to NanoChat format.
|
||||
|
||||
Args:
|
||||
conversation: Conversation object
|
||||
|
||||
Returns:
|
||||
NanoChatConversation (just messages array)
|
||||
"""
|
||||
messages = [
|
||||
NanoChatMessage(role=msg.role, content=msg.content)
|
||||
for msg in conversation.messages
|
||||
]
|
||||
return NanoChatConversation(messages=messages)
|
||||
|
||||
|
||||
def save_top_conversations_nanochat(
|
||||
judged_conversations: list[JudgedConversation],
|
||||
output_path: str,
|
||||
top_k: int = 1000,
|
||||
min_score: float = None
|
||||
):
|
||||
"""
|
||||
Save top K conversations in NanoChat format.
|
||||
|
||||
Args:
|
||||
judged_conversations: List of judged conversations
|
||||
output_path: Path to output JSONL file
|
||||
top_k: Number of top conversations to save
|
||||
min_score: Minimum score threshold (optional)
|
||||
"""
|
||||
# Filter to only passing conversations
|
||||
passing_conversations = [
|
||||
jc for jc in judged_conversations
|
||||
if jc.judgment.overall_pass
|
||||
]
|
||||
|
||||
# Sort by number of criteria passed (for ordering within passing conversations)
|
||||
def count_passes(jc):
|
||||
return sum([
|
||||
jc.judgment.factually_accurate,
|
||||
jc.judgment.natural_conversation,
|
||||
jc.judgment.on_topic,
|
||||
jc.judgment.adds_value
|
||||
])
|
||||
|
||||
sorted_conversations = sorted(
|
||||
passing_conversations,
|
||||
key=count_passes,
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# Note: min_score parameter is ignored with bool-only system
|
||||
|
||||
# Take top K
|
||||
top_conversations = sorted_conversations[:top_k]
|
||||
|
||||
# Convert to NanoChat format and save
|
||||
nanochat_convs = [
|
||||
conversation_to_nanochat(jc.conversation)
|
||||
for jc in top_conversations
|
||||
]
|
||||
save_jsonl(nanochat_convs, output_path)
|
||||
|
||||
# Log statistics
|
||||
print(f"\nTop {len(top_conversations)} passing conversations selected")
|
||||
print(f" All passed: factually_accurate AND natural AND on_topic AND adds_value")
|
||||
|
||||
|
||||
def print_quality_statistics(judged_conversations: list[JudgedConversation]):
|
||||
"""Print quality statistics for all judged conversations."""
|
||||
if not judged_conversations:
|
||||
return
|
||||
|
||||
total = len(judged_conversations)
|
||||
passing = sum(1 for jc in judged_conversations if jc.judgment.overall_pass)
|
||||
factual_pass = sum(1 for jc in judged_conversations if jc.judgment.factually_accurate)
|
||||
natural_pass = sum(1 for jc in judged_conversations if jc.judgment.natural_conversation)
|
||||
ontopic_pass = sum(1 for jc in judged_conversations if jc.judgment.on_topic)
|
||||
value_pass = sum(1 for jc in judged_conversations if jc.judgment.adds_value)
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("QUALITY STATISTICS (All Conversations)")
|
||||
print("="*80)
|
||||
print(f"Total conversations judged: {total}")
|
||||
print(f"Overall PASS (all 4 criteria): {passing} ({passing/total*100:.1f}%)")
|
||||
print(f"\nIndividual criteria:")
|
||||
print(f" Factually accurate : {factual_pass}/{total} ({factual_pass/total*100:.1f}%)")
|
||||
print(f" Natural conversation: {natural_pass}/{total} ({natural_pass/total*100:.1f}%)")
|
||||
print(f" On topic : {ontopic_pass}/{total} ({ontopic_pass/total*100:.1f}%)")
|
||||
print(f" Adds value : {value_pass}/{total} ({value_pass/total*100:.1f}%)")
|
||||
print("="*80 + "\n")
|
||||
|
||||
|
||||
async def main(
|
||||
input_file: str = None,
|
||||
judged_output: str = None,
|
||||
nanochat_output: str = None,
|
||||
max_concurrent: int = None,
|
||||
top_k: int = None,
|
||||
min_score: float = None
|
||||
):
|
||||
"""
|
||||
Main function to judge conversations and save top K.
|
||||
|
||||
Args:
|
||||
input_file: Path to raw conversations JSONL file (default from config)
|
||||
judged_output: Path to save all judged conversations (default from config)
|
||||
nanochat_output: Path to save top K in NanoChat format (default from config)
|
||||
max_concurrent: Maximum concurrent API calls (default from config)
|
||||
top_k: Number of top conversations to save (default from config)
|
||||
min_score: Minimum quality score threshold (default from config)
|
||||
"""
|
||||
# Use defaults from config if not specified
|
||||
input_file = input_file or PATHS.stage3_conversations_raw
|
||||
judged_output = judged_output or PATHS.stage4_conversations_judged
|
||||
nanochat_output = nanochat_output or PATHS.stage7_conversations_final
|
||||
max_concurrent = max_concurrent or config.max_concurrent
|
||||
top_k = top_k or FULL_PARAMS.top_k
|
||||
min_score = min_score or FULL_PARAMS.min_quality_score
|
||||
|
||||
logfire.info("Starting conversation judging", input_file=input_file)
|
||||
|
||||
# Load conversations
|
||||
conversations = load_jsonl(input_file, model_class=Conversation)
|
||||
logfire.info(f"Loaded {len(conversations)} conversations")
|
||||
|
||||
# Judge conversations
|
||||
with logfire.span("judge_conversations"):
|
||||
judged_conversations = await process_with_concurrency(
|
||||
conversations,
|
||||
judge_conversation,
|
||||
max_concurrent=max_concurrent,
|
||||
desc="Judging conversations"
|
||||
)
|
||||
|
||||
logfire.info(f"Judged {len(judged_conversations)} conversations")
|
||||
|
||||
# Save all judged conversations
|
||||
save_jsonl(judged_conversations, judged_output)
|
||||
|
||||
# Print statistics
|
||||
print_quality_statistics(judged_conversations)
|
||||
|
||||
# Save top K in NanoChat format
|
||||
save_top_conversations_nanochat(
|
||||
judged_conversations,
|
||||
nanochat_output,
|
||||
top_k,
|
||||
min_score
|
||||
)
|
||||
|
||||
# Print sample of a passing conversation
|
||||
passing_convs = [jc for jc in judged_conversations if jc.judgment.overall_pass]
|
||||
if passing_convs:
|
||||
print("\n" + "="*80)
|
||||
print("SAMPLE PASSING CONVERSATION:")
|
||||
print("="*80)
|
||||
sample = passing_convs[0]
|
||||
print(f"Overall: PASS (all 4 criteria met)")
|
||||
print(f"Feedback: {sample.judgment.feedback}")
|
||||
print("\nConversation:")
|
||||
for msg in sample.conversation.messages:
|
||||
print(f"\n{msg.role.upper()}: {msg.content[:200]}...")
|
||||
print("="*80 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
108
synth-data-pipeline/5_embed_conversations.py
Normal file
108
synth-data-pipeline/5_embed_conversations.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
"""
|
||||
Stage 5: Embed conversations using OpenAI embeddings.
|
||||
|
||||
This script:
|
||||
1. Loads judged conversations from Stage 4
|
||||
2. Converts each conversation to text
|
||||
3. Generates embeddings using OpenAI API
|
||||
4. Saves conversations with embeddings
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import logfire
|
||||
from dotenv import load_dotenv
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from src.synth_data_pipeline.models import JudgedConversation, EmbeddedConversation
|
||||
from src.synth_data_pipeline.config import PATHS, FULL_PARAMS
|
||||
from src.synth_data_pipeline.utils import load_jsonl, save_jsonl
|
||||
from src.synth_data_pipeline.embedding_utils import (
|
||||
batch_embed,
|
||||
conversation_to_text,
|
||||
)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Configure logging
|
||||
logfire.configure(scrubbing=False)
|
||||
|
||||
|
||||
async def main(
|
||||
input_file: str = None,
|
||||
output_file: str = None
|
||||
):
|
||||
"""
|
||||
Main function to embed conversations.
|
||||
|
||||
Args:
|
||||
input_file: Path to input JSONL file (default from config)
|
||||
output_file: Path to output JSONL file (default from config)
|
||||
"""
|
||||
# Use defaults from config if not specified
|
||||
input_file = input_file or PATHS.stage4_conversations_judged
|
||||
output_file = output_file or PATHS.stage5_conversations_embedded
|
||||
|
||||
logfire.info("Starting conversation embedding", input_file=input_file)
|
||||
|
||||
# Load judged conversations
|
||||
judged_convs = load_jsonl(input_file, model_class=JudgedConversation)
|
||||
logfire.info(f"Loaded {len(judged_convs)} judged conversations")
|
||||
|
||||
# Convert conversations to text
|
||||
texts = []
|
||||
for jc in judged_convs:
|
||||
# Convert messages to text
|
||||
text = conversation_to_text(
|
||||
[msg.model_dump() for msg in jc.conversation.messages],
|
||||
max_chars=FULL_PARAMS.embedding_max_chars
|
||||
)
|
||||
texts.append(text)
|
||||
|
||||
# Initialize OpenAI client
|
||||
client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
# Generate embeddings
|
||||
with logfire.span("generate_embeddings"):
|
||||
embeddings = await batch_embed(
|
||||
texts,
|
||||
client,
|
||||
model=FULL_PARAMS.embedding_model,
|
||||
dimensions=FULL_PARAMS.embedding_dimensions,
|
||||
batch_size=FULL_PARAMS.embedding_batch_size,
|
||||
max_concurrent=20
|
||||
)
|
||||
|
||||
logfire.info(f"Generated {len(embeddings)} embeddings")
|
||||
|
||||
# Create embedded conversations
|
||||
embedded_convs = []
|
||||
for jc, emb, text in zip(judged_convs, embeddings, texts):
|
||||
embedded_conv = EmbeddedConversation(
|
||||
conversation=jc.conversation,
|
||||
judgment=jc.judgment,
|
||||
embedding=emb.tolist(), # Convert numpy array to list
|
||||
text_preview=text[:200] # First 200 chars for debugging
|
||||
)
|
||||
embedded_convs.append(embedded_conv)
|
||||
|
||||
# Save results
|
||||
save_jsonl(embedded_convs, output_file)
|
||||
logfire.info(f"Saved {len(embedded_convs)} embedded conversations")
|
||||
|
||||
# Print sample
|
||||
if embedded_convs:
|
||||
print("\n" + "="*80)
|
||||
print("EMBEDDING SAMPLE:")
|
||||
print("="*80)
|
||||
sample = embedded_convs[0]
|
||||
print(f"Conversation preview: {sample.text_preview}...")
|
||||
print(f"Embedding dimensions: {len(sample.embedding)}")
|
||||
print(f"Quality score: {sample.judgment.overall_score:.2f}")
|
||||
print("="*80 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
119
synth-data-pipeline/6_deduplicate.py
Normal file
119
synth-data-pipeline/6_deduplicate.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
"""
|
||||
Stage 6: Deduplicate conversations based on embedding similarity.
|
||||
|
||||
This script:
|
||||
1. Loads embedded conversations from Stage 5
|
||||
2. L2-normalizes embeddings
|
||||
3. Computes pairwise cosine similarity
|
||||
4. Removes duplicates above similarity threshold
|
||||
5. Saves unique conversations
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import numpy as np
|
||||
|
||||
import logfire
|
||||
|
||||
from src.synth_data_pipeline.models import EmbeddedConversation, UniqueConversation
|
||||
from src.synth_data_pipeline.config import PATHS, FULL_PARAMS
|
||||
from src.synth_data_pipeline.utils import load_jsonl, save_jsonl
|
||||
from src.synth_data_pipeline.embedding_utils import (
|
||||
l2_normalize,
|
||||
greedy_deduplicate,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logfire.configure(scrubbing=False)
|
||||
|
||||
|
||||
async def main(
|
||||
input_file: str = None,
|
||||
output_file: str = None,
|
||||
similarity_threshold: float = None
|
||||
):
|
||||
"""
|
||||
Main function to deduplicate conversations.
|
||||
|
||||
Args:
|
||||
input_file: Path to input JSONL file (default from config)
|
||||
output_file: Path to output JSONL file (default from config)
|
||||
similarity_threshold: Similarity threshold for deduplication (default from config)
|
||||
"""
|
||||
# Use defaults from config if not specified
|
||||
input_file = input_file or PATHS.stage5_conversations_embedded
|
||||
output_file = output_file or PATHS.stage6_conversations_unique
|
||||
similarity_threshold = similarity_threshold or FULL_PARAMS.dedup_similarity_threshold
|
||||
|
||||
logfire.info(
|
||||
"Starting deduplication",
|
||||
input_file=input_file,
|
||||
similarity_threshold=similarity_threshold
|
||||
)
|
||||
|
||||
# Load embedded conversations
|
||||
embedded_convs = load_jsonl(input_file, model_class=EmbeddedConversation)
|
||||
logfire.info(f"Loaded {len(embedded_convs)} embedded conversations")
|
||||
|
||||
# Extract embeddings and scores
|
||||
embeddings = [np.array(ec.embedding, dtype=np.float32) for ec in embedded_convs]
|
||||
scores = [ec.judgment.overall_score for ec in embedded_convs]
|
||||
|
||||
# L2-normalize embeddings for cosine similarity
|
||||
with logfire.span("normalize_embeddings"):
|
||||
normalized_embeddings = l2_normalize(embeddings)
|
||||
|
||||
logfire.info("Normalized embeddings for cosine similarity")
|
||||
|
||||
# Deduplicate
|
||||
with logfire.span("deduplicate"):
|
||||
kept_indices = greedy_deduplicate(
|
||||
normalized_embeddings,
|
||||
scores,
|
||||
similarity_threshold=similarity_threshold
|
||||
)
|
||||
|
||||
# Create unique conversations (without embeddings to save space)
|
||||
unique_convs = []
|
||||
for idx in kept_indices:
|
||||
ec = embedded_convs[idx]
|
||||
unique_conv = UniqueConversation(
|
||||
conversation=ec.conversation,
|
||||
judgment=ec.judgment
|
||||
)
|
||||
unique_convs.append(unique_conv)
|
||||
|
||||
# Save results
|
||||
save_jsonl(unique_convs, output_file)
|
||||
|
||||
# Statistics
|
||||
total = len(embedded_convs)
|
||||
kept = len(unique_convs)
|
||||
removed = total - kept
|
||||
removal_rate = 100 * removed / total if total > 0 else 0
|
||||
|
||||
logfire.info(
|
||||
f"Deduplication complete: {kept} kept, {removed} removed ({removal_rate:.1f}%)"
|
||||
)
|
||||
|
||||
# Print statistics
|
||||
print("\n" + "="*80)
|
||||
print("DEDUPLICATION STATISTICS:")
|
||||
print("="*80)
|
||||
print(f"Total conversations: {total}")
|
||||
print(f"Unique conversations: {kept}")
|
||||
print(f"Duplicates removed: {removed} ({removal_rate:.1f}%)")
|
||||
print(f"Similarity threshold: {similarity_threshold}")
|
||||
print("="*80 + "\n")
|
||||
|
||||
# Print score statistics for unique conversations
|
||||
unique_scores = [uc.judgment.overall_score for uc in unique_convs]
|
||||
if unique_scores:
|
||||
print("Unique conversation scores:")
|
||||
print(f" Average: {np.mean(unique_scores):.2f}")
|
||||
print(f" Min: {np.min(unique_scores):.2f}")
|
||||
print(f" Max: {np.max(unique_scores):.2f}")
|
||||
print("="*80 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
134
synth-data-pipeline/7_select_top.py
Normal file
134
synth-data-pipeline/7_select_top.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""
|
||||
Stage 7: Select top K conversations and convert to NanoChat format.
|
||||
|
||||
This script:
|
||||
1. Loads unique conversations from Stage 6
|
||||
2. Sorts by quality score
|
||||
3. Selects top K conversations
|
||||
4. Converts to NanoChat format (messages only)
|
||||
5. Saves final dataset
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import logfire
|
||||
|
||||
from src.synth_data_pipeline.models import UniqueConversation, NanoChatConversation, NanoChatMessage
|
||||
from src.synth_data_pipeline.config import PATHS, FULL_PARAMS
|
||||
from src.synth_data_pipeline.utils import load_jsonl, save_jsonl
|
||||
|
||||
# Configure logging
|
||||
logfire.configure(scrubbing=False)
|
||||
|
||||
|
||||
def conversation_to_nanochat(unique_conv: UniqueConversation) -> NanoChatConversation:
|
||||
"""
|
||||
Convert a UniqueConversation to NanoChat format.
|
||||
|
||||
Args:
|
||||
unique_conv: UniqueConversation object
|
||||
|
||||
Returns:
|
||||
NanoChatConversation (messages only)
|
||||
"""
|
||||
nanochat_messages = [
|
||||
NanoChatMessage(role=msg.role, content=msg.content)
|
||||
for msg in unique_conv.conversation.messages
|
||||
]
|
||||
|
||||
return NanoChatConversation(messages=nanochat_messages)
|
||||
|
||||
|
||||
async def main(
|
||||
input_file: str = None,
|
||||
output_file: str = None,
|
||||
top_k: int = None,
|
||||
min_score: float = None
|
||||
):
|
||||
"""
|
||||
Main function to select top K conversations.
|
||||
|
||||
Args:
|
||||
input_file: Path to input JSONL file (default from config)
|
||||
output_file: Path to output JSONL file (default from config)
|
||||
top_k: Number of top conversations to select (default from config)
|
||||
min_score: Minimum quality score threshold (default from config)
|
||||
"""
|
||||
# Use defaults from config if not specified
|
||||
input_file = input_file or PATHS.stage6_conversations_unique
|
||||
output_file = output_file or PATHS.stage7_conversations_final
|
||||
top_k = top_k or FULL_PARAMS.top_k
|
||||
min_score = min_score or FULL_PARAMS.min_quality_score
|
||||
|
||||
logfire.info(
|
||||
"Starting top-K selection",
|
||||
input_file=input_file,
|
||||
top_k=top_k,
|
||||
min_score=min_score
|
||||
)
|
||||
|
||||
# Load unique conversations
|
||||
unique_convs = load_jsonl(input_file, model_class=UniqueConversation)
|
||||
logfire.info(f"Loaded {len(unique_convs)} unique conversations")
|
||||
|
||||
# Filter by minimum score if specified
|
||||
if min_score is not None:
|
||||
filtered_convs = [
|
||||
uc for uc in unique_convs
|
||||
if uc.judgment.overall_score >= min_score
|
||||
]
|
||||
logfire.info(
|
||||
f"Filtered to {len(filtered_convs)} conversations with score >= {min_score}"
|
||||
)
|
||||
else:
|
||||
filtered_convs = unique_convs
|
||||
|
||||
# Sort by quality score (descending)
|
||||
sorted_convs = sorted(
|
||||
filtered_convs,
|
||||
key=lambda uc: uc.judgment.overall_score,
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# Select top K
|
||||
top_convs = sorted_convs[:top_k]
|
||||
logfire.info(f"Selected top {len(top_convs)} conversations")
|
||||
|
||||
# Convert to NanoChat format
|
||||
nanochat_convs = [conversation_to_nanochat(uc) for uc in top_convs]
|
||||
|
||||
# Save results
|
||||
save_jsonl(nanochat_convs, output_file)
|
||||
logfire.info(f"Saved {len(nanochat_convs)} conversations in NanoChat format")
|
||||
|
||||
# Print statistics
|
||||
print("\n" + "="*80)
|
||||
print("TOP-K SELECTION STATISTICS:")
|
||||
print("="*80)
|
||||
print(f"Total unique conversations: {len(unique_convs)}")
|
||||
print(f"After minimum score filter: {len(filtered_convs)}")
|
||||
print(f"Top K selected: {len(top_convs)}")
|
||||
print("="*80 + "\n")
|
||||
|
||||
if top_convs:
|
||||
scores = [uc.judgment.overall_score for uc in top_convs]
|
||||
print("Selected conversation scores:")
|
||||
print(f" Average: {sum(scores) / len(scores):.2f}")
|
||||
print(f" Min: {min(scores):.2f}")
|
||||
print(f" Max: {max(scores):.2f}")
|
||||
print("="*80 + "\n")
|
||||
|
||||
# Show best conversation
|
||||
best = top_convs[0]
|
||||
print("BEST CONVERSATION:")
|
||||
print("="*80)
|
||||
print(f"Score: {best.judgment.overall_score:.2f}")
|
||||
print(f"Feedback: {best.judgment.feedback}")
|
||||
print("\nMessages:")
|
||||
for msg in best.conversation.messages:
|
||||
print(f" {msg.role.upper()}: {msg.content[:100]}...")
|
||||
print("="*80 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
129
synth-data-pipeline/Makefile
Normal file
129
synth-data-pipeline/Makefile
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
.PHONY: help setup trial clean stage1 stage2 stage3 stage4 stage5 stage6 stage7 full
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@echo "Synthetic Data Pipeline - Makefile Commands"
|
||||
@echo "============================================"
|
||||
@echo ""
|
||||
@echo "Setup & Testing:"
|
||||
@echo " make setup - Install dependencies"
|
||||
@echo " make trial - Run trial pipeline (small dataset)"
|
||||
@echo ""
|
||||
@echo "Main Commands:"
|
||||
@echo " make full - Run full pipeline with deduplication (all stages)"
|
||||
@echo ""
|
||||
@echo "Individual Stages:"
|
||||
@echo " make stage1 - Extract Q&A pairs from documentation"
|
||||
@echo " make stage2 - Validate extracted Q&A pairs"
|
||||
@echo " make stage3 - Generate conversations from validated Q&A"
|
||||
@echo " make stage4 - Judge conversations and save passing ones"
|
||||
@echo " make stage5 - Generate embeddings for deduplication"
|
||||
@echo " make stage6 - Deduplicate similar conversations"
|
||||
@echo " make stage7 - Select top-K final conversations"
|
||||
@echo ""
|
||||
@echo "Utilities:"
|
||||
@echo " make clean - Remove generated outputs"
|
||||
@echo " make clean-trial - Remove trial outputs only"
|
||||
@echo " make stats - Show statistics from latest run"
|
||||
@echo ""
|
||||
@echo "Development:"
|
||||
@echo " make lint - Run code formatting checks"
|
||||
@echo " make format - Format code with ruff"
|
||||
@echo ""
|
||||
|
||||
# Setup
|
||||
setup:
|
||||
@echo "Installing dependencies..."
|
||||
uv sync
|
||||
@echo "✓ Dependencies installed"
|
||||
|
||||
# Testing
|
||||
trial:
|
||||
@echo "Running trial pipeline (small dataset)..."
|
||||
uv run trial_run.py
|
||||
|
||||
# Pipeline Stages
|
||||
stage1:
|
||||
@echo "Stage 1: Extracting Q&A pairs..."
|
||||
uv run 1_extract_qa.py
|
||||
|
||||
stage2:
|
||||
@echo "Stage 2: Validating Q&A pairs..."
|
||||
uv run 2_validate_qa.py
|
||||
|
||||
stage3:
|
||||
@echo "Stage 3: Generating conversations..."
|
||||
uv run 3_generate_conversations.py
|
||||
|
||||
stage4:
|
||||
@echo "Stage 4: Judging and saving passing conversations..."
|
||||
uv run 4_judge_and_save.py
|
||||
|
||||
stage5:
|
||||
@echo "Stage 5: Generating embeddings for deduplication..."
|
||||
uv run 5_embed_conversations.py
|
||||
|
||||
stage6:
|
||||
@echo "Stage 6: Deduplicating similar conversations..."
|
||||
uv run 6_deduplicate.py
|
||||
|
||||
stage7:
|
||||
@echo "Stage 7: Selecting top-K conversations..."
|
||||
uv run 7_select_top.py
|
||||
|
||||
# Run full pipeline (all stages with deduplication)
|
||||
full: stage1 stage2 stage3 stage4 stage5 stage6 stage7
|
||||
@echo ""
|
||||
@echo "============================================"
|
||||
@echo "✓ Full pipeline completed!"
|
||||
@echo "============================================"
|
||||
@echo ""
|
||||
@echo "Final outputs:"
|
||||
@echo " - output/qa_pairs.jsonl"
|
||||
@echo " - output/conversations_raw.jsonl"
|
||||
@echo " - output/conversations_judged.jsonl"
|
||||
@echo " - output/conversations_deduplicated.jsonl"
|
||||
@echo " - output/conversations_final.jsonl <-- Use this for training"
|
||||
@echo ""
|
||||
|
||||
# Cleaning
|
||||
clean:
|
||||
@echo "Removing all generated outputs..."
|
||||
rm -rf output/
|
||||
@echo "✓ Cleaned"
|
||||
|
||||
clean-trial:
|
||||
@echo "Removing trial outputs..."
|
||||
rm -f output/trial_*
|
||||
@echo "✓ Trial outputs cleaned"
|
||||
|
||||
# Statistics
|
||||
stats:
|
||||
@echo "Pipeline Statistics:"
|
||||
@echo "==================="
|
||||
@if [ -f output/qa_pairs.jsonl ]; then \
|
||||
echo "Q&A Pairs: $$(wc -l < output/qa_pairs.jsonl) pairs"; \
|
||||
fi
|
||||
@if [ -f output/conversations_raw.jsonl ]; then \
|
||||
echo "Raw Conversations: $$(wc -l < output/conversations_raw.jsonl) conversations"; \
|
||||
fi
|
||||
@if [ -f output/conversations_judged.jsonl ]; then \
|
||||
echo "Judged Conversations: $$(wc -l < output/conversations_judged.jsonl) conversations"; \
|
||||
fi
|
||||
@if [ -f output/conversations_final.jsonl ]; then \
|
||||
echo "Top Conversations: $$(wc -l < output/conversations_final.jsonl) conversations"; \
|
||||
fi
|
||||
|
||||
# Development
|
||||
lint:
|
||||
@echo "Running linting checks..."
|
||||
uv run ruff check src/ || true
|
||||
|
||||
format:
|
||||
@echo "Formatting code..."
|
||||
uv run ruff format src/ || true
|
||||
@echo "✓ Code formatted"
|
||||
|
||||
# Quick iteration workflow
|
||||
quick: clean-trial trial
|
||||
@echo "✓ Quick iteration complete"
|
||||
97
synth-data-pipeline/README.md
Normal file
97
synth-data-pipeline/README.md
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
# Synthetic Data Pipeline
|
||||
|
||||
Generate synthetic SWAP Commerce conversations for NanoChat fine-tuning. The workflow is split into small scripts so you can rerun or customize each stage independently, but every script is executed the same way: `uv run <script>.py`.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. **Environment variables** – create a `.env` file in the project root:
|
||||
```bash
|
||||
GOOGLE_API_KEY=your_gemini_key
|
||||
OPENAI_API_KEY=your_openai_key
|
||||
LOGFIRE_TOKEN=optional_logfire_token
|
||||
```
|
||||
2. **Source material** – place the facts you want to model at `data/swap_facts.md`.
|
||||
|
||||
## Pipeline Stages
|
||||
|
||||
Each stage reads from `output/` (if the previous step has already run) and writes its own JSONL artifact. Run them sequentially with `uv run`.
|
||||
|
||||
| Stage | Script | Purpose | Output |
|
||||
| --- | --- | --- | --- |
|
||||
| 1 | `1_extract_qa.py` | Chunk the facts file and create 3 grounded Q&A pairs per chunk. Filters out future dates. | `output/qa_pairs.jsonl` |
|
||||
| 2 | `2_validate_qa.py` | Filter hallucinated or low-quality pairs. Saves full validation metadata and a `_passed` file. | `output/qa_pairs_validated.jsonl` + `output/qa_pairs_validated_passed.jsonl` |
|
||||
| 3 | `3_generate_conversations.py` | Sample personas/styles and turn Q&A seeds into synthetic dialogues. Matches topics to personas. | `output/conversations_raw.jsonl` |
|
||||
| 4 | `4_judge_and_save.py` | Judge conversations using **bool-only rubric** (factual, natural, on-topic, adds-value). Only passing conversations advance. | `output/conversations_judged.jsonl`, `output/conversations_final.jsonl` |
|
||||
| 5 | `5_embed_conversations.py` | Create OpenAI embeddings for deduplication. | `output/conversations_embedded.jsonl` |
|
||||
| 6 | `6_deduplicate.py` | Remove near-duplicates via cosine similarity. | `output/conversations_deduplicated.jsonl` |
|
||||
| 7 | `7_select_top.py` | Export final training set of top-K conversations. | `output/conversations_final.jsonl` |
|
||||
|
||||
### Quick Start
|
||||
|
||||
**Run the full pipeline** (all 7 stages with deduplication):
|
||||
```bash
|
||||
make full
|
||||
```
|
||||
|
||||
**Or run a quick trial** on a small dataset first:
|
||||
```bash
|
||||
make trial
|
||||
```
|
||||
|
||||
### Running stages individually
|
||||
|
||||
If you prefer to run stages one at a time:
|
||||
|
||||
```bash
|
||||
uv run 1_extract_qa.py
|
||||
uv run 2_validate_qa.py
|
||||
uv run 3_generate_conversations.py
|
||||
uv run 4_judge_and_save.py
|
||||
uv run 5_embed_conversations.py
|
||||
uv run 6_deduplicate.py
|
||||
uv run 7_select_top.py
|
||||
```
|
||||
|
||||
Or use Makefile shortcuts: `make stage1`, `make stage2`, etc.
|
||||
|
||||
### Trial run (smoke test)
|
||||
|
||||
To test the pipeline on a small subset (~30 Q&A pairs, ~280 conversations):
|
||||
|
||||
```bash
|
||||
make trial
|
||||
# or
|
||||
uv run trial_run.py
|
||||
```
|
||||
|
||||
This produces `output/trial_*.jsonl` files plus quality statistics showing pass rates for each bool criterion.
|
||||
|
||||
## Makefile Commands
|
||||
|
||||
Run `make help` to see all available commands:
|
||||
- `make trial` - Quick test on small dataset
|
||||
- `make full` - Run complete pipeline with all 7 stages
|
||||
- `make stage1` through `make stage7` - Run individual stages
|
||||
- `make clean` - Remove all generated outputs
|
||||
- `make stats` - Show pipeline statistics
|
||||
|
||||
## Repository structure
|
||||
|
||||
```
|
||||
src/
|
||||
└── synth_data_pipeline/
|
||||
├── agents/ # Model-facing helpers + prompts for each stage
|
||||
│ ├── base.py
|
||||
│ ├── qa_extractor.py
|
||||
│ ├── qa_validator.py
|
||||
│ ├── conversation_generator.py
|
||||
│ ├── conversation_judge.py
|
||||
│ └── prompts/*.txt
|
||||
├── config.py # Shared constants and file paths
|
||||
├── models.py # Pydantic schemas for every artifact
|
||||
├── sampling.py # Persona/style sampling utilities
|
||||
├── utils.py # IO + concurrency helpers
|
||||
└── prompts/{system_prompts,personas}
|
||||
```
|
||||
|
||||
Workflow scripts (e.g., `1_extract_qa.py`, `3_generate_conversations.py`, etc.) live in the repo root next to `trial_run.py`. All generated data lands under `output/`.
|
||||
22
synth-data-pipeline/pyproject.toml
Normal file
22
synth-data-pipeline/pyproject.toml
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
[project]
|
||||
name = "synth-data-pipeline"
|
||||
version = "0.1.0"
|
||||
description = "Synthetic data generation pipeline for NanoChat fine-tuning"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"logfire>=4.14.2",
|
||||
"pydantic-ai>=1.6.0",
|
||||
"python-dotenv>=1.1.1",
|
||||
"textprompts>=1.0",
|
||||
"tqdm>=4.66.0",
|
||||
"openai>=1.0.0",
|
||||
"numpy>=1.24.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/synth_data_pipeline"]
|
||||
8
synth-data-pipeline/src/synth_data_pipeline/__init__.py
Normal file
8
synth-data-pipeline/src/synth_data_pipeline/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
"""Synthetic Data Pipeline for NanoChat Fine-Tuning."""
|
||||
|
||||
import textprompts
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
# Set strict metadata requirement for all prompts globally
|
||||
textprompts.set_metadata('strict')
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
"""Agent definitions for the synthetic data pipeline."""
|
||||
|
||||
from . import qa_extractor, qa_validator, conversation_generator, conversation_judge
|
||||
|
||||
__all__ = [
|
||||
"qa_extractor",
|
||||
"qa_validator",
|
||||
"conversation_generator",
|
||||
"conversation_judge",
|
||||
]
|
||||
57
synth-data-pipeline/src/synth_data_pipeline/agents/base.py
Normal file
57
synth-data-pipeline/src/synth_data_pipeline/agents/base.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
"""Shared helpers for agent construction and prompt loading."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
|
||||
import textprompts
|
||||
from pydantic_ai import Agent
|
||||
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
|
||||
from pydantic_ai.providers.google import GoogleProvider
|
||||
|
||||
from src.synth_data_pipeline.config import APIConfig
|
||||
|
||||
PROMPTS_DIR = Path(__file__).with_name("prompts")
|
||||
|
||||
|
||||
def get_prompt_path(name: str) -> Path:
|
||||
"""Return the absolute path to an agent prompt file."""
|
||||
return PROMPTS_DIR / f"{name}.txt"
|
||||
|
||||
|
||||
def load_prompt_template(name: str):
|
||||
"""Load a textprompts template by agent name."""
|
||||
return textprompts.load_prompt(str(get_prompt_path(name)))
|
||||
|
||||
|
||||
def build_google_agent(
|
||||
api_config: APIConfig,
|
||||
*,
|
||||
system_prompt: str,
|
||||
output_type: Type,
|
||||
api_key: str | None = None,
|
||||
) -> Agent:
|
||||
"""Construct a Google Gemini-backed agent for a stage."""
|
||||
api_key = api_key or os.getenv("GOOGLE_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("GOOGLE_API_KEY is not set")
|
||||
|
||||
model_settings = GoogleModelSettings(
|
||||
gemini_thinking_config={"thinking_budget": api_config.thinking_budget},
|
||||
temperature=api_config.temperature,
|
||||
timeout=api_config.timeout,
|
||||
)
|
||||
|
||||
model = GoogleModel(
|
||||
api_config.model,
|
||||
provider=GoogleProvider(api_key=api_key),
|
||||
)
|
||||
|
||||
return Agent(
|
||||
model,
|
||||
system_prompt=system_prompt,
|
||||
model_settings=model_settings,
|
||||
output_type=output_type,
|
||||
)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""Agent definition for conversation generation."""
|
||||
|
||||
from src.synth_data_pipeline.agents.base import build_google_agent, load_prompt_template
|
||||
from src.synth_data_pipeline.config import APIConfig
|
||||
from src.synth_data_pipeline.models import Conversation
|
||||
|
||||
PROMPT_NAME = "conversation_generator"
|
||||
SYSTEM_PROMPT = "You are an expert at creating natural, realistic conversations."
|
||||
|
||||
|
||||
def build_agent(config: APIConfig, *, api_key: str | None = None):
|
||||
"""Return a configured conversation generation agent."""
|
||||
return build_google_agent(
|
||||
config,
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
output_type=Conversation,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
|
||||
def get_prompt_template():
|
||||
"""Load the conversation generation prompt template."""
|
||||
return load_prompt_template(PROMPT_NAME)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""Agent definition for conversation quality judging."""
|
||||
|
||||
from src.synth_data_pipeline.agents.base import build_google_agent, load_prompt_template
|
||||
from src.synth_data_pipeline.config import APIConfig
|
||||
from src.synth_data_pipeline.models import JudgmentScore
|
||||
|
||||
PROMPT_NAME = "conversation_judge"
|
||||
SYSTEM_PROMPT = "You are an expert evaluator of training data quality for language models."
|
||||
|
||||
|
||||
def build_agent(config: APIConfig, *, api_key: str | None = None):
|
||||
"""Return a configured judging agent."""
|
||||
return build_google_agent(
|
||||
config,
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
output_type=JudgmentScore,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
|
||||
def get_prompt_template():
|
||||
"""Load the conversation judging prompt template."""
|
||||
return load_prompt_template(PROMPT_NAME)
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
---
|
||||
title = "Conversation Generation"
|
||||
version = "1.0.0"
|
||||
description = "Generate natural user-assistant conversations from Q&A pairs"
|
||||
---
|
||||
You are an expert at creating natural, realistic conversations.
|
||||
|
||||
Generate a conversation between a user and an AI assistant about SWAP Commerce.
|
||||
|
||||
CONFIGURATION:
|
||||
- Number of turns: {num_turns}
|
||||
- Conversation style: {style}
|
||||
- User persona: {user_persona}
|
||||
- System prompt: {system_prompt}
|
||||
|
||||
SOURCE Q&A PAIRS:
|
||||
{qa_pairs}
|
||||
|
||||
Create a natural conversation where:
|
||||
1. DON'T just ask the Q&A questions verbatim - make them conversational and contextual
|
||||
2. Weave Q&A content into realistic dialogue (e.g., "I'm evaluating platforms and wondering...")
|
||||
3. Combine multiple related Q&A pairs creatively when possible
|
||||
4. The assistant provides helpful, accurate answers based ONLY on the Q&A content
|
||||
5. The conversation style matches the specified style ({style})
|
||||
6. The user's tone/language matches their persona ({user_persona})
|
||||
7. Follow-up questions (if multi-turn) relate logically to previous exchanges
|
||||
|
||||
Style guidelines:
|
||||
- FORMAL: Professional language, complete sentences, no slang
|
||||
- CASUAL: Friendly tone, can use contractions, conversational
|
||||
- TECHNICAL: Uses technical terminology, assumes some expertise
|
||||
|
||||
IMPORTANT FOR UNIQUENESS:
|
||||
- Add realistic context/motivation for questions (not just "What is X?")
|
||||
- Vary conversation patterns (not all Q&A need to be direct questions)
|
||||
- Make it specific and interesting - avoid generic exchanges
|
||||
- The conversation should teach something concrete about SWAP Commerce
|
||||
|
||||
The conversation should feel natural and realistic, not like reading from a script.
|
||||
|
||||
Return the structured output with:
|
||||
- messages: array of {{role: "system"/"user"/"assistant", content: "..."}}
|
||||
- metadata: {{num_turns, style, user_persona, source_qa_ids, difficulty, categories}}
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
---
|
||||
title = "Conversation Quality Judging"
|
||||
version = "1.0.0"
|
||||
description = "Judge the quality of generated conversations"
|
||||
---
|
||||
You are an expert evaluator of training data quality for language models.
|
||||
|
||||
Evaluate the following conversation about SWAP Commerce:
|
||||
|
||||
{conversation}
|
||||
|
||||
SOURCE Q&A PAIRS (use these to verify factual accuracy):
|
||||
{source_qa}
|
||||
|
||||
Evaluate the conversation using these CLEAR YES/NO criteria:
|
||||
|
||||
1. FACTUALLY_ACCURATE (bool):
|
||||
PASS = All facts match SOURCE Q&A PAIRS above
|
||||
PASS = No hallucinations or invented details
|
||||
FAIL = Any fact contradicts source OR is made up
|
||||
|
||||
Important: Facts in the source Q&A are CORRECT - use them to verify claims
|
||||
|
||||
2. NATURAL_CONVERSATION (bool):
|
||||
PASS = Sounds like real human conversation
|
||||
PASS = Messages flow smoothly, natural transitions
|
||||
FAIL = Robotic, awkward, or unrealistic dialogue
|
||||
|
||||
3. ON_TOPIC (bool):
|
||||
PASS = Relevant to SWAP Commerce
|
||||
PASS = Would be useful for training a SWAP Commerce assistant
|
||||
FAIL = Off-topic or irrelevant content
|
||||
|
||||
4. ADDS_VALUE (bool):
|
||||
PASS = Covers topic in specific, interesting, or unique way
|
||||
PASS = Not just generic questions with simple answers
|
||||
FAIL = Generic, repetitive, or adds no unique insight
|
||||
|
||||
OVERALL_PASS (bool):
|
||||
TRUE = ALL four criteria above are TRUE
|
||||
FALSE = ANY criterion is FALSE
|
||||
|
||||
Provide:
|
||||
- Brief feedback (1-2 sentences) explaining your judgment
|
||||
- List specific issues found (if any)
|
||||
|
||||
Return the structured output with fields: factually_accurate, natural_conversation, on_topic, adds_value, overall_pass, feedback, issues.
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
---
|
||||
title = "Q&A Extraction from Documentation"
|
||||
version = "3.0.0"
|
||||
description = "Extract 3 diverse question-answer pairs from SWAP Commerce documentation"
|
||||
---
|
||||
You are an expert at creating high-quality, DIVERSE Q&A pairs about SWAP Commerce's e-commerce platform and business.
|
||||
|
||||
CRITICAL: Generate EXACTLY 3 different Q&A pairs from this text chunk. Each pair must be unique and diverse.
|
||||
|
||||
IMPORTANT CONTEXT:
|
||||
- We are generating training data for an AI assistant about SWAP Commerce
|
||||
- The MAIN TEXT is the primary focus for your Q&A pairs
|
||||
- CONTEXT BEFORE/AFTER provides surrounding information (3 lines before and 3 lines after the focus line)
|
||||
- Use context to understand the topic, but base your Q&A pairs primarily on the MAIN TEXT
|
||||
|
||||
CONTEXT BEFORE (3 lines):
|
||||
{context_before}
|
||||
|
||||
>>> MAIN TEXT (FOCUS ON THIS):
|
||||
{source_text}
|
||||
|
||||
CONTEXT AFTER (3 lines):
|
||||
{context_after}
|
||||
|
||||
For EACH of the 3 Q&A pairs, generate a question that:
|
||||
1. Sounds natural and human-like (as if a real user would ask it)
|
||||
2. Is clearly answerable from the MAIN TEXT
|
||||
3. Is specific and focused (not too broad or vague)
|
||||
4. VARIES in style - use different question types:
|
||||
- Factual: "What is...?" "Who...?" "When...?"
|
||||
- How-to: "How do I...?" "How can...?"
|
||||
- Comparison: "What's the difference between...?"
|
||||
- Explanatory: "Why...?" "What does... mean?"
|
||||
- Practical: "Can I...?" "Is it possible to...?"
|
||||
5. Uses different phrasings and vocabulary (avoid repetitive patterns)
|
||||
6. AVOID overly simple questions with one-word answers (e.g., "What is the website?" → "swap-commerce.com")
|
||||
|
||||
Generate an answer that:
|
||||
1. Is factually accurate based on the MAIN TEXT
|
||||
2. Is complete and helpful
|
||||
3. Maintains professional but friendly tone
|
||||
4. Includes specific details when relevant (numbers, names, dates, etc.)
|
||||
5. VARIES in length and structure (some concise, some detailed)
|
||||
|
||||
Also determine:
|
||||
- Difficulty level:
|
||||
* basic (simple factual recall)
|
||||
* intermediate (requires understanding/reasoning)
|
||||
* advanced (technical/complex concepts)
|
||||
- Categories: what topics does this cover (e.g., pricing, features, integrations, compliance, company_info, funding, etc.)
|
||||
|
||||
CRITICAL REQUIREMENTS:
|
||||
1. Generate EXACTLY 3 Q&A pairs
|
||||
2. Each pair must be DIVERSE (different question types, different aspects of the text)
|
||||
3. Avoid repetitive patterns across the 3 pairs
|
||||
4. Use varied vocabulary and sentence structures
|
||||
5. Each pair should focus on different information from the MAIN TEXT
|
||||
|
||||
Return a list of exactly 3 Q&A pairs, each with fields: question, answer, source_text, context_before, context_after, difficulty, categories.
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
---
|
||||
title = "Q&A Pair Validation"
|
||||
version = "1.0.0"
|
||||
description = "Validate extracted Q&A pairs for quality and accuracy"
|
||||
---
|
||||
You are an expert validator ensuring high-quality training data.
|
||||
|
||||
Validate the following Q&A pair about SWAP Commerce:
|
||||
|
||||
**Question:** {question}
|
||||
|
||||
**Answer:** {answer}
|
||||
|
||||
**Source Text:** {source_text}
|
||||
|
||||
**Context Before:** {context_before}
|
||||
|
||||
**Context After:** {context_after}
|
||||
|
||||
Evaluate this Q&A pair on THREE critical criteria:
|
||||
|
||||
1. **uses_source_fact** (bool):
|
||||
- Does the answer CORRECTLY use facts from the source text?
|
||||
- Are there any hallucinations or made-up information?
|
||||
- Is all information grounded in the provided source?
|
||||
- Return TRUE only if the answer is 100% factually accurate based on the source
|
||||
- Return FALSE if there are ANY hallucinations, incorrect facts, or unsupported claims
|
||||
|
||||
2. **realistic_question** (bool):
|
||||
- Would a real person actually ask this question?
|
||||
- Is it natural and human-like?
|
||||
- Is it specific enough to be useful (not too vague)?
|
||||
- Is it too generic or templated?
|
||||
- Return TRUE if this feels like a genuine user query
|
||||
- Return FALSE if it's awkward, too formal, or clearly AI-generated
|
||||
|
||||
3. **sensible_answer** (bool):
|
||||
- Is the answer appropriate for the question asked?
|
||||
- Does it actually answer what was asked?
|
||||
- Is the answer complete (not cut off or incomplete)?
|
||||
- Is the level of detail appropriate?
|
||||
- Return TRUE if the answer properly addresses the question
|
||||
- Return FALSE if it's off-topic, incomplete, or inappropriate
|
||||
|
||||
**Overall Pass/Fail:**
|
||||
- Set `passed = true` ONLY if ALL THREE booleans are true
|
||||
- Set `passed = false` if ANY criterion fails
|
||||
|
||||
**Feedback:**
|
||||
Provide a brief (1-2 sentence) explanation of your validation decision. If failed, specify which criterion/criteria failed and why.
|
||||
|
||||
Return the structured output with fields: uses_source_fact, realistic_question, sensible_answer, passed, feedback.
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
"""Agent definition for Q&A extraction."""
|
||||
|
||||
from src.synth_data_pipeline.agents.base import build_google_agent, load_prompt_template
|
||||
from src.synth_data_pipeline.config import APIConfig
|
||||
from src.synth_data_pipeline.models import QAPairBatch
|
||||
|
||||
PROMPT_NAME = "qa_extractor"
|
||||
SYSTEM_PROMPT = (
|
||||
"You are an expert at creating high-quality, diverse Q&A pairs from documentation."
|
||||
)
|
||||
|
||||
|
||||
def build_agent(config: APIConfig, *, api_key: str | None = None):
|
||||
"""Return a configured Q&A extraction agent."""
|
||||
return build_google_agent(
|
||||
config,
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
output_type=QAPairBatch,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
|
||||
def get_prompt_template():
|
||||
"""Load the Q&A extraction prompt template."""
|
||||
return load_prompt_template(PROMPT_NAME)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""Agent definition for Q&A validation."""
|
||||
|
||||
from src.synth_data_pipeline.agents.base import build_google_agent, load_prompt_template
|
||||
from src.synth_data_pipeline.config import APIConfig
|
||||
from src.synth_data_pipeline.models import QAValidation
|
||||
|
||||
PROMPT_NAME = "qa_validator"
|
||||
SYSTEM_PROMPT = "You are an expert validator ensuring high-quality training data."
|
||||
|
||||
|
||||
def build_agent(config: APIConfig, *, api_key: str | None = None):
|
||||
"""Return a configured Q&A validation agent."""
|
||||
return build_google_agent(
|
||||
config,
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
output_type=QAValidation,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
|
||||
def get_prompt_template():
|
||||
"""Load the Q&A validation prompt template."""
|
||||
return load_prompt_template(PROMPT_NAME)
|
||||
426
synth-data-pipeline/src/synth_data_pipeline/config.py
Normal file
426
synth-data-pipeline/src/synth_data_pipeline/config.py
Normal file
|
|
@ -0,0 +1,426 @@
|
|||
"""
|
||||
Configuration and constants for the synthetic data pipeline.
|
||||
|
||||
This module centralizes all variations, tags, and configuration options
|
||||
for generating diverse training data.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Literal
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Difficulty Levels
|
||||
# ============================================================================
|
||||
|
||||
DIFFICULTY_LEVELS = ["basic", "intermediate", "advanced"]
|
||||
|
||||
DIFFICULTY_DESCRIPTIONS = {
|
||||
"basic": "Simple factual questions requiring basic recall",
|
||||
"intermediate": "Questions requiring understanding and reasoning",
|
||||
"advanced": "Complex technical or multi-faceted questions"
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Conversation Styles
|
||||
# ============================================================================
|
||||
|
||||
CONVERSATION_STYLES = ["formal", "casual", "technical"]
|
||||
|
||||
STYLE_DESCRIPTIONS = {
|
||||
"formal": "Professional language, complete sentences, no slang",
|
||||
"casual": "Friendly tone, can use contractions, conversational",
|
||||
"technical": "Uses technical terminology, assumes some expertise"
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# User Personas
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class Persona:
|
||||
"""Definition of a user persona."""
|
||||
name: str
|
||||
description: str
|
||||
typical_questions: List[str] = field(default_factory=list)
|
||||
formality: Literal["formal", "casual", "neutral"] = "neutral"
|
||||
|
||||
|
||||
PERSONAS = {
|
||||
"developer": Persona(
|
||||
name="developer",
|
||||
description="Software developer or engineer evaluating SWAP Commerce's APIs and technical implementation",
|
||||
typical_questions=[
|
||||
"API integration details",
|
||||
"Technical specifications",
|
||||
"SDK usage",
|
||||
"Error handling"
|
||||
],
|
||||
formality="technical"
|
||||
),
|
||||
"product_manager": Persona(
|
||||
name="product_manager",
|
||||
description="Product manager researching SWAP Commerce features, capabilities, and business value",
|
||||
typical_questions=[
|
||||
"Feature comparisons",
|
||||
"Roadmap questions",
|
||||
"Use cases",
|
||||
"ROI analysis"
|
||||
],
|
||||
formality="formal"
|
||||
),
|
||||
"cs_agent": Persona(
|
||||
name="cs_agent",
|
||||
description="Customer success or support agent learning about SWAP Commerce to help customers",
|
||||
typical_questions=[
|
||||
"Setup instructions",
|
||||
"Troubleshooting",
|
||||
"Configuration options",
|
||||
"Best practices"
|
||||
],
|
||||
formality="neutral"
|
||||
),
|
||||
"executive": Persona(
|
||||
name="executive",
|
||||
description="Business executive or decision-maker evaluating SWAP Commerce for strategic fit and ROI",
|
||||
typical_questions=[
|
||||
"Business value",
|
||||
"Competitive advantages",
|
||||
"Pricing strategy",
|
||||
"Scalability"
|
||||
],
|
||||
formality="formal"
|
||||
),
|
||||
"operations": Persona(
|
||||
name="operations",
|
||||
description="Operations or logistics manager interested in SWAP Commerce's operational features and integrations",
|
||||
typical_questions=[
|
||||
"Integration capabilities",
|
||||
"Workflow automation",
|
||||
"Performance metrics",
|
||||
"SLA guarantees"
|
||||
],
|
||||
formality="neutral"
|
||||
),
|
||||
"finance": Persona(
|
||||
name="finance",
|
||||
description="Finance or accounting professional interested in tax compliance, pricing, and financial aspects",
|
||||
typical_questions=[
|
||||
"Tax compliance",
|
||||
"Financial reporting",
|
||||
"Audit trails",
|
||||
"Cost structure"
|
||||
],
|
||||
formality="formal"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# System Prompt Templates
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class SystemPromptTemplate:
|
||||
"""Definition of a system prompt template."""
|
||||
name: str
|
||||
description: str
|
||||
template: str
|
||||
verbosity: Literal["concise", "balanced", "detailed"] = "balanced"
|
||||
use_case: str = "general"
|
||||
|
||||
|
||||
SYSTEM_PROMPT_TEMPLATES = {
|
||||
"helpful": SystemPromptTemplate(
|
||||
name="helpful",
|
||||
description="Helpful and friendly assistant",
|
||||
template="You are a helpful AI assistant with expertise in SWAP Commerce's e-commerce platform and services. You provide accurate, friendly, and detailed answers to questions about SWAP Commerce's products, features, integrations, and pricing.",
|
||||
verbosity="detailed",
|
||||
use_case="general"
|
||||
),
|
||||
"concise": SystemPromptTemplate(
|
||||
name="concise",
|
||||
description="Brief and to-the-point responses",
|
||||
template="You are a SWAP Commerce expert providing clear, concise answers. Focus on key information without unnecessary detail.",
|
||||
verbosity="concise",
|
||||
use_case="quick_reference"
|
||||
),
|
||||
"technical": SystemPromptTemplate(
|
||||
name="technical",
|
||||
description="Technical expert for developers",
|
||||
template="You are a technical expert on SWAP Commerce's platform. You provide detailed technical information about APIs, integrations, implementation, and system architecture. You assume the user has technical knowledge.",
|
||||
verbosity="detailed",
|
||||
use_case="developer"
|
||||
),
|
||||
"detailed": SystemPromptTemplate(
|
||||
name="detailed",
|
||||
description="Comprehensive explanations",
|
||||
template="You are a comprehensive SWAP Commerce expert who provides thorough, well-explained answers with examples, context, and relevant details. You ensure users fully understand the topic.",
|
||||
verbosity="detailed",
|
||||
use_case="learning"
|
||||
),
|
||||
"sales": SystemPromptTemplate(
|
||||
name="sales",
|
||||
description="Sales and solutions focused",
|
||||
template="You are a SWAP Commerce solutions consultant helping potential customers understand how SWAP Commerce can solve their e-commerce challenges. You're knowledgeable about features, benefits, and competitive advantages.",
|
||||
verbosity="balanced",
|
||||
use_case="sales"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Topic Categories
|
||||
# ============================================================================
|
||||
|
||||
TOPIC_CATEGORIES = [
|
||||
"pricing",
|
||||
"features",
|
||||
"integrations",
|
||||
"api",
|
||||
"compliance",
|
||||
"tax",
|
||||
"shipping",
|
||||
"returns",
|
||||
"tracking",
|
||||
"inventory",
|
||||
"operations",
|
||||
"company_info",
|
||||
"funding",
|
||||
"customers",
|
||||
"partnerships",
|
||||
"security",
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Conversation Length Distribution
|
||||
# ============================================================================
|
||||
|
||||
# Default distribution of conversation lengths (num_turns)
|
||||
DEFAULT_TURN_DISTRIBUTION = {
|
||||
1: 0.20, # 20% single-turn
|
||||
2: 0.35, # 35% two-turn
|
||||
3: 0.35, # 35% three-turn
|
||||
4: 0.10, # 10% four-turn
|
||||
}
|
||||
|
||||
# Alternative distributions for different use cases
|
||||
TURN_DISTRIBUTIONS = {
|
||||
"default": DEFAULT_TURN_DISTRIBUTION,
|
||||
"short": {1: 0.6, 2: 0.3, 3: 0.1}, # Mostly short conversations
|
||||
"long": {2: 0.2, 3: 0.4, 4: 0.4}, # Longer conversations
|
||||
"balanced": {1: 0.25, 2: 0.25, 3: 0.25, 4: 0.25}, # Equal distribution
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# User Emotions
|
||||
# ============================================================================
|
||||
|
||||
USER_EMOTIONS = ["professional", "happy", "frustrated", "impatient", "confused"]
|
||||
|
||||
USER_EMOTION_DISTRIBUTION = {
|
||||
"professional": 0.50, # Most common - neutral, business-like
|
||||
"happy": 0.15, # Positive, enthusiastic
|
||||
"frustrated": 0.15, # Having issues, needs help
|
||||
"impatient": 0.10, # Wants quick answers
|
||||
"confused": 0.10, # Unclear about something
|
||||
}
|
||||
|
||||
EMOTION_DESCRIPTIONS = {
|
||||
"professional": "Neutral, business-like tone. Formal language, clear questions.",
|
||||
"happy": "Positive, enthusiastic. May express excitement about features or capabilities.",
|
||||
"frustrated": "Experiencing issues or challenges. May express mild annoyance or urgency.",
|
||||
"impatient": "Wants quick, direct answers. Brief messages, may skip pleasantries.",
|
||||
"confused": "Unclear about concepts or features. May ask for clarification or examples."
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Input Modalities
|
||||
# ============================================================================
|
||||
|
||||
INPUT_MODALITIES = ["standard", "typed_on_phone", "voice_dictated"]
|
||||
|
||||
INPUT_MODALITY_DISTRIBUTION = {
|
||||
"standard": 0.70, # Normal typing on computer
|
||||
"typed_on_phone": 0.20, # Mobile typing - autocorrect errors, brevity
|
||||
"voice_dictated": 0.10, # Voice-to-text - filler words, natural speech
|
||||
}
|
||||
|
||||
MODALITY_DESCRIPTIONS = {
|
||||
"standard": "Standard computer typing. Clean text, proper formatting.",
|
||||
"typed_on_phone": "Mobile device typing. May have autocorrect errors, abbreviations, shorter messages.",
|
||||
"voice_dictated": "Voice-to-text transcription. May include 'um', 'uh', natural speech patterns, occasional transcription errors."
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Text Variations
|
||||
# ============================================================================
|
||||
|
||||
TEXT_VARIATIONS = ["standard", "all_lowercase", "no_punctuation"]
|
||||
|
||||
TEXT_VARIATION_DISTRIBUTION = {
|
||||
"standard": 0.80, # Normal capitalization and punctuation
|
||||
"all_lowercase": 0.15, # all lowercase (casual/mobile)
|
||||
"no_punctuation": 0.05, # missing punctuation (rushed/mobile)
|
||||
}
|
||||
|
||||
VARIATION_DESCRIPTIONS = {
|
||||
"standard": "Standard capitalization and punctuation.",
|
||||
"all_lowercase": "All lowercase letters (common in casual or mobile communication).",
|
||||
"no_punctuation": "Missing or minimal punctuation (rushed typing or informal style)."
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Quality Scoring Weights
|
||||
# ============================================================================
|
||||
|
||||
QUALITY_WEIGHTS = {
|
||||
"factual_accuracy": 0.35,
|
||||
"naturalness": 0.25,
|
||||
"relevance": 0.25,
|
||||
"diversity": 0.15,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# API Configuration
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class APIConfig:
|
||||
"""Configuration for API calls."""
|
||||
model: str = "gemini-2.5-flash-lite"
|
||||
max_concurrent: int = 10
|
||||
temperature: float = 0.9 # Higher for generation, lower for judging
|
||||
thinking_budget: int = 0
|
||||
timeout: int = 60
|
||||
|
||||
|
||||
# Default configurations for each stage
|
||||
STAGE_CONFIGS = {
|
||||
"stage1_qa_extraction": APIConfig(
|
||||
model="gemini-2.5-flash",
|
||||
temperature=0.8,
|
||||
max_concurrent=30,
|
||||
),
|
||||
"stage2_qa_validation": APIConfig(
|
||||
model="gemini-2.5-flash-lite",
|
||||
temperature=0.0, # Low for validation consistency
|
||||
max_concurrent=50, # Can be faster since lighter model
|
||||
),
|
||||
"stage3_conversation_generation": APIConfig(
|
||||
model="gemini-2.5-flash",
|
||||
temperature=0.9, # High diversity
|
||||
max_concurrent=30,
|
||||
),
|
||||
"stage4_judging": APIConfig(
|
||||
model="gemini-2.5-flash",
|
||||
temperature=0.0, # Low for consistency
|
||||
max_concurrent=50,
|
||||
),
|
||||
"stage5_embedding": APIConfig(
|
||||
model="text-embedding-3-small", # OpenAI model
|
||||
temperature=0, # Not applicable to embeddings
|
||||
max_concurrent=20, # High concurrency for batch processing
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# File Paths
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class FilePaths:
|
||||
"""Standard file paths for the pipeline."""
|
||||
data_dir: str = "data"
|
||||
prompts_dir: str = "prompts"
|
||||
output_dir: str = "output"
|
||||
|
||||
# Input files
|
||||
source_facts: str = "data/swap_facts.md"
|
||||
|
||||
# Stage outputs (full pipeline)
|
||||
stage1_qa_pairs: str = "output/qa_pairs.jsonl"
|
||||
stage2_qa_validated: str = "output/qa_pairs_validated.jsonl"
|
||||
stage2_qa_validated_passed: str = "output/qa_pairs_validated_passed.jsonl"
|
||||
stage3_conversations_raw: str = "output/conversations_raw.jsonl"
|
||||
stage4_conversations_judged: str = "output/conversations_judged.jsonl"
|
||||
stage5_conversations_embedded: str = "output/conversations_embedded.jsonl"
|
||||
stage6_conversations_unique: str = "output/conversations_unique.jsonl"
|
||||
stage7_conversations_final: str = "output/conversations_final.jsonl"
|
||||
|
||||
# Trial outputs
|
||||
trial_qa_pairs: str = "output/trial_qa_pairs.jsonl"
|
||||
trial_qa_validated: str = "output/trial_qa_validated.jsonl"
|
||||
trial_conversations_raw: str = "output/trial_conversations_raw.jsonl"
|
||||
trial_conversations_judged: str = "output/trial_conversations_judged.jsonl"
|
||||
trial_conversations_embedded: str = "output/trial_conversations_embedded.jsonl"
|
||||
trial_conversations_unique: str = "output/trial_conversations_unique.jsonl"
|
||||
trial_conversations_final: str = "output/trial_conversations_final.jsonl"
|
||||
|
||||
|
||||
PATHS = FilePaths()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Pipeline Parameters
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class PipelineParams:
|
||||
"""Parameters for the full pipeline run."""
|
||||
|
||||
# Stage 1: Q&A Extraction
|
||||
qa_chunk_context_lines: int = 3
|
||||
qa_pairs_per_chunk: int = 3 # Generate 3 Q&A pairs per chunk
|
||||
qa_chunk_limit: int | None = None # None = no limit on chunks
|
||||
|
||||
# Stage 2: Q&A Validation
|
||||
qa_validation_enabled: bool = True
|
||||
|
||||
# Stage 3: Conversation Generation
|
||||
num_conversations: int = 2000
|
||||
conversations_per_qa: int = 10
|
||||
turn_distribution: Dict[int, float] = field(default_factory=lambda: DEFAULT_TURN_DISTRIBUTION)
|
||||
emotion_distribution: Dict[str, float] = field(default_factory=lambda: USER_EMOTION_DISTRIBUTION)
|
||||
modality_distribution: Dict[str, float] = field(default_factory=lambda: INPUT_MODALITY_DISTRIBUTION)
|
||||
variation_distribution: Dict[str, float] = field(default_factory=lambda: TEXT_VARIATION_DISTRIBUTION)
|
||||
|
||||
# Stage 4: Judging
|
||||
min_quality_score: float = 5.0 # Minimum acceptable score
|
||||
|
||||
# Stage 5: Embedding
|
||||
embedding_model: str = "text-embedding-3-small"
|
||||
embedding_dimensions: int = 1024
|
||||
embedding_batch_size: int = 100
|
||||
embedding_max_chars: int = 24000
|
||||
|
||||
# Stage 6: Deduplication
|
||||
dedup_enabled: bool = True
|
||||
dedup_similarity_threshold: float = 0.95 # 95% similarity
|
||||
|
||||
# Stage 7: Selection
|
||||
top_k: int = 1000
|
||||
|
||||
# General
|
||||
max_concurrent: int = 10
|
||||
random_seed: int | None = 42 # For reproducibility
|
||||
|
||||
|
||||
# Default parameters for full runs
|
||||
FULL_PARAMS = PipelineParams(
|
||||
qa_chunk_limit=None,
|
||||
qa_pairs_per_chunk=3,
|
||||
num_conversations=2000,
|
||||
top_k=1000,
|
||||
max_concurrent=10,
|
||||
dedup_enabled=True,
|
||||
)
|
||||
183
synth-data-pipeline/src/synth_data_pipeline/embedding_utils.py
Normal file
183
synth-data-pipeline/src/synth_data_pipeline/embedding_utils.py
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
"""
|
||||
Embedding utilities for conversation deduplication.
|
||||
|
||||
This module provides functions for:
|
||||
1. Generating embeddings using OpenAI API
|
||||
2. Computing cosine similarity between conversations
|
||||
3. Deduplicating based on similarity threshold
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List
|
||||
import numpy as np
|
||||
from openai import AsyncOpenAI
|
||||
import logfire
|
||||
|
||||
|
||||
async def batch_embed(
|
||||
texts: List[str],
|
||||
client: AsyncOpenAI,
|
||||
model: str = "text-embedding-3-small",
|
||||
dimensions: int = 1024,
|
||||
batch_size: int = 100,
|
||||
max_concurrent: int = 20
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Generate embeddings for a list of texts using OpenAI API.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to embed
|
||||
client: AsyncOpenAI client instance
|
||||
model: OpenAI embedding model name
|
||||
dimensions: Number of dimensions for embeddings
|
||||
batch_size: Number of texts per API call
|
||||
max_concurrent: Maximum concurrent API calls
|
||||
|
||||
Returns:
|
||||
List of numpy arrays (embeddings)
|
||||
"""
|
||||
logfire.info(f"Embedding {len(texts)} texts in batches of {batch_size}")
|
||||
|
||||
# Split into batches
|
||||
batches = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]
|
||||
|
||||
# Semaphore for rate limiting
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def embed_batch(batch: List[str]) -> List[List[float]]:
|
||||
"""Embed a single batch of texts."""
|
||||
async with semaphore:
|
||||
try:
|
||||
response = await client.embeddings.create(
|
||||
model=model,
|
||||
input=batch,
|
||||
dimensions=dimensions
|
||||
)
|
||||
return [item.embedding for item in response.data]
|
||||
except Exception as e:
|
||||
logfire.error(f"Error embedding batch: {e}")
|
||||
raise
|
||||
|
||||
# Process all batches concurrently
|
||||
all_embeddings = []
|
||||
tasks = [embed_batch(batch) for batch in batches]
|
||||
|
||||
batch_results = await asyncio.gather(*tasks)
|
||||
|
||||
# Flatten results
|
||||
for batch_embeds in batch_results:
|
||||
all_embeddings.extend(batch_embeds)
|
||||
|
||||
# Convert to numpy arrays
|
||||
embeddings_np = [np.array(emb, dtype=np.float32) for emb in all_embeddings]
|
||||
|
||||
logfire.info(f"Generated {len(embeddings_np)} embeddings")
|
||||
return embeddings_np
|
||||
|
||||
|
||||
def l2_normalize(embeddings: List[np.ndarray]) -> List[np.ndarray]:
|
||||
"""
|
||||
L2-normalize embeddings for cosine similarity via dot product.
|
||||
|
||||
Args:
|
||||
embeddings: List of numpy arrays
|
||||
|
||||
Returns:
|
||||
List of L2-normalized numpy arrays
|
||||
"""
|
||||
normalized = []
|
||||
for emb in embeddings:
|
||||
norm = np.linalg.norm(emb)
|
||||
if norm > 0:
|
||||
normalized.append(emb / norm)
|
||||
else:
|
||||
normalized.append(emb)
|
||||
return normalized
|
||||
|
||||
|
||||
def compute_similarity(emb1: np.ndarray, emb2: np.ndarray) -> float:
|
||||
"""
|
||||
Compute cosine similarity between two L2-normalized embeddings.
|
||||
|
||||
Args:
|
||||
emb1: First embedding (L2-normalized)
|
||||
emb2: Second embedding (L2-normalized)
|
||||
|
||||
Returns:
|
||||
Cosine similarity (0-1, where 1 is identical)
|
||||
"""
|
||||
return float(np.dot(emb1, emb2))
|
||||
|
||||
|
||||
def greedy_deduplicate(
|
||||
embeddings: List[np.ndarray],
|
||||
scores: List[float],
|
||||
similarity_threshold: float = 0.95
|
||||
) -> List[int]:
|
||||
"""
|
||||
Greedy deduplication: keep highest-scoring items, remove similar duplicates.
|
||||
|
||||
Args:
|
||||
embeddings: List of L2-normalized embeddings
|
||||
scores: List of quality scores (higher is better)
|
||||
similarity_threshold: Similarity threshold for considering items duplicates
|
||||
|
||||
Returns:
|
||||
List of indices to KEEP (deduplicated)
|
||||
"""
|
||||
# Create indices sorted by score (descending)
|
||||
sorted_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
|
||||
|
||||
kept_indices = []
|
||||
kept_embeddings = []
|
||||
|
||||
for idx in sorted_indices:
|
||||
emb = embeddings[idx]
|
||||
|
||||
# Check similarity against all kept items
|
||||
is_duplicate = False
|
||||
for kept_emb in kept_embeddings:
|
||||
similarity = compute_similarity(emb, kept_emb)
|
||||
if similarity >= similarity_threshold:
|
||||
is_duplicate = True
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
kept_indices.append(idx)
|
||||
kept_embeddings.append(emb)
|
||||
|
||||
# Return indices in original order
|
||||
kept_indices.sort()
|
||||
|
||||
logfire.info(
|
||||
f"Deduplication: {len(kept_indices)} kept, {len(embeddings) - len(kept_indices)} removed "
|
||||
f"({100 * (len(embeddings) - len(kept_indices)) / len(embeddings):.1f}% removed)"
|
||||
)
|
||||
|
||||
return kept_indices
|
||||
|
||||
|
||||
def conversation_to_text(messages: List[dict], max_chars: int = 24000) -> str:
|
||||
"""
|
||||
Convert conversation messages to a single text string for embedding.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'
|
||||
max_chars: Maximum characters to include
|
||||
|
||||
Returns:
|
||||
Concatenated text string
|
||||
"""
|
||||
parts = []
|
||||
for msg in messages:
|
||||
role = msg.get('role', 'unknown').upper()
|
||||
content = msg.get('content', '')
|
||||
parts.append(f"{role}: {content}")
|
||||
|
||||
full_text = "\n\n".join(parts)
|
||||
|
||||
# Truncate if too long
|
||||
if len(full_text) > max_chars:
|
||||
full_text = full_text[:max_chars]
|
||||
|
||||
return full_text
|
||||
275
synth-data-pipeline/src/synth_data_pipeline/models.py
Normal file
275
synth-data-pipeline/src/synth_data_pipeline/models.py
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
"""
|
||||
Pydantic models for the synthetic data generation pipeline.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Stage 1: Q&A Extraction Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class QAPair(BaseModel):
|
||||
"""A question-answer pair extracted from source documentation."""
|
||||
|
||||
question: str = Field(
|
||||
description="A natural question that could be asked about this topic"
|
||||
)
|
||||
answer: str = Field(
|
||||
description="The accurate answer grounded in the source text"
|
||||
)
|
||||
source_text: str = Field(
|
||||
description="The specific text chunk this Q&A was generated from"
|
||||
)
|
||||
context_before: str = Field(
|
||||
default="",
|
||||
description="Preceding lines for context"
|
||||
)
|
||||
context_after: str = Field(
|
||||
default="",
|
||||
description="Following lines for context"
|
||||
)
|
||||
difficulty: Literal["basic", "intermediate", "advanced"] = Field(
|
||||
description="The difficulty level of this question"
|
||||
)
|
||||
categories: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Topic categories (e.g., 'pricing', 'features', 'integrations')"
|
||||
)
|
||||
|
||||
|
||||
class QAPairBatch(BaseModel):
|
||||
"""A batch of 3 Q&A pairs generated from a single chunk."""
|
||||
|
||||
qa_pairs: list[QAPair] = Field(
|
||||
description="Exactly 3 diverse Q&A pairs from the same source chunk"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Stage 2: Q&A Validation Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class QAValidation(BaseModel):
|
||||
"""Validation result for a Q&A pair."""
|
||||
|
||||
uses_source_fact: bool = Field(
|
||||
description="Does the Q&A correctly use facts from the source text (no hallucinations)?"
|
||||
)
|
||||
realistic_question: bool = Field(
|
||||
description="Is this a question a real person would ask?"
|
||||
)
|
||||
sensible_answer: bool = Field(
|
||||
description="Is the answer appropriate and sensible for the question?"
|
||||
)
|
||||
passed: bool = Field(
|
||||
description="Overall pass (all three bools must be True)"
|
||||
)
|
||||
feedback: str = Field(
|
||||
description="Brief explanation of validation result"
|
||||
)
|
||||
|
||||
|
||||
class ValidatedQAPair(BaseModel):
|
||||
"""A Q&A pair with its validation result."""
|
||||
|
||||
qa_pair: QAPair = Field(
|
||||
description="The Q&A pair being validated"
|
||||
)
|
||||
validation: QAValidation = Field(
|
||||
description="The validation result"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Stage 3: Conversation Generation Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""A single message in a conversation."""
|
||||
|
||||
role: Literal["system", "user", "assistant"] = Field(
|
||||
description="The role of the message sender"
|
||||
)
|
||||
content: str = Field(
|
||||
description="The message content"
|
||||
)
|
||||
|
||||
|
||||
class ConversationMetadata(BaseModel):
|
||||
"""Metadata about how a conversation was generated."""
|
||||
|
||||
num_turns: int = Field(
|
||||
description="Number of user-assistant turns (not counting system message)"
|
||||
)
|
||||
style: Literal["formal", "casual", "technical"] = Field(
|
||||
description="The conversation style"
|
||||
)
|
||||
user_persona: str = Field(
|
||||
description="The persona/role of the user (e.g., 'developer', 'business owner')"
|
||||
)
|
||||
user_emotion: Literal["professional", "happy", "frustrated", "impatient", "confused"] = Field(
|
||||
default="professional",
|
||||
description="The emotional state of the user"
|
||||
)
|
||||
input_modality: Literal["standard", "typed_on_phone", "voice_dictated"] = Field(
|
||||
default="standard",
|
||||
description="How the user is inputting their messages"
|
||||
)
|
||||
text_variation: Literal["standard", "all_lowercase", "no_punctuation"] = Field(
|
||||
default="standard",
|
||||
description="Text formatting variation applied to user messages"
|
||||
)
|
||||
source_qa_ids: list[int] = Field(
|
||||
default_factory=list,
|
||||
description="Indices of Q&A pairs used to generate this conversation"
|
||||
)
|
||||
difficulty: str = Field(
|
||||
description="Overall difficulty level"
|
||||
)
|
||||
categories: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Topic categories covered in this conversation"
|
||||
)
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
"""A complete conversation with metadata."""
|
||||
|
||||
messages: list[Message] = Field(
|
||||
description="The conversation messages (system, user, assistant)"
|
||||
)
|
||||
metadata: ConversationMetadata = Field(
|
||||
description="Metadata about this conversation"
|
||||
)
|
||||
source_qa_pairs: list[QAPair] = Field(
|
||||
default_factory=list,
|
||||
description="The Q&A pairs used to generate this conversation (for fact-checking)"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Stage 4: Judging Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class JudgmentScore(BaseModel):
|
||||
"""Quality judgment for a conversation using clear YES/NO rubrics."""
|
||||
|
||||
factually_accurate: bool = Field(
|
||||
description="PASS: All facts match source Q&A, no hallucinations or invented details"
|
||||
)
|
||||
natural_conversation: bool = Field(
|
||||
description="PASS: Sounds human, flows naturally, realistic interaction"
|
||||
)
|
||||
on_topic: bool = Field(
|
||||
description="PASS: Relevant to SWAP Commerce, would be useful for training"
|
||||
)
|
||||
adds_value: bool = Field(
|
||||
description="PASS: Not generic/repetitive, covers topic in specific/interesting way"
|
||||
)
|
||||
overall_pass: bool = Field(
|
||||
description="TRUE only if ALL four criteria above are TRUE"
|
||||
)
|
||||
feedback: str = Field(
|
||||
description="Brief explanation of judgment (1-2 sentences)"
|
||||
)
|
||||
issues: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Specific problems found (if any)"
|
||||
)
|
||||
|
||||
|
||||
class JudgedConversation(BaseModel):
|
||||
"""A conversation with its quality judgment."""
|
||||
|
||||
conversation: Conversation = Field(
|
||||
description="The conversation being judged"
|
||||
)
|
||||
judgment: JudgmentScore = Field(
|
||||
description="The quality judgment scores"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Stage 5: Embedding Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class EmbeddedConversation(BaseModel):
|
||||
"""A judged conversation with its embedding."""
|
||||
|
||||
conversation: Conversation = Field(
|
||||
description="The conversation"
|
||||
)
|
||||
judgment: JudgmentScore = Field(
|
||||
description="The quality judgment"
|
||||
)
|
||||
embedding: list[float] = Field(
|
||||
description="Conversation embedding (1024 dimensions)"
|
||||
)
|
||||
text_preview: str = Field(
|
||||
description="First 200 characters for debugging"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Stage 6: Deduplication Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class UniqueConversation(BaseModel):
|
||||
"""A conversation marked as unique after deduplication."""
|
||||
|
||||
conversation: Conversation = Field(
|
||||
description="The conversation"
|
||||
)
|
||||
judgment: JudgmentScore = Field(
|
||||
description="The quality judgment"
|
||||
)
|
||||
# Note: embedding removed to save space after dedup
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Stage 7: Final Output Format (NanoChat compatible)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class NanoChatMessage(BaseModel):
|
||||
"""A message in NanoChat format."""
|
||||
|
||||
role: Literal["system", "user", "assistant"]
|
||||
content: str
|
||||
|
||||
|
||||
class NanoChatConversation(BaseModel):
|
||||
"""NanoChat training format - just the messages array."""
|
||||
|
||||
messages: list[NanoChatMessage]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Prompt Generation Models (for structured LLM outputs)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class QAGenerationRequest(BaseModel):
|
||||
"""Input for Q&A generation from a text chunk."""
|
||||
|
||||
chunk: str
|
||||
context_before: str = ""
|
||||
context_after: str = ""
|
||||
|
||||
|
||||
class ConversationGenerationRequest(BaseModel):
|
||||
"""Input for conversation generation."""
|
||||
|
||||
qa_pairs: list[QAPair]
|
||||
num_turns: int = Field(ge=1, le=4, description="Number of conversation turns")
|
||||
style: Literal["formal", "casual", "technical"]
|
||||
user_persona: str
|
||||
system_prompt_template: str
|
||||
|
|
@ -0,0 +1 @@
|
|||
Customer success or support agent learning about SWAP Commerce to help customers
|
||||
|
|
@ -0,0 +1 @@
|
|||
Software developer or engineer evaluating SWAP Commerce's APIs and technical implementation
|
||||
|
|
@ -0,0 +1 @@
|
|||
Business executive or decision-maker evaluating SWAP Commerce for strategic fit and ROI
|
||||
|
|
@ -0,0 +1 @@
|
|||
Finance or accounting professional interested in tax compliance, pricing, and financial aspects
|
||||
|
|
@ -0,0 +1 @@
|
|||
Operations or logistics manager interested in SWAP Commerce's operational features and integrations
|
||||
|
|
@ -0,0 +1 @@
|
|||
Product manager researching SWAP Commerce features, capabilities, and business value
|
||||
|
|
@ -0,0 +1 @@
|
|||
You are a SWAP Commerce expert providing clear, concise answers. Focus on key information without unnecessary detail.
|
||||
|
|
@ -0,0 +1 @@
|
|||
You are a comprehensive SWAP Commerce expert who provides thorough, well-explained answers with examples, context, and relevant details. You ensure users fully understand the topic.
|
||||
|
|
@ -0,0 +1 @@
|
|||
You are a helpful AI assistant with expertise in SWAP Commerce's e-commerce platform and services. You provide accurate, friendly, and detailed answers to questions about SWAP Commerce's products, features, integrations, and pricing.
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
You are a SWAP Commerce solutions consultant helping potential customers understand how SWAP Commerce can solve their e-commerce challenges. You're knowledgeable about features, benefits, and competitive advantages.
|
||||
|
||||
If asked about internal company matters (board appointments, corporate registration, etc.), politely note that you focus on product solutions and can help with questions about features and benefits.
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
You are a technical expert on SWAP Commerce's platform. You provide detailed technical information about APIs, integrations, implementation, and system architecture. You assume the user has technical knowledge.
|
||||
|
||||
If asked about non-technical matters (corporate governance, board appointments, etc.), suggest those questions would be better directed to company information resources.
|
||||
315
synth-data-pipeline/src/synth_data_pipeline/sampling.py
Normal file
315
synth-data-pipeline/src/synth_data_pipeline/sampling.py
Normal file
|
|
@ -0,0 +1,315 @@
|
|||
"""
|
||||
Sampling utilities for generating diverse conversation configurations.
|
||||
"""
|
||||
|
||||
import random
|
||||
from typing import List, Dict
|
||||
from pathlib import Path
|
||||
|
||||
from .config import (
|
||||
PERSONAS,
|
||||
SYSTEM_PROMPT_TEMPLATES,
|
||||
CONVERSATION_STYLES,
|
||||
DEFAULT_TURN_DISTRIBUTION,
|
||||
USER_EMOTIONS,
|
||||
USER_EMOTION_DISTRIBUTION,
|
||||
INPUT_MODALITIES,
|
||||
INPUT_MODALITY_DISTRIBUTION,
|
||||
TEXT_VARIATIONS,
|
||||
TEXT_VARIATION_DISTRIBUTION,
|
||||
Persona,
|
||||
SystemPromptTemplate,
|
||||
)
|
||||
|
||||
|
||||
def sample_persona() -> Persona:
|
||||
"""Sample a random user persona."""
|
||||
return random.choice(list(PERSONAS.values()))
|
||||
|
||||
|
||||
def sample_system_prompt() -> SystemPromptTemplate:
|
||||
"""Sample a random system prompt template."""
|
||||
return random.choice(list(SYSTEM_PROMPT_TEMPLATES.values()))
|
||||
|
||||
|
||||
def sample_style() -> str:
|
||||
"""Sample a random conversation style."""
|
||||
return random.choice(CONVERSATION_STYLES)
|
||||
|
||||
|
||||
def sample_num_turns(distribution: Dict[int, float] = None) -> int:
|
||||
"""
|
||||
Sample number of conversation turns based on a distribution.
|
||||
|
||||
Args:
|
||||
distribution: Dict mapping num_turns -> probability.
|
||||
If None, uses DEFAULT_TURN_DISTRIBUTION.
|
||||
|
||||
Returns:
|
||||
Number of turns (1-4)
|
||||
"""
|
||||
if distribution is None:
|
||||
distribution = DEFAULT_TURN_DISTRIBUTION
|
||||
|
||||
turns = list(distribution.keys())
|
||||
weights = list(distribution.values())
|
||||
return random.choices(turns, weights=weights)[0]
|
||||
|
||||
|
||||
def sample_emotion(distribution: Dict[str, float] = None) -> str:
|
||||
"""
|
||||
Sample user emotion based on a distribution.
|
||||
|
||||
Args:
|
||||
distribution: Dict mapping emotion -> probability.
|
||||
If None, uses USER_EMOTION_DISTRIBUTION.
|
||||
|
||||
Returns:
|
||||
Emotion string: "professional", "happy", "frustrated", "impatient", or "confused"
|
||||
"""
|
||||
if distribution is None:
|
||||
distribution = USER_EMOTION_DISTRIBUTION
|
||||
|
||||
emotions = list(distribution.keys())
|
||||
weights = list(distribution.values())
|
||||
return random.choices(emotions, weights=weights)[0]
|
||||
|
||||
|
||||
def sample_input_modality(distribution: Dict[str, float] = None) -> str:
|
||||
"""
|
||||
Sample input modality based on a distribution.
|
||||
|
||||
Args:
|
||||
distribution: Dict mapping modality -> probability.
|
||||
If None, uses INPUT_MODALITY_DISTRIBUTION.
|
||||
|
||||
Returns:
|
||||
Modality string: "standard", "typed_on_phone", or "voice_dictated"
|
||||
"""
|
||||
if distribution is None:
|
||||
distribution = INPUT_MODALITY_DISTRIBUTION
|
||||
|
||||
modalities = list(distribution.keys())
|
||||
weights = list(distribution.values())
|
||||
return random.choices(modalities, weights=weights)[0]
|
||||
|
||||
|
||||
def sample_text_variation(distribution: Dict[str, float] = None) -> str:
|
||||
"""
|
||||
Sample text variation based on a distribution.
|
||||
|
||||
Args:
|
||||
distribution: Dict mapping variation -> probability.
|
||||
If None, uses TEXT_VARIATION_DISTRIBUTION.
|
||||
|
||||
Returns:
|
||||
Variation string: "standard", "all_lowercase", or "no_punctuation"
|
||||
"""
|
||||
if distribution is None:
|
||||
distribution = TEXT_VARIATION_DISTRIBUTION
|
||||
|
||||
variations = list(distribution.keys())
|
||||
weights = list(distribution.values())
|
||||
return random.choices(variations, weights=weights)[0]
|
||||
|
||||
|
||||
def sample_persona_by_formality(formality: str) -> Persona:
|
||||
"""
|
||||
Sample a persona matching a specific formality level.
|
||||
|
||||
Args:
|
||||
formality: "formal", "casual", or "neutral"
|
||||
|
||||
Returns:
|
||||
A Persona with matching formality
|
||||
"""
|
||||
matching = [p for p in PERSONAS.values() if p.formality == formality]
|
||||
return random.choice(matching) if matching else sample_persona()
|
||||
|
||||
|
||||
def sample_system_prompt_by_use_case(use_case: str) -> SystemPromptTemplate:
|
||||
"""
|
||||
Sample a system prompt matching a specific use case.
|
||||
|
||||
Args:
|
||||
use_case: e.g., "developer", "sales", "general"
|
||||
|
||||
Returns:
|
||||
A SystemPromptTemplate with matching use case
|
||||
"""
|
||||
matching = [
|
||||
s for s in SYSTEM_PROMPT_TEMPLATES.values()
|
||||
if s.use_case == use_case
|
||||
]
|
||||
return random.choice(matching) if matching else sample_system_prompt()
|
||||
|
||||
|
||||
def sample_balanced_config(
|
||||
prefer_long_conversations: bool = False,
|
||||
prefer_technical: bool = False,
|
||||
) -> Dict:
|
||||
"""
|
||||
Sample a balanced conversation configuration.
|
||||
|
||||
Args:
|
||||
prefer_long_conversations: If True, bias towards longer conversations
|
||||
prefer_technical: If True, bias towards technical personas/prompts
|
||||
|
||||
Returns:
|
||||
Dict with sampled configuration
|
||||
"""
|
||||
# Sample turn distribution
|
||||
if prefer_long_conversations:
|
||||
num_turns = sample_num_turns({2: 0.2, 3: 0.4, 4: 0.4})
|
||||
else:
|
||||
num_turns = sample_num_turns()
|
||||
|
||||
# Sample style
|
||||
if prefer_technical:
|
||||
style = "technical" if random.random() < 0.6 else sample_style()
|
||||
else:
|
||||
style = sample_style()
|
||||
|
||||
# Sample persona
|
||||
if prefer_technical:
|
||||
persona = PERSONAS.get("developer") if random.random() < 0.5 else sample_persona()
|
||||
else:
|
||||
persona = sample_persona()
|
||||
|
||||
# Sample system prompt - match use case to persona if possible
|
||||
if prefer_technical:
|
||||
system_prompt = sample_system_prompt_by_use_case("developer")
|
||||
else:
|
||||
system_prompt = sample_system_prompt()
|
||||
|
||||
return {
|
||||
"num_turns": num_turns,
|
||||
"style": style,
|
||||
"persona": persona,
|
||||
"system_prompt": system_prompt,
|
||||
"user_emotion": sample_emotion(),
|
||||
"input_modality": sample_input_modality(),
|
||||
"text_variation": sample_text_variation(),
|
||||
}
|
||||
|
||||
|
||||
def load_system_prompts_from_files(prompts_dir: str = "src/synth_data_pipeline/prompts/system_prompts") -> Dict[str, str]:
|
||||
"""
|
||||
Load system prompt templates from text files.
|
||||
|
||||
Args:
|
||||
prompts_dir: Directory containing prompt text files
|
||||
|
||||
Returns:
|
||||
Dict mapping template name to prompt text
|
||||
"""
|
||||
prompts_path = Path(prompts_dir)
|
||||
system_prompts = {}
|
||||
|
||||
for prompt_file in prompts_path.glob("*.txt"):
|
||||
name = prompt_file.stem
|
||||
with open(prompt_file, 'r', encoding='utf-8') as f:
|
||||
system_prompts[name] = f.read().strip()
|
||||
|
||||
return system_prompts
|
||||
|
||||
|
||||
def load_personas_from_files(personas_dir: str = "src/synth_data_pipeline/prompts/personas") -> Dict[str, str]:
|
||||
"""
|
||||
Load persona descriptions from text files.
|
||||
|
||||
Args:
|
||||
personas_dir: Directory containing persona text files
|
||||
|
||||
Returns:
|
||||
Dict mapping persona name to description
|
||||
"""
|
||||
personas_path = Path(personas_dir)
|
||||
personas = {}
|
||||
|
||||
for persona_file in personas_path.glob("*.txt"):
|
||||
name = persona_file.stem
|
||||
with open(persona_file, 'r', encoding='utf-8') as f:
|
||||
personas[name] = f.read().strip()
|
||||
|
||||
return personas
|
||||
|
||||
|
||||
def set_random_seed(seed: int):
|
||||
"""Set random seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def sample_multiple_configs(
|
||||
n: int,
|
||||
distribution: Dict[int, float] = None,
|
||||
prefer_long_conversations: bool = False,
|
||||
prefer_technical: bool = False,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Sample multiple conversation configurations.
|
||||
|
||||
Args:
|
||||
n: Number of configurations to sample
|
||||
distribution: Turn count distribution
|
||||
prefer_long_conversations: Bias towards longer conversations
|
||||
prefer_technical: Bias towards technical content
|
||||
|
||||
Returns:
|
||||
List of configuration dicts
|
||||
"""
|
||||
configs = []
|
||||
for _ in range(n):
|
||||
config = sample_balanced_config(
|
||||
prefer_long_conversations=prefer_long_conversations,
|
||||
prefer_technical=prefer_technical,
|
||||
)
|
||||
configs.append(config)
|
||||
return configs
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Sampling Strategies
|
||||
# ============================================================================
|
||||
|
||||
def stratified_sample_configs(
|
||||
n: int,
|
||||
ensure_coverage: bool = True
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Sample configurations with stratified sampling to ensure diversity.
|
||||
|
||||
Args:
|
||||
n: Total number of configurations to sample
|
||||
ensure_coverage: If True, ensure all personas/styles are represented
|
||||
|
||||
Returns:
|
||||
List of configuration dicts with guaranteed diversity
|
||||
"""
|
||||
configs = []
|
||||
|
||||
if ensure_coverage and n >= len(PERSONAS) * len(CONVERSATION_STYLES):
|
||||
# First, ensure we have at least one of each persona-style combination
|
||||
for persona in PERSONAS.values():
|
||||
for style in CONVERSATION_STYLES:
|
||||
configs.append({
|
||||
"num_turns": sample_num_turns(),
|
||||
"style": style,
|
||||
"persona": persona,
|
||||
"system_prompt": sample_system_prompt(),
|
||||
"user_emotion": sample_emotion(),
|
||||
"input_modality": sample_input_modality(),
|
||||
"text_variation": sample_text_variation(),
|
||||
})
|
||||
|
||||
# Fill remaining with random samples
|
||||
remaining = n - len(configs)
|
||||
for _ in range(remaining):
|
||||
configs.append(sample_balanced_config())
|
||||
else:
|
||||
# Just sample randomly
|
||||
configs = sample_multiple_configs(n)
|
||||
|
||||
# Shuffle to avoid patterns
|
||||
random.shuffle(configs)
|
||||
return configs[:n]
|
||||
220
synth-data-pipeline/src/synth_data_pipeline/utils.py
Normal file
220
synth-data-pipeline/src/synth_data_pipeline/utils.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
"""
|
||||
Common utilities for the synthetic data pipeline.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, TypeVar, Callable, Awaitable
|
||||
|
||||
import logfire
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
T = TypeVar('T')
|
||||
R = TypeVar('R')
|
||||
|
||||
|
||||
async def process_with_concurrency(
|
||||
items: List[T],
|
||||
process_fn: Callable[[T], Awaitable[R]],
|
||||
max_concurrent: int = 10,
|
||||
desc: str = "Processing"
|
||||
) -> List[R]:
|
||||
"""
|
||||
Process items concurrently with a semaphore to limit concurrency.
|
||||
|
||||
Args:
|
||||
items: List of items to process
|
||||
process_fn: Async function to process each item
|
||||
max_concurrent: Maximum number of concurrent operations
|
||||
desc: Description for progress bar
|
||||
|
||||
Returns:
|
||||
List of results (None entries filtered out)
|
||||
"""
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def bounded_process(item: T) -> R | None:
|
||||
async with semaphore:
|
||||
try:
|
||||
return await process_fn(item)
|
||||
except Exception as e:
|
||||
logfire.error(f"Error processing item: {e}", item=item)
|
||||
return None
|
||||
|
||||
# Process all items with progress bar
|
||||
tasks = [bounded_process(item) for item in items]
|
||||
results = await tqdm_asyncio.gather(*tasks, desc=desc)
|
||||
|
||||
# Filter out None results (errors)
|
||||
return [r for r in results if r is not None]
|
||||
|
||||
|
||||
def save_jsonl(items: List, output_path: str | Path):
|
||||
"""
|
||||
Save items to JSONL file.
|
||||
|
||||
Args:
|
||||
items: List of items (must have model_dump_json method or be dicts)
|
||||
output_path: Path to output JSONL file
|
||||
"""
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for item in items:
|
||||
if hasattr(item, 'model_dump_json'):
|
||||
f.write(item.model_dump_json() + '\n')
|
||||
else:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
|
||||
logfire.info(f"Saved {len(items)} items to {output_path}")
|
||||
|
||||
|
||||
def load_jsonl(file_path: str | Path, model_class=None) -> List:
|
||||
"""
|
||||
Load items from JSONL file.
|
||||
|
||||
Args:
|
||||
file_path: Path to JSONL file
|
||||
model_class: Optional Pydantic model class to validate/parse items
|
||||
|
||||
Returns:
|
||||
List of items
|
||||
"""
|
||||
items = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if model_class:
|
||||
items.append(model_class.model_validate_json(line))
|
||||
else:
|
||||
items.append(json.loads(line))
|
||||
return items
|
||||
|
||||
|
||||
def parse_markdown_chunks(
|
||||
file_path: str | Path,
|
||||
context_lines: int = 3
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Parse markdown file and create chunks with context.
|
||||
|
||||
Args:
|
||||
file_path: Path to the markdown file
|
||||
context_lines: Number of lines before/after to include as context
|
||||
|
||||
Returns:
|
||||
List of dicts with 'source_text', 'context_before', 'context_after'
|
||||
"""
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
chunks = []
|
||||
i = 0
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
|
||||
# Skip empty lines and metadata
|
||||
if not line or line.startswith('---') or line.startswith('**As of:**'):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Process bullets and significant lines
|
||||
if line.startswith('*') or line.startswith('#') or len(line) > 50:
|
||||
# Get context before
|
||||
context_start = max(0, i - context_lines)
|
||||
context_before = ''.join(lines[context_start:i]).strip()
|
||||
|
||||
# Get the main text (current line)
|
||||
source_text = line
|
||||
|
||||
# Get context after
|
||||
context_end = min(len(lines), i + context_lines + 1)
|
||||
context_after = ''.join(lines[i + 1:context_end]).strip()
|
||||
|
||||
chunks.append({
|
||||
'source_text': source_text,
|
||||
'context_before': context_before,
|
||||
'context_after': context_after,
|
||||
})
|
||||
|
||||
i += 1
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def calculate_overall_score(
|
||||
factual_accuracy: float,
|
||||
naturalness: float,
|
||||
relevance: float,
|
||||
diversity: float,
|
||||
weights: dict = None
|
||||
) -> float:
|
||||
"""
|
||||
Calculate overall quality score from individual metrics.
|
||||
|
||||
Args:
|
||||
factual_accuracy: Score 0-10
|
||||
naturalness: Score 0-10
|
||||
relevance: Score 0-10
|
||||
diversity: Score 0-10
|
||||
weights: Optional custom weights (defaults from config)
|
||||
|
||||
Returns:
|
||||
Weighted overall score
|
||||
"""
|
||||
if weights is None:
|
||||
from .config import QUALITY_WEIGHTS
|
||||
weights = QUALITY_WEIGHTS
|
||||
|
||||
overall = (
|
||||
factual_accuracy * weights.get("factual_accuracy", 0.35) +
|
||||
naturalness * weights.get("naturalness", 0.25) +
|
||||
relevance * weights.get("relevance", 0.25) +
|
||||
diversity * weights.get("diversity", 0.15)
|
||||
)
|
||||
|
||||
return round(overall, 2)
|
||||
|
||||
|
||||
def print_sample(item, title: str = "SAMPLE"):
|
||||
"""
|
||||
Print a sample item for inspection.
|
||||
|
||||
Args:
|
||||
item: Item to print (conversation, Q&A, etc.)
|
||||
title: Title for the sample section
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print(title)
|
||||
print("="*80)
|
||||
|
||||
if hasattr(item, 'model_dump'):
|
||||
# Pydantic model
|
||||
print(json.dumps(item.model_dump(), indent=2))
|
||||
elif isinstance(item, dict):
|
||||
print(json.dumps(item, indent=2))
|
||||
else:
|
||||
print(item)
|
||||
|
||||
print("="*80 + "\n")
|
||||
|
||||
|
||||
def print_statistics(scores: List[float], metric_name: str = "Score"):
|
||||
"""
|
||||
Print statistics for a list of scores.
|
||||
|
||||
Args:
|
||||
scores: List of numeric scores
|
||||
metric_name: Name of the metric being measured
|
||||
"""
|
||||
if not scores:
|
||||
print(f"No {metric_name} data available")
|
||||
return
|
||||
|
||||
avg = sum(scores) / len(scores)
|
||||
min_val = min(scores)
|
||||
max_val = max(scores)
|
||||
|
||||
print(f"{metric_name:20s}: avg={avg:5.2f}, min={min_val:5.2f}, max={max_val:5.2f}")
|
||||
299
synth-data-pipeline/trial_run.py
Normal file
299
synth-data-pipeline/trial_run.py
Normal file
|
|
@ -0,0 +1,299 @@
|
|||
"""
|
||||
Trial run script to validate the pipeline with a small dataset.
|
||||
|
||||
This script:
|
||||
1. Runs all 3 stages with limited data
|
||||
2. Validates prompts and logic
|
||||
3. Prints samples for manual inspection
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import logfire
|
||||
|
||||
# Import the main functions from each stage
|
||||
from src.synth_data_pipeline.models import QAPair, Conversation, JudgedConversation
|
||||
|
||||
# We'll use the actual script functions
|
||||
import sys
|
||||
sys.path.append(str(Path(__file__).parent))
|
||||
|
||||
|
||||
TRIAL_QA_CHUNK_LIMIT = 10
|
||||
TRIAL_FINAL_CONVERSATIONS = "output/trial_conversations_final.jsonl"
|
||||
|
||||
|
||||
async def trial_extract_qa():
|
||||
"""Trial run of Q&A extraction with 10 chunks."""
|
||||
print("\n" + "="*80)
|
||||
print("STAGE 1: Q&A EXTRACTION (Trial with 10 chunks)")
|
||||
print("="*80)
|
||||
|
||||
from importlib import import_module
|
||||
stage1 = import_module('1_extract_qa')
|
||||
|
||||
# Run with trial parameters (uses STAGE_CONFIGS for max_concurrent)
|
||||
await stage1.main(
|
||||
input_file="data/swap_facts.md",
|
||||
output_file="output/trial_qa_pairs.jsonl",
|
||||
limit=TRIAL_QA_CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
# Load and show results
|
||||
qa_pairs = []
|
||||
with open("output/trial_qa_pairs.jsonl", 'r') as f:
|
||||
for line in f:
|
||||
qa_pairs.append(QAPair.model_validate_json(line))
|
||||
|
||||
print(f"\n✓ Generated {len(qa_pairs)} Q&A pairs")
|
||||
|
||||
# Show first 3
|
||||
for i, qa in enumerate(qa_pairs[:3], 1):
|
||||
print(f"\n--- Q&A Pair {i} ---")
|
||||
print(f"Q: {qa.question}")
|
||||
print(f"A: {qa.answer[:150]}...")
|
||||
print(f"Difficulty: {qa.difficulty}")
|
||||
print(f"Categories: {', '.join(qa.categories)}")
|
||||
|
||||
return len(qa_pairs)
|
||||
|
||||
|
||||
async def trial_validate_qa(num_qa_pairs: int):
|
||||
"""Trial run of Q&A validation."""
|
||||
print("\n" + "="*80)
|
||||
print("STAGE 2: Q&A VALIDATION (Trial with production configs)")
|
||||
print("="*80)
|
||||
|
||||
from importlib import import_module
|
||||
stage2 = import_module('2_validate_qa')
|
||||
|
||||
await stage2.main(
|
||||
input_file="output/trial_qa_pairs.jsonl",
|
||||
output_file="output/trial_qa_validated.jsonl",
|
||||
)
|
||||
|
||||
passed_pairs = []
|
||||
with open("output/trial_qa_validated_passed.jsonl", 'r') as f:
|
||||
for line in f:
|
||||
passed_pairs.append(QAPair.model_validate_json(line))
|
||||
|
||||
print(f"\n✓ Validated {num_qa_pairs} Q&A pairs")
|
||||
print(f"✓ {len(passed_pairs)} passed validation")
|
||||
|
||||
for i, qa in enumerate(passed_pairs[:3], 1):
|
||||
print(f"\n--- Passed Q&A {i} ---")
|
||||
print(f"Q: {qa.question}")
|
||||
print(f"A: {qa.answer[:150]}...")
|
||||
|
||||
return len(passed_pairs)
|
||||
|
||||
|
||||
async def trial_generate_conversations(num_valid_pairs: int):
|
||||
"""Trial run of conversation generation with 20 conversations."""
|
||||
print("\n" + "="*80)
|
||||
print("STAGE 3: CONVERSATION GENERATION (Trial with production configs)")
|
||||
print("="*80)
|
||||
|
||||
from importlib import import_module
|
||||
stage3 = import_module('3_generate_conversations')
|
||||
|
||||
# Run with trial parameters (uses STAGE_CONFIGS for max_concurrent)
|
||||
await stage3.main(
|
||||
qa_file="output/trial_qa_validated_passed.jsonl",
|
||||
output_file="output/trial_conversations_raw.jsonl",
|
||||
)
|
||||
|
||||
# Load and show results
|
||||
conversations = []
|
||||
with open("output/trial_conversations_raw.jsonl", 'r') as f:
|
||||
for line in f:
|
||||
conversations.append(Conversation.model_validate_json(line))
|
||||
|
||||
print(f"\n✓ Valid Q&A pairs available: {num_valid_pairs}")
|
||||
print(f"✓ Generated {len(conversations)} conversations")
|
||||
|
||||
# Show first 2
|
||||
for i, conv in enumerate(conversations[:2], 1):
|
||||
print(f"\n--- Conversation {i} ---")
|
||||
print(f"Style: {conv.metadata.style}")
|
||||
print(f"Persona: {conv.metadata.user_persona}")
|
||||
print(f"Turns: {conv.metadata.num_turns}")
|
||||
print("\nMessages:")
|
||||
for msg in conv.messages:
|
||||
print(f" {msg.role.upper()}: {msg.content[:100]}...")
|
||||
|
||||
return len(conversations)
|
||||
|
||||
|
||||
async def trial_judge_conversations(num_conversations: int):
|
||||
"""Trial run of judging all conversations."""
|
||||
print("\n" + "="*80)
|
||||
print("STAGE 4: JUDGING & SELECTION (Trial with all conversations)")
|
||||
print("="*80)
|
||||
|
||||
from importlib import import_module
|
||||
stage3 = import_module('4_judge_and_save')
|
||||
|
||||
# Judge all and save top K (uses STAGE_CONFIGS for max_concurrent)
|
||||
await stage3.main(
|
||||
input_file="output/trial_conversations_raw.jsonl",
|
||||
judged_output="output/trial_conversations_judged.jsonl",
|
||||
nanochat_output=TRIAL_FINAL_CONVERSATIONS,
|
||||
)
|
||||
|
||||
# Load and show results
|
||||
judged = []
|
||||
with open("output/trial_conversations_judged.jsonl", 'r') as f:
|
||||
for line in f:
|
||||
judged.append(JudgedConversation.model_validate_json(line))
|
||||
|
||||
print(f"\n✓ Judged {len(judged)} conversations")
|
||||
|
||||
# Show pass/fail statistics (bool-based system)
|
||||
total = len(judged)
|
||||
passing = sum(1 for jc in judged if jc.judgment.overall_pass)
|
||||
factual_pass = sum(1 for jc in judged if jc.judgment.factually_accurate)
|
||||
natural_pass = sum(1 for jc in judged if jc.judgment.natural_conversation)
|
||||
ontopic_pass = sum(1 for jc in judged if jc.judgment.on_topic)
|
||||
value_pass = sum(1 for jc in judged if jc.judgment.adds_value)
|
||||
|
||||
print(f"\nQuality statistics:")
|
||||
print(f" Overall PASS (all 4 criteria): {passing}/{total} ({passing/total*100:.1f}%)")
|
||||
print(f"\nIndividual criteria:")
|
||||
print(f" Factually accurate : {factual_pass}/{total} ({factual_pass/total*100:.1f}%)")
|
||||
print(f" Natural conversation: {natural_pass}/{total} ({natural_pass/total*100:.1f}%)")
|
||||
print(f" On topic : {ontopic_pass}/{total} ({ontopic_pass/total*100:.1f}%)")
|
||||
print(f" Adds value : {value_pass}/{total} ({value_pass/total*100:.1f}%)")
|
||||
|
||||
# Show sample passing and failing conversations
|
||||
passing_convs = [jc for jc in judged if jc.judgment.overall_pass]
|
||||
failing_convs = [jc for jc in judged if not jc.judgment.overall_pass]
|
||||
|
||||
if passing_convs:
|
||||
sample = passing_convs[0]
|
||||
print(f"\n--- Sample PASSING Conversation ---")
|
||||
print(f"Feedback: {sample.judgment.feedback}")
|
||||
|
||||
if failing_convs:
|
||||
sample = failing_convs[0]
|
||||
print(f"\n--- Sample FAILING Conversation ---")
|
||||
print(f"Failed criteria: ", end="")
|
||||
failed = []
|
||||
if not sample.judgment.factually_accurate: failed.append("factual")
|
||||
if not sample.judgment.natural_conversation: failed.append("natural")
|
||||
if not sample.judgment.on_topic: failed.append("on-topic")
|
||||
if not sample.judgment.adds_value: failed.append("adds-value")
|
||||
print(", ".join(failed))
|
||||
print(f"Feedback: {sample.judgment.feedback}")
|
||||
if sample.judgment.issues:
|
||||
print(f"Issues: {', '.join(sample.judgment.issues)}")
|
||||
|
||||
return len(judged)
|
||||
|
||||
|
||||
def validate_output_formats():
|
||||
"""Validate that output files match expected formats."""
|
||||
print("\n" + "="*80)
|
||||
print("VALIDATION: Checking output formats")
|
||||
print("="*80)
|
||||
|
||||
checks = {
|
||||
"Q&A pairs JSONL": "output/trial_qa_pairs.jsonl",
|
||||
"Raw conversations JSONL": "output/trial_conversations_raw.jsonl",
|
||||
"Judged conversations JSONL": "output/trial_conversations_judged.jsonl",
|
||||
"NanoChat format JSONL": TRIAL_FINAL_CONVERSATIONS,
|
||||
}
|
||||
|
||||
all_valid = True
|
||||
|
||||
for name, path in checks.items():
|
||||
if not Path(path).exists():
|
||||
print(f"✗ {name}: FILE NOT FOUND")
|
||||
all_valid = False
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(path, 'r') as f:
|
||||
lines = f.readlines()
|
||||
if not lines:
|
||||
print(f"✗ {name}: EMPTY FILE")
|
||||
all_valid = False
|
||||
continue
|
||||
|
||||
# Try parsing first line as JSON
|
||||
json.loads(lines[0])
|
||||
print(f"✓ {name}: Valid ({len(lines)} entries)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ {name}: {e}")
|
||||
all_valid = False
|
||||
|
||||
return all_valid
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run the complete trial pipeline."""
|
||||
print("="*80)
|
||||
print("SYNTHETIC DATA PIPELINE - TRIAL RUN")
|
||||
print("="*80)
|
||||
print("\nThis will:")
|
||||
print(f" 1. Extract up to {TRIAL_QA_CHUNK_LIMIT} chunks worth of Q&A pairs")
|
||||
print(" 2. Validate those Q&A pairs with the production agent")
|
||||
print(" 3. Generate conversations using the same production configuration")
|
||||
print(" 4. Judge all conversations and select the configured top K")
|
||||
print()
|
||||
|
||||
# Configure logfire without sending to cloud (for trial runs)
|
||||
logfire.configure(send_to_logfire=False)
|
||||
|
||||
try:
|
||||
# Stage 1: Extract Q&A
|
||||
num_qa = await trial_extract_qa()
|
||||
|
||||
# Stage 2: Validate Q&A
|
||||
num_valid = await trial_validate_qa(num_qa)
|
||||
|
||||
# Stage 3: Generate conversations
|
||||
num_conv = await trial_generate_conversations(num_valid)
|
||||
|
||||
# Stage 4: Judge and select
|
||||
num_judged = await trial_judge_conversations(num_conv)
|
||||
|
||||
# Validate formats
|
||||
all_valid = validate_output_formats()
|
||||
|
||||
# Final summary
|
||||
print("\n" + "="*80)
|
||||
print("TRIAL RUN COMPLETE")
|
||||
print("="*80)
|
||||
print(f"✓ Q&A pairs extracted: {num_qa}")
|
||||
print(f"✓ Q&A pairs passed validation: {num_valid}")
|
||||
print(f"✓ Conversations generated: {num_conv}")
|
||||
print(f"✓ Conversations judged: {num_judged}")
|
||||
print(f"✓ Output formats valid: {'YES' if all_valid else 'NO'}")
|
||||
print()
|
||||
|
||||
if all_valid:
|
||||
print("🎉 Trial run successful! You can now run the full pipeline.")
|
||||
print()
|
||||
print("Next steps:")
|
||||
print(" 1. Review the trial outputs in output/ directory")
|
||||
print(" 2. Adjust prompts if needed")
|
||||
print(" 3. Run full pipeline:")
|
||||
print(" - uv run 1_extract_qa.py")
|
||||
print(" - uv run 2_validate_qa.py")
|
||||
print(" - uv run 3_generate_conversations.py")
|
||||
print(" - uv run 4_judge_and_save.py")
|
||||
else:
|
||||
print("⚠️ Some validations failed. Please review the errors above.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Trial run failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
3143
synth-data-pipeline/uv.lock
Normal file
3143
synth-data-pipeline/uv.lock
Normal file
File diff suppressed because it is too large
Load Diff
92
uv.lock
92
uv.lock
|
|
@ -371,7 +371,7 @@ name = "exceptiongroup"
|
|||
version = "1.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.12' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" }
|
||||
wheels = [
|
||||
|
|
@ -1107,7 +1107,7 @@ name = "nvidia-cudnn-cu12"
|
|||
version = "9.10.2.21"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" },
|
||||
|
|
@ -1120,7 +1120,7 @@ name = "nvidia-cufft-cu12"
|
|||
version = "11.3.3.83"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" },
|
||||
|
|
@ -1152,9 +1152,9 @@ name = "nvidia-cusolver-cu12"
|
|||
version = "11.7.3.90"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" },
|
||||
|
|
@ -1167,7 +1167,7 @@ name = "nvidia-cusparse-cu12"
|
|||
version = "12.5.8.93"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" },
|
||||
|
|
@ -1941,41 +1941,51 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.2.1"
|
||||
version = "2.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/18/87/302344fed471e44a87289cf4967697d07e532f2421fdaf868a303cbae4ff/tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff", size = 17175, upload-time = "2024-11-27T22:38:36.873Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/52/ed/3f73f72945444548f33eba9a87fc7a6e969915e7b1acc8260b30e1f76a2f/tomli-2.3.0.tar.gz", hash = "sha256:64be704a875d2a59753d80ee8a533c3fe183e3f06807ff7dc2232938ccb01549", size = 17392, upload-time = "2025-10-08T22:01:47.119Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/43/ca/75707e6efa2b37c77dadb324ae7d9571cb424e61ea73fad7c56c2d14527f/tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249", size = 131077, upload-time = "2024-11-27T22:37:54.956Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/16/51ae563a8615d472fdbffc43a3f3d46588c264ac4f024f63f01283becfbb/tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6", size = 123429, upload-time = "2024-11-27T22:37:56.698Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/dd/4f6cd1e7b160041db83c694abc78e100473c15d54620083dbd5aae7b990e/tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a", size = 226067, upload-time = "2024-11-27T22:37:57.63Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/6b/c54ede5dc70d648cc6361eaf429304b02f2871a345bbdd51e993d6cdf550/tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee", size = 236030, upload-time = "2024-11-27T22:37:59.344Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/47/999514fa49cfaf7a92c805a86c3c43f4215621855d151b61c602abb38091/tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e", size = 240898, upload-time = "2024-11-27T22:38:00.429Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/73/41/0a01279a7ae09ee1573b423318e7934674ce06eb33f50936655071d81a24/tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4", size = 229894, upload-time = "2024-11-27T22:38:02.094Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/55/18/5d8bc5b0a0362311ce4d18830a5d28943667599a60d20118074ea1b01bb7/tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106", size = 245319, upload-time = "2024-11-27T22:38:03.206Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/a3/7ade0576d17f3cdf5ff44d61390d4b3febb8a9fc2b480c75c47ea048c646/tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8", size = 238273, upload-time = "2024-11-27T22:38:04.217Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/72/6f/fa64ef058ac1446a1e51110c375339b3ec6be245af9d14c87c4a6412dd32/tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff", size = 98310, upload-time = "2024-11-27T22:38:05.908Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/1c/4a2dcde4a51b81be3530565e92eda625d94dafb46dbeb15069df4caffc34/tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b", size = 108309, upload-time = "2024-11-27T22:38:06.812Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/52/e1/f8af4c2fcde17500422858155aeb0d7e93477a0d59a98e56cbfe75070fd0/tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea", size = 132762, upload-time = "2024-11-27T22:38:07.731Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/03/b8/152c68bb84fc00396b83e7bbddd5ec0bd3dd409db4195e2a9b3e398ad2e3/tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8", size = 123453, upload-time = "2024-11-27T22:38:09.384Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/d6/fc9267af9166f79ac528ff7e8c55c8181ded34eb4b0e93daa767b8841573/tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192", size = 233486, upload-time = "2024-11-27T22:38:10.329Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/51/51c3f2884d7bab89af25f678447ea7d297b53b5a3b5730a7cb2ef6069f07/tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222", size = 242349, upload-time = "2024-11-27T22:38:11.443Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/df/bfa89627d13a5cc22402e441e8a931ef2108403db390ff3345c05253935e/tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77", size = 252159, upload-time = "2024-11-27T22:38:13.099Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/6e/fa2b916dced65763a5168c6ccb91066f7639bdc88b48adda990db10c8c0b/tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6", size = 237243, upload-time = "2024-11-27T22:38:14.766Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/04/885d3b1f650e1153cbb93a6a9782c58a972b94ea4483ae4ac5cedd5e4a09/tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd", size = 259645, upload-time = "2024-11-27T22:38:15.843Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9c/de/6b432d66e986e501586da298e28ebeefd3edc2c780f3ad73d22566034239/tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e", size = 244584, upload-time = "2024-11-27T22:38:17.645Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1c/9a/47c0449b98e6e7d1be6cbac02f93dd79003234ddc4aaab6ba07a9a7482e2/tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98", size = 98875, upload-time = "2024-11-27T22:38:19.159Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ef/60/9b9638f081c6f1261e2688bd487625cd1e660d0a85bd469e91d8db969734/tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4", size = 109418, upload-time = "2024-11-27T22:38:20.064Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/90/2ee5f2e0362cb8a0b6499dc44f4d7d48f8fff06d28ba46e6f1eaa61a1388/tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7", size = 132708, upload-time = "2024-11-27T22:38:21.659Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c0/ec/46b4108816de6b385141f082ba99e315501ccd0a2ea23db4a100dd3990ea/tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c", size = 123582, upload-time = "2024-11-27T22:38:22.693Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/bd/b470466d0137b37b68d24556c38a0cc819e8febe392d5b199dcd7f578365/tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13", size = 232543, upload-time = "2024-11-27T22:38:24.367Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/e5/82e80ff3b751373f7cead2815bcbe2d51c895b3c990686741a8e56ec42ab/tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281", size = 241691, upload-time = "2024-11-27T22:38:26.081Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/05/7e/2a110bc2713557d6a1bfb06af23dd01e7dde52b6ee7dadc589868f9abfac/tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272", size = 251170, upload-time = "2024-11-27T22:38:27.921Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/64/7b/22d713946efe00e0adbcdfd6d1aa119ae03fd0b60ebed51ebb3fa9f5a2e5/tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140", size = 236530, upload-time = "2024-11-27T22:38:29.591Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/38/31/3a76f67da4b0cf37b742ca76beaf819dca0ebef26d78fc794a576e08accf/tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2", size = 258666, upload-time = "2024-11-27T22:38:30.639Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/07/10/5af1293da642aded87e8a988753945d0cf7e00a9452d3911dd3bb354c9e2/tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744", size = 243954, upload-time = "2024-11-27T22:38:31.702Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/b9/1ed31d167be802da0fc95020d04cd27b7d7065cc6fbefdd2f9186f60d7bd/tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec", size = 98724, upload-time = "2024-11-27T22:38:32.837Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/32/b0963458706accd9afcfeb867c0f9175a741bf7b19cd424230714d722198/tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69", size = 109383, upload-time = "2024-11-27T22:38:34.455Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257, upload-time = "2024-11-27T22:38:35.385Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/2e/299f62b401438d5fe1624119c723f5d877acc86a4c2492da405626665f12/tomli-2.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:88bd15eb972f3664f5ed4b57c1634a97153b4bac4479dcb6a495f41921eb7f45", size = 153236, upload-time = "2025-10-08T22:01:00.137Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/86/7f/d8fffe6a7aefdb61bced88fcb5e280cfd71e08939da5894161bd71bea022/tomli-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:883b1c0d6398a6a9d29b508c331fa56adbcdff647f6ace4dfca0f50e90dfd0ba", size = 148084, upload-time = "2025-10-08T22:01:01.63Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/47/5c/24935fb6a2ee63e86d80e4d3b58b222dafaf438c416752c8b58537c8b89a/tomli-2.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1381caf13ab9f300e30dd8feadb3de072aeb86f1d34a8569453ff32a7dea4bf", size = 234832, upload-time = "2025-10-08T22:01:02.543Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/89/da/75dfd804fc11e6612846758a23f13271b76d577e299592b4371a4ca4cd09/tomli-2.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0e285d2649b78c0d9027570d4da3425bdb49830a6156121360b3f8511ea3441", size = 242052, upload-time = "2025-10-08T22:01:03.836Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/70/8c/f48ac899f7b3ca7eb13af73bacbc93aec37f9c954df3c08ad96991c8c373/tomli-2.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0a154a9ae14bfcf5d8917a59b51ffd5a3ac1fd149b71b47a3a104ca4edcfa845", size = 239555, upload-time = "2025-10-08T22:01:04.834Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ba/28/72f8afd73f1d0e7829bfc093f4cb98ce0a40ffc0cc997009ee1ed94ba705/tomli-2.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:74bf8464ff93e413514fefd2be591c3b0b23231a77f901db1eb30d6f712fc42c", size = 245128, upload-time = "2025-10-08T22:01:05.84Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b6/eb/a7679c8ac85208706d27436e8d421dfa39d4c914dcf5fa8083a9305f58d9/tomli-2.3.0-cp311-cp311-win32.whl", hash = "sha256:00b5f5d95bbfc7d12f91ad8c593a1659b6387b43f054104cda404be6bda62456", size = 96445, upload-time = "2025-10-08T22:01:06.896Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0a/fe/3d3420c4cb1ad9cb462fb52967080575f15898da97e21cb6f1361d505383/tomli-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:4dc4ce8483a5d429ab602f111a93a6ab1ed425eae3122032db7e9acf449451be", size = 107165, upload-time = "2025-10-08T22:01:08.107Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/b7/40f36368fcabc518bb11c8f06379a0fd631985046c038aca08c6d6a43c6e/tomli-2.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7d86942e56ded512a594786a5ba0a5e521d02529b3826e7761a05138341a2ac", size = 154891, upload-time = "2025-10-08T22:01:09.082Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/3f/d9dd692199e3b3aab2e4e4dd948abd0f790d9ded8cd10cbaae276a898434/tomli-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:73ee0b47d4dad1c5e996e3cd33b8a76a50167ae5f96a2607cbe8cc773506ab22", size = 148796, upload-time = "2025-10-08T22:01:10.266Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/60/83/59bff4996c2cf9f9387a0f5a3394629c7efa5ef16142076a23a90f1955fa/tomli-2.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:792262b94d5d0a466afb5bc63c7daa9d75520110971ee269152083270998316f", size = 242121, upload-time = "2025-10-08T22:01:11.332Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/e5/7c5119ff39de8693d6baab6c0b6dcb556d192c165596e9fc231ea1052041/tomli-2.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f195fe57ecceac95a66a75ac24d9d5fbc98ef0962e09b2eddec5d39375aae52", size = 250070, upload-time = "2025-10-08T22:01:12.498Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/12/ad5126d3a278f27e6701abde51d342aa78d06e27ce2bb596a01f7709a5a2/tomli-2.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e31d432427dcbf4d86958c184b9bfd1e96b5b71f8eb17e6d02531f434fd335b8", size = 245859, upload-time = "2025-10-08T22:01:13.551Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/a1/4d6865da6a71c603cfe6ad0e6556c73c76548557a8d658f9e3b142df245f/tomli-2.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b0882799624980785240ab732537fcfc372601015c00f7fc367c55308c186f6", size = 250296, upload-time = "2025-10-08T22:01:14.614Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/b7/a7a7042715d55c9ba6e8b196d65d2cb662578b4d8cd17d882d45322b0d78/tomli-2.3.0-cp312-cp312-win32.whl", hash = "sha256:ff72b71b5d10d22ecb084d345fc26f42b5143c5533db5e2eaba7d2d335358876", size = 97124, upload-time = "2025-10-08T22:01:15.629Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/06/1e/f22f100db15a68b520664eb3328fb0ae4e90530887928558112c8d1f4515/tomli-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:1cb4ed918939151a03f33d4242ccd0aa5f11b3547d0cf30f7c74a408a5b99878", size = 107698, upload-time = "2025-10-08T22:01:16.51Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/89/48/06ee6eabe4fdd9ecd48bf488f4ac783844fd777f547b8d1b61c11939974e/tomli-2.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5192f562738228945d7b13d4930baffda67b69425a7f0da96d360b0a3888136b", size = 154819, upload-time = "2025-10-08T22:01:17.964Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/01/88793757d54d8937015c75dcdfb673c65471945f6be98e6a0410fba167ed/tomli-2.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:be71c93a63d738597996be9528f4abe628d1adf5e6eb11607bc8fe1a510b5dae", size = 148766, upload-time = "2025-10-08T22:01:18.959Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/17/5e2c956f0144b812e7e107f94f1cc54af734eb17b5191c0bbfb72de5e93e/tomli-2.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c4665508bcbac83a31ff8ab08f424b665200c0e1e645d2bd9ab3d3e557b6185b", size = 240771, upload-time = "2025-10-08T22:01:20.106Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d5/f4/0fbd014909748706c01d16824eadb0307115f9562a15cbb012cd9b3512c5/tomli-2.3.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4021923f97266babc6ccab9f5068642a0095faa0a51a246a6a02fccbb3514eaf", size = 248586, upload-time = "2025-10-08T22:01:21.164Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/30/77/fed85e114bde5e81ecf9bc5da0cc69f2914b38f4708c80ae67d0c10180c5/tomli-2.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a4ea38c40145a357d513bffad0ed869f13c1773716cf71ccaa83b0fa0cc4e42f", size = 244792, upload-time = "2025-10-08T22:01:22.417Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/55/92/afed3d497f7c186dc71e6ee6d4fcb0acfa5f7d0a1a2878f8beae379ae0cc/tomli-2.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ad805ea85eda330dbad64c7ea7a4556259665bdf9d2672f5dccc740eb9d3ca05", size = 248909, upload-time = "2025-10-08T22:01:23.859Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/84/ef50c51b5a9472e7265ce1ffc7f24cd4023d289e109f669bdb1553f6a7c2/tomli-2.3.0-cp313-cp313-win32.whl", hash = "sha256:97d5eec30149fd3294270e889b4234023f2c69747e555a27bd708828353ab606", size = 96946, upload-time = "2025-10-08T22:01:24.893Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/b7/718cd1da0884f281f95ccfa3a6cc572d30053cba64603f79d431d3c9b61b/tomli-2.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:0c95ca56fbe89e065c6ead5b593ee64b84a26fca063b5d71a1122bf26e533999", size = 107705, upload-time = "2025-10-08T22:01:26.153Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/94/aeafa14a52e16163008060506fcb6aa1949d13548d13752171a755c65611/tomli-2.3.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:cebc6fe843e0733ee827a282aca4999b596241195f43b4cc371d64fc6639da9e", size = 154244, upload-time = "2025-10-08T22:01:27.06Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/db/e4/1e58409aa78eefa47ccd19779fc6f36787edbe7d4cd330eeeedb33a4515b/tomli-2.3.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:4c2ef0244c75aba9355561272009d934953817c49f47d768070c3c94355c2aa3", size = 148637, upload-time = "2025-10-08T22:01:28.059Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/26/b6/d1eccb62f665e44359226811064596dd6a366ea1f985839c566cd61525ae/tomli-2.3.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c22a8bf253bacc0cf11f35ad9808b6cb75ada2631c2d97c971122583b129afbc", size = 241925, upload-time = "2025-10-08T22:01:29.066Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/70/91/7cdab9a03e6d3d2bb11beae108da5bdc1c34bdeb06e21163482544ddcc90/tomli-2.3.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0eea8cc5c5e9f89c9b90c4896a8deefc74f518db5927d0e0e8d4a80953d774d0", size = 249045, upload-time = "2025-10-08T22:01:31.98Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/15/1b/8c26874ed1f6e4f1fcfeb868db8a794cbe9f227299402db58cfcc858766c/tomli-2.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b74a0e59ec5d15127acdabd75ea17726ac4c5178ae51b85bfe39c4f8a278e879", size = 245835, upload-time = "2025-10-08T22:01:32.989Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/42/8e3c6a9a4b1a1360c1a2a39f0b972cef2cc9ebd56025168c4137192a9321/tomli-2.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b5870b50c9db823c595983571d1296a6ff3e1b88f734a4c8f6fc6188397de005", size = 253109, upload-time = "2025-10-08T22:01:34.052Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/22/0c/b4da635000a71b5f80130937eeac12e686eefb376b8dee113b4a582bba42/tomli-2.3.0-cp314-cp314-win32.whl", hash = "sha256:feb0dacc61170ed7ab602d3d972a58f14ee3ee60494292d384649a3dc38ef463", size = 97930, upload-time = "2025-10-08T22:01:35.082Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b9/74/cb1abc870a418ae99cd5c9547d6bce30701a954e0e721821df483ef7223c/tomli-2.3.0-cp314-cp314-win_amd64.whl", hash = "sha256:b273fcbd7fc64dc3600c098e39136522650c49bca95df2d11cf3b626422392c8", size = 107964, upload-time = "2025-10-08T22:01:36.057Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/54/78/5c46fff6432a712af9f792944f4fcd7067d8823157949f4e40c56b8b3c83/tomli-2.3.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:940d56ee0410fa17ee1f12b817b37a4d4e4dc4d27340863cc67236c74f582e77", size = 163065, upload-time = "2025-10-08T22:01:37.27Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/39/67/f85d9bd23182f45eca8939cd2bc7050e1f90c41f4a2ecbbd5963a1d1c486/tomli-2.3.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:f85209946d1fe94416debbb88d00eb92ce9cd5266775424ff81bc959e001acaf", size = 159088, upload-time = "2025-10-08T22:01:38.235Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/26/5a/4b546a0405b9cc0659b399f12b6adb750757baf04250b148d3c5059fc4eb/tomli-2.3.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a56212bdcce682e56b0aaf79e869ba5d15a6163f88d5451cbde388d48b13f530", size = 268193, upload-time = "2025-10-08T22:01:39.712Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/4f/2c12a72ae22cf7b59a7fe75b3465b7aba40ea9145d026ba41cb382075b0e/tomli-2.3.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c5f3ffd1e098dfc032d4d3af5c0ac64f6d286d98bc148698356847b80fa4de1b", size = 275488, upload-time = "2025-10-08T22:01:40.773Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/04/a038d65dbe160c3aa5a624e93ad98111090f6804027d474ba9c37c8ae186/tomli-2.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:5e01decd096b1530d97d5d85cb4dff4af2d8347bd35686654a004f8dea20fc67", size = 272669, upload-time = "2025-10-08T22:01:41.824Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/be/2f/8b7c60a9d1612a7cbc39ffcca4f21a73bf368a80fc25bccf8253e2563267/tomli-2.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:8a35dd0e643bb2610f156cca8db95d213a90015c11fee76c946aa62b7ae7e02f", size = 279709, upload-time = "2025-10-08T22:01:43.177Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7e/46/cc36c679f09f27ded940281c38607716c86cf8ba4a518d524e349c8b4874/tomli-2.3.0-cp314-cp314t-win32.whl", hash = "sha256:a1f7f282fe248311650081faafa5f4732bdbfef5d45fe3f2e702fbc6f2d496e0", size = 107563, upload-time = "2025-10-08T22:01:44.233Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/84/ff/426ca8683cf7b753614480484f6437f568fd2fda2edbdf57a2d3d8b27a0b/tomli-2.3.0-cp314-cp314t-win_amd64.whl", hash = "sha256:70a251f8d4ba2d9ac2542eecf008b3c8a9fc5c3f9f02c56a9d7952612be2fdba", size = 119756, upload-time = "2025-10-08T22:01:45.234Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408, upload-time = "2025-10-08T22:01:46.04Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2176,7 +2186,7 @@ name = "triton"
|
|||
version = "3.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "setuptools", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
|
||||
{ name = "setuptools", marker = "extra == 'extra-8-nanochat-gpu'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069, upload-time = "2025-07-30T19:58:21.715Z" },
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user