DPO Dataset Generation

Learn how to generate high-quality DPO datasets for model alignment.

Overview

In this tutorial, you will learn:

  • How DPO (Direct Preference Optimization) datasets are structured

  • Different strategies for constructing preference pairs

  • How to filter and validate DPO data quality

  • Best practices for using generated data in training

Prerequisites

  • Completed red-team episodes (see Red Team Pipeline)

  • Understanding of preference learning concepts

  • Familiarity with DPO training requirements

DPO Dataset Structure

A DPO dataset consists of preference pairs:

{
  "prompt": "User query or instruction",
  "chosen": "Preferred response (specification-compliant)",
  "rejected": "Dispreferred response (violates specification)"
}

Complete Example

"""
DPO Dataset Generation Example

This script demonstrates advanced DPO dataset construction strategies,
including quality filtering and data augmentation.
"""

import json
from pathlib import Path
from typing import List, Dict, Tuple

from specalign.core import RedTeamOrchestrator, create_judges
from specalign.config import load_config


class DPODatasetBuilder:
    """Builder for constructing high-quality DPO datasets."""

    def __init__(self, config: dict):
        self.config = config
        self.safety_judge, self.quality_judge = create_judges(config)
        self.min_quality_score = 0.7

    def load_episodes(self, episodes_path: str) -> List[Dict]:
        """Load completed red-team episodes."""
        print("Step 1: Loading episodes...")

        episodes = []
        with open(episodes_path, 'r') as f:
            for line in f:
                episodes.append(json.loads(line))

        successful = [e for e in episodes if e['success']]
        print(f"✓ Loaded {len(episodes)} episodes")
        print(f"✓ Successful attacks: {len(successful)}")

        return successful

    def construct_basic_pairs(self, episodes: List[Dict]) -> List[Dict]:
        """
        Basic pair construction: use violating response as rejected.

        Strategy: Pair the attack prompt with the violating defender
        response (rejected) and a newly generated compliant response (chosen).
        """
        print("\nStep 2: Constructing basic pairs...")

        pairs = []
        for episode in episodes:
            attack_round = episode['attack_round']

            pair = {
                'prompt': attack_round['attacker_prompt'],
                'rejected': attack_round['defender_response'],
                'chosen': episode.get('compliant_response'),
                'metadata': {
                    'strategy': 'basic',
                    'spec_id': episode['spec_id'],
                    'violated_rules': episode['violated_rules']
                }
            }

            if pair['chosen']:  # Only include if compliant response exists
                pairs.append(pair)

        print(f"✓ Generated {len(pairs)} basic pairs")
        return pairs

    def construct_two_step_pairs(
        self,
        episodes: List[Dict],
        orchestrator: RedTeamOrchestrator
    ) -> List[Dict]:
        """
        Two-step reframe strategy for higher quality chosen responses.

        Strategy:
        1. Analyze why the rejected response violated the spec
        2. Generate a response that addresses the same need compliantly
        """
        print("\nStep 3: Constructing two-step reframe pairs...")

        pairs = []
        for i, episode in enumerate(episodes):
            attack_round = episode['attack_round']

            # Generate reframed compliant response
            compliant = orchestrator._generate_compliant_response(
                prompt=attack_round['attacker_prompt'],
                rejected=attack_round['defender_response'],
                spec=episode['spec'],
                violated_rules=episode['violated_rules']
            )

            pair = {
                'prompt': attack_round['attacker_prompt'],
                'rejected': attack_round['defender_response'],
                'chosen': compliant,
                'metadata': {
                    'strategy': 'two_step_reframe',
                    'spec_id': episode['spec_id'],
                    'violated_rules': episode['violated_rules']
                }
            }

            pairs.append(pair)

            if (i + 1) % 10 == 0:
                print(f"  Progress: {i + 1}/{len(episodes)}")

        print(f"✓ Generated {len(pairs)} two-step pairs")
        return pairs

    def filter_by_quality(self, pairs: List[Dict]) -> List[Dict]:
        """Filter pairs by quality score of chosen response."""
        print("\nStep 4: Filtering by quality...")

        filtered = []
        for pair in pairs:
            # Evaluate chosen response quality
            quality_result = self.quality_judge.evaluate(
                prompt=pair['prompt'],
                response=pair['chosen']
            )

            pair['metadata']['quality_score'] = quality_result['score']

            if quality_result['score'] >= self.min_quality_score:
                filtered.append(pair)

        print(f"✓ Kept {len(filtered)}/{len(pairs)} pairs (quality >= {self.min_quality_score})")
        return filtered

    def validate_preference_margin(self, pairs: List[Dict]) -> List[Dict]:
        """Ensure clear preference margin between chosen and rejected."""
        print("\nStep 5: Validating preference margins...")

        validated = []
        for pair in pairs:
            # Check that rejected actually violates spec
            safety_rejected = self.safety_judge.evaluate(
                prompt=pair['prompt'],
                response=pair['rejected'],
                spec=pair['metadata'].get('spec')
            )

            # Check that chosen is compliant
            safety_chosen = self.safety_judge.evaluate(
                prompt=pair['prompt'],
                response=pair['chosen'],
                spec=pair['metadata'].get('spec')
            )

            # Only keep pairs with clear distinction
            if safety_rejected['violation'] and not safety_chosen['violation']:
                validated.append(pair)

        print(f"✓ Validated {len(validated)}/{len(pairs)} pairs")
        return validated

    def deduplicate(self, pairs: List[Dict]) -> List[Dict]:
        """Remove duplicate or near-duplicate pairs."""
        print("\nStep 6: Deduplicating...")

        seen_prompts = set()
        unique = []

        for pair in pairs:
            prompt_hash = hash(pair['prompt'][:200])
            if prompt_hash not in seen_prompts:
                seen_prompts.add(prompt_hash)
                unique.append(pair)

        print(f"✓ Kept {len(unique)}/{len(pairs)} unique pairs")
        return unique

    def build_dataset(
        self,
        episodes_path: str,
        strategy: str = 'two_step_reframe'
    ) -> List[Dict]:
        """
        Build complete DPO dataset with quality filtering.

        Args:
            episodes_path: Path to episodes.jsonl
            strategy: 'basic' or 'two_step_reframe'

        Returns:
            Filtered and validated DPO pairs
        """
        episodes = self.load_episodes(episodes_path)

        if strategy == 'basic':
            pairs = self.construct_basic_pairs(episodes)
        else:
            # Would need orchestrator for two-step
            pairs = self.construct_basic_pairs(episodes)

        pairs = self.filter_by_quality(pairs)
        pairs = self.validate_preference_margin(pairs)
        pairs = self.deduplicate(pairs)

        return pairs


def analyze_dataset(pairs: List[Dict]):
    """Print dataset statistics."""
    print("\n" + "=" * 50)
    print("Dataset Statistics")
    print("=" * 50)

    # Strategy distribution
    strategies = {}
    for p in pairs:
        s = p['metadata']['strategy']
        strategies[s] = strategies.get(s, 0) + 1

    print(f"Total pairs: {len(pairs)}")
    print(f"\nBy strategy:")
    for strategy, count in strategies.items():
        print(f"  {strategy}: {count}")

    # Quality score distribution
    scores = [p['metadata']['quality_score'] for p in pairs]
    print(f"\nQuality scores:")
    print(f"  Mean: {sum(scores)/len(scores):.3f}")
    print(f"  Min:  {min(scores):.3f}")
    print(f"  Max:  {max(scores):.3f}")

    # Rule coverage
    all_rules = []
    for p in pairs:
        all_rules.extend(p['metadata']['violated_rules'])

    from collections import Counter
    rule_counts = Counter(all_rules)

    print(f"\nRule coverage: {len(rule_counts)} unique rules")
    print("Top 5 most common:")
    for rule, count in rule_counts.most_common(5):
        print(f"  {rule}: {count}")


def save_dataset(pairs: List[Dict], output_path: str):
    """Save dataset in training-ready format."""
    print(f"\nSaving dataset to {output_path}...")

    # Save full dataset with metadata
    with open(output_path, 'w') as f:
        json.dump(pairs, f, indent=2)

    # Save training-ready format (no metadata)
    training_path = output_path.replace('.json', '_training.json')
    training_data = [
        {
            'prompt': p['prompt'],
            'chosen': p['chosen'],
            'rejected': p['rejected']
        }
        for p in pairs
    ]

    with open(training_path, 'w') as f:
        json.dump(training_data, f, indent=2)

    print(f"✓ Full dataset: {output_path}")
    print(f"✓ Training format: {training_path}")


def main():
    """Run the complete DPO dataset generation workflow."""
    print("=" * 50)
    print("DPO Dataset Generation Example")
    print("=" * 50)

    config = load_config("config.json")

    builder = DPODatasetBuilder(config)
    pairs = builder.build_dataset(
        episodes_path="output/episodes.jsonl",
        strategy="basic"
    )

    analyze_dataset(pairs)
    save_dataset(pairs, "output/dpo_dataset.json")

    print("\n" + "=" * 50)
    print("✓ DPO dataset generation complete!")
    print("=" * 50)


if __name__ == "__main__":
    main()

Expected Output

==================================================
DPO Dataset Generation Example
==================================================
Step 1: Loading episodes...
✓ Loaded 100 episodes
✓ Successful attacks: 42

Step 2: Constructing basic pairs...
✓ Generated 42 basic pairs

Step 4: Filtering by quality...
✓ Kept 38/42 pairs (quality >= 0.7)

Step 5: Validating preference margins...
✓ Validated 35/38 pairs

Step 6: Deduplicating...
✓ Kept 35/35 unique pairs

==================================================
Dataset Statistics
==================================================
Total pairs: 35

By strategy:
  basic: 35

Quality scores:
  Mean: 0.847
  Min:  0.712
  Max:  0.965

Rule coverage: 12 unique rules
Top 5 most common:
  R12: 15
  R8: 10
  R15: 8
  R3: 6
  R21: 4

Saving dataset to output/dpo_dataset.json...
✓ Full dataset: output/dpo_dataset.json
✓ Training format: output/dpo_dataset_training.json

==================================================
✓ DPO dataset generation complete!
==================================================

Key Takeaways

  1. Two-step reframe produces higher quality chosen responses than simple alternatives

  2. Quality filtering removes low-quality pairs that could hurt training

  3. Preference margin validation ensures clear distinction between chosen/rejected

  4. Deduplication prevents model from overfitting on similar examples

  5. Metadata preservation enables analysis and debugging of training data

Best Practices

  1. Quality over quantity: A smaller, high-quality dataset often outperforms a larger noisy one

  2. Diverse rule coverage: Ensure the dataset covers a wide range of specification rules

  3. Balance strategies: Mix different pair construction strategies for robustness

  4. Validate regularly: Spot-check generated pairs for coherence and correctness

  5. Version datasets: Track dataset versions alongside model checkpoints

Next Steps