Skip to content

Commit 54d4f6b

Browse files
authored
🎁 Reward submodule (#3430)
1 parent 05bc43e commit 54d4f6b

File tree

6 files changed

+198
-4
lines changed

6 files changed

+198
-4
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@
107107
title: Callbacks
108108
- local: data_utils
109109
title: Data Utilities
110+
- local: rewards
111+
title: Reward Functions
110112
- local: script_utils
111113
title: Script Utilities
112114
- local: others

docs/source/rewards.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Reward Functions
2+
3+
This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`].
4+
5+
## Format rewards
6+
7+
### think_format_reward
8+
9+
[[autodoc]] rewards.think_format_reward

tests/test_rewards.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
from trl.rewards import think_format_reward
18+
19+
20+
class ThinkFormatRewardTester(unittest.TestCase):
21+
def test_valid_format(self):
22+
completions = [
23+
"<think>This is my reasoning.</think>This is my answer.", # Simple, one-line reasoning
24+
"<think>\nThis is my reasoning.\n</think>\nThis is my answer.", # Multiline reasoning
25+
"<think>\nThis is\nmy reasoning.\n</think>\nThis is my answer.", # Multiline reasoning
26+
"<think>\nThis is <some tag> my reasoning.</think>\nThis is my answer.", # Reasoning including other tags
27+
"<think></think>\nThis is my answer.", # Empty reasoning
28+
]
29+
completions = [[{"content": completion}] for completion in completions]
30+
expected_rewards = [1.0, 1.0, 1.0, 1.0, 1.0] # All should be valid
31+
rewards = think_format_reward(completions)
32+
self.assertEqual(rewards, expected_rewards)
33+
34+
def test_invalid_format(self):
35+
completions = [
36+
"<think>\nThis is my reasoning.\nThis is my answer.", # No closing </think>
37+
"<think>This is my reasoning.\nThis is my answer.", # No closing </think>
38+
"This is my reasoning. This is my answer.", # No <think> tags
39+
"This is my reasoning.\nThis is my answer.", # No <think> tags
40+
"This is my reasoning.</think>\nThis is my answer.", # No opening <think>
41+
"This is my reasoning.</think>This is my answer.", # No opening <think>
42+
"This<think>is my reasoning.</think>\nThis is my answer.", # <think> tag in the middle
43+
"<think>This is<think>my reasoning.</think></think>This is my answer.", # Nested <think> tags
44+
"<think>This is</think>\nmy\n<think>reasoning.</think>\nThis is my answer.", # Multiline <think>
45+
]
46+
completions = [[{"content": completion}] for completion in completions]
47+
expected_rewards = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] # All should be invalid
48+
rewards = think_format_reward(completions)
49+
self.assertEqual(rewards, expected_rewards)
50+
51+
def test_mixed_format(self):
52+
completions = [
53+
"<think>This is my reasoning.</think>This is my answer.", # Valid
54+
"<think>\nThis is my reasoning.\n</think>\nThis is my answer.", # Valid
55+
"<think>This is my reasoning.\nThis is my answer.", # Invalid
56+
"This is my reasoning. This is my answer.", # Invalid
57+
]
58+
completions = [[{"content": completion}] for completion in completions]
59+
expected_rewards = [1.0, 1.0, 0.0, 0.0]
60+
rewards = think_format_reward(completions)
61+
self.assertEqual(rewards, expected_rewards)
62+
63+
64+
if __name__ == "__main__":
65+
unittest.main()

trl/rewards/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import sys
17+
from typing import TYPE_CHECKING
18+
19+
from ..import_utils import _LazyModule
20+
21+
22+
_import_structure = {
23+
"format_rewards": ["think_format_reward"],
24+
}
25+
26+
27+
if TYPE_CHECKING:
28+
from .format_rewards import think_format_reward
29+
30+
31+
else:
32+
sys.modules[__name__] = _LazyModule(__name__, __file__, _import_structure, module_spec=__spec__)

trl/rewards/format_rewards.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import re
16+
17+
18+
def think_format_reward(completions: list[list[dict[str, str]]], **kwargs) -> list[float]:
19+
r"""
20+
Reward function that checks if the reasoning process is enclosed within `"<think>"` and `"</think>"` tags. The
21+
function returns a reward of 1.0 if the format is correct, otherwise 0.0.
22+
23+
Args:
24+
completions (`list[list[dict[str, str]]]`):
25+
List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary
26+
containing the key `"content"` with the value being the text of the completion.
27+
**kwargs:
28+
Additional keyword arguments. This function does not use them, but they are required in the function
29+
signature to ensure compatibility with trainers like [`GRPOTrainer`].
30+
31+
Returns:
32+
`list[float]`:
33+
A list of rewards, where each reward is 1.0 if the completion matches the expected format, otherwise 0.0.
34+
35+
Example:
36+
```python
37+
>>> from trl.rewards import think_format_reward
38+
>>> completions = [
39+
... [{"content": "<think>\nThis is my reasoning.\n</think>\nThis is my answer."}],
40+
... [{"content": "<think>\nThis is my reasoning.\nThis is my answer."}],
41+
... ]
42+
>>> think_format_reward(completions)
43+
[1.0, 0.0]
44+
```
45+
"""
46+
pattern = r"^<think>(?!.*<think>)(.*?)</think>.*$"
47+
completion_contents = [completion[0]["content"] for completion in completions]
48+
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
49+
return [1.0 if match else 0.0 for match in matches]

trl/scripts/grpo.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,20 @@
1313
# limitations under the License.
1414

1515
import argparse
16+
import importlib
1617
from dataclasses import dataclass, field
1718
from typing import Optional
1819

1920
from datasets import load_dataset
2021
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
2122

2223
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
24+
from trl.rewards import think_format_reward
25+
26+
27+
reward_funcs_registry = {
28+
"think_format_reward": think_format_reward,
29+
}
2330

2431

2532
@dataclass
@@ -28,9 +35,12 @@ class GRPOScriptArguments(ScriptArguments):
2835
Script arguments for the GRPO training script.
2936
3037
Args:
31-
reward_model_name_or_path (`str` or `None`):
38+
reward_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
3239
Reward model id of a pretrained model hosted inside a model repo on huggingface.co or local path to a
3340
directory containing model weights saved using [`~transformers.PreTrainedModel.save_pretrained`].
41+
reward_funcs (`list[str]` or `None`, *optional*, defaults to `None`):
42+
Reward functions to use. It can be either one of `"think_format_reward"`; or a dotted import path "
43+
(e.g., `'my_lib.rewards.custom_reward'`).
3444
"""
3545

3646
reward_model_name_or_path: Optional[str] = field(
@@ -40,6 +50,13 @@ class GRPOScriptArguments(ScriptArguments):
4050
"local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`."
4151
},
4252
)
53+
reward_funcs: Optional[list[str]] = field(
54+
default=None,
55+
metadata={
56+
"help": "Reward functions to use. It can be either one of 'think_format_reward'; or a dotted "
57+
"import path. (e.g., 'my_lib.rewards.custom_reward')."
58+
},
59+
)
4360

4461

4562
def main(script_args, training_args, model_args):
@@ -50,9 +67,29 @@ def main(script_args, training_args, model_args):
5067
tokenizer = AutoTokenizer.from_pretrained(
5168
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
5269
)
53-
reward_model = AutoModelForSequenceClassification.from_pretrained(
54-
script_args.reward_model_name_or_path, trust_remote_code=model_args.trust_remote_code, num_labels=1
55-
)
70+
71+
# Get the reward models and functions
72+
reward_funcs = []
73+
if script_args.reward_model_name_or_path:
74+
reward_model = AutoModelForSequenceClassification.from_pretrained(
75+
script_args.reward_model_name_or_path, trust_remote_code=model_args.trust_remote_code, num_labels=1
76+
)
77+
reward_funcs.append(reward_model)
78+
79+
if script_args.reward_funcs:
80+
for func_name in script_args.reward_funcs:
81+
if func_name in reward_funcs_registry:
82+
reward_funcs.append(reward_funcs_registry[func_name])
83+
elif "." in func_name:
84+
module_path, func_name = func_name.rsplit(".", 1)
85+
module = importlib.import_module(module_path)
86+
reward_func = getattr(module, func_name)
87+
reward_funcs.append(reward_func)
88+
else:
89+
raise ValueError(
90+
f"Could not load reward function '{func_name}'. Expected one of "
91+
f"{list(reward_funcs_registry.keys())} or a valid import path."
92+
)
5693

5794
# Load the dataset
5895
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

0 commit comments

Comments
 (0)