Skip to content

Conversation

qgallouedec
Copy link
Member

This PR is the second of a sequence of PR (after #4111) that aims to refactor the generation part of GRPO/RLOO to allow for easier customization.

While refactoring, I realized that having a clean multi-image support help having a cleaner separation between functions.

try with

from datasets import load_dataset

from trl import GRPOConfig, GRPOTrainer


# If not handled properly, prompt truncation may truncate image token
dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train")

dataset = dataset.filter(lambda x: len(x["images"]) > 0) # currently, mixing samples with and without images is not supported

def my_reward_function(prompts, completions, **kwargs):
    return [1.0] * len(prompts)

training_args = GRPOConfig(
    output_dir="tmp_dir",   
    learning_rate=0.1,  # increase the learning rate to speed up the test
    per_device_train_batch_size=6,  # reduce the batch size to reduce memory usage
    num_generations=3,  # reduce the number of generations to reduce memory usage
    max_completion_length=8,  # reduce the completion length to reduce memory usage
    max_prompt_length=32,
    report_to="none",
)
trainer = GRPOTrainer(
    model="Qwen/Qwen2-VL-2B-Instruct",
    reward_funcs=my_reward_function,  # define a dummy reward function
    args=training_args,
    train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec changed the base branch from main to drop-image_split_sizes September 19, 2025 22:53
)
trainer = GRPOTrainer(
model=model_id,
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
Copy link
Member Author

@qgallouedec qgallouedec Sep 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't support visual reward model, so it doesn't really make sense to test this case, where the image is dropped and a warning is raised.

Comment on lines +1022 to +1029
# VLM reward models aren't supported yet, so we drop the image and raise a warning if needed
for prompt in prompts:
for turn in prompt:
if isinstance(turn["content"], list):
logger.warning_once("Visual reward models aren't supported yet; dropping image.")
turn["content"] = " ".join(
e["text"] for e in turn["content"] if e["type"] == "text"
)
Copy link
Member Author

@qgallouedec qgallouedec Sep 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from

[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]

to

[{"role": "user", "content": "What color is the sky?"}]

plus raise warning

Comment on lines -1068 to -1071
# We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for
# later use in the reward computation. If images are present, we insert {"type": "image"} as required by the
# VLM chat template.
original_prompts = copy.deepcopy(prompts)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of keeping the original prompt, we just drop the image later, and raise a warning, see https://github.com/huggingface/trl/pull/4113/files#r2364899902

# important because rewards will be normalized per group, and completions are distributed. We will later slice
# rewards_per_func to extract each process's subset.
rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list)
rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +1800 to +1804
if self._logs["images"]:
table["images"] = []
for image_list in self._logs["images"]:
# Convert images to wandb Image objects for proper visualization
table["images"].append([wandb.Image(image) for image in image_list])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tested
Uploading Screenshot 2025-09-19 at 5.30.14 PM.png…

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screenshot 2025-09-19 at 5 30 14 PM

Comment on lines +1796 to +1800
boundaries = [0, *accumulate(batch["num_images"])] # [3, 4, 5] -> [0, 3, 7, 12]
sections = [sum(lengths[boundaries[i] : boundaries[i + 1]]) for i in range(len(batch["num_images"]))]
split_values = list(torch.split(batch["pixel_values"], sections, dim=0))
image_grid_thw = list(torch.split(batch["image_grid_thw"], batch["num_images"], dim=0))
return {**batch, "pixel_values": split_values, "image_grid_thw": image_grid_thw}
Copy link
Member Author

@qgallouedec qgallouedec Sep 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of keeping image_grid_thw as is, we need to split it depending on the number of images. It gets concatenated later in _get_per_token_logps_and_entropies (see line 807)

Comment on lines 806 to 808
model_inputs["image_grid_thw"] = torch.cat(image_grid_thw[start : start + batch_size])
start_pixel_idx = 0 if start == 0 else torch.cat(image_grid_thw[:start]).prod(-1).sum().item()
end_pixel_idx = torch.cat(image_grid_thw[: start + batch_size]).prod(-1).sum().item()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See https://github.com/huggingface/trl/pull/4113/files#r2364904060, image_grid_thw is not a tensor anymore, but a list of tensor

@qgallouedec qgallouedec changed the title Multi image support for GRPO/RLOO [WIP] Multi image support for GRPO/RLOO Sep 20, 2025
@qgallouedec qgallouedec changed the title [WIP] Multi image support for GRPO/RLOO Multi image support for GRPO/RLOO Sep 20, 2025
Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a question about whether raising an error vs a warning is best when images + text are being passed to the reward function

# Because of the way the tiny models are initialized, the gradient does not flow properly through the
# vision parts of the model, so we skip them. Ideally, we should fix the init of these models.
params_to_skip = (
# "model.vision_tower.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are commented out - restore?


self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

for n, param in previous_trainable_params.items():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for prompt in prompts:
for turn in prompt:
if isinstance(turn["content"], list):
logger.warning_once("Visual reward models aren't supported yet; dropping image.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would raising an error be better than a warning? Otherwise I could imagine the warning could be missed and the training "fails silently" because the reward is only computed on the text part.

table["images"] = []
for image_list in self._logs["images"]:
# Convert images to wandb Image objects for proper visualization
table["images"].append([wandb.Image(image) for image in image_list])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point it would be nice to also add the trackio variant for table images

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants