Skip to content

Commit a01d9d7

Browse files
committed
properly tag models when pushed to hub
1 parent bbdef00 commit a01d9d7

11 files changed

+141
-19
lines changed

trl/trainer/bco_trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,11 +1428,16 @@ def log(self, logs: Dict[str, float]) -> None:
14281428
return super().log(logs)
14291429

14301430
@wraps(Trainer.push_to_hub)
1431-
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
1431+
def push_to_hub(
1432+
self,
1433+
commit_message: Optional[str] = "End of training",
1434+
blocking: bool = True,
1435+
token: Optional[str] = None,
1436+
**kwargs,
1437+
) -> str:
14321438
"""
14331439
Overwrite the `push_to_hub` method in order to force-add the tag "bco" when pushing the
14341440
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
14351441
"""
14361442
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
1437-
1438-
return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
1443+
return super().push_to_hub(commit_message=commit_message, blocking=blocking, token=token, **kwargs)

trl/trainer/cpo_trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -966,11 +966,16 @@ def _shift_right(self, input_ids):
966966
return shifted_input_ids
967967

968968
@wraps(Trainer.push_to_hub)
969-
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
969+
def push_to_hub(
970+
self,
971+
commit_message: Optional[str] = "End of training",
972+
blocking: bool = True,
973+
token: Optional[str] = None,
974+
**kwargs,
975+
) -> str:
970976
"""
971977
Overwrite the `push_to_hub` method in order to force-add the tag "cpo" when pushing the
972978
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
973979
"""
974980
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
975-
976-
return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
981+
return super().push_to_hub(commit_message=commit_message, blocking=blocking, token=token, **kwargs)

trl/trainer/dpo_trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,11 +1611,16 @@ def log(self, logs: Dict[str, float]) -> None:
16111611
return super().log(logs)
16121612

16131613
@wraps(Trainer.push_to_hub)
1614-
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
1614+
def push_to_hub(
1615+
self,
1616+
commit_message: Optional[str] = "End of training",
1617+
blocking: bool = True,
1618+
token: Optional[str] = None,
1619+
**kwargs,
1620+
) -> str:
16151621
"""
16161622
Overwrite the `push_to_hub` method in order to force-add the tag "dpo" when pushing the
16171623
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
16181624
"""
16191625
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
1620-
1621-
return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
1626+
return super().push_to_hub(commit_message=commit_message, blocking=blocking, token=token, **kwargs)

trl/trainer/iterative_sft_trainer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import warnings
15+
from functools import wraps
1516
from typing import Callable, Dict, List, Optional, Tuple, Union
1617

1718
import torch
@@ -30,6 +31,7 @@
3031

3132
from ..core import PPODecorators
3233
from ..import_utils import is_peft_available
34+
from .utils import trl_sanitze_kwargs_for_tagging
3335

3436

3537
if is_peft_available():
@@ -58,6 +60,8 @@ class IterativeSFTTrainer(Trainer):
5860
**optimize_device_cache ** (`bool`, *optional*, defaults to `False`) -- Optimize CUDA cache for slightly more memory-efficient training.
5961
"""
6062

63+
_tag_names = ["trl", "iterative-sft"]
64+
6165
def __init__(
6266
self,
6367
model: Optional[PreTrainedModel] = None,
@@ -365,3 +369,18 @@ def _maybe_log_save_evaluate(self):
365369
self._globalstep_last_logged = self.state.global_step
366370

367371
self.log(logs)
372+
373+
@wraps(Trainer.push_to_hub)
374+
def push_to_hub(
375+
self,
376+
commit_message: Optional[str] = "End of training",
377+
blocking: bool = True,
378+
token: Optional[str] = None,
379+
**kwargs,
380+
) -> str:
381+
"""
382+
Overwrite the `push_to_hub` method in order to force-add the tag "dpo" when pushing the
383+
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
384+
"""
385+
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
386+
return super().push_to_hub(commit_message=commit_message, blocking=blocking, token=token, **kwargs)

trl/trainer/kto_trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,11 +1381,16 @@ def log(self, logs: Dict[str, float]) -> None:
13811381
return super().log(logs)
13821382

13831383
@wraps(Trainer.push_to_hub)
1384-
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
1384+
def push_to_hub(
1385+
self,
1386+
commit_message: Optional[str] = "End of training",
1387+
blocking: bool = True,
1388+
token: Optional[str] = None,
1389+
**kwargs,
1390+
) -> str:
13851391
"""
13861392
Overwrite the `push_to_hub` method in order to force-add the tag "kto" when pushing the
13871393
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
13881394
"""
13891395
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
1390-
1391-
return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
1396+
return super().push_to_hub(commit_message=commit_message, blocking=blocking, token=token, **kwargs)

trl/trainer/online_dpo_trainer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import time
55
from collections import defaultdict
6+
from functools import wraps
67
from typing import Dict, List, Optional, Tuple, Union
78

89
import numpy as np
@@ -40,12 +41,15 @@
4041
truncate_response,
4142
)
4243
from .online_dpo_config import OnlineDPOConfig
44+
from .utils import trl_sanitze_kwargs_for_tagging
4345

4446

4547
INVALID_LOGPROB = 1.0
4648

4749

4850
class OnlineDPOTrainer(Trainer):
51+
_tag_names = ["trl", "online-dpo"]
52+
4953
def __init__(
5054
self,
5155
config: OnlineDPOConfig,
@@ -570,3 +574,18 @@ def generate_completions(self, sampling: bool = False):
570574

571575
if wandb.run is not None:
572576
wandb.log({"completions": wandb.Table(dataframe=df)})
577+
578+
@wraps(Trainer.push_to_hub)
579+
def push_to_hub(
580+
self,
581+
commit_message: Optional[str] = "End of training",
582+
blocking: bool = True,
583+
token: Optional[str] = None,
584+
**kwargs,
585+
) -> str:
586+
"""
587+
Overwrite the `push_to_hub` method in order to force-add the tag "online-dpo" when pushing the
588+
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
589+
"""
590+
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
591+
return super().push_to_hub(commit_message=commit_message, blocking=blocking, token=token, **kwargs)

trl/trainer/orpo_trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -971,11 +971,16 @@ def _shift_right(self, input_ids):
971971
return shifted_input_ids
972972

973973
@wraps(Trainer.push_to_hub)
974-
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
974+
def push_to_hub(
975+
self,
976+
commit_message: Optional[str] = "End of training",
977+
blocking: bool = True,
978+
token: Optional[str] = None,
979+
**kwargs,
980+
) -> str:
975981
"""
976982
Overwrite the `push_to_hub` method in order to force-add the tag "orpo" when pushing the
977983
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
978984
"""
979985
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
980-
981-
return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
986+
return super().push_to_hub(commit_message=commit_message, blocking=blocking, token=token, **kwargs)

trl/trainer/ppov2_trainer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import time
55
from collections import defaultdict
6+
from functools import wraps
67
from typing import Dict, List, Optional, Tuple, Union
78

89
import numpy as np
@@ -41,6 +42,7 @@
4142
truncate_response,
4243
)
4344
from .ppov2_config import PPOv2Config
45+
from .utils import trl_sanitze_kwargs_for_tagging
4446

4547

4648
INVALID_LOGPROB = 1.0
@@ -64,6 +66,8 @@ def forward(self, **kwargs):
6466

6567

6668
class PPOv2Trainer(Trainer):
69+
_tag_names = ["trl", "ppo"]
70+
6771
def __init__(
6872
self,
6973
config: PPOv2Config,
@@ -607,3 +611,18 @@ def generate_completions(self, sampling: bool = False):
607611

608612
if wandb.run is not None:
609613
wandb.log({"completions": wandb.Table(dataframe=df)})
614+
615+
@wraps(Trainer.push_to_hub)
616+
def push_to_hub(
617+
self,
618+
commit_message: Optional[str] = "End of training",
619+
blocking: bool = True,
620+
token: Optional[str] = None,
621+
**kwargs,
622+
) -> str:
623+
"""
624+
Overwrite the `push_to_hub` method in order to force-add the tag "ppo" when pushing the
625+
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
626+
"""
627+
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
628+
return super().push_to_hub(commit_message=commit_message, blocking=blocking, token=token, **kwargs)

trl/trainer/reward_trainer.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import warnings
1616
from collections import defaultdict
1717
from dataclasses import FrozenInstanceError, replace
18+
from functools import wraps
1819
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1920

2021
import pandas as pd
@@ -29,7 +30,7 @@
2930

3031
from ..import_utils import is_peft_available
3132
from .reward_config import RewardConfig
32-
from .utils import RewardDataCollatorWithPadding, compute_accuracy, print_rich_table
33+
from .utils import RewardDataCollatorWithPadding, compute_accuracy, print_rich_table, trl_sanitze_kwargs_for_tagging
3334

3435

3536
if is_peft_available():
@@ -317,3 +318,18 @@ def visualize_samples(self, num_print_samples: int):
317318

318319
if wandb.run is not None:
319320
wandb.log({"completions": wandb.Table(dataframe=df)})
321+
322+
@wraps(Trainer.push_to_hub)
323+
def push_to_hub(
324+
self,
325+
commit_message: Optional[str] = "End of training",
326+
blocking: bool = True,
327+
token: Optional[str] = None,
328+
**kwargs,
329+
) -> str:
330+
"""
331+
Overwrite the `push_to_hub` method in order to force-add the tag "reward-trainer" when pushing the
332+
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
333+
"""
334+
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
335+
return super().push_to_hub(commit_message=commit_message, blocking=blocking, token=token, **kwargs)

trl/trainer/rloo_trainer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import time
55
from collections import defaultdict
6+
from functools import wraps
67
from typing import Dict, List, Optional, Tuple, Union
78

89
import numpy as np
@@ -40,12 +41,15 @@
4041
truncate_response,
4142
)
4243
from .rloo_config import RLOOConfig
44+
from .utils import trl_sanitze_kwargs_for_tagging
4345

4446

4547
INVALID_LOGPROB = 1.0
4648

4749

4850
class RLOOTrainer(Trainer):
51+
_tag_names = ["trl", "rloo"]
52+
4953
def __init__(
5054
self,
5155
config: RLOOConfig,
@@ -505,3 +509,18 @@ def generate_completions(self, sampling: bool = False):
505509

506510
if wandb.run is not None:
507511
wandb.log({"completions": wandb.Table(dataframe=df)})
512+
513+
@wraps(Trainer.push_to_hub)
514+
def push_to_hub(
515+
self,
516+
commit_message: Optional[str] = "End of training",
517+
blocking: bool = True,
518+
token: Optional[str] = None,
519+
**kwargs,
520+
) -> str:
521+
"""
522+
Overwrite the `push_to_hub` method in order to force-add the tag "rloo" when pushing the
523+
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
524+
"""
525+
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
526+
return super().push_to_hub(commit_message=commit_message, blocking=blocking, token=token, **kwargs)

0 commit comments

Comments
 (0)