File tree Expand file tree Collapse file tree 3 files changed +9
-9
lines changed Expand file tree Collapse file tree 3 files changed +9
-9
lines changed Original file line number Diff line number Diff line change @@ -1027,7 +1027,7 @@ def test_training_with_mask_truncated_completions_all_masked(self):
10271027 new_param = trainer .model .get_parameter (n )
10281028 assert torch .equal (param , new_param ), f"Parameter { n } has changed."
10291029
1030- def test_warning_raised_all_rewards_none (self ):
1030+ def test_warning_raised_all_rewards_none (self , caplog ):
10311031 """Test that a proper warning is raised when all rewards are None."""
10321032 dataset = load_dataset ("trl-internal-testing/zen" , "standard_prompt_only" , split = "train" )
10331033
@@ -1050,11 +1050,11 @@ def always_none_reward_func(completions, **kwargs):
10501050 train_dataset = dataset ,
10511051 )
10521052
1053- with self . assertLogs ( " trl.trainer.grpo_trainer", level = "WARNING" ) as cm :
1053+ with caplog . at_level ( "WARNING" , logger = " trl.trainer.grpo_trainer") :
10541054 trainer .train ()
10551055
10561056 expected_warning = "All reward functions returned None for the following kwargs:"
1057- assert expected_warning in cm . output [ 0 ]
1057+ assert expected_warning in caplog . text
10581058
10591059 def test_training_num_generations_larger_than_batch_size (self ):
10601060 dataset = load_dataset ("trl-internal-testing/zen" , "standard_prompt_only" , split = "train" )
Original file line number Diff line number Diff line change @@ -882,7 +882,7 @@ def test_training_with_mask_truncated_completions_all_masked(self):
882882 new_param = trainer .model .get_parameter (n )
883883 assert torch .equal (param , new_param ), f"Parameter { n } has changed."
884884
885- def test_warning_raised_all_rewards_none (self ):
885+ def test_warning_raised_all_rewards_none (self , caplog ):
886886 """Test that a proper warning is raised when all rewards are None."""
887887 dataset = load_dataset ("trl-internal-testing/zen" , "standard_prompt_only" , split = "train" )
888888
@@ -905,11 +905,11 @@ def always_none_reward_func(completions, **kwargs):
905905 train_dataset = dataset ,
906906 )
907907
908- with self . assertLogs ( " trl.trainer.rloo_trainer", level = "WARNING" ) as cm :
908+ with caplog . at_level ( "WARNING" , logger = " trl.trainer.rloo_trainer") :
909909 trainer .train ()
910910
911911 expected_warning = "All reward functions returned None for the following kwargs:"
912- assert expected_warning in cm . output [ 0 ]
912+ assert expected_warning in caplog . text
913913
914914 def test_training_num_generations_larger_than_batch_size (self ):
915915 dataset = load_dataset ("trl-internal-testing/zen" , "standard_prompt_only" , split = "train" )
Original file line number Diff line number Diff line change @@ -422,7 +422,7 @@ def test_token_classification_task_with_ignored_tokens_1(self):
422422 result = compute_accuracy (eval_pred )
423423 assert round (abs (result ["accuracy" ] - expected_accuracy ), 7 ) == 0
424424
425- def test_rewards_comparison_task (self ):
425+ def test_rewards_comparison_task (self , caplog ):
426426 eval_pred = (
427427 np .array (
428428 [
@@ -435,15 +435,15 @@ def test_rewards_comparison_task(self):
435435 )
436436 expected_accuracy = 0.5 # 1 match, 1 mismatch, 1 equal (ignored)
437437
438- with self . assertLogs ( " trl.trainer.utils", level = "WARNING" ) as cm :
438+ with caplog . at_level ( "WARNING" , logger = " trl.trainer.utils") :
439439 result = compute_accuracy (eval_pred )
440440
441441 assert round (abs (result ["accuracy" ] - expected_accuracy ), 7 ) == 0
442442 expected_warning = (
443443 "There are 1 out of 3 instances where the predictions for both options are equal. "
444444 "These instances are ignored in the accuracy computation."
445445 )
446- assert expected_warning in cm . output [ 0 ]
446+ assert expected_warning in caplog . text
447447
448448
449449class TestFlushLeft (TrlTestCase ):
You can’t perform that action at this time.
0 commit comments