diff --git a/trl/__init__.py b/trl/__init__.py index 44228d2092..5920fccac9 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -69,6 +69,7 @@ "GKDTrainer", "GRPOConfig", "GRPOTrainer", + "Qwen3VLGRPOTrainer", "HfPairwiseJudge", "KTOConfig", "KTOTrainer", @@ -150,6 +151,7 @@ GKDTrainer, GRPOConfig, GRPOTrainer, + Qwen3VLGRPOTrainer, HfPairwiseJudge, KTOConfig, KTOTrainer, diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 98846bf715..f6dd90c97e 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -36,7 +36,7 @@ "gkd_config": ["GKDConfig"], "gkd_trainer": ["GKDTrainer"], "grpo_config": ["GRPOConfig"], - "grpo_trainer": ["GRPOTrainer"], + "grpo_trainer": ["GRPOTrainer", "Qwen3VLGRPOTrainer"], "judges": [ "AllTrueJudge", "BaseBinaryJudge", @@ -96,7 +96,7 @@ from .gkd_config import GKDConfig from .gkd_trainer import GKDTrainer from .grpo_config import GRPOConfig - from .grpo_trainer import GRPOTrainer + from .grpo_trainer import GRPOTrainer, Qwen3VLGRPOTrainer from .judges import ( AllTrueJudge, BaseBinaryJudge, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ff7674214a..d2540612d1 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -2026,3 +2026,149 @@ def _save_checkpoint(self, model, trial): model_name = self.args.hub_model_id.split("/")[-1] self.create_model_card(model_name=model_name) super()._save_checkpoint(model, trial) + + +class Qwen3VLGRPOTrainer(GRPOTrainer): + r""" + A GRPOTrainer specialization for Qwen3-VL (multi-modal, video-based). + + Motivation + ---------- + - Qwen3-VL uses a Processor that expects full multi-modal conversations + (with `"type": "video"`, `"fps"`, etc.) in `apply_chat_template`. + - The generic GRPOTrainer `_generate_single_turn` assumes text-only or + image-style inputs and its batching/vLLM logic currently does not + work with Qwen3-VL video inputs. + - This subclass keeps the GRPO logic intact, but overrides + `_generate_single_turn` with a simple, per-sample generation loop + that is compatible with Qwen3-VL's Processor. + + Expected `prompts` format + ------------------------- + `prompts` is a list of conversations, each being a list of messages: + + [ + [ + { + "role": "system", + "content": [ + {"type": "text", "text": "..."}, + ... + ], + }, + { + "role": "user", + "content": [ + {"type": "video", "video": "", "fps": 4}, + {"type": "text", "text": "Question ..."}, + ], + }, + ], + ... + ] + + The method returns: + - prompt_ids: List[List[int]] + - completion_ids: List[List[int]] + - logprobs: None (not used in this path) + - extra_fields: dict (empty for now) + """ + + def _generate_single_turn(self, prompts: list): + if self.use_vllm or self.use_transformers_paged: + raise ValueError( + "Qwen3VLVideoGRPOTrainer currently supports only the standard " + "transformers.generate path. Please set `use_vllm=False` and " + "`use_transformers_paged=False` in GRPOConfig." + ) + + device = self.accelerator.device + + cleaned_prompts: list[list[dict]] = [] + for conv in prompts: + if not isinstance(conv, list): + cleaned_prompts.append(conv) + continue + + new_conv = [] + for msg in conv: + role = msg.get("role", "user") + content = msg.get("content", []) + new_content = [] + + for chunk in content: + ctype = chunk.get("type") + + if ctype == "text": + new_content.append( + { + "type": "text", + "text": chunk.get("text", ""), + } + ) + elif ctype == "image": + new_content.append( + { + "type": "image", + "image": chunk.get("image"), + } + ) + elif ctype == "video": + new_content.append( + { + "type": "video", + "video": chunk.get("video"), + "fps": chunk.get("fps", None), + } + ) + else: + new_content.append(chunk) + + new_conv.append({"role": role, "content": new_content}) + cleaned_prompts.append(new_conv) + + prompt_ids_list: list[list[int]] = [] + completion_ids_list: list[list[int]] = [] + logprobs = None + extra_fields: dict[str, list] = {} + + gen_config = self.generation_config + + model = self.accelerator.unwrap_model(self.model) + was_training = model.training + model.eval() + + with torch.no_grad(): + for conv in cleaned_prompts: + processor_inputs = self.processing_class.apply_chat_template( + conv, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + **self.chat_template_kwargs, + ) + + processor_inputs = {k: v.to(device) for k, v in processor_inputs.items()} + + output_ids = model.generate( + **processor_inputs, + generation_config=gen_config, + # disable_compile=True # nếu người dùng không compile thì không cần, để họ cấu hình ngoài + ) + + input_ids = processor_inputs["input_ids"] # [1, L_prompt] + assert ( + output_ids.shape[0] == 1 + ), "Qwen3VLVideoGRPOTrainer expects per-sample generation with batch size 1." + + full_ids = output_ids[0] + prompt_len = input_ids.shape[1] + + prompt_ids_list.append(input_ids[0].tolist()) + completion_ids_list.append(full_ids[prompt_len:].tolist()) + + if was_training: + model.train() + + return prompt_ids_list, completion_ids_list, logprobs, extra_fields