DPO Dataset Generation ====================== Learn how to generate high-quality DPO datasets for model alignment. Overview -------- In this tutorial, you will learn: - How DPO (Direct Preference Optimization) datasets are structured - Different strategies for constructing preference pairs - How to filter and validate DPO data quality - Best practices for using generated data in training Prerequisites ------------- - Completed red-team episodes (see :doc:`red_team_pipeline`) - Understanding of preference learning concepts - Familiarity with DPO training requirements DPO Dataset Structure --------------------- A DPO dataset consists of preference pairs: .. code-block:: json { "prompt": "User query or instruction", "chosen": "Preferred response (specification-compliant)", "rejected": "Dispreferred response (violates specification)" } Complete Example ---------------- .. code-block:: python """ DPO Dataset Generation Example This script demonstrates advanced DPO dataset construction strategies, including quality filtering and data augmentation. """ import json from pathlib import Path from typing import List, Dict, Tuple from specalign.core import RedTeamOrchestrator, create_judges from specalign.config import load_config class DPODatasetBuilder: """Builder for constructing high-quality DPO datasets.""" def __init__(self, config: dict): self.config = config self.safety_judge, self.quality_judge = create_judges(config) self.min_quality_score = 0.7 def load_episodes(self, episodes_path: str) -> List[Dict]: """Load completed red-team episodes.""" print("Step 1: Loading episodes...") episodes = [] with open(episodes_path, 'r') as f: for line in f: episodes.append(json.loads(line)) successful = [e for e in episodes if e['success']] print(f"✓ Loaded {len(episodes)} episodes") print(f"✓ Successful attacks: {len(successful)}") return successful def construct_basic_pairs(self, episodes: List[Dict]) -> List[Dict]: """ Basic pair construction: use violating response as rejected. Strategy: Pair the attack prompt with the violating defender response (rejected) and a newly generated compliant response (chosen). """ print("\nStep 2: Constructing basic pairs...") pairs = [] for episode in episodes: attack_round = episode['attack_round'] pair = { 'prompt': attack_round['attacker_prompt'], 'rejected': attack_round['defender_response'], 'chosen': episode.get('compliant_response'), 'metadata': { 'strategy': 'basic', 'spec_id': episode['spec_id'], 'violated_rules': episode['violated_rules'] } } if pair['chosen']: # Only include if compliant response exists pairs.append(pair) print(f"✓ Generated {len(pairs)} basic pairs") return pairs def construct_two_step_pairs( self, episodes: List[Dict], orchestrator: RedTeamOrchestrator ) -> List[Dict]: """ Two-step reframe strategy for higher quality chosen responses. Strategy: 1. Analyze why the rejected response violated the spec 2. Generate a response that addresses the same need compliantly """ print("\nStep 3: Constructing two-step reframe pairs...") pairs = [] for i, episode in enumerate(episodes): attack_round = episode['attack_round'] # Generate reframed compliant response compliant = orchestrator._generate_compliant_response( prompt=attack_round['attacker_prompt'], rejected=attack_round['defender_response'], spec=episode['spec'], violated_rules=episode['violated_rules'] ) pair = { 'prompt': attack_round['attacker_prompt'], 'rejected': attack_round['defender_response'], 'chosen': compliant, 'metadata': { 'strategy': 'two_step_reframe', 'spec_id': episode['spec_id'], 'violated_rules': episode['violated_rules'] } } pairs.append(pair) if (i + 1) % 10 == 0: print(f" Progress: {i + 1}/{len(episodes)}") print(f"✓ Generated {len(pairs)} two-step pairs") return pairs def filter_by_quality(self, pairs: List[Dict]) -> List[Dict]: """Filter pairs by quality score of chosen response.""" print("\nStep 4: Filtering by quality...") filtered = [] for pair in pairs: # Evaluate chosen response quality quality_result = self.quality_judge.evaluate( prompt=pair['prompt'], response=pair['chosen'] ) pair['metadata']['quality_score'] = quality_result['score'] if quality_result['score'] >= self.min_quality_score: filtered.append(pair) print(f"✓ Kept {len(filtered)}/{len(pairs)} pairs (quality >= {self.min_quality_score})") return filtered def validate_preference_margin(self, pairs: List[Dict]) -> List[Dict]: """Ensure clear preference margin between chosen and rejected.""" print("\nStep 5: Validating preference margins...") validated = [] for pair in pairs: # Check that rejected actually violates spec safety_rejected = self.safety_judge.evaluate( prompt=pair['prompt'], response=pair['rejected'], spec=pair['metadata'].get('spec') ) # Check that chosen is compliant safety_chosen = self.safety_judge.evaluate( prompt=pair['prompt'], response=pair['chosen'], spec=pair['metadata'].get('spec') ) # Only keep pairs with clear distinction if safety_rejected['violation'] and not safety_chosen['violation']: validated.append(pair) print(f"✓ Validated {len(validated)}/{len(pairs)} pairs") return validated def deduplicate(self, pairs: List[Dict]) -> List[Dict]: """Remove duplicate or near-duplicate pairs.""" print("\nStep 6: Deduplicating...") seen_prompts = set() unique = [] for pair in pairs: prompt_hash = hash(pair['prompt'][:200]) if prompt_hash not in seen_prompts: seen_prompts.add(prompt_hash) unique.append(pair) print(f"✓ Kept {len(unique)}/{len(pairs)} unique pairs") return unique def build_dataset( self, episodes_path: str, strategy: str = 'two_step_reframe' ) -> List[Dict]: """ Build complete DPO dataset with quality filtering. Args: episodes_path: Path to episodes.jsonl strategy: 'basic' or 'two_step_reframe' Returns: Filtered and validated DPO pairs """ episodes = self.load_episodes(episodes_path) if strategy == 'basic': pairs = self.construct_basic_pairs(episodes) else: # Would need orchestrator for two-step pairs = self.construct_basic_pairs(episodes) pairs = self.filter_by_quality(pairs) pairs = self.validate_preference_margin(pairs) pairs = self.deduplicate(pairs) return pairs def analyze_dataset(pairs: List[Dict]): """Print dataset statistics.""" print("\n" + "=" * 50) print("Dataset Statistics") print("=" * 50) # Strategy distribution strategies = {} for p in pairs: s = p['metadata']['strategy'] strategies[s] = strategies.get(s, 0) + 1 print(f"Total pairs: {len(pairs)}") print(f"\nBy strategy:") for strategy, count in strategies.items(): print(f" {strategy}: {count}") # Quality score distribution scores = [p['metadata']['quality_score'] for p in pairs] print(f"\nQuality scores:") print(f" Mean: {sum(scores)/len(scores):.3f}") print(f" Min: {min(scores):.3f}") print(f" Max: {max(scores):.3f}") # Rule coverage all_rules = [] for p in pairs: all_rules.extend(p['metadata']['violated_rules']) from collections import Counter rule_counts = Counter(all_rules) print(f"\nRule coverage: {len(rule_counts)} unique rules") print("Top 5 most common:") for rule, count in rule_counts.most_common(5): print(f" {rule}: {count}") def save_dataset(pairs: List[Dict], output_path: str): """Save dataset in training-ready format.""" print(f"\nSaving dataset to {output_path}...") # Save full dataset with metadata with open(output_path, 'w') as f: json.dump(pairs, f, indent=2) # Save training-ready format (no metadata) training_path = output_path.replace('.json', '_training.json') training_data = [ { 'prompt': p['prompt'], 'chosen': p['chosen'], 'rejected': p['rejected'] } for p in pairs ] with open(training_path, 'w') as f: json.dump(training_data, f, indent=2) print(f"✓ Full dataset: {output_path}") print(f"✓ Training format: {training_path}") def main(): """Run the complete DPO dataset generation workflow.""" print("=" * 50) print("DPO Dataset Generation Example") print("=" * 50) config = load_config("config.json") builder = DPODatasetBuilder(config) pairs = builder.build_dataset( episodes_path="output/episodes.jsonl", strategy="basic" ) analyze_dataset(pairs) save_dataset(pairs, "output/dpo_dataset.json") print("\n" + "=" * 50) print("✓ DPO dataset generation complete!") print("=" * 50) if __name__ == "__main__": main() Expected Output --------------- .. code-block:: text ================================================== DPO Dataset Generation Example ================================================== Step 1: Loading episodes... ✓ Loaded 100 episodes ✓ Successful attacks: 42 Step 2: Constructing basic pairs... ✓ Generated 42 basic pairs Step 4: Filtering by quality... ✓ Kept 38/42 pairs (quality >= 0.7) Step 5: Validating preference margins... ✓ Validated 35/38 pairs Step 6: Deduplicating... ✓ Kept 35/35 unique pairs ================================================== Dataset Statistics ================================================== Total pairs: 35 By strategy: basic: 35 Quality scores: Mean: 0.847 Min: 0.712 Max: 0.965 Rule coverage: 12 unique rules Top 5 most common: R12: 15 R8: 10 R15: 8 R3: 6 R21: 4 Saving dataset to output/dpo_dataset.json... ✓ Full dataset: output/dpo_dataset.json ✓ Training format: output/dpo_dataset_training.json ================================================== ✓ DPO dataset generation complete! ================================================== Key Takeaways ------------- 1. **Two-step reframe** produces higher quality chosen responses than simple alternatives 2. **Quality filtering** removes low-quality pairs that could hurt training 3. **Preference margin validation** ensures clear distinction between chosen/rejected 4. **Deduplication** prevents model from overfitting on similar examples 5. **Metadata preservation** enables analysis and debugging of training data Best Practices -------------- 1. **Quality over quantity**: A smaller, high-quality dataset often outperforms a larger noisy one 2. **Diverse rule coverage**: Ensure the dataset covers a wide range of specification rules 3. **Balance strategies**: Mix different pair construction strategies for robustness 4. **Validate regularly**: Spot-check generated pairs for coherence and correctness 5. **Version datasets**: Track dataset versions alongside model checkpoints Next Steps ---------- - :doc:`../user_guide/configuration` - DPO construction configuration options - :doc:`custom_providers` - Use different models for generation - :doc:`../api_reference/core` - DPO construction API details