Skip to content

Commit da209f8

Browse files
qgallouedecjue-jue-zisinging-catkashif
authored
🎁 RewardTrainer refactor (#4093)
Co-authored-by: juejuezi <[email protected]> Co-authored-by: Yi Shi <[email protected]> Co-authored-by: Kashif Rasul <[email protected]>
1 parent ebb8899 commit da209f8

19 files changed

+1966
-513
lines changed

README.md

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -136,23 +136,13 @@ trainer.train()
136136
Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer):
137137

138138
```python
139-
from trl import RewardConfig, RewardTrainer
139+
from trl import RewardTrainer
140140
from datasets import load_dataset
141-
from transformers import AutoModelForSequenceClassification, AutoTokenizer
142-
143-
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
144-
model = AutoModelForSequenceClassification.from_pretrained(
145-
"Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
146-
)
147-
model.config.pad_token_id = tokenizer.pad_token_id
148141

149142
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
150143

151-
training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
152144
trainer = RewardTrainer(
153-
args=training_args,
154-
model=model,
155-
processing_class=tokenizer,
145+
model="Qwen/Qwen2.5-0.5B-Instruct",
156146
train_dataset=dataset,
157147
)
158148
trainer.train()

docs/source/clis.md

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Currently supported commands are:
99
- `trl dpo`: fine-tune a LLM with DPO
1010
- `trl grpo`: fine-tune a LLM with GRPO
1111
- `trl kto`: fine-tune a LLM with KTO
12+
- `trl reward`: train a Reward Model
1213
- `trl rloo`: fine-tune a LLM with RLOO
1314
- `trl sft`: fine-tune a LLM with SFT
1415

@@ -41,6 +42,15 @@ trl dpo \
4142
--dataset_name anthropic/hh-rlhf
4243
```
4344

45+
</hfoption>
46+
<hfoption id="Reward">
47+
48+
```bash
49+
trl reward \
50+
--model_name_or_path Qwen/Qwen2.5-0.5B \
51+
--dataset_name trl-lib/ultrafeedback_binarized
52+
```
53+
4454
</hfoption>
4555
</hfoptions>
4656

@@ -78,6 +88,21 @@ Launch with:
7888
trl dpo --config dpo_config.yaml
7989
```
8090

91+
</hfoption>
92+
<hfoption id="Reward">
93+
94+
```yaml
95+
# reward_config.yaml
96+
model_name_or_path: Qwen/Qwen2.5-0.5B
97+
dataset_name: trl-lib/ultrafeedback_binarized
98+
```
99+
100+
Launch with:
101+
102+
```bash
103+
trl reward --config reward_config.yaml
104+
```
105+
81106
</hfoption>
82107
</hfoptions>
83108

@@ -138,6 +163,33 @@ Launch with:
138163
```bash
139164
trl dpo --config dpo_config.yaml
140165
```
166+
167+
</hfoption>
168+
<hfoption id="Reward inline">
169+
170+
```bash
171+
trl reward \
172+
--model_name_or_path Qwen/Qwen2.5-0.5B \
173+
--dataset_name trl-lib/ultrafeedback_binarized \
174+
--num_processes 4
175+
```
176+
177+
</hfoption>
178+
<hfoption id="Reward w/ config file">
179+
180+
```yaml
181+
# reward_config.yaml
182+
model_name_or_path: Qwen/Qwen2.5-0.5B
183+
dataset_name: trl-lib/ultrafeedback_binarized
184+
num_processes: 4
185+
```
186+
187+
Launch with:
188+
189+
```bash
190+
trl reward --config reward_config.yaml
191+
```
192+
141193
</hfoption>
142194
</hfoptions>
143195

@@ -217,14 +269,41 @@ Launch with:
217269
```bash
218270
trl dpo --config dpo_config.yaml
219271
```
272+
273+
</hfoption>
274+
<hfoption id="Reward inline">
275+
276+
```bash
277+
trl reward \
278+
--model_name_or_path Qwen/Qwen2.5-0.5B \
279+
--dataset_name trl-lib/ultrafeedback_binarized \
280+
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
281+
```
282+
283+
</hfoption>
284+
<hfoption id="Reward w/ config file">
285+
286+
```yaml
287+
# reward_config.yaml
288+
model_name_or_path: Qwen/Qwen2.5-0.5B
289+
dataset_name: trl-lib/ultrafeedback_binarized
290+
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
291+
```
292+
293+
Launch with:
294+
295+
```bash
296+
trl reward --config reward_config.yaml
297+
```
298+
220299
</hfoption>
221300
</hfoptions>
222301

223302
### Using dataset mixtures
224303

225304
You can use dataset mixtures to combine multiple datasets into a single training dataset. This is useful for training on diverse data sources or when you want to mix different types of data.
226305

227-
<hfoptions id="accelerate_config">
306+
<hfoptions id="dataset_mixtures">
228307
<hfoption id="SFT">
229308

230309
```yaml
@@ -258,6 +337,23 @@ Launch with:
258337
trl dpo --config dpo_config.yaml
259338
```
260339

340+
</hfoption>
341+
<hfoption id="Reward">
342+
343+
```yaml
344+
# reward_config.yaml
345+
model_name_or_path: Qwen/Qwen2.5-0.5B
346+
datasets:
347+
- path: trl-lib/tldr-preference
348+
- path: trl-lib/lm-human-preferences-sentiment
349+
```
350+
351+
Launch with:
352+
353+
```bash
354+
trl reward --config reward_config.yaml
355+
```
356+
261357
</hfoption>
262358
</hfoptions>
263359

docs/source/paper_index.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,3 +533,53 @@ training_args = CPOConfig(
533533
...
534534
)
535535
```
536+
537+
## Reward Modeling
538+
539+
Papers relating to the [`RewardTrainer`]
540+
541+
### Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking
542+
543+
**📜 Paper**: https://huggingface.co/papers/2312.09244
544+
545+
This paper proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs and thereby resolving the issue of underdetermination.
546+
547+
$$
548+
\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \textcolor{red}{- \eta \cdot (r_\theta(x, y^+) + r_\theta(x, y^-))^2} \right].
549+
$$
550+
551+
To use this auxiliary loss with [`RewardTrainer`], you can use the `center_rewards_coefficient` argument in [`RewardConfig`] as follows:
552+
553+
```python
554+
from trl import RewardConfig
555+
556+
training_args = RewardConfig(
557+
center_rewards_coefficient=0.01, # η in the paper
558+
...
559+
)
560+
```
561+
562+
### Llama 2: Open Foundation and Fine-Tuned Chat Models
563+
564+
**📜 Paper**: https://huggingface.co/papers/2307.09288
565+
566+
In this paper, the authors propose to leverage their preference ratings being decomposed as a scale of four points (e.g., _significantly better_) to provide more informative feedback to the reward model. This is done by adding a margin to the loss function, which encourages the reward model to assign larger gaps in scores for pairs with higher preference ratings.
567+
568+
$$
569+
\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-,\textcolor{red}{m}) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-) \textcolor{red}{- m}) \right].
570+
$$
571+
572+
You can add a margin to the loss by adding a `margin` column to the dataset. The following example shows how to set up a the "Margin Small" setting of the paper.
573+
574+
```python
575+
def add_margin(example):
576+
preference_to_margin = {
577+
"significantly better": 1.0,
578+
"better": 2.0/3.0,
579+
"slightly better": 1.0/3.0,
580+
"negligibly better / unsure": 0.0,
581+
}
582+
return {"margin": preference_to_margin[example["preference_label"]]}
583+
584+
dataset = dataset.map(add_margin)
585+
```

docs/source/quickstart.md

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Quickstart
22

3-
TRL is a comprehensive library for post-training foundation models using techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO).
3+
TRL is a comprehensive library for post-training foundation models using techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO).
44

55
## Quick Examples
66

@@ -51,6 +51,21 @@ trainer = DPOTrainer(
5151
trainer.train()
5252
```
5353

54+
### Reward Modeling
55+
56+
```python
57+
from trl import RewardTrainer
58+
from datasets import load_dataset
59+
60+
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
61+
62+
trainer = RewardTrainer(
63+
model="Qwen/Qwen2.5-0.5B-Instruct",
64+
train_dataset=dataset,
65+
)
66+
trainer.train()
67+
```
68+
5469
## Command Line Interface
5570

5671
Skip the code entirely - train directly from your terminal:
@@ -63,6 +78,10 @@ trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
6378
# DPO: Align with preferences
6479
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
6580
--dataset_name trl-lib/ultrafeedback_binarized
81+
82+
# Reward: Train a reward model
83+
trl reward --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
84+
--dataset_name trl-lib/ultrafeedback_binarized
6685
```
6786

6887
## What's Next?

0 commit comments

Comments
 (0)