-
Notifications
You must be signed in to change notification settings - Fork 2.2k
🧺 [1/N] Refactor _generate
in GRPO/RLOO: list of ints instead of tensors
#4146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…_thw` in GRPO and RLOO trainers; update `split_pixel_values_by_grid` to use `image_grid_thw`
trl/trainer/grpo_trainer.py
Outdated
**kwargs, | ||
) | ||
prompt_inputs = super()._prepare_inputs(prompt_inputs) | ||
prompt_inputs = self.processing_class(text=prompts_text, add_special_tokens=False, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the function must now return a list of ints, so we must remove padding
prompt_mask, | ||
completion_mask, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prompt and completion masks are later inferred from the sequence lengths
_generate
in GRPO/RLOO_generate
in GRPO/RLOO
_generate
in GRPO/RLOO_generate
in GRPO/RLOO: list of ints instead of tensors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
trl/trainer/utils.py
Outdated
sequences (`list[int]`): | ||
Input sequence of token IDs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sequences
name in the docstring is not aligned with the ids
name in the signature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additionally, before it accepted batch_size sequences (within the tensor) and now it accepts a single sequence (list[int]). Isn't this breaking something? Some tests should be failing because of the new behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some tests should be failing because of the new behavior?
yes, tests have been updated as well, see TruncateWithProtectedTokensTester
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
sequences
name in the docstring is not aligned with theids
name in the signature.
thanks! fixed in c570fb0
Co-authored-by: Albert Villanova del Moral <[email protected]>
Co-authored-by: Albert Villanova del Moral <[email protected]>
This PR belongs to a sequence of PR that aims to refactor the generation part of GRPO/RLOO to allow for easier customization and ultimately tool calling
Previous:
image_split_sizes
in favour ofimage_grid_thw
#4111_generate
#4114Next:
_generate
in GRPO/RLOO: Useprompt_ids
from generation #4152_generate
in GRPO/RLOO: Rely on generator for prompt truncation #4153_generate
in GRPO/RLOO: Moveforward_kwargs
outside generation method #4154_generate
in GRPO/RLOO: Insert images in the prompt #4155The idea with this PR is to make
_generate
return list of ints instead of tensors. This will help a lots when implementing tool calling.Several modifications:
truncate_with_protected_tokens
: instead of operating on 2D tensors (ids and mask), it will operate on sequence ids directly:before
after
_generate
now returns list of ids, instead of tensor + maskconversion to tensor is handle in
_generate_and_score_completions
The generation part is moved to a function
_generate_single_turn
, which is called by_generate