|
43 | 43 | TrainerCallback,
|
44 | 44 | is_wandb_available,
|
45 | 45 | )
|
46 |
| -from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled |
47 | 46 | from transformers.trainer_utils import seed_worker
|
48 | 47 | from transformers.utils import is_datasets_available, is_peft_available, is_rich_available
|
49 | 48 |
|
50 | 49 | from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
|
51 | 50 | from ..extras.profiling import profiling_context, profiling_decorator
|
52 | 51 | from ..extras.vllm_client import VLLMClient
|
53 | 52 | from ..import_utils import is_liger_kernel_available, is_vllm_available
|
54 |
| -from ..models import create_reference_model, prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation |
| 53 | +from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation |
55 | 54 | from ..models.utils import _ForwardRedirection
|
56 | 55 | from .callbacks import SyncRefModelCallback
|
57 | 56 | from .grpo_config import GRPOConfig
|
@@ -557,15 +556,13 @@ def __init__(
|
557 | 556 | if self.beta == 0.0:
|
558 | 557 | # If beta is 0.0, the reference model is not needed
|
559 | 558 | self.ref_model = None
|
560 |
| - elif is_deepspeed_zero3_enabled() or self.is_fsdp_enabled: |
561 |
| - self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) |
562 | 559 | elif is_peft_model(model):
|
563 | 560 | # If PEFT is used, the reference model is not needed since the adapter can be disabled
|
564 | 561 | # to revert to the initial model.
|
565 | 562 | self.ref_model = None
|
566 | 563 | else:
|
567 |
| - # If PEFT configuration is not provided, create a reference model based on the initial model. |
568 |
| - self.ref_model = create_reference_model(model) |
| 564 | + # For deepspeed, fsdp or non-distributed models, create a reference model from scratch |
| 565 | + self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) |
569 | 566 |
|
570 | 567 | # Disable dropout in the models
|
571 | 568 | if args.disable_dropout:
|
|
0 commit comments