Skip to content

Conversation

qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Sep 26, 2025

This PR belongs to a sequence of PR that aims to refactor the generation part of GRPO/RLOO to allow for easier customization and ultimately tool calling

Previous:

Next:

The idea with this PR is to make _generate return list of ints instead of tensors. This will help a lots when implementing tool calling.

Several modifications:

  1. truncate_with_protected_tokens: instead of operating on 2D tensors (ids and mask), it will operate on sequence ids directly:

    before

    >>> ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
    >>> mask = torch.tensor([[1, 1, 1], [0, 0, 1]])
    >>> truncate_with_protected_tokens(ids, mask, target_length=2, protected_tokens=[])
    tensor([[2, 3], [5, 6]]), tensor([[1, 1], [0, 1]])

    after

    >>> ids = [[1, 2, 3], [4, 5, 6]]
    >>> [truncate_with_protected_tokens(seq, target_length=2, protected_tokens=[]) for seq in ids]
    [[2, 3], [5, 6]]
  2. _generate now returns list of ids, instead of tensor + mask
    conversion to tensor is handle in _generate_and_score_completions

  3. The generation part is moved to a function _generate_single_turn, which is called by _generate

**kwargs,
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_inputs = self.processing_class(text=prompts_text, add_special_tokens=False, **kwargs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the function must now return a list of ints, so we must remove padding

Comment on lines -1385 to -1389
prompt_mask,
completion_mask,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prompt and completion masks are later inferred from the sequence lengths

@qgallouedec qgallouedec changed the title 🧺 Refactor _generate in GRPO/RLOO 🧺 [1/N] Refactor _generate in GRPO/RLOO Sep 26, 2025
@qgallouedec qgallouedec changed the title 🧺 [1/N] Refactor _generate in GRPO/RLOO 🧺 [1/N] Refactor _generate in GRPO/RLOO: list of ints instead of tensors Sep 26, 2025
Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

Comment on lines 1869 to 1870
sequences (`list[int]`):
Input sequence of token IDs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sequences name in the docstring is not aligned with the ids name in the signature.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, before it accepted batch_size sequences (within the tensor) and now it accepts a single sequence (list[int]). Isn't this breaking something? Some tests should be failing because of the new behavior?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some tests should be failing because of the new behavior?

yes, tests have been updated as well, see TruncateWithProtectedTokensTester

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sequences name in the docstring is not aligned with the ids name in the signature.

thanks! fixed in c570fb0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants