From db6868c60f6964dddc3244e3826e1f65162f5129 Mon Sep 17 00:00:00 2001 From: ginkyenglee Date: Thu, 4 Sep 2025 07:00:08 +0000 Subject: [PATCH 1/2] Fix: ignore precompute_ref_log_probs when use_liger_loss=True --- trl/trainer/dpo_trainer.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 2cacb26e3f..3ee0072de0 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -362,6 +362,7 @@ def __init__( "You set `use_liger_loss=True` but the loss type is not from `[sigmoid, apo_zero, apo_down, sppo_hard, nca_pair`. " "Please set `loss_type='[sigmoid | apo_zero | apo_down | sppo_hard | nca_pair]'` to use the liger kernel." ) + self.dpo_loss_fn = LigerFusedLinearDPOLoss( ignore_index=args.label_pad_token_id, beta=args.beta, @@ -389,6 +390,13 @@ def __init__( self.max_length = args.max_length self.truncation_mode = args.truncation_mode self.precompute_ref_log_probs = args.precompute_ref_log_probs + if args.use_liger_loss and self.precompute_ref_log_probs: + logger.warning( + "You set `use_liger_loss=True`, but also enabled `precompute_ref_log_probs`. " + "The `precompute_ref_log_probs` setting will be ignored." + ) + self.precompute_ref_log_probs = False + self.use_logits_to_keep = args.use_logits_to_keep if args.padding_free: From dfa4a26083b30a1d53a979c01c126fd417b480ca Mon Sep 17 00:00:00 2001 From: ginkyenglee Date: Mon, 8 Sep 2025 02:29:12 +0000 Subject: [PATCH 2/2] DPOConfig: disable precompute_ref_log_probs in __post_init__ when use_liger_loss=True; revert trainer change --- trl/trainer/dpo_config.py | 6 ++++++ trl/trainer/dpo_trainer.py | 8 -------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index befb0d2921..7e029dba48 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -464,4 +464,10 @@ def __post_init__(self): f"Length of loss_weights list ({self.loss_weights}) must match number of loss types " f"({loss_types})." ) + + # If Liger loss is enabled, precomputing ref log probs is not used. + # Force-disable it at config time and warn the user to avoid wasted work upstream. + if self.use_liger_loss and self.precompute_ref_log_probs: + self.precompute_ref_log_probs = False + super().__post_init__() diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index cc9fd278d6..a99394624e 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -362,7 +362,6 @@ def __init__( "You set `use_liger_loss=True` but the loss type is not from `[sigmoid, apo_zero, apo_down, sppo_hard, nca_pair`. " "Please set `loss_type='[sigmoid | apo_zero | apo_down | sppo_hard | nca_pair]'` to use the liger kernel." ) - self.dpo_loss_fn = LigerFusedLinearDPOLoss( ignore_index=args.label_pad_token_id, beta=args.beta, @@ -390,13 +389,6 @@ def __init__( self.max_length = args.max_length self.truncation_mode = args.truncation_mode self.precompute_ref_log_probs = args.precompute_ref_log_probs - if args.use_liger_loss and self.precompute_ref_log_probs: - logger.warning( - "You set `use_liger_loss=True`, but also enabled `precompute_ref_log_probs`. " - "The `precompute_ref_log_probs` setting will be ignored." - ) - self.precompute_ref_log_probs = False - self.use_logits_to_keep = args.use_logits_to_keep if args.padding_free: