Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/dpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,7 @@ dpo_trainer = DPOTrainer(
## DataCollatorForPreference

[[autodoc]] trainer.dpo_trainer.DataCollatorForPreference

## FDivergenceType

[[autodoc]] trainer.dpo_trainer.FDivergenceType
27 changes: 26 additions & 1 deletion trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,37 @@


class FDivergenceType(Enum):
"""Types of f-divergence functions for DPO loss regularization.

Attributes:
REVERSE_KL: Reverse KL divergence.
JS_DIVERGENCE: Jensen-Shannon divergence.
ALPHA_DIVERGENCE: Alpha divergence.

Examples:
```python
>>> from trl.trainer.dpo_config import DPOConfig, FDivergenceType

>>> config = DPOConfig(
... f_divergence_type=FDivergenceType.ALPHA_DIVERGENCE,
... f_alpha_divergence_coef=0.5, # used only with ALPHA_DIVERGENCE
... )
```
"""

REVERSE_KL = "reverse_kl"
JS_DIVERGENCE = "js_divergence"
ALPHA_DIVERGENCE = "alpha_divergence"


class FDivergenceConstants:
"""Constants for f-divergence types and their parameters.

Attributes:
ALPHA_DIVERGENCE_COEF_KEY (`str`): Key for the alpha divergence coefficient.
ALPHA_DIVERGENCE_COEF_DEFAULT (`float`): Default value for the alpha divergence coefficient.
"""

ALPHA_DIVERGENCE_COEF_KEY = "alpha_divergence_coef"
ALPHA_DIVERGENCE_COEF_DEFAULT = 1.0

Expand Down Expand Up @@ -140,7 +165,7 @@ class DPOConfig(TrainingArguments):
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
the [paper](https://huggingface.co/papers/2310.12036).
f_divergence_type (`str`, *optional*, defaults to `FDivergenceType.REVERSE_KL`):
f_divergence_type ([`FDivergenceType`], *optional*, defaults to `FDivergenceType.REVERSE_KL`):
Type of f-divergence regularization function to compute divergence between policy and reference model.
f_alpha_divergence_coef (`float`, *optional*, defaults to `1.0`):
α coefficient in the α-divergence u^-α regularization function for DPO loss.
Expand Down
Loading