Skip to content

Commit ea66a9e

Browse files
🧺 [1/N] Refactor _generate in GRPO/RLOO: list of ints instead of tensors (#4146)
Co-authored-by: Albert Villanova del Moral <[email protected]>
1 parent da209f8 commit ea66a9e

File tree

5 files changed

+216
-252
lines changed

5 files changed

+216
-252
lines changed

‎tests/test_utils.py‎

Lines changed: 26 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,136 +1025,79 @@ def test_multi_images(self):
10251025
class TruncateWithProtectedTokensTester(TrlTestCase):
10261026
def test_basic_example(self):
10271027
"""Test the basic example from the problem description."""
1028-
prompt_ids = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
1029-
prompt_mask = torch.ones_like(prompt_ids)
1030-
protected_tokens = [2, 3, 6]
1028+
prompt_ids = [1, 2, 3, 4, 5]
1029+
protected_tokens = [2, 3]
10311030
target_length = 3
10321031

1033-
new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
1034-
1035-
expected_ids = torch.tensor([[2, 3, 5], [6, 9, 10]])
1036-
expected_mask = torch.ones_like(expected_ids)
1032+
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
10371033

1038-
self.assertTrue(torch.equal(new_ids, expected_ids))
1039-
self.assertTrue(torch.equal(new_mask, expected_mask))
1034+
expected_ids = [2, 3, 5]
1035+
self.assertEqual(new_ids, expected_ids)
10401036

10411037
def test_no_truncation_needed(self):
10421038
"""Test when target length equals current length."""
1043-
prompt_ids = torch.tensor([[1, 2, 3]])
1044-
prompt_mask = torch.ones_like(prompt_ids)
1039+
prompt_ids = [1, 2, 3]
10451040
protected_tokens = [2]
10461041
target_length = 3
10471042

1048-
new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
1043+
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
10491044

1050-
self.assertTrue(torch.equal(new_ids, prompt_ids))
1051-
self.assertTrue(torch.equal(new_mask, prompt_mask))
1045+
self.assertEqual(new_ids, prompt_ids)
10521046

10531047
def test_no_protected_tokens(self):
10541048
"""Test truncation with no protected tokens (normal right truncation)."""
1055-
prompt_ids = torch.tensor([[1, 2, 3, 4, 5]])
1056-
prompt_mask = torch.ones_like(prompt_ids)
1049+
prompt_ids = [1, 2, 3, 4, 5]
10571050
protected_tokens = []
10581051
target_length = 3
10591052

1060-
new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
1053+
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
10611054

1062-
expected_ids = torch.tensor([[3, 4, 5]]) # Last 3 tokens
1063-
self.assertTrue(torch.equal(new_ids, expected_ids))
1055+
expected_ids = [3, 4, 5] # Last 3 tokens
1056+
self.assertEqual(new_ids, expected_ids)
10641057

10651058
def test_all_tokens_protected(self):
10661059
"""Test when all remaining tokens are protected."""
1067-
prompt_ids = torch.tensor([[1, 2, 3, 4, 5]])
1068-
prompt_mask = torch.ones_like(prompt_ids)
1060+
prompt_ids = [1, 2, 3, 4, 5]
10691061
protected_tokens = [3, 4, 5]
10701062
target_length = 3
10711063

1072-
new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
1064+
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
10731065

1074-
expected_ids = torch.tensor([[3, 4, 5]])
1075-
self.assertTrue(torch.equal(new_ids, expected_ids))
1066+
expected_ids = [3, 4, 5]
1067+
self.assertEqual(new_ids, expected_ids)
10761068

10771069
def test_too_many_protected_tokens(self):
10781070
"""Test error when too many protected tokens for target length."""
1079-
prompt_ids = torch.tensor([[1, 2, 3, 4, 5]])
1080-
prompt_mask = torch.ones_like(prompt_ids)
1071+
prompt_ids = [1, 2, 3, 4, 5]
10811072
protected_tokens = [1, 2, 3, 4]
10821073
target_length = 3
10831074

10841075
with self.assertRaises(ValueError):
1085-
truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
1076+
truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
10861077

10871078
def test_single_batch_single_token(self):
10881079
"""Test edge case with single batch and single token."""
1089-
prompt_ids = torch.tensor([[5]])
1090-
prompt_mask = torch.ones_like(prompt_ids)
1080+
prompt_ids = [5]
10911081
protected_tokens = [5]
10921082
target_length = 1
10931083

1094-
new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
1095-
1096-
self.assertTrue(torch.equal(new_ids, prompt_ids))
1097-
1098-
def test_mask_preservation(self):
1099-
"""Test that mask values are correctly preserved."""
1100-
prompt_ids = torch.tensor([[1, 2, 3, 4, 5]])
1101-
prompt_mask = torch.tensor([[1, 0, 1, 0, 1]]) # Mixed mask values
1102-
protected_tokens = [2, 4]
1103-
target_length = 3
1104-
1105-
new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
1106-
1107-
expected_ids = torch.tensor([[2, 4, 5]])
1108-
expected_mask = torch.tensor([[0, 0, 1]]) # Corresponding mask values
1109-
1110-
self.assertTrue(torch.equal(new_ids, expected_ids))
1111-
self.assertTrue(torch.equal(new_mask, expected_mask))
1112-
1113-
def test_multiple_batches_different_protected(self):
1114-
"""Test multiple batches where protected tokens appear differently."""
1115-
prompt_ids = torch.tensor([[1, 2, 3, 4, 5], [2, 6, 7, 8, 9], [10, 11, 12, 2, 13]])
1116-
prompt_mask = torch.ones_like(prompt_ids)
1117-
protected_tokens = [2]
1118-
target_length = 3
1119-
1120-
new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
1084+
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
11211085

1122-
expected_ids = torch.tensor(
1123-
[
1124-
[2, 4, 5], # 2 is protected, keep last 2 non-protected (4,5)
1125-
[2, 8, 9], # 2 is protected, keep last 2 non-protected (8,9)
1126-
[12, 2, 13], # 2 is protected, keep last 2 non-protected (12,13)
1127-
]
1128-
)
1129-
1130-
self.assertTrue(torch.equal(new_ids, expected_ids))
1086+
self.assertEqual(new_ids, prompt_ids)
11311087

11321088
def test_order_preservation(self):
11331089
"""Test that relative order is preserved."""
1134-
prompt_ids = torch.tensor([[10, 2, 20, 3, 30, 40]])
1135-
prompt_mask = torch.ones_like(prompt_ids)
1090+
prompt_ids = [10, 2, 20, 3, 30, 40]
11361091
protected_tokens = [2, 3]
11371092
target_length = 4
11381093

1139-
new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
1094+
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
11401095

1141-
# Should keep protected tokens 2,3 and last 2 non-protected tokens 30,40
1096+
# Should keep protected tokens 2, 3 and last 2 non-protected tokens 30, 40
11421097
# Order should be: 2, 3, 30, 40 (maintaining original relative positions)
1143-
expected_ids = torch.tensor([[2, 3, 30, 40]])
1144-
1145-
self.assertTrue(torch.equal(new_ids, expected_ids))
1146-
1147-
def test_empty_protected_tokens_list(self):
1148-
"""Test with empty protected tokens list."""
1149-
prompt_ids = torch.tensor([[1, 2, 3, 4, 5]])
1150-
prompt_mask = torch.ones_like(prompt_ids)
1151-
protected_tokens = []
1152-
target_length = 2
1153-
1154-
new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
1098+
expected_ids = [2, 3, 30, 40]
11551099

1156-
expected_ids = torch.tensor([[4, 5]]) # Last 2 tokens
1157-
self.assertTrue(torch.equal(new_ids, expected_ids))
1100+
self.assertEqual(new_ids, expected_ids)
11581101

11591102

11601103
class UnsplitPixelValuesByGridTester(TrlTestCase):

‎trl/experimental/gfpo/gfpo_trainer.py‎

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from ...data_utils import is_conversational
2222
from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer
23-
from ...trainer.utils import nanmax, nanmin, nanstd
23+
from ...trainer.utils import nanmax, nanmin, nanstd, pad
2424

2525

2626
logger = logging.getLogger(__name__)
@@ -78,18 +78,33 @@ def _generate_and_score_completions(self, inputs):
7878
images = None
7979

8080
(
81-
prompt_ids,
82-
completion_ids,
83-
prompt_mask,
84-
completion_mask,
81+
prompt_ids_list,
82+
completion_ids_list,
8583
num_items_in_batch,
86-
sampling_per_token_logps,
84+
sampling_per_token_logps_list,
8785
forward_kwargs,
8886
) = self._generate(prompts, images)
8987

90-
# Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need
91-
# to re-tokenize completions if the reward is computed from tokens.
92-
completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())]
88+
# Convert lists of token IDs to padded tensors
89+
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
90+
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
91+
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
92+
prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
93+
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list]
94+
completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids]
95+
completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")
96+
completion_mask = pad(completion_mask, padding_value=0, padding_side="right")
97+
if sampling_per_token_logps_list is not None:
98+
sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list]
99+
sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right")
100+
else:
101+
sampling_per_token_logps = None
102+
103+
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
104+
if self.mask_truncated_completions:
105+
eos_and_pad = [self.eos_token_id, self.pad_token_id]
106+
is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
107+
completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int()
93108

94109
# Concatenate prompt_mask with completion_mask for logit computation
95110
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)

0 commit comments

Comments
 (0)