diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 1d1e94fcf9..1d4860c1c5 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -33,12 +33,18 @@ from transformers.testing_utils import ( get_device_properties, require_liger_kernel, - require_torch_gpu_if_bnb_not_multi_backend_enabled, ) from trl import DPOConfig, DPOTrainer, FDivergenceType -from .testing_utils import TrlTestCase, require_bitsandbytes, require_no_wandb, require_peft, require_vision +from .testing_utils import ( + TrlTestCase, + require_bitsandbytes, + require_no_wandb, + require_peft, + require_torch_gpu_if_bnb_not_multi_backend_enabled, + require_vision, +) if is_vision_available(): diff --git a/tests/test_peft_models.py b/tests/test_peft_models.py index 508ad17556..ff50972834 100644 --- a/tests/test_peft_models.py +++ b/tests/test_peft_models.py @@ -16,12 +16,11 @@ import torch from transformers import AutoModelForCausalLM -from transformers.testing_utils import require_torch_gpu_if_bnb_not_multi_backend_enabled from transformers.utils import is_peft_available from trl import AutoModelForCausalLMWithValueHead -from .testing_utils import TrlTestCase, require_peft +from .testing_utils import TrlTestCase, require_peft, require_torch_gpu_if_bnb_not_multi_backend_enabled if is_peft_available(): diff --git a/tests/testing_utils.py b/tests/testing_utils.py index d012d26881..fb2dec9cff 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -47,6 +47,21 @@ ) +def is_bitsandbytes_multi_backend_available() -> bool: + if is_bitsandbytes_available(): + import bitsandbytes as bnb + + return "multi_backend" in getattr(bnb, "features", set()) + return False + + +# Function ported from transformers.testing_utils before transformers#41283 +require_torch_gpu_if_bnb_not_multi_backend_enabled = pytest.mark.skipif( + not is_bitsandbytes_multi_backend_available() and not torch_device == "cuda", + reason="test requires bitsandbytes multi-backend enabled or 'cuda' torch device", +) + + class RandomBinaryJudge(BaseBinaryJudge): """ Random binary judge, for testing purposes.