|
| 1 | +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# /// script |
| 16 | +# dependencies = [ |
| 17 | +# "trl @ git+https://github.com/huggingface/trl.git", |
| 18 | +# "peft", |
| 19 | +# "math-verify", |
| 20 | +# "latex2sympy2_extended", |
| 21 | +# "trackio", |
| 22 | +# "torchvision", |
| 23 | +# "kernels", |
| 24 | +# ] |
| 25 | +# /// |
| 26 | + |
| 27 | +""" |
| 28 | +pip install math_verify |
| 29 | +
|
| 30 | +# For Qwen/Qwen2.5-VL-3B-Instruct |
| 31 | +accelerate launch \ |
| 32 | + --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ |
| 33 | + examples/scripts/online_dpo_vlm.py \ |
| 34 | + --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \ |
| 35 | + --reward_model_path Qwen/Qwen2.5-VL-3B-Instruct \ |
| 36 | + --output_dir online-dpo-Qwen2.5-VL-3B-Instruct \ |
| 37 | + --learning_rate 1e-5 \ |
| 38 | + --gradient_checkpointing \ |
| 39 | + --dtype bfloat16 \ |
| 40 | + --max_length 1536 \ |
| 41 | + --max_new_tokens 1024 \ |
| 42 | + --use_vllm \ |
| 43 | + --vllm_mode server \ |
| 44 | + --use_peft \ |
| 45 | + --lora_target_modules "q_proj", "v_proj" \ |
| 46 | + --per_device_train_batch_size 1 \ |
| 47 | + --gradient_accumulation_steps 2 |
| 48 | +
|
| 49 | +# For HuggingFaceTB/SmolVLM2-2.2B-Instruct |
| 50 | +pip install num2words |
| 51 | +
|
| 52 | +accelerate launch \ |
| 53 | + --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ |
| 54 | + examples/scripts/online_dpo_vlm.py \ |
| 55 | + --model_name_or_path HuggingFaceTB/SmolVLM2-2.2B-Instruct \ |
| 56 | + --reward_model_path HuggingFaceTB/SmolVLM2-2.2B-Instruct \ |
| 57 | + --output_dir online-dpo-SmolVLM2-2.2B-Instruct \ |
| 58 | + --learning_rate 1e-5 \ |
| 59 | + --dtype bfloat16 \ |
| 60 | + --max_length 1536 \ |
| 61 | + --max_new_tokens 1024 \ |
| 62 | + --use_peft \ |
| 63 | + --lora_target_modules "q_proj", "v_proj" \ |
| 64 | + --per_device_train_batch_size 1 \ |
| 65 | + --gradient_accumulation_steps 2 |
| 66 | +
|
| 67 | +# Single GPU test command: |
| 68 | +python examples/scripts/online_dpo_vlm.py \ |
| 69 | + --model_name_or_path HuggingFaceTB/SmolVLM2-2.2B-Instruct \ |
| 70 | + --reward_model_path HuggingFaceTB/SmolVLM2-2.2B-Instruct \ |
| 71 | + --output_dir online-dpo-SmolVLM2-2.2B-Instruct-test \ |
| 72 | + --learning_rate 1e-5 \ |
| 73 | + --dtype bfloat16 \ |
| 74 | + --max_length 1536 \ |
| 75 | + --max_new_tokens 128 \ |
| 76 | + --use_peft \ |
| 77 | + --lora_target_modules "q_proj", "v_proj" \ |
| 78 | + --per_device_train_batch_size 1 \ |
| 79 | + --gradient_accumulation_steps 1 \ |
| 80 | + --max_steps 2 \ |
| 81 | + --logging_steps 1 \ |
| 82 | + --trust_remote_code |
| 83 | +""" |
| 84 | + |
| 85 | +import os |
| 86 | + |
| 87 | +import torch |
| 88 | +import transformers |
| 89 | +from datasets import load_dataset |
| 90 | +from latex2sympy2_extended import NormalizationConfig |
| 91 | +from math_verify import LatexExtractionConfig, parse, verify |
| 92 | +from transformers import AutoConfig, AutoProcessor, GenerationConfig |
| 93 | + |
| 94 | +from trl import ( |
| 95 | + LogCompletionsCallback, |
| 96 | + ModelConfig, |
| 97 | + OnlineDPOConfig, |
| 98 | + OnlineDPOTrainer, |
| 99 | + ScriptArguments, |
| 100 | + TrlParser, |
| 101 | + get_kbit_device_map, |
| 102 | + get_peft_config, |
| 103 | + get_quantization_config, |
| 104 | +) |
| 105 | +from trl.rewards import think_format_reward |
| 106 | + |
| 107 | + |
| 108 | +# Enable logging in a Hugging Face Space |
| 109 | +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") |
| 110 | + |
| 111 | + |
| 112 | +if __name__ == "__main__": |
| 113 | + parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig)) |
| 114 | + script_args, training_args, model_args = parser.parse_args_and_config() |
| 115 | + training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} |
| 116 | + |
| 117 | + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) |
| 118 | + quantization_config = get_quantization_config(model_args) |
| 119 | + model_kwargs = dict( |
| 120 | + revision=model_args.model_revision, |
| 121 | + attn_implementation=model_args.attn_implementation, |
| 122 | + dtype=dtype, |
| 123 | + use_cache=False if training_args.gradient_checkpointing else True, |
| 124 | + device_map=get_kbit_device_map() if quantization_config is not None else None, |
| 125 | + quantization_config=quantization_config, |
| 126 | + ) |
| 127 | + |
| 128 | + # Load the VLM model using correct architecture (from GRPO pattern) |
| 129 | + config = AutoConfig.from_pretrained(model_args.model_name_or_path) |
| 130 | + architecture = getattr(transformers, config.architectures[0]) |
| 131 | + model = architecture.from_pretrained( |
| 132 | + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs |
| 133 | + ) |
| 134 | + |
| 135 | + # For VLM online DPO, using a reward model is complex because it needs images |
| 136 | + # Instead, we'll use a simple random judge for testing |
| 137 | + # In production, you'd want to use a proper text-only reward model or a custom judge |
| 138 | + reward_model = None |
| 139 | + reward_processor = None |
| 140 | + |
| 141 | + # Load processor for main model |
| 142 | + processor = AutoProcessor.from_pretrained( |
| 143 | + model_args.model_name_or_path, |
| 144 | + trust_remote_code=model_args.trust_remote_code, |
| 145 | + ) |
| 146 | + if hasattr(processor, "tokenizer"): |
| 147 | + processor.tokenizer.padding_side = "left" |
| 148 | + if processor.tokenizer.pad_token_id is None: |
| 149 | + processor.tokenizer.pad_token = processor.tokenizer.eos_token |
| 150 | + |
| 151 | + ################ |
| 152 | + # Dataset |
| 153 | + ################ |
| 154 | + dataset = load_dataset("lmms-lab/multimodal-open-r1-8k-verified", split="train") |
| 155 | + dataset = dataset.train_test_split(test_size=100, seed=42) |
| 156 | + |
| 157 | + SYSTEM_PROMPT = ( |
| 158 | + "A conversation between user and assistant. The user asks a question, and the assistant solves it. The " |
| 159 | + "assistant first thinks about the reasoning process in the mind and then provides the user with the answer. " |
| 160 | + "The reasoning process and answer are enclosed within <think></think> tags, i.e., <think>\nThis is my " |
| 161 | + "reasoning.\n</think>\nThis is my answer." |
| 162 | + ) |
| 163 | + |
| 164 | + def make_conversation(example): |
| 165 | + # Create conversational format that OnlineDPOTrainer expects |
| 166 | + prompt = [ |
| 167 | + {"role": "system", "content": SYSTEM_PROMPT}, |
| 168 | + {"role": "user", "content": example["problem"]}, |
| 169 | + ] |
| 170 | + return {"prompt": prompt, "image": example["image"]} |
| 171 | + |
| 172 | + dataset = dataset.map(make_conversation) |
| 173 | + |
| 174 | + # Filter big images (from GRPO pattern) |
| 175 | + def filter_big_images(example): |
| 176 | + image = example["image"] |
| 177 | + return image.size[0] < 512 and image.size[1] < 512 |
| 178 | + |
| 179 | + dataset = dataset.filter(filter_big_images) |
| 180 | + |
| 181 | + def convert_to_rgb(example): |
| 182 | + image = example["image"] |
| 183 | + if image.mode != "RGB": |
| 184 | + image = image.convert("RGB") |
| 185 | + example["image"] = image |
| 186 | + return example |
| 187 | + |
| 188 | + dataset = dataset.map(convert_to_rgb) |
| 189 | + |
| 190 | + train_dataset = dataset["train"] |
| 191 | + eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None |
| 192 | + |
| 193 | + ################ |
| 194 | + # Reward Function for Training (same as GRPO VLM) |
| 195 | + ################ |
| 196 | + def accuracy_reward(completions, solution: list[str], **kwargs): |
| 197 | + """Reward function that checks if the completion matches the ground truth. |
| 198 | + - If both gold and prediction are parseable → use math verification. |
| 199 | + - If not parseable → compare as normalized text. |
| 200 | + """ |
| 201 | + rewards = [] |
| 202 | + contents = [completion[0]["content"] for completion in completions] |
| 203 | + for content, sol in zip(contents, solution): |
| 204 | + try: |
| 205 | + gold_parsed = parse(sol, extraction_mode="first_match") |
| 206 | + except Exception: |
| 207 | + gold_parsed = [] |
| 208 | + |
| 209 | + if len(gold_parsed) != 0: |
| 210 | + # Try parsing predicted answer too |
| 211 | + try: |
| 212 | + answer_parsed = parse( |
| 213 | + content, |
| 214 | + extraction_config=[ |
| 215 | + LatexExtractionConfig( |
| 216 | + normalization_config=NormalizationConfig( |
| 217 | + nits=False, |
| 218 | + malformed_operators=False, |
| 219 | + basic_latex=True, |
| 220 | + boxed="all", |
| 221 | + units=True, |
| 222 | + ), |
| 223 | + boxed_match_priority=0, |
| 224 | + try_extract_without_anchor=False, |
| 225 | + ) |
| 226 | + ], |
| 227 | + extraction_mode="first_match", |
| 228 | + ) |
| 229 | + reward = float(verify(gold_parsed, answer_parsed)) |
| 230 | + except Exception as e: |
| 231 | + print(f"verify failed: {e}, answer: {content}, gold: {sol}") |
| 232 | + reward = None |
| 233 | + else: |
| 234 | + # fallback to text match |
| 235 | + reward = float(content.strip().lower() == sol.strip().lower()) |
| 236 | + |
| 237 | + rewards.append(reward) |
| 238 | + |
| 239 | + return rewards |
| 240 | + |
| 241 | + ################ |
| 242 | + # Training |
| 243 | + ################ |
| 244 | + trainer = OnlineDPOTrainer( |
| 245 | + model=model, |
| 246 | + reward_funcs=[think_format_reward, accuracy_reward], # Use same reward functions as GRPO VLM |
| 247 | + args=training_args, |
| 248 | + train_dataset=train_dataset, |
| 249 | + eval_dataset=eval_dataset, |
| 250 | + processing_class=processor, |
| 251 | + peft_config=get_peft_config(model_args), |
| 252 | + ) |
| 253 | + |
| 254 | + # Add completion logging callback (from online DPO pattern) |
| 255 | + if training_args.eval_strategy != "no": |
| 256 | + generation_config = GenerationConfig( |
| 257 | + max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature |
| 258 | + ) |
| 259 | + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) |
| 260 | + trainer.add_callback(completions_callback) |
| 261 | + |
| 262 | + trainer.train() |
| 263 | + |
| 264 | + # Save and push to hub |
| 265 | + trainer.save_model(training_args.output_dir) |
| 266 | + if training_args.push_to_hub: |
| 267 | + trainer.push_to_hub(dataset_name="lmms-lab/multimodal-open-r1-8k-verified") |
0 commit comments