DPO Dataset Generation
Learn how to generate high-quality DPO datasets for model alignment.
Overview
In this tutorial, you will learn:
How DPO (Direct Preference Optimization) datasets are structured
Different strategies for constructing preference pairs
How to filter and validate DPO data quality
Best practices for using generated data in training
Prerequisites
Completed red-team episodes (see Red Team Pipeline)
Understanding of preference learning concepts
Familiarity with DPO training requirements
DPO Dataset Structure
A DPO dataset consists of preference pairs:
{
"prompt": "User query or instruction",
"chosen": "Preferred response (specification-compliant)",
"rejected": "Dispreferred response (violates specification)"
}
Complete Example
"""
DPO Dataset Generation Example
This script demonstrates advanced DPO dataset construction strategies,
including quality filtering and data augmentation.
"""
import json
from pathlib import Path
from typing import List, Dict, Tuple
from specalign.core import RedTeamOrchestrator, create_judges
from specalign.config import load_config
class DPODatasetBuilder:
"""Builder for constructing high-quality DPO datasets."""
def __init__(self, config: dict):
self.config = config
self.safety_judge, self.quality_judge = create_judges(config)
self.min_quality_score = 0.7
def load_episodes(self, episodes_path: str) -> List[Dict]:
"""Load completed red-team episodes."""
print("Step 1: Loading episodes...")
episodes = []
with open(episodes_path, 'r') as f:
for line in f:
episodes.append(json.loads(line))
successful = [e for e in episodes if e['success']]
print(f"✓ Loaded {len(episodes)} episodes")
print(f"✓ Successful attacks: {len(successful)}")
return successful
def construct_basic_pairs(self, episodes: List[Dict]) -> List[Dict]:
"""
Basic pair construction: use violating response as rejected.
Strategy: Pair the attack prompt with the violating defender
response (rejected) and a newly generated compliant response (chosen).
"""
print("\nStep 2: Constructing basic pairs...")
pairs = []
for episode in episodes:
attack_round = episode['attack_round']
pair = {
'prompt': attack_round['attacker_prompt'],
'rejected': attack_round['defender_response'],
'chosen': episode.get('compliant_response'),
'metadata': {
'strategy': 'basic',
'spec_id': episode['spec_id'],
'violated_rules': episode['violated_rules']
}
}
if pair['chosen']: # Only include if compliant response exists
pairs.append(pair)
print(f"✓ Generated {len(pairs)} basic pairs")
return pairs
def construct_two_step_pairs(
self,
episodes: List[Dict],
orchestrator: RedTeamOrchestrator
) -> List[Dict]:
"""
Two-step reframe strategy for higher quality chosen responses.
Strategy:
1. Analyze why the rejected response violated the spec
2. Generate a response that addresses the same need compliantly
"""
print("\nStep 3: Constructing two-step reframe pairs...")
pairs = []
for i, episode in enumerate(episodes):
attack_round = episode['attack_round']
# Generate reframed compliant response
compliant = orchestrator._generate_compliant_response(
prompt=attack_round['attacker_prompt'],
rejected=attack_round['defender_response'],
spec=episode['spec'],
violated_rules=episode['violated_rules']
)
pair = {
'prompt': attack_round['attacker_prompt'],
'rejected': attack_round['defender_response'],
'chosen': compliant,
'metadata': {
'strategy': 'two_step_reframe',
'spec_id': episode['spec_id'],
'violated_rules': episode['violated_rules']
}
}
pairs.append(pair)
if (i + 1) % 10 == 0:
print(f" Progress: {i + 1}/{len(episodes)}")
print(f"✓ Generated {len(pairs)} two-step pairs")
return pairs
def filter_by_quality(self, pairs: List[Dict]) -> List[Dict]:
"""Filter pairs by quality score of chosen response."""
print("\nStep 4: Filtering by quality...")
filtered = []
for pair in pairs:
# Evaluate chosen response quality
quality_result = self.quality_judge.evaluate(
prompt=pair['prompt'],
response=pair['chosen']
)
pair['metadata']['quality_score'] = quality_result['score']
if quality_result['score'] >= self.min_quality_score:
filtered.append(pair)
print(f"✓ Kept {len(filtered)}/{len(pairs)} pairs (quality >= {self.min_quality_score})")
return filtered
def validate_preference_margin(self, pairs: List[Dict]) -> List[Dict]:
"""Ensure clear preference margin between chosen and rejected."""
print("\nStep 5: Validating preference margins...")
validated = []
for pair in pairs:
# Check that rejected actually violates spec
safety_rejected = self.safety_judge.evaluate(
prompt=pair['prompt'],
response=pair['rejected'],
spec=pair['metadata'].get('spec')
)
# Check that chosen is compliant
safety_chosen = self.safety_judge.evaluate(
prompt=pair['prompt'],
response=pair['chosen'],
spec=pair['metadata'].get('spec')
)
# Only keep pairs with clear distinction
if safety_rejected['violation'] and not safety_chosen['violation']:
validated.append(pair)
print(f"✓ Validated {len(validated)}/{len(pairs)} pairs")
return validated
def deduplicate(self, pairs: List[Dict]) -> List[Dict]:
"""Remove duplicate or near-duplicate pairs."""
print("\nStep 6: Deduplicating...")
seen_prompts = set()
unique = []
for pair in pairs:
prompt_hash = hash(pair['prompt'][:200])
if prompt_hash not in seen_prompts:
seen_prompts.add(prompt_hash)
unique.append(pair)
print(f"✓ Kept {len(unique)}/{len(pairs)} unique pairs")
return unique
def build_dataset(
self,
episodes_path: str,
strategy: str = 'two_step_reframe'
) -> List[Dict]:
"""
Build complete DPO dataset with quality filtering.
Args:
episodes_path: Path to episodes.jsonl
strategy: 'basic' or 'two_step_reframe'
Returns:
Filtered and validated DPO pairs
"""
episodes = self.load_episodes(episodes_path)
if strategy == 'basic':
pairs = self.construct_basic_pairs(episodes)
else:
# Would need orchestrator for two-step
pairs = self.construct_basic_pairs(episodes)
pairs = self.filter_by_quality(pairs)
pairs = self.validate_preference_margin(pairs)
pairs = self.deduplicate(pairs)
return pairs
def analyze_dataset(pairs: List[Dict]):
"""Print dataset statistics."""
print("\n" + "=" * 50)
print("Dataset Statistics")
print("=" * 50)
# Strategy distribution
strategies = {}
for p in pairs:
s = p['metadata']['strategy']
strategies[s] = strategies.get(s, 0) + 1
print(f"Total pairs: {len(pairs)}")
print(f"\nBy strategy:")
for strategy, count in strategies.items():
print(f" {strategy}: {count}")
# Quality score distribution
scores = [p['metadata']['quality_score'] for p in pairs]
print(f"\nQuality scores:")
print(f" Mean: {sum(scores)/len(scores):.3f}")
print(f" Min: {min(scores):.3f}")
print(f" Max: {max(scores):.3f}")
# Rule coverage
all_rules = []
for p in pairs:
all_rules.extend(p['metadata']['violated_rules'])
from collections import Counter
rule_counts = Counter(all_rules)
print(f"\nRule coverage: {len(rule_counts)} unique rules")
print("Top 5 most common:")
for rule, count in rule_counts.most_common(5):
print(f" {rule}: {count}")
def save_dataset(pairs: List[Dict], output_path: str):
"""Save dataset in training-ready format."""
print(f"\nSaving dataset to {output_path}...")
# Save full dataset with metadata
with open(output_path, 'w') as f:
json.dump(pairs, f, indent=2)
# Save training-ready format (no metadata)
training_path = output_path.replace('.json', '_training.json')
training_data = [
{
'prompt': p['prompt'],
'chosen': p['chosen'],
'rejected': p['rejected']
}
for p in pairs
]
with open(training_path, 'w') as f:
json.dump(training_data, f, indent=2)
print(f"✓ Full dataset: {output_path}")
print(f"✓ Training format: {training_path}")
def main():
"""Run the complete DPO dataset generation workflow."""
print("=" * 50)
print("DPO Dataset Generation Example")
print("=" * 50)
config = load_config("config.json")
builder = DPODatasetBuilder(config)
pairs = builder.build_dataset(
episodes_path="output/episodes.jsonl",
strategy="basic"
)
analyze_dataset(pairs)
save_dataset(pairs, "output/dpo_dataset.json")
print("\n" + "=" * 50)
print("✓ DPO dataset generation complete!")
print("=" * 50)
if __name__ == "__main__":
main()
Expected Output
==================================================
DPO Dataset Generation Example
==================================================
Step 1: Loading episodes...
✓ Loaded 100 episodes
✓ Successful attacks: 42
Step 2: Constructing basic pairs...
✓ Generated 42 basic pairs
Step 4: Filtering by quality...
✓ Kept 38/42 pairs (quality >= 0.7)
Step 5: Validating preference margins...
✓ Validated 35/38 pairs
Step 6: Deduplicating...
✓ Kept 35/35 unique pairs
==================================================
Dataset Statistics
==================================================
Total pairs: 35
By strategy:
basic: 35
Quality scores:
Mean: 0.847
Min: 0.712
Max: 0.965
Rule coverage: 12 unique rules
Top 5 most common:
R12: 15
R8: 10
R15: 8
R3: 6
R21: 4
Saving dataset to output/dpo_dataset.json...
✓ Full dataset: output/dpo_dataset.json
✓ Training format: output/dpo_dataset_training.json
==================================================
✓ DPO dataset generation complete!
==================================================
Key Takeaways
Two-step reframe produces higher quality chosen responses than simple alternatives
Quality filtering removes low-quality pairs that could hurt training
Preference margin validation ensures clear distinction between chosen/rejected
Deduplication prevents model from overfitting on similar examples
Metadata preservation enables analysis and debugging of training data
Best Practices
Quality over quantity: A smaller, high-quality dataset often outperforms a larger noisy one
Diverse rule coverage: Ensure the dataset covers a wide range of specification rules
Balance strategies: Mix different pair construction strategies for robustness
Validate regularly: Spot-check generated pairs for coherence and correctness
Version datasets: Track dataset versions alongside model checkpoints
Next Steps
Configuration Reference - DPO construction configuration options
Custom Providers - Use different models for generation
Core Module - DPO construction API details