Skip to content
Merged
Show file tree
Hide file tree
Changes from 61 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
552e899
Refactor image handling: replace `image_split_sizes` with `image_grid…
qgallouedec Sep 19, 2025
449ef07
simpler
qgallouedec Sep 19, 2025
c8933aa
gfpo
qgallouedec Sep 19, 2025
229c554
multi-image grpo
qgallouedec Sep 19, 2025
3ca6ad5
log with wandb
qgallouedec Sep 19, 2025
dcf4b92
no vlm reward models
qgallouedec Sep 20, 2025
30ad7ca
rloo
qgallouedec Sep 20, 2025
86cc30b
gfpo
qgallouedec Sep 20, 2025
088897b
fix
qgallouedec Sep 20, 2025
d2adc63
test peft
qgallouedec Sep 20, 2025
f4c82bf
fix gfpo
qgallouedec Sep 20, 2025
1257796
rloo test
qgallouedec Sep 20, 2025
099a39b
peft rloo
qgallouedec Sep 20, 2025
529add6
oops
qgallouedec Sep 20, 2025
fc6b11f
update test
qgallouedec Sep 20, 2025
ae1f497
generate method
qgallouedec Sep 20, 2025
f998432
debug
qgallouedec Sep 20, 2025
fa73876
skip failing test
qgallouedec Sep 20, 2025
52d8bd9
Merge branch 'main' into drop-image_split_sizes
qgallouedec Sep 20, 2025
dfc0d38
Merge branch 'drop-image_split_sizes' into multi-image-support
qgallouedec Sep 20, 2025
fc52e68
test fixed!
qgallouedec Sep 20, 2025
4d12aeb
Merge branch 'multi-image-support' into generate-method
qgallouedec Sep 20, 2025
4fc2b5b
gfpo
qgallouedec Sep 20, 2025
b628744
rm vllm
qgallouedec Sep 20, 2025
d3a769f
fix doc
qgallouedec Sep 20, 2025
e17ec42
Merge branch 'main' into drop-image_split_sizes
qgallouedec Sep 22, 2025
efbb03a
Merge branch 'drop-image_split_sizes' into multi-image-support
qgallouedec Sep 22, 2025
562c662
Merge branch 'main' into multi-image-support
qgallouedec Sep 22, 2025
485781c
Merge branch 'main' into multi-image-support
qgallouedec Sep 22, 2025
05270f8
update layers to ignore
qgallouedec Sep 22, 2025
1c53094
clarify image column desc
qgallouedec Sep 22, 2025
9b6652e
rm VLM x RM warning
qgallouedec Sep 23, 2025
c500440
Merge branch 'multi-image-support' into generate-method
qgallouedec Sep 23, 2025
a6a8c44
Merge branch 'main' into generate-method
qgallouedec Sep 23, 2025
d8665e1
Merge branch 'main' into generate-method
qgallouedec Sep 23, 2025
365d501
Merge branch 'main' into generate-method
qgallouedec Sep 23, 2025
cdb4c76
Merge branch 'main' into generate-method
qgallouedec Sep 24, 2025
c83e710
same for rloo
qgallouedec Sep 24, 2025
ec6ad25
nits style and align
qgallouedec Sep 24, 2025
b4cadde
Merge branch 'main' into generate-method
qgallouedec Sep 24, 2025
b0dceb9
restart
qgallouedec Sep 25, 2025
ebe32c2
progress
qgallouedec Sep 25, 2025
0213662
progress continues
qgallouedec Sep 25, 2025
8b3a724
progress again again
qgallouedec Sep 25, 2025
c1ae6aa
back to working point
qgallouedec Sep 25, 2025
1a66b43
revert chage data utils
qgallouedec Sep 25, 2025
2dc69a6
Merge branch 'main' into generate-method
qgallouedec Sep 26, 2025
9435a94
refactor in grpo
qgallouedec Sep 26, 2025
d3f1d3c
Merge branch 'main' into refactor_generate
qgallouedec Sep 26, 2025
3d8ea27
wrong merge commit
qgallouedec Sep 26, 2025
27dc958
fix num_input_tokens_seen
qgallouedec Sep 26, 2025
53772ef
getting closer
qgallouedec Sep 26, 2025
8766fa5
consistent naming
qgallouedec Sep 26, 2025
236b78b
better
qgallouedec Sep 26, 2025
9da4830
simplify a bit + comment
qgallouedec Sep 26, 2025
b3bd0b0
another one
qgallouedec Sep 26, 2025
8d34d54
remove pad token removal
qgallouedec Sep 26, 2025
55a2480
rloo + doc
qgallouedec Sep 26, 2025
c5064d6
gfpo
qgallouedec Sep 27, 2025
effb41b
Merge branch 'main' into refactor_generate
qgallouedec Sep 27, 2025
e82bfb4
Merge branch 'main' into refactor_generate
qgallouedec Sep 27, 2025
3a0ba92
Merge branch 'main' into refactor_generate
qgallouedec Sep 30, 2025
c5fa2df
Apply suggestion from @albertvillanova
qgallouedec Sep 30, 2025
c570fb0
fix docstring
qgallouedec Sep 30, 2025
2f70440
Apply suggestion from @albertvillanova
qgallouedec Sep 30, 2025
80b7403
style
qgallouedec Sep 30, 2025
84f400c
Merge branch 'main' into refactor_generate
qgallouedec Sep 30, 2025
c72f54a
Merge branch 'main' into refactor_generate
qgallouedec Sep 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 26 additions & 83 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,136 +937,79 @@ def test_multi_images(self):
class TruncateWithProtectedTokensTester(TrlTestCase):
def test_basic_example(self):
"""Test the basic example from the problem description."""
prompt_ids = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
prompt_mask = torch.ones_like(prompt_ids)
protected_tokens = [2, 3, 6]
prompt_ids = [1, 2, 3, 4, 5]
protected_tokens = [2, 3]
target_length = 3

new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)

expected_ids = torch.tensor([[2, 3, 5], [6, 9, 10]])
expected_mask = torch.ones_like(expected_ids)
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)

self.assertTrue(torch.equal(new_ids, expected_ids))
self.assertTrue(torch.equal(new_mask, expected_mask))
expected_ids = [2, 3, 5]
self.assertEqual(new_ids, expected_ids)

def test_no_truncation_needed(self):
"""Test when target length equals current length."""
prompt_ids = torch.tensor([[1, 2, 3]])
prompt_mask = torch.ones_like(prompt_ids)
prompt_ids = [1, 2, 3]
protected_tokens = [2]
target_length = 3

new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)

self.assertTrue(torch.equal(new_ids, prompt_ids))
self.assertTrue(torch.equal(new_mask, prompt_mask))
self.assertEqual(new_ids, prompt_ids)

def test_no_protected_tokens(self):
"""Test truncation with no protected tokens (normal right truncation)."""
prompt_ids = torch.tensor([[1, 2, 3, 4, 5]])
prompt_mask = torch.ones_like(prompt_ids)
prompt_ids = [1, 2, 3, 4, 5]
protected_tokens = []
target_length = 3

new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)

expected_ids = torch.tensor([[3, 4, 5]]) # Last 3 tokens
self.assertTrue(torch.equal(new_ids, expected_ids))
expected_ids = [3, 4, 5] # Last 3 tokens
self.assertEqual(new_ids, expected_ids)

def test_all_tokens_protected(self):
"""Test when all remaining tokens are protected."""
prompt_ids = torch.tensor([[1, 2, 3, 4, 5]])
prompt_mask = torch.ones_like(prompt_ids)
prompt_ids = [1, 2, 3, 4, 5]
protected_tokens = [3, 4, 5]
target_length = 3

new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)

expected_ids = torch.tensor([[3, 4, 5]])
self.assertTrue(torch.equal(new_ids, expected_ids))
expected_ids = [3, 4, 5]
self.assertEqual(new_ids, expected_ids)

def test_too_many_protected_tokens(self):
"""Test error when too many protected tokens for target length."""
prompt_ids = torch.tensor([[1, 2, 3, 4, 5]])
prompt_mask = torch.ones_like(prompt_ids)
prompt_ids = [1, 2, 3, 4, 5]
protected_tokens = [1, 2, 3, 4]
target_length = 3

with self.assertRaises(ValueError):
truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)

def test_single_batch_single_token(self):
"""Test edge case with single batch and single token."""
prompt_ids = torch.tensor([[5]])
prompt_mask = torch.ones_like(prompt_ids)
prompt_ids = [5]
protected_tokens = [5]
target_length = 1

new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)

self.assertTrue(torch.equal(new_ids, prompt_ids))

def test_mask_preservation(self):
"""Test that mask values are correctly preserved."""
prompt_ids = torch.tensor([[1, 2, 3, 4, 5]])
prompt_mask = torch.tensor([[1, 0, 1, 0, 1]]) # Mixed mask values
protected_tokens = [2, 4]
target_length = 3

new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)

expected_ids = torch.tensor([[2, 4, 5]])
expected_mask = torch.tensor([[0, 0, 1]]) # Corresponding mask values

self.assertTrue(torch.equal(new_ids, expected_ids))
self.assertTrue(torch.equal(new_mask, expected_mask))

def test_multiple_batches_different_protected(self):
"""Test multiple batches where protected tokens appear differently."""
prompt_ids = torch.tensor([[1, 2, 3, 4, 5], [2, 6, 7, 8, 9], [10, 11, 12, 2, 13]])
prompt_mask = torch.ones_like(prompt_ids)
protected_tokens = [2]
target_length = 3

new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)

expected_ids = torch.tensor(
[
[2, 4, 5], # 2 is protected, keep last 2 non-protected (4,5)
[2, 8, 9], # 2 is protected, keep last 2 non-protected (8,9)
[12, 2, 13], # 2 is protected, keep last 2 non-protected (12,13)
]
)

self.assertTrue(torch.equal(new_ids, expected_ids))
self.assertEqual(new_ids, prompt_ids)

def test_order_preservation(self):
"""Test that relative order is preserved."""
prompt_ids = torch.tensor([[10, 2, 20, 3, 30, 40]])
prompt_mask = torch.ones_like(prompt_ids)
prompt_ids = [10, 2, 20, 3, 30, 40]
protected_tokens = [2, 3]
target_length = 4

new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)

# Should keep protected tokens 2,3 and last 2 non-protected tokens 30,40
# Should keep protected tokens 2, 3 and last 2 non-protected tokens 30, 40
# Order should be: 2, 3, 30, 40 (maintaining original relative positions)
expected_ids = torch.tensor([[2, 3, 30, 40]])

self.assertTrue(torch.equal(new_ids, expected_ids))

def test_empty_protected_tokens_list(self):
"""Test with empty protected tokens list."""
prompt_ids = torch.tensor([[1, 2, 3, 4, 5]])
prompt_mask = torch.ones_like(prompt_ids)
protected_tokens = []
target_length = 2

new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
expected_ids = [2, 3, 30, 40]

expected_ids = torch.tensor([[4, 5]]) # Last 2 tokens
self.assertTrue(torch.equal(new_ids, expected_ids))
self.assertEqual(new_ids, expected_ids)


class UnsplitPixelValuesByGridTester(TrlTestCase):
Expand Down
33 changes: 24 additions & 9 deletions trl/experimental/gfpo/gfpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ...data_utils import is_conversational
from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer
from ...trainer.utils import nanmax, nanmin, nanstd
from ...trainer.utils import nanmax, nanmin, nanstd, pad


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -78,18 +78,33 @@ def _generate_and_score_completions(self, inputs):
images = None

(
prompt_ids,
completion_ids,
prompt_mask,
completion_mask,
prompt_ids_list,
completion_ids_list,
num_items_in_batch,
sampling_per_token_logps,
sampling_per_token_logps_list,
forward_kwargs,
) = self._generate(prompts, images)

# Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need
# to re-tokenize completions if the reward is computed from tokens.
completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())]
# Convert lists of token IDs to padded tensors
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list]
completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")
completion_mask = pad(completion_mask, padding_value=0, padding_side="right")
if sampling_per_token_logps_list is not None:
sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list]
sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right")
else:
sampling_per_token_logps = None

# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
if self.mask_truncated_completions:
eos_and_pad = [self.eos_token_id, self.pad_token_id]
is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int()

# Concatenate prompt_mask with completion_mask for logit computation
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
Expand Down
Loading
Loading