Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions tests/unit/rl_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import unittest
import pytest
from unittest import mock
from types import SimpleNamespace

evaluate_rl = pytest.importorskip(
Expand Down Expand Up @@ -330,5 +331,112 @@ def test_without_hash(self):
self.assertIsNone(utils_rl.extract_hash_answer(""))


class TestCheckAnswer(unittest.TestCase):
"""Tests for utils_rl.check_answer."""

def setUp(self):
self.config = _make_config()

@pytest.mark.cpu_only
def test_exact_match(self):
"""Test when the guess exactly matches the answer."""
scores = utils_rl.check_answer(
prompts=[""],
completions=["<reasoning>r</reasoning><answer>42</answer>"],
answer=["42"],
tmvp_config=self.config,
)
self.assertEqual(scores, [self.config.reward_exact_format_match])

@pytest.mark.cpu_only
def test_whitespace_match(self):
"""Test when the guess matches the answer after stripping whitespace."""
scores = utils_rl.check_answer(
prompts=[""],
completions=["<reasoning>r</reasoning><answer> 42 </answer>"],
answer=["42"],
tmvp_config=self.config,
)
self.assertEqual(scores, [self.config.reward_white_space_format_match])

@pytest.mark.cpu_only
def test_ratio_high_match(self):
"""Test when the guess is within the high ratio threshold (0.9 to 1.1)."""
scores = utils_rl.check_answer(
prompts=[""],
completions=["<reasoning>r</reasoning><answer>10.5</answer>"],
answer=["10"],
tmvp_config=self.config,
)
self.assertEqual(scores, [self.config.reward_ratio_guess_to_answer_high])

@pytest.mark.cpu_only
def test_ratio_low_match(self):
"""Test when the guess is within the low ratio threshold (0.8 to 1.2)."""
scores = utils_rl.check_answer(
prompts=[""],
completions=["<reasoning>r</reasoning><answer>11.5</answer>"],
answer=["10"],
tmvp_config=self.config,
)
self.assertEqual(scores, [self.config.reward_ratio_guess_to_answer_low])

@pytest.mark.cpu_only
def test_incorrect_answer(self):
"""Test when the guess is outside the acceptable ratio thresholds."""
scores = utils_rl.check_answer(
prompts=[""],
completions=["<reasoning>r</reasoning><answer>15</answer>"],
answer=["10"],
tmvp_config=self.config,
)
self.assertEqual(scores, [self.config.penalty_incorrect_answer])

@pytest.mark.cpu_only
def test_no_format_match(self):
"""Test when the completion does not match the expected format."""
scores = utils_rl.check_answer(
prompts=[""],
completions=["Just some random text without the right tags."],
answer=["42"],
tmvp_config=self.config,
)
self.assertEqual(scores, [0])

@pytest.mark.cpu_only
@mock.patch.object(utils_rl, "math_verify_func")
@mock.patch.object(utils_rl, "parse")
def test_dataset_specific_normalization(self, mock_parse, mock_math_verify):
"""Test that specific datasets trigger normalize_final_answer."""
# Mock math_verify and parse to fail, so we only rely on exact string match
# after normalization.
mock_math_verify.return_value = [0.0]
mock_parse.side_effect = Exception("parse failed")

complex_guess = r"x=$\text{\textbf{\overline{\boxed{\frac12 \sqrt3}}}}$"
expected_normalized = r"\frac{1}{2} \sqrt{3}"

# These datasets should trigger normalization and match exactly
for dataset in ["DAPO-Math-17k", "OpenMathInstruct-2"]:
self.config.dataset_name = dataset
scores = utils_rl.check_answer(
prompts=[""],
completions=[f"<reasoning>r</reasoning><answer>{complex_guess}</answer>"],
answer=[expected_normalized],
tmvp_config=self.config,
)
self.assertEqual(scores, [self.config.reward_exact_format_match])

# Other datasets should not normalize, leading to a mismatch/parse error
self.config.dataset_name = "gsm8k"
scores = utils_rl.check_answer(
prompts=[""],
completions=[f"<reasoning>r</reasoning><answer>{complex_guess}</answer>"],
answer=[expected_normalized],
tmvp_config=self.config,
)
self.assertEqual(scores, [self.config.penalty_incorrect_format])


if __name__ == "__main__":
unittest.main()
Loading