|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | import warnings
|
| 15 | +from functools import wraps |
15 | 16 | from typing import Callable, Dict, List, Optional, Tuple, Union
|
16 | 17 |
|
17 | 18 | import torch
|
|
30 | 31 |
|
31 | 32 | from ..core import PPODecorators
|
32 | 33 | from ..import_utils import is_peft_available
|
| 34 | +from .utils import trl_sanitze_kwargs_for_tagging |
33 | 35 |
|
34 | 36 |
|
35 | 37 | if is_peft_available():
|
@@ -58,6 +60,8 @@ class IterativeSFTTrainer(Trainer):
|
58 | 60 | **optimize_device_cache ** (`bool`, *optional*, defaults to `False`) -- Optimize CUDA cache for slightly more memory-efficient training.
|
59 | 61 | """
|
60 | 62 |
|
| 63 | + _tag_names = ["trl", "iterative-sft"] |
| 64 | + |
61 | 65 | def __init__(
|
62 | 66 | self,
|
63 | 67 | model: Optional[PreTrainedModel] = None,
|
@@ -365,3 +369,18 @@ def _maybe_log_save_evaluate(self):
|
365 | 369 | self._globalstep_last_logged = self.state.global_step
|
366 | 370 |
|
367 | 371 | self.log(logs)
|
| 372 | + |
| 373 | + @wraps(Trainer.push_to_hub) |
| 374 | + def push_to_hub( |
| 375 | + self, |
| 376 | + commit_message: Optional[str] = "End of training", |
| 377 | + blocking: bool = True, |
| 378 | + token: Optional[str] = None, |
| 379 | + **kwargs, |
| 380 | + ) -> str: |
| 381 | + """ |
| 382 | + Overwrite the `push_to_hub` method in order to force-add the tag "dpo" when pushing the |
| 383 | + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. |
| 384 | + """ |
| 385 | + kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs) |
| 386 | + return super().push_to_hub(commit_message=commit_message, blocking=blocking, token=token, **kwargs) |
0 commit comments