Skip to content

Commit 45ee98b

Browse files
Replace unittest with pytest (#4188)
1 parent 3800a6e commit 45ee98b

37 files changed

+1475
-1746
lines changed

tests/slow/test_dpo_slow.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121
from datasets import load_dataset
2222
from parameterized import parameterized
2323
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
24-
from transformers.testing_utils import backend_empty_cache, require_peft, require_torch_accelerator, torch_device
24+
from transformers.testing_utils import backend_empty_cache, require_torch_accelerator, torch_device
2525
from transformers.utils import is_peft_available
2626

2727
from trl import DPOConfig, DPOTrainer
2828

29-
from ..testing_utils import TrlTestCase, require_bitsandbytes
29+
from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft
3030
from .testing_constants import DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST
3131

3232

@@ -37,9 +37,8 @@
3737
@pytest.mark.slow
3838
@require_torch_accelerator
3939
@require_peft
40-
class DPOTrainerSlowTester(TrlTestCase):
41-
def setUp(self):
42-
super().setUp()
40+
class TestDPOTrainerSlow(TrlTestCase):
41+
def setup_method(self):
4342
self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
4443
self.peft_config = LoraConfig(
4544
lora_alpha=16,
@@ -50,11 +49,10 @@ def setUp(self):
5049
)
5150
self.max_length = 128
5251

53-
def tearDown(self):
52+
def teardown_method(self):
5453
gc.collect()
5554
backend_empty_cache(torch_device)
5655
gc.collect()
57-
super().tearDown()
5856

5957
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS)))
6058
def test_dpo_bare_model(self, model_id, loss_type, pre_compute_logits):
@@ -151,8 +149,8 @@ def test_dpo_peft_model(self, model_id, loss_type, pre_compute_logits, gradient_
151149
peft_config=self.peft_config,
152150
)
153151

154-
self.assertIsInstance(trainer.model, PeftModel)
155-
self.assertIsNone(trainer.ref_model)
152+
assert isinstance(trainer.model, PeftModel)
153+
assert trainer.ref_model is None
156154

157155
# train the model
158156
trainer.train()
@@ -215,8 +213,8 @@ def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gra
215213
peft_config=self.peft_config,
216214
)
217215

218-
self.assertIsInstance(trainer.model, PeftModel)
219-
self.assertIsNone(trainer.ref_model)
216+
assert isinstance(trainer.model, PeftModel)
217+
assert trainer.ref_model is None
220218

221219
# train the model
222220
trainer.train()

tests/slow/test_grpo_slow.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
backend_empty_cache,
3636
require_flash_attn,
3737
require_liger_kernel,
38-
require_peft,
3938
require_torch_accelerator,
4039
torch_device,
4140
)
@@ -44,7 +43,7 @@
4443
from trl import GRPOConfig, GRPOTrainer
4544
from trl.trainer.utils import get_kbit_device_map
4645

47-
from ..testing_utils import TrlTestCase, require_bitsandbytes, require_vllm
46+
from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft, require_vllm
4847
from .testing_constants import MODELS_TO_TEST
4948

5049

@@ -54,18 +53,16 @@
5453

5554
@pytest.mark.slow
5655
@require_torch_accelerator
57-
class GRPOTrainerSlowTester(TrlTestCase):
58-
def setUp(self):
59-
super().setUp()
56+
class TestGRPOTrainerSlow(TrlTestCase):
57+
def setup_method(self):
6058
self.train_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
6159
self.eval_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="test")
6260
self.max_length = 128
6361

64-
def tearDown(self):
62+
def teardown_method(self):
6563
gc.collect()
6664
backend_empty_cache(torch_device)
6765
gc.collect()
68-
super().tearDown()
6966

7067
@parameterized.expand(MODELS_TO_TEST)
7168
@require_liger_kernel
@@ -103,7 +100,7 @@ def test_training_with_liger_grpo_loss(self, model_name):
103100

104101
for n, param in previous_trainable_params.items():
105102
new_param = model.get_parameter(n)
106-
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
103+
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
107104

108105
release_memory(model, trainer)
109106

@@ -153,20 +150,20 @@ def test_training_with_liger_grpo_loss_and_peft(self, model_name):
153150
# Verify PEFT adapter is properly initialized
154151
from peft import PeftModel
155152

156-
self.assertTrue(isinstance(trainer.model, PeftModel), "Model should be wrapped with PEFT")
153+
assert isinstance(trainer.model, PeftModel), "Model should be wrapped with PEFT"
157154

158155
# Store adapter weights before training
159156
previous_trainable_params = {
160157
n: param.clone() for n, param in trainer.model.named_parameters() if param.requires_grad
161158
}
162-
self.assertTrue(len(previous_trainable_params) > 0, "No trainable parameters found in PEFT model")
159+
assert len(previous_trainable_params) > 0, "No trainable parameters found in PEFT model"
163160

164161
trainer.train()
165162

166163
# Verify adapter weights have changed after training
167164
for n, param in previous_trainable_params.items():
168165
new_param = trainer.model.get_parameter(n)
169-
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
166+
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
170167

171168
release_memory(model, trainer)
172169

@@ -199,12 +196,12 @@ def test_training_with_transformers_paged(self, model_name):
199196

200197
trainer.train()
201198

202-
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
199+
assert trainer.state.log_history[-1]["train_loss"] is not None
203200

204201
# Check that the params have changed
205202
for n, param in previous_trainable_params.items():
206203
new_param = model.get_parameter(n)
207-
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
204+
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
208205

209206
release_memory(model, trainer)
210207

@@ -310,13 +307,13 @@ def reward_func(prompts, completions, **kwargs):
310307
peft_config=lora_config,
311308
)
312309

313-
self.assertIsInstance(trainer.model, PeftModel)
310+
assert isinstance(trainer.model, PeftModel)
314311

315312
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
316313

317314
trainer.train()
318315

319-
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
316+
assert trainer.state.log_history[-1]["train_loss"] is not None
320317

321318
# Check that LoRA parameters have changed
322319
# For VLM models, we're more permissive about which parameters can change
@@ -328,7 +325,7 @@ def reward_func(prompts, completions, **kwargs):
328325
lora_params_changed = True
329326

330327
# At least some LoRA parameters should have changed during training
331-
self.assertTrue(lora_params_changed, "No LoRA parameters were updated during training.")
328+
assert lora_params_changed, "No LoRA parameters were updated during training."
332329

333330
except torch.OutOfMemoryError as e:
334331
self.skipTest(f"Skipping VLM training test due to insufficient GPU memory: {e}")
@@ -378,8 +375,8 @@ def test_vlm_processor_vllm_colocate_mode(self):
378375
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct", use_fast=True, padding_side="left")
379376

380377
# Verify processor has both required attributes for VLM detection
381-
self.assertTrue(hasattr(processor, "tokenizer"))
382-
self.assertTrue(hasattr(processor, "image_processor"))
378+
assert hasattr(processor, "tokenizer")
379+
assert hasattr(processor, "image_processor")
383380

384381
def dummy_reward_func(completions, **kwargs):
385382
return [1.0] * len(completions)
@@ -438,16 +435,14 @@ def dummy_reward_func(completions, **kwargs):
438435
)
439436

440437
# Should detect VLM processor correctly and allow vLLM
441-
self.assertTrue(trainer.use_vllm, "vLLM should be enabled for VLM processors in colocate mode")
442-
self.assertEqual(trainer.vllm_mode, "colocate", "Should use colocate mode")
438+
assert trainer.use_vllm, "vLLM should be enabled for VLM processors in colocate mode"
439+
assert trainer.vllm_mode == "colocate", "Should use colocate mode"
443440

444441
# Check if signature columns were set properly
445442
if trainer._signature_columns is not None:
446443
# Should include 'image' in signature columns for VLM processors
447-
self.assertIn(
448-
"image",
449-
trainer._signature_columns,
450-
"Should include 'image' in signature columns for VLM",
444+
assert "image" in trainer._signature_columns, (
445+
"Should include 'image' in signature columns for VLM"
451446
)
452447

453448
# Should not emit any warnings about VLM incompatibility
@@ -457,10 +452,8 @@ def dummy_reward_func(completions, **kwargs):
457452
if "does not support VLMs" in str(w_item.message)
458453
or "not compatible" in str(w_item.message).lower()
459454
]
460-
self.assertEqual(
461-
len(incompatibility_warnings),
462-
0,
463-
f"Should not emit VLM incompatibility warnings, but got: {incompatibility_warnings}",
455+
assert len(incompatibility_warnings) == 0, (
456+
f"Should not emit VLM incompatibility warnings, but got: {incompatibility_warnings}"
464457
)
465458

466459
# Test passes if we get this far without exceptions
@@ -525,12 +518,12 @@ def test_training_vllm(self):
525518

526519
trainer.train()
527520

528-
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
521+
assert trainer.state.log_history[-1]["train_loss"] is not None
529522

530523
# Check that the params have changed
531524
for n, param in previous_trainable_params.items():
532525
new_param = trainer.model.get_parameter(n)
533-
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
526+
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
534527

535528
except Exception as e:
536529
# If vLLM fails to initialize due to hardware constraints or other issues, that's expected

tests/slow/test_sft_slow.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from transformers.testing_utils import (
2525
backend_empty_cache,
2626
require_liger_kernel,
27-
require_peft,
2827
require_torch_accelerator,
2928
require_torch_multi_accelerator,
3029
torch_device,
@@ -33,7 +32,7 @@
3332

3433
from trl import SFTConfig, SFTTrainer
3534

36-
from ..testing_utils import TrlTestCase, require_bitsandbytes
35+
from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft
3736
from .testing_constants import DEVICE_MAP_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST, PACKING_OPTIONS
3837

3938

@@ -44,9 +43,8 @@
4443
@pytest.mark.slow
4544
@require_torch_accelerator
4645
@require_peft
47-
class SFTTrainerSlowTester(TrlTestCase):
48-
def setUp(self):
49-
super().setUp()
46+
class TestSFTTrainerSlow(TrlTestCase):
47+
def setup_method(self):
5048
self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]")
5149
self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]")
5250
self.max_length = 128
@@ -58,11 +56,10 @@ def setUp(self):
5856
task_type="CAUSAL_LM",
5957
)
6058

61-
def tearDown(self):
59+
def teardown_method(self):
6260
gc.collect()
6361
backend_empty_cache(torch_device)
6462
gc.collect()
65-
super().tearDown()
6663

6764
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
6865
def test_sft_trainer_str(self, model_name, packing):
@@ -148,7 +145,7 @@ def test_sft_trainer_peft(self, model_name, packing):
148145
peft_config=self.peft_config,
149146
)
150147

151-
self.assertIsInstance(trainer.model, PeftModel)
148+
assert isinstance(trainer.model, PeftModel)
152149

153150
trainer.train()
154151

@@ -252,7 +249,7 @@ def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient
252249
peft_config=self.peft_config,
253250
)
254251

255-
self.assertIsInstance(trainer.model, PeftModel)
252+
assert isinstance(trainer.model, PeftModel)
256253

257254
trainer.train()
258255

@@ -332,7 +329,7 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr
332329
peft_config=self.peft_config,
333330
)
334331

335-
self.assertIsInstance(trainer.model, PeftModel)
332+
assert isinstance(trainer.model, PeftModel)
336333

337334
trainer.train()
338335

@@ -372,7 +369,7 @@ def test_sft_trainer_with_chat_format_qlora(self, model_name, packing):
372369
peft_config=self.peft_config,
373370
)
374371

375-
self.assertIsInstance(trainer.model, PeftModel)
372+
assert isinstance(trainer.model, PeftModel)
376373

377374
trainer.train()
378375

@@ -447,11 +444,11 @@ def test_train_offloading(self, model_name, packing):
447444
trainer.train()
448445

449446
# Check that the training loss is not None
450-
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
447+
assert trainer.state.log_history[-1]["train_loss"] is not None
451448

452449
# Check the params have changed
453450
for n, param in previous_trainable_params.items():
454451
new_param = trainer.model.get_parameter(n)
455-
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
452+
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
456453

457454
release_memory(trainer.model, trainer)

tests/test_activation_offloading.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
import torch
1717
from torch import nn
1818
from transformers import AutoModelForCausalLM
19-
from transformers.testing_utils import require_peft, require_torch_accelerator, torch_device
19+
from transformers.testing_utils import require_torch_accelerator, torch_device
2020
from transformers.utils import is_peft_available
2121

2222
from trl.models.activation_offloading import NoOpManager, OffloadActivations
2323

24-
from .testing_utils import TrlTestCase
24+
from .testing_utils import TrlTestCase, require_peft
2525

2626

2727
if is_peft_available():
@@ -72,9 +72,8 @@ def test_offloading_with_peft_models(self) -> None:
7272
for name_orig, grad_orig in grads_original:
7373
for name_param, param in model.named_parameters():
7474
if name_param == name_orig and param.requires_grad and param.grad is not None:
75-
self.assertTrue(
76-
torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5),
77-
f"Gradient mismatch for {name_orig}",
75+
assert torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5), (
76+
f"Gradient mismatch for {name_orig}"
7877
)
7978

8079
@require_torch_accelerator
@@ -105,7 +104,7 @@ def test_noop_manager_with_offloading(self):
105104

106105
# Gradients should match as NoOpManager should have prevented offloading
107106
for g1, g2 in zip(grads1, grads2):
108-
self.assertTrue(torch.allclose(g1, g2, rtol=1e-4, atol=1e-5))
107+
assert torch.allclose(g1, g2, rtol=1e-4, atol=1e-5)
109108

110109
@require_torch_accelerator
111110
def test_min_offload_size(self):
@@ -152,6 +151,6 @@ def test_real_hf_model(self):
152151
grads2 = [p.grad.clone() for p in model.parameters()]
153152

154153
# Check outputs and gradients match
155-
self.assertTrue(torch.allclose(out1, out2, rtol=1e-5))
154+
assert torch.allclose(out1, out2, rtol=1e-5)
156155
for g1, g2 in zip(grads1, grads2):
157-
self.assertTrue(torch.allclose(g1, g2, rtol=1e-5))
156+
assert torch.allclose(g1, g2, rtol=1e-5)

0 commit comments

Comments
 (0)