35
35
backend_empty_cache ,
36
36
require_flash_attn ,
37
37
require_liger_kernel ,
38
- require_peft ,
39
38
require_torch_accelerator ,
40
39
torch_device ,
41
40
)
44
43
from trl import GRPOConfig , GRPOTrainer
45
44
from trl .trainer .utils import get_kbit_device_map
46
45
47
- from ..testing_utils import TrlTestCase , require_bitsandbytes , require_vllm
46
+ from ..testing_utils import TrlTestCase , require_bitsandbytes , require_peft , require_vllm
48
47
from .testing_constants import MODELS_TO_TEST
49
48
50
49
54
53
55
54
@pytest .mark .slow
56
55
@require_torch_accelerator
57
- class GRPOTrainerSlowTester (TrlTestCase ):
58
- def setUp (self ):
59
- super ().setUp ()
56
+ class TestGRPOTrainerSlow (TrlTestCase ):
57
+ def setup_method (self ):
60
58
self .train_dataset = load_dataset ("trl-internal-testing/zen" , "standard_prompt_only" , split = "train" )
61
59
self .eval_dataset = load_dataset ("trl-internal-testing/zen" , "standard_prompt_only" , split = "test" )
62
60
self .max_length = 128
63
61
64
- def tearDown (self ):
62
+ def teardown_method (self ):
65
63
gc .collect ()
66
64
backend_empty_cache (torch_device )
67
65
gc .collect ()
68
- super ().tearDown ()
69
66
70
67
@parameterized .expand (MODELS_TO_TEST )
71
68
@require_liger_kernel
@@ -103,7 +100,7 @@ def test_training_with_liger_grpo_loss(self, model_name):
103
100
104
101
for n , param in previous_trainable_params .items ():
105
102
new_param = model .get_parameter (n )
106
- self . assertFalse ( torch .equal (param , new_param ), f"Parameter { n } has not changed." )
103
+ assert not torch .equal (param , new_param ), f"Parameter { n } has not changed."
107
104
108
105
release_memory (model , trainer )
109
106
@@ -153,20 +150,20 @@ def test_training_with_liger_grpo_loss_and_peft(self, model_name):
153
150
# Verify PEFT adapter is properly initialized
154
151
from peft import PeftModel
155
152
156
- self . assertTrue ( isinstance (trainer .model , PeftModel ), "Model should be wrapped with PEFT" )
153
+ assert isinstance (trainer .model , PeftModel ), "Model should be wrapped with PEFT"
157
154
158
155
# Store adapter weights before training
159
156
previous_trainable_params = {
160
157
n : param .clone () for n , param in trainer .model .named_parameters () if param .requires_grad
161
158
}
162
- self . assertTrue ( len (previous_trainable_params ) > 0 , "No trainable parameters found in PEFT model" )
159
+ assert len (previous_trainable_params ) > 0 , "No trainable parameters found in PEFT model"
163
160
164
161
trainer .train ()
165
162
166
163
# Verify adapter weights have changed after training
167
164
for n , param in previous_trainable_params .items ():
168
165
new_param = trainer .model .get_parameter (n )
169
- self . assertFalse ( torch .equal (param , new_param ), f"Parameter { n } has not changed." )
166
+ assert not torch .equal (param , new_param ), f"Parameter { n } has not changed."
170
167
171
168
release_memory (model , trainer )
172
169
@@ -199,12 +196,12 @@ def test_training_with_transformers_paged(self, model_name):
199
196
200
197
trainer .train ()
201
198
202
- self . assertIsNotNone ( trainer .state .log_history [- 1 ]["train_loss" ])
199
+ assert trainer .state .log_history [- 1 ]["train_loss" ] is not None
203
200
204
201
# Check that the params have changed
205
202
for n , param in previous_trainable_params .items ():
206
203
new_param = model .get_parameter (n )
207
- self . assertFalse ( torch .equal (param , new_param ), f"Parameter { n } has not changed." )
204
+ assert not torch .equal (param , new_param ), f"Parameter { n } has not changed."
208
205
209
206
release_memory (model , trainer )
210
207
@@ -310,13 +307,13 @@ def reward_func(prompts, completions, **kwargs):
310
307
peft_config = lora_config ,
311
308
)
312
309
313
- self . assertIsInstance (trainer .model , PeftModel )
310
+ assert isinstance (trainer .model , PeftModel )
314
311
315
312
previous_trainable_params = {n : param .clone () for n , param in trainer .model .named_parameters ()}
316
313
317
314
trainer .train ()
318
315
319
- self . assertIsNotNone ( trainer .state .log_history [- 1 ]["train_loss" ])
316
+ assert trainer .state .log_history [- 1 ]["train_loss" ] is not None
320
317
321
318
# Check that LoRA parameters have changed
322
319
# For VLM models, we're more permissive about which parameters can change
@@ -328,7 +325,7 @@ def reward_func(prompts, completions, **kwargs):
328
325
lora_params_changed = True
329
326
330
327
# At least some LoRA parameters should have changed during training
331
- self . assertTrue ( lora_params_changed , "No LoRA parameters were updated during training." )
328
+ assert lora_params_changed , "No LoRA parameters were updated during training."
332
329
333
330
except torch .OutOfMemoryError as e :
334
331
self .skipTest (f"Skipping VLM training test due to insufficient GPU memory: { e } " )
@@ -378,8 +375,8 @@ def test_vlm_processor_vllm_colocate_mode(self):
378
375
processor = AutoProcessor .from_pretrained ("HuggingFaceTB/SmolVLM-Instruct" , use_fast = True , padding_side = "left" )
379
376
380
377
# Verify processor has both required attributes for VLM detection
381
- self . assertTrue ( hasattr (processor , "tokenizer" ) )
382
- self . assertTrue ( hasattr (processor , "image_processor" ) )
378
+ assert hasattr (processor , "tokenizer" )
379
+ assert hasattr (processor , "image_processor" )
383
380
384
381
def dummy_reward_func (completions , ** kwargs ):
385
382
return [1.0 ] * len (completions )
@@ -438,16 +435,14 @@ def dummy_reward_func(completions, **kwargs):
438
435
)
439
436
440
437
# Should detect VLM processor correctly and allow vLLM
441
- self . assertTrue ( trainer .use_vllm , "vLLM should be enabled for VLM processors in colocate mode" )
442
- self . assertEqual ( trainer .vllm_mode , "colocate" , "Should use colocate mode" )
438
+ assert trainer .use_vllm , "vLLM should be enabled for VLM processors in colocate mode"
439
+ assert trainer .vllm_mode == "colocate" , "Should use colocate mode"
443
440
444
441
# Check if signature columns were set properly
445
442
if trainer ._signature_columns is not None :
446
443
# Should include 'image' in signature columns for VLM processors
447
- self .assertIn (
448
- "image" ,
449
- trainer ._signature_columns ,
450
- "Should include 'image' in signature columns for VLM" ,
444
+ assert "image" in trainer ._signature_columns , (
445
+ "Should include 'image' in signature columns for VLM"
451
446
)
452
447
453
448
# Should not emit any warnings about VLM incompatibility
@@ -457,10 +452,8 @@ def dummy_reward_func(completions, **kwargs):
457
452
if "does not support VLMs" in str (w_item .message )
458
453
or "not compatible" in str (w_item .message ).lower ()
459
454
]
460
- self .assertEqual (
461
- len (incompatibility_warnings ),
462
- 0 ,
463
- f"Should not emit VLM incompatibility warnings, but got: { incompatibility_warnings } " ,
455
+ assert len (incompatibility_warnings ) == 0 , (
456
+ f"Should not emit VLM incompatibility warnings, but got: { incompatibility_warnings } "
464
457
)
465
458
466
459
# Test passes if we get this far without exceptions
@@ -525,12 +518,12 @@ def test_training_vllm(self):
525
518
526
519
trainer .train ()
527
520
528
- self . assertIsNotNone ( trainer .state .log_history [- 1 ]["train_loss" ])
521
+ assert trainer .state .log_history [- 1 ]["train_loss" ] is not None
529
522
530
523
# Check that the params have changed
531
524
for n , param in previous_trainable_params .items ():
532
525
new_param = trainer .model .get_parameter (n )
533
- self . assertFalse ( torch .equal (param , new_param ), f"Parameter { n } has not changed." )
526
+ assert not torch .equal (param , new_param ), f"Parameter { n } has not changed."
534
527
535
528
except Exception as e :
536
529
# If vLLM fails to initialize due to hardware constraints or other issues, that's expected
0 commit comments