-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Multi image support for GRPO/RLOO #4113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: drop-image_split_sizes
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
) | ||
trainer = GRPOTrainer( | ||
model=model_id, | ||
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't support visual reward model, so it doesn't really make sense to test this case, where the image is dropped and a warning is raised.
# VLM reward models aren't supported yet, so we drop the image and raise a warning if needed | ||
for prompt in prompts: | ||
for turn in prompt: | ||
if isinstance(turn["content"], list): | ||
logger.warning_once("Visual reward models aren't supported yet; dropping image.") | ||
turn["content"] = " ".join( | ||
e["text"] for e in turn["content"] if e["type"] == "text" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
to
[{"role": "user", "content": "What color is the sky?"}]
plus raise warning
# We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for | ||
# later use in the reward computation. If images are present, we insert {"type": "image"} as required by the | ||
# VLM chat template. | ||
original_prompts = copy.deepcopy(prompts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of keeping the original prompt, we just drop the image later, and raise a warning, see https://github.com/huggingface/trl/pull/4113/files#r2364899902
# important because rewards will be normalized per group, and completions are distributed. We will later slice | ||
# rewards_per_func to extract each process's subset. | ||
rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list) | ||
rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self._logs["images"]: | ||
table["images"] = [] | ||
for image_list in self._logs["images"]: | ||
# Convert images to wandb Image objects for proper visualization | ||
table["images"].append([wandb.Image(image) for image in image_list]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
boundaries = [0, *accumulate(batch["num_images"])] # [3, 4, 5] -> [0, 3, 7, 12] | ||
sections = [sum(lengths[boundaries[i] : boundaries[i + 1]]) for i in range(len(batch["num_images"]))] | ||
split_values = list(torch.split(batch["pixel_values"], sections, dim=0)) | ||
image_grid_thw = list(torch.split(batch["image_grid_thw"], batch["num_images"], dim=0)) | ||
return {**batch, "pixel_values": split_values, "image_grid_thw": image_grid_thw} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of keeping image_grid_thw
as is, we need to split it depending on the number of images. It gets concatenated later in _get_per_token_logps_and_entropies
(see line 807)
trl/trainer/grpo_trainer.py
Outdated
model_inputs["image_grid_thw"] = torch.cat(image_grid_thw[start : start + batch_size]) | ||
start_pixel_idx = 0 if start == 0 else torch.cat(image_grid_thw[:start]).prod(-1).sum().item() | ||
end_pixel_idx = torch.cat(image_grid_thw[: start + batch_size]).prod(-1).sum().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See https://github.com/huggingface/trl/pull/4113/files#r2364904060, image_grid_thw
is not a tensor anymore, but a list of tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a question about whether raising an error vs a warning is best when images + text are being passed to the reward function
# Because of the way the tiny models are initialized, the gradient does not flow properly through the | ||
# vision parts of the model, so we skip them. Ideally, we should fix the init of these models. | ||
params_to_skip = ( | ||
# "model.vision_tower.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are commented out - restore?
|
||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | ||
|
||
for n, param in previous_trainable_params.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the same comment for GRPO apply here? https://github.com/huggingface/trl/pull/4113/files#diff-96dca172e696190fc3e1469166e88aface95ebae959284c6806f2e25d2217c16R1587
for prompt in prompts: | ||
for turn in prompt: | ||
if isinstance(turn["content"], list): | ||
logger.warning_once("Visual reward models aren't supported yet; dropping image.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would raising an error be better than a warning? Otherwise I could imagine the warning could be missed and the training "fails silently" because the reward is only computed on the text part.
table["images"] = [] | ||
for image_list in self._logs["images"]: | ||
# Convert images to wandb Image objects for proper visualization | ||
table["images"].append([wandb.Image(image) for image in image_list]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At some point it would be nice to also add the trackio
variant for table images
This PR is the second of a sequence of PR (after #4111) that aims to refactor the generation part of GRPO/RLOO to allow for easier customization.
While refactoring, I realized that having a clean multi-image support help having a cleaner separation between functions.
try with