Skip to content

Commit 0be6323

Browse files
authored
[tiny] Remove use_liger argument (#1779)
1 parent 1a6534d commit 0be6323

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

src/oumi/core/configs/training_config.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,13 @@ def __post_init__(self):
132132

133133
# Set Liger kernel flags if using a HF trainer, and if so, don't do Liger
134134
# patch ourselves.
135-
# TODO(OPE-1117): Clean up this logic after upgrading to trl 0.16.
136135
if self.model.enable_liger_kernel:
137-
if trainer_type == TrainerType.TRL_SFT:
138-
self.training.trainer_kwargs["use_liger"] = True
139-
self.training.trainer_kwargs["use_liger_kernel"] = True
140-
self.model.enable_liger_kernel = False
141-
elif trainer_type in (TrainerType.TRL_DPO, TrainerType.HF):
136+
if trainer_type in (
137+
TrainerType.TRL_SFT,
138+
TrainerType.TRL_DPO,
139+
TrainerType.TRL_GRPO,
140+
TrainerType.HF,
141+
):
142142
self.training.trainer_kwargs["use_liger_kernel"] = True
143143
self.model.enable_liger_kernel = False
144144
elif trainer_type == TrainerType.OUMI:

tests/e2e/test_train_e2e.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,20 @@ def test_train_text_1gpu_24gb(
446446
save_steps=5,
447447
is_lora=True,
448448
),
449+
TrainTestConfig(
450+
test_name="train_text_llama3_1_8b_trl_sft_full",
451+
config_path=(
452+
get_configs_dir()
453+
/ "recipes"
454+
/ "llama3_1"
455+
/ "sft"
456+
/ "8b_full"
457+
/ "train.yaml"
458+
),
459+
trainer_type=TrainerType.TRL_SFT,
460+
max_steps=5,
461+
save_steps=5,
462+
),
449463
],
450464
ids=get_train_test_id_fn,
451465
)

0 commit comments

Comments
 (0)