Skip to content

Commit d1bf560

Browse files
vaelevkashifsergiopaniegoqgallouedec
authored
⚖️ Add vLLM server mode and VLM support to OnlineDPOTrainer (#3783)
Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: Sergio Paniego Blanco <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 19f9b9e commit d1bf560

18 files changed

+1739
-273
lines changed

docs/source/example_overview.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ Scripts can be used as examples of how to use TRL trainers. They are located in
4040
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
4141
| [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. |
4242
| [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a model. |
43-
| [`examples/scripts/dpo_online.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_online.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. |
4443
| [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
4544
| [`examples/scripts/evals/judge_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/evals/judge_tldr.py) | This script shows how to use [`HfPairwiseJudge`] or [`OpenAIPairwiseJudge`] to judge model generations. |
4645
| [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`GKDTrainer`] to fine-tune a model. |
@@ -51,6 +50,8 @@ Scripts can be used as examples of how to use TRL trainers. They are located in
5150
| [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. |
5251
| [`examples/scripts/mpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/mpo_vlm.py) | This script shows how to use MPO via the [`DPOTrainer`] to align a model based on preferences using the [HuggingFaceH4/rlaif-v_formatted](https://huggingface.co/datasets/HuggingFaceH4/rlaif-v_formatted) dataset and a set of loss weights with weights. |
5352
| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`NashMDTrainer`] to fine-tune a model. |
53+
| [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. |
54+
| [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a a Vision Language Model. |
5455
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
5556
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language. |
5657
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |

docs/source/nash_md_trainer.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht
8181
trainer = NashMDTrainer(
8282
...
8383
- judge=judge,
84-
+ reward_model=reward_model,
84+
+ reward_funcs=reward_model,
8585
)
8686
```
8787

docs/source/online_dpo_trainer.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht
8484
trainer = OnlineDPOTrainer(
8585
...
8686
- judge=judge,
87-
+ reward_model=reward_model,
87+
+ reward_funcs=reward_model,
8888
+ reward_processing_class=reward_tokenizer,
8989
...
9090
)

docs/source/xpo_trainer.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht
8080
trainer = XPOTrainer(
8181
...
8282
- judge=judge,
83-
+ reward_model=reward_model,
83+
+ reward_funcs=reward_model,
8484
)
8585
```
8686

examples/scripts/nash_md.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@
134134
trainer = NashMDTrainer(
135135
model=model,
136136
ref_model=ref_model,
137-
reward_model=reward_model,
137+
reward_funcs=reward_model,
138138
judge=judge,
139139
args=training_args,
140140
train_dataset=dataset[script_args.dataset_train_split],

examples/scripts/dpo_online.py renamed to examples/scripts/online_dpo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"""
2525
Usage:
2626
27-
python examples/scripts/dpo_online.py \
27+
python examples/scripts/online_dpo.py \
2828
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
2929
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
3030
--dataset_name trl-lib/tldr \
@@ -36,7 +36,7 @@
3636
--missing_eos_penalty 1.0
3737
3838
With LoRA:
39-
python examples/scripts/dpo_online.py \
39+
python examples/scripts/online_dpo.py \
4040
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
4141
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
4242
--dataset_name trl-lib/tldr \
@@ -138,13 +138,13 @@
138138

139139
trainer = OnlineDPOTrainer(
140140
model=model,
141-
reward_model=reward_model,
141+
reward_funcs=reward_model,
142142
judge=judge,
143143
args=training_args,
144144
train_dataset=dataset[script_args.dataset_train_split],
145145
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
146146
processing_class=tokenizer,
147-
reward_processing_class=reward_tokenizer,
147+
reward_processing_classes=reward_tokenizer,
148148
peft_config=get_peft_config(model_args),
149149
)
150150

examples/scripts/online_dpo_vlm.py

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# /// script
16+
# dependencies = [
17+
# "trl @ git+https://github.com/huggingface/trl.git",
18+
# "peft",
19+
# "math-verify",
20+
# "latex2sympy2_extended",
21+
# "trackio",
22+
# "torchvision",
23+
# "kernels",
24+
# ]
25+
# ///
26+
27+
"""
28+
pip install math_verify
29+
30+
# For Qwen/Qwen2.5-VL-3B-Instruct
31+
accelerate launch \
32+
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
33+
examples/scripts/online_dpo_vlm.py \
34+
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
35+
--reward_model_path Qwen/Qwen2.5-VL-3B-Instruct \
36+
--output_dir online-dpo-Qwen2.5-VL-3B-Instruct \
37+
--learning_rate 1e-5 \
38+
--gradient_checkpointing \
39+
--dtype bfloat16 \
40+
--max_length 1536 \
41+
--max_new_tokens 1024 \
42+
--use_vllm \
43+
--vllm_mode server \
44+
--use_peft \
45+
--lora_target_modules "q_proj", "v_proj" \
46+
--per_device_train_batch_size 1 \
47+
--gradient_accumulation_steps 2
48+
49+
# For HuggingFaceTB/SmolVLM2-2.2B-Instruct
50+
pip install num2words
51+
52+
accelerate launch \
53+
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
54+
examples/scripts/online_dpo_vlm.py \
55+
--model_name_or_path HuggingFaceTB/SmolVLM2-2.2B-Instruct \
56+
--reward_model_path HuggingFaceTB/SmolVLM2-2.2B-Instruct \
57+
--output_dir online-dpo-SmolVLM2-2.2B-Instruct \
58+
--learning_rate 1e-5 \
59+
--dtype bfloat16 \
60+
--max_length 1536 \
61+
--max_new_tokens 1024 \
62+
--use_peft \
63+
--lora_target_modules "q_proj", "v_proj" \
64+
--per_device_train_batch_size 1 \
65+
--gradient_accumulation_steps 2
66+
67+
# Single GPU test command:
68+
python examples/scripts/online_dpo_vlm.py \
69+
--model_name_or_path HuggingFaceTB/SmolVLM2-2.2B-Instruct \
70+
--reward_model_path HuggingFaceTB/SmolVLM2-2.2B-Instruct \
71+
--output_dir online-dpo-SmolVLM2-2.2B-Instruct-test \
72+
--learning_rate 1e-5 \
73+
--dtype bfloat16 \
74+
--max_length 1536 \
75+
--max_new_tokens 128 \
76+
--use_peft \
77+
--lora_target_modules "q_proj", "v_proj" \
78+
--per_device_train_batch_size 1 \
79+
--gradient_accumulation_steps 1 \
80+
--max_steps 2 \
81+
--logging_steps 1 \
82+
--trust_remote_code
83+
"""
84+
85+
import os
86+
87+
import torch
88+
import transformers
89+
from datasets import load_dataset
90+
from latex2sympy2_extended import NormalizationConfig
91+
from math_verify import LatexExtractionConfig, parse, verify
92+
from transformers import AutoConfig, AutoProcessor, GenerationConfig
93+
94+
from trl import (
95+
LogCompletionsCallback,
96+
ModelConfig,
97+
OnlineDPOConfig,
98+
OnlineDPOTrainer,
99+
ScriptArguments,
100+
TrlParser,
101+
get_kbit_device_map,
102+
get_peft_config,
103+
get_quantization_config,
104+
)
105+
from trl.rewards import think_format_reward
106+
107+
108+
# Enable logging in a Hugging Face Space
109+
os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
110+
111+
112+
if __name__ == "__main__":
113+
parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig))
114+
script_args, training_args, model_args = parser.parse_args_and_config()
115+
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
116+
117+
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
118+
quantization_config = get_quantization_config(model_args)
119+
model_kwargs = dict(
120+
revision=model_args.model_revision,
121+
attn_implementation=model_args.attn_implementation,
122+
dtype=dtype,
123+
use_cache=False if training_args.gradient_checkpointing else True,
124+
device_map=get_kbit_device_map() if quantization_config is not None else None,
125+
quantization_config=quantization_config,
126+
)
127+
128+
# Load the VLM model using correct architecture (from GRPO pattern)
129+
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
130+
architecture = getattr(transformers, config.architectures[0])
131+
model = architecture.from_pretrained(
132+
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
133+
)
134+
135+
# For VLM online DPO, using a reward model is complex because it needs images
136+
# Instead, we'll use a simple random judge for testing
137+
# In production, you'd want to use a proper text-only reward model or a custom judge
138+
reward_model = None
139+
reward_processor = None
140+
141+
# Load processor for main model
142+
processor = AutoProcessor.from_pretrained(
143+
model_args.model_name_or_path,
144+
trust_remote_code=model_args.trust_remote_code,
145+
)
146+
if hasattr(processor, "tokenizer"):
147+
processor.tokenizer.padding_side = "left"
148+
if processor.tokenizer.pad_token_id is None:
149+
processor.tokenizer.pad_token = processor.tokenizer.eos_token
150+
151+
################
152+
# Dataset
153+
################
154+
dataset = load_dataset("lmms-lab/multimodal-open-r1-8k-verified", split="train")
155+
dataset = dataset.train_test_split(test_size=100, seed=42)
156+
157+
SYSTEM_PROMPT = (
158+
"A conversation between user and assistant. The user asks a question, and the assistant solves it. The "
159+
"assistant first thinks about the reasoning process in the mind and then provides the user with the answer. "
160+
"The reasoning process and answer are enclosed within <think></think> tags, i.e., <think>\nThis is my "
161+
"reasoning.\n</think>\nThis is my answer."
162+
)
163+
164+
def make_conversation(example):
165+
# Create conversational format that OnlineDPOTrainer expects
166+
prompt = [
167+
{"role": "system", "content": SYSTEM_PROMPT},
168+
{"role": "user", "content": example["problem"]},
169+
]
170+
return {"prompt": prompt, "image": example["image"]}
171+
172+
dataset = dataset.map(make_conversation)
173+
174+
# Filter big images (from GRPO pattern)
175+
def filter_big_images(example):
176+
image = example["image"]
177+
return image.size[0] < 512 and image.size[1] < 512
178+
179+
dataset = dataset.filter(filter_big_images)
180+
181+
def convert_to_rgb(example):
182+
image = example["image"]
183+
if image.mode != "RGB":
184+
image = image.convert("RGB")
185+
example["image"] = image
186+
return example
187+
188+
dataset = dataset.map(convert_to_rgb)
189+
190+
train_dataset = dataset["train"]
191+
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
192+
193+
################
194+
# Reward Function for Training (same as GRPO VLM)
195+
################
196+
def accuracy_reward(completions, solution: list[str], **kwargs):
197+
"""Reward function that checks if the completion matches the ground truth.
198+
- If both gold and prediction are parseable → use math verification.
199+
- If not parseable → compare as normalized text.
200+
"""
201+
rewards = []
202+
contents = [completion[0]["content"] for completion in completions]
203+
for content, sol in zip(contents, solution):
204+
try:
205+
gold_parsed = parse(sol, extraction_mode="first_match")
206+
except Exception:
207+
gold_parsed = []
208+
209+
if len(gold_parsed) != 0:
210+
# Try parsing predicted answer too
211+
try:
212+
answer_parsed = parse(
213+
content,
214+
extraction_config=[
215+
LatexExtractionConfig(
216+
normalization_config=NormalizationConfig(
217+
nits=False,
218+
malformed_operators=False,
219+
basic_latex=True,
220+
boxed="all",
221+
units=True,
222+
),
223+
boxed_match_priority=0,
224+
try_extract_without_anchor=False,
225+
)
226+
],
227+
extraction_mode="first_match",
228+
)
229+
reward = float(verify(gold_parsed, answer_parsed))
230+
except Exception as e:
231+
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
232+
reward = None
233+
else:
234+
# fallback to text match
235+
reward = float(content.strip().lower() == sol.strip().lower())
236+
237+
rewards.append(reward)
238+
239+
return rewards
240+
241+
################
242+
# Training
243+
################
244+
trainer = OnlineDPOTrainer(
245+
model=model,
246+
reward_funcs=[think_format_reward, accuracy_reward], # Use same reward functions as GRPO VLM
247+
args=training_args,
248+
train_dataset=train_dataset,
249+
eval_dataset=eval_dataset,
250+
processing_class=processor,
251+
peft_config=get_peft_config(model_args),
252+
)
253+
254+
# Add completion logging callback (from online DPO pattern)
255+
if training_args.eval_strategy != "no":
256+
generation_config = GenerationConfig(
257+
max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature
258+
)
259+
completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
260+
trainer.add_callback(completions_callback)
261+
262+
trainer.train()
263+
264+
# Save and push to hub
265+
trainer.save_model(training_args.output_dir)
266+
if training_args.push_to_hub:
267+
trainer.push_to_hub(dataset_name="lmms-lab/multimodal-open-r1-8k-verified")

examples/scripts/xpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@
119119
trainer = XPOTrainer(
120120
model=model,
121121
ref_model=ref_model,
122-
reward_model=reward_model,
122+
reward_funcs=reward_model,
123123
judge=judge,
124124
args=training_args,
125125
train_dataset=dataset[script_args.dataset_train_split],

0 commit comments

Comments
 (0)