Skip to content

Commit e63588a

Browse files
Tavish9qgallouedecshirinyamani
authored
🏁 Refactor reference model initialization in GRPOTrainer (#3575)
Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Shirin Yamani <[email protected]>
1 parent d9d25a7 commit e63588a

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,14 @@
4343
TrainerCallback,
4444
is_wandb_available,
4545
)
46-
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
4746
from transformers.trainer_utils import seed_worker
4847
from transformers.utils import is_datasets_available, is_peft_available, is_rich_available
4948

5049
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
5150
from ..extras.profiling import profiling_context, profiling_decorator
5251
from ..extras.vllm_client import VLLMClient
5352
from ..import_utils import is_liger_kernel_available, is_vllm_available
54-
from ..models import create_reference_model, prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
53+
from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
5554
from ..models.utils import _ForwardRedirection
5655
from .callbacks import SyncRefModelCallback
5756
from .grpo_config import GRPOConfig
@@ -557,15 +556,13 @@ def __init__(
557556
if self.beta == 0.0:
558557
# If beta is 0.0, the reference model is not needed
559558
self.ref_model = None
560-
elif is_deepspeed_zero3_enabled() or self.is_fsdp_enabled:
561-
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
562559
elif is_peft_model(model):
563560
# If PEFT is used, the reference model is not needed since the adapter can be disabled
564561
# to revert to the initial model.
565562
self.ref_model = None
566563
else:
567-
# If PEFT configuration is not provided, create a reference model based on the initial model.
568-
self.ref_model = create_reference_model(model)
564+
# For deepspeed, fsdp or non-distributed models, create a reference model from scratch
565+
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
569566

570567
# Disable dropout in the models
571568
if args.disable_dropout:

0 commit comments

Comments
 (0)