File tree Expand file tree Collapse file tree 2 files changed +20
-6
lines changed Expand file tree Collapse file tree 2 files changed +20
-6
lines changed Original file line number Diff line number Diff line change @@ -132,13 +132,13 @@ def __post_init__(self):
132
132
133
133
# Set Liger kernel flags if using a HF trainer, and if so, don't do Liger
134
134
# patch ourselves.
135
- # TODO(OPE-1117): Clean up this logic after upgrading to trl 0.16.
136
135
if self .model .enable_liger_kernel :
137
- if trainer_type == TrainerType .TRL_SFT :
138
- self .training .trainer_kwargs ["use_liger" ] = True
139
- self .training .trainer_kwargs ["use_liger_kernel" ] = True
140
- self .model .enable_liger_kernel = False
141
- elif trainer_type in (TrainerType .TRL_DPO , TrainerType .HF ):
136
+ if trainer_type in (
137
+ TrainerType .TRL_SFT ,
138
+ TrainerType .TRL_DPO ,
139
+ TrainerType .TRL_GRPO ,
140
+ TrainerType .HF ,
141
+ ):
142
142
self .training .trainer_kwargs ["use_liger_kernel" ] = True
143
143
self .model .enable_liger_kernel = False
144
144
elif trainer_type == TrainerType .OUMI :
Original file line number Diff line number Diff line change @@ -446,6 +446,20 @@ def test_train_text_1gpu_24gb(
446
446
save_steps = 5 ,
447
447
is_lora = True ,
448
448
),
449
+ TrainTestConfig (
450
+ test_name = "train_text_llama3_1_8b_trl_sft_full" ,
451
+ config_path = (
452
+ get_configs_dir ()
453
+ / "recipes"
454
+ / "llama3_1"
455
+ / "sft"
456
+ / "8b_full"
457
+ / "train.yaml"
458
+ ),
459
+ trainer_type = TrainerType .TRL_SFT ,
460
+ max_steps = 5 ,
461
+ save_steps = 5 ,
462
+ ),
449
463
],
450
464
ids = get_train_test_id_fn ,
451
465
)
You can’t perform that action at this time.
0 commit comments