@@ -1015,6 +1015,29 @@ def dpo_loss(
1015
1015
Log probabilities of the reference model for the chosen responses. Shape: `(batch_size,)`.
1016
1016
ref_rejected_logps (`torch.FloatTensor`):
1017
1017
Log probabilities of the reference model for the rejected responses. Shape: `(batch_size,)`.
1018
+ loss_type (`str`, defaults to `"sigmoid"`):
1019
+ The type of loss to compute. One of:
1020
+ - `"sigmoid"`: Sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
1021
+ - `"hinge"`: Hinge loss on the normalized likelihood from the
1022
+ [SLiC](https://huggingface.co/papers/2305.10425) paper.
1023
+ - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
1024
+ - `"exo_pair"`: Pairwise EXO loss from the [EXO](https://huggingface.co/papers/2402.00856) paper.
1025
+ - `"nca_pair"`: Pairwise NCA loss from the [NCA](https://huggingface.co/papers/2402.05369) paper.
1026
+ - `"robust"`: Unbiased estimate of the DPO loss that is robust to preference noise from the [Robust
1027
+ DPO](https://huggingface.co/papers/2403.00409) paper.
1028
+ - `"bco_pair"`: Pairwise BCO loss from the [BCO](https://huggingface.co/papers/2404.04656) paper.
1029
+ - `"sppo_hard"`: SPPO loss with hard label from the [SPPO](https://huggingface.co/papers/2405.00675)
1030
+ paper.
1031
+ - `"aot"`: AOT loss for paired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper.
1032
+ - `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882)
1033
+ paper.
1034
+ - `"discopop"`: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the
1035
+ [DiscoPOP](https://huggingface.co/papers/2406.08414) paper.
1036
+ - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
1037
+ - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
1038
+ - `"sft"`: Negative log-likelihood loss (standard supervised fine-tuning loss).
1039
+ model_output (`dict[str, torch.FloatTensor]`, *optional*):
1040
+ The output of the model's forward pass. This is used to compute auxiliary losses if enabled.
1018
1041
1019
1042
Returns:
1020
1043
A tuple of three tensors: `(losses, chosen_rewards, rejected_rewards)`. The losses tensor contains the DPO
0 commit comments