Skip to content

Commit 309e76d

Browse files
Replace assertLogs
1 parent 01d1262 commit 309e76d

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

tests/test_grpo_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,7 @@ def test_training_with_mask_truncated_completions_all_masked(self):
10271027
new_param = trainer.model.get_parameter(n)
10281028
assert torch.equal(param, new_param), f"Parameter {n} has changed."
10291029

1030-
def test_warning_raised_all_rewards_none(self):
1030+
def test_warning_raised_all_rewards_none(self, caplog):
10311031
"""Test that a proper warning is raised when all rewards are None."""
10321032
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
10331033

@@ -1050,11 +1050,11 @@ def always_none_reward_func(completions, **kwargs):
10501050
train_dataset=dataset,
10511051
)
10521052

1053-
with self.assertLogs("trl.trainer.grpo_trainer", level="WARNING") as cm:
1053+
with caplog.at_level("WARNING", logger="trl.trainer.grpo_trainer"):
10541054
trainer.train()
10551055

10561056
expected_warning = "All reward functions returned None for the following kwargs:"
1057-
assert expected_warning in cm.output[0]
1057+
assert expected_warning in caplog.text
10581058

10591059
def test_training_num_generations_larger_than_batch_size(self):
10601060
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

tests/test_rloo_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,7 @@ def test_training_with_mask_truncated_completions_all_masked(self):
882882
new_param = trainer.model.get_parameter(n)
883883
assert torch.equal(param, new_param), f"Parameter {n} has changed."
884884

885-
def test_warning_raised_all_rewards_none(self):
885+
def test_warning_raised_all_rewards_none(self, caplog):
886886
"""Test that a proper warning is raised when all rewards are None."""
887887
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
888888

@@ -905,11 +905,11 @@ def always_none_reward_func(completions, **kwargs):
905905
train_dataset=dataset,
906906
)
907907

908-
with self.assertLogs("trl.trainer.rloo_trainer", level="WARNING") as cm:
908+
with caplog.at_level("WARNING", logger="trl.trainer.rloo_trainer"):
909909
trainer.train()
910910

911911
expected_warning = "All reward functions returned None for the following kwargs:"
912-
assert expected_warning in cm.output[0]
912+
assert expected_warning in caplog.text
913913

914914
def test_training_num_generations_larger_than_batch_size(self):
915915
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

tests/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def test_token_classification_task_with_ignored_tokens_1(self):
422422
result = compute_accuracy(eval_pred)
423423
assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0
424424

425-
def test_rewards_comparison_task(self):
425+
def test_rewards_comparison_task(self, caplog):
426426
eval_pred = (
427427
np.array(
428428
[
@@ -435,15 +435,15 @@ def test_rewards_comparison_task(self):
435435
)
436436
expected_accuracy = 0.5 # 1 match, 1 mismatch, 1 equal (ignored)
437437

438-
with self.assertLogs("trl.trainer.utils", level="WARNING") as cm:
438+
with caplog.at_level("WARNING", logger="trl.trainer.utils"):
439439
result = compute_accuracy(eval_pred)
440440

441441
assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0
442442
expected_warning = (
443443
"There are 1 out of 3 instances where the predictions for both options are equal. "
444444
"These instances are ignored in the accuracy computation."
445445
)
446-
assert expected_warning in cm.output[0]
446+
assert expected_warning in caplog.text
447447

448448

449449
class TestFlushLeft(TrlTestCase):

0 commit comments

Comments
 (0)