diff --git a/tests/test_utils.py b/tests/test_utils.py index a9febf2e9be..4758109e08b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1025,136 +1025,79 @@ def test_multi_images(self): class TruncateWithProtectedTokensTester(TrlTestCase): def test_basic_example(self): """Test the basic example from the problem description.""" - prompt_ids = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) - prompt_mask = torch.ones_like(prompt_ids) - protected_tokens = [2, 3, 6] + prompt_ids = [1, 2, 3, 4, 5] + protected_tokens = [2, 3] target_length = 3 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) - - expected_ids = torch.tensor([[2, 3, 5], [6, 9, 10]]) - expected_mask = torch.ones_like(expected_ids) + new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - self.assertTrue(torch.equal(new_ids, expected_ids)) - self.assertTrue(torch.equal(new_mask, expected_mask)) + expected_ids = [2, 3, 5] + self.assertEqual(new_ids, expected_ids) def test_no_truncation_needed(self): """Test when target length equals current length.""" - prompt_ids = torch.tensor([[1, 2, 3]]) - prompt_mask = torch.ones_like(prompt_ids) + prompt_ids = [1, 2, 3] protected_tokens = [2] target_length = 3 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - self.assertTrue(torch.equal(new_ids, prompt_ids)) - self.assertTrue(torch.equal(new_mask, prompt_mask)) + self.assertEqual(new_ids, prompt_ids) def test_no_protected_tokens(self): """Test truncation with no protected tokens (normal right truncation).""" - prompt_ids = torch.tensor([[1, 2, 3, 4, 5]]) - prompt_mask = torch.ones_like(prompt_ids) + prompt_ids = [1, 2, 3, 4, 5] protected_tokens = [] target_length = 3 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - expected_ids = torch.tensor([[3, 4, 5]]) # Last 3 tokens - self.assertTrue(torch.equal(new_ids, expected_ids)) + expected_ids = [3, 4, 5] # Last 3 tokens + self.assertEqual(new_ids, expected_ids) def test_all_tokens_protected(self): """Test when all remaining tokens are protected.""" - prompt_ids = torch.tensor([[1, 2, 3, 4, 5]]) - prompt_mask = torch.ones_like(prompt_ids) + prompt_ids = [1, 2, 3, 4, 5] protected_tokens = [3, 4, 5] target_length = 3 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - expected_ids = torch.tensor([[3, 4, 5]]) - self.assertTrue(torch.equal(new_ids, expected_ids)) + expected_ids = [3, 4, 5] + self.assertEqual(new_ids, expected_ids) def test_too_many_protected_tokens(self): """Test error when too many protected tokens for target length.""" - prompt_ids = torch.tensor([[1, 2, 3, 4, 5]]) - prompt_mask = torch.ones_like(prompt_ids) + prompt_ids = [1, 2, 3, 4, 5] protected_tokens = [1, 2, 3, 4] target_length = 3 with self.assertRaises(ValueError): - truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) def test_single_batch_single_token(self): """Test edge case with single batch and single token.""" - prompt_ids = torch.tensor([[5]]) - prompt_mask = torch.ones_like(prompt_ids) + prompt_ids = [5] protected_tokens = [5] target_length = 1 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) - - self.assertTrue(torch.equal(new_ids, prompt_ids)) - - def test_mask_preservation(self): - """Test that mask values are correctly preserved.""" - prompt_ids = torch.tensor([[1, 2, 3, 4, 5]]) - prompt_mask = torch.tensor([[1, 0, 1, 0, 1]]) # Mixed mask values - protected_tokens = [2, 4] - target_length = 3 - - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) - - expected_ids = torch.tensor([[2, 4, 5]]) - expected_mask = torch.tensor([[0, 0, 1]]) # Corresponding mask values - - self.assertTrue(torch.equal(new_ids, expected_ids)) - self.assertTrue(torch.equal(new_mask, expected_mask)) - - def test_multiple_batches_different_protected(self): - """Test multiple batches where protected tokens appear differently.""" - prompt_ids = torch.tensor([[1, 2, 3, 4, 5], [2, 6, 7, 8, 9], [10, 11, 12, 2, 13]]) - prompt_mask = torch.ones_like(prompt_ids) - protected_tokens = [2] - target_length = 3 - - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - expected_ids = torch.tensor( - [ - [2, 4, 5], # 2 is protected, keep last 2 non-protected (4,5) - [2, 8, 9], # 2 is protected, keep last 2 non-protected (8,9) - [12, 2, 13], # 2 is protected, keep last 2 non-protected (12,13) - ] - ) - - self.assertTrue(torch.equal(new_ids, expected_ids)) + self.assertEqual(new_ids, prompt_ids) def test_order_preservation(self): """Test that relative order is preserved.""" - prompt_ids = torch.tensor([[10, 2, 20, 3, 30, 40]]) - prompt_mask = torch.ones_like(prompt_ids) + prompt_ids = [10, 2, 20, 3, 30, 40] protected_tokens = [2, 3] target_length = 4 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - # Should keep protected tokens 2,3 and last 2 non-protected tokens 30,40 + # Should keep protected tokens 2, 3 and last 2 non-protected tokens 30, 40 # Order should be: 2, 3, 30, 40 (maintaining original relative positions) - expected_ids = torch.tensor([[2, 3, 30, 40]]) - - self.assertTrue(torch.equal(new_ids, expected_ids)) - - def test_empty_protected_tokens_list(self): - """Test with empty protected tokens list.""" - prompt_ids = torch.tensor([[1, 2, 3, 4, 5]]) - prompt_mask = torch.ones_like(prompt_ids) - protected_tokens = [] - target_length = 2 - - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + expected_ids = [2, 3, 30, 40] - expected_ids = torch.tensor([[4, 5]]) # Last 2 tokens - self.assertTrue(torch.equal(new_ids, expected_ids)) + self.assertEqual(new_ids, expected_ids) class UnsplitPixelValuesByGridTester(TrlTestCase): diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index 784db368afd..58ca39a6f45 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -20,7 +20,7 @@ from ...data_utils import is_conversational from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer -from ...trainer.utils import nanmax, nanmin, nanstd +from ...trainer.utils import nanmax, nanmin, nanstd, pad logger = logging.getLogger(__name__) @@ -78,18 +78,33 @@ def _generate_and_score_completions(self, inputs): images = None ( - prompt_ids, - completion_ids, - prompt_mask, - completion_mask, + prompt_ids_list, + completion_ids_list, num_items_in_batch, - sampling_per_token_logps, + sampling_per_token_logps_list, forward_kwargs, ) = self._generate(prompts, images) - # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need - # to re-tokenize completions if the reward is computed from tokens. - completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] + # Convert lists of token IDs to padded tensors + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") + completion_mask = pad(completion_mask, padding_value=0, padding_side="right") + if sampling_per_token_logps_list is not None: + sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list] + sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") + else: + sampling_per_token_logps = None + + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index f4617665c88..22ab4df9275 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1073,9 +1073,8 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate(self, prompts: list[str], images: Optional[list]): + def _generate_single_turn(self, prompts: list[str], images: Optional[list]): device = self.accelerator.device - mode = "train" if self.model.training else "eval" # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to @@ -1102,21 +1101,19 @@ def _generate(self, prompts: list[str], images: Optional[list]): prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. - # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, - # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation). + # Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special + # tokens are needed for generation. protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id] protected = [token for token in protected if token is not None] - prompt_ids, prompt_mask = truncate_with_protected_tokens( - prompt_ids, prompt_mask, self.max_prompt_length, protected - ) + prompt_ids = [truncate_with_protected_tokens(ids, self.max_prompt_length, protected) for ids in prompt_ids] prompts_text = self.processing_class.batch_decode( prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False ) - prompts_text = [re.sub(rf"^({re.escape(self.pad_token)})+", "", text) for text in prompts_text] # The chat template sometimes inserts a single image token into the prompt text. However, when this text is # later tokenized, the single image token string is expanded into multiple image token IDs, depending on the @@ -1195,14 +1192,14 @@ def _generate(self, prompts: list[str], images: Optional[list]): # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. obj_list = [payload] broadcast_object_list(obj_list, from_process=0) - completion_ids, all_logprobs = obj_list[0] + all_completion_ids, all_logprobs = obj_list[0] process_slice = slice( self.accelerator.process_index * len(prompts), (self.accelerator.process_index + 1) * len(prompts), ) - completion_ids = completion_ids[process_slice] - all_logprobs = all_logprobs[process_slice] + completion_ids = all_completion_ids[process_slice] + logprobs = all_logprobs[process_slice] # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts elif self.vllm_mode == "colocate": @@ -1255,7 +1252,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): with profiling_context(self, "vLLM.generate"): all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) - completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] all_logprobs = [ [next(iter(lp.values())).logprob for lp in output.logprobs] for outputs in all_outputs @@ -1267,22 +1264,15 @@ def _generate(self, prompts: list[str], images: Optional[list]): # Each rank generates all outputs — we keep only our share. local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - completion_ids = completion_ids[tp_slice] - all_logprobs = all_logprobs[tp_slice] + completion_ids = all_completion_ids[tp_slice] + logprobs = all_logprobs[tp_slice] + else: + completion_ids = all_completion_ids + logprobs = all_logprobs if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=1) - # Pad the completions, and concatenate them with the prompts - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_mask = [torch.ones(len(ids), device=device, dtype=torch.long) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id) - completion_mask = pad(completion_mask, padding_value=0) - sampling_per_token_logps = [ - torch.tensor(logprobs, device=device, dtype=torch.float32) for logprobs in all_logprobs - ] - sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0) - elif self.use_transformers_paged: # Re-process inputs for paged generation if needed # Note: images are already validated and preprocessed above @@ -1312,16 +1302,18 @@ def _generate(self, prompts: list[str], images: Optional[list]): ) unwrapped_model.train() # restore training mode, as generate_batch forces eval mode completion_ids = [output.generated_tokens for output in all_outputs.values()] - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") - prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_ids = paged_prompt_inputs.input_ids # Restore the original attention implementation, training mode self.model_wrapped.config._attn_implementation = previous_attn - sampling_per_token_logps = None # not used in this case + logprobs = None # not used in this case else: # Regular generation path + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + with ( profiling_context(self, "transformers.generate"), unwrap_model_for_generation( @@ -1341,29 +1333,36 @@ def _generate(self, prompts: list[str], images: Optional[list]): prompt_length = prompt_ids.size(1) prompt_ids = prompt_completion_ids[:, :prompt_length] completion_ids = prompt_completion_ids[:, prompt_length:] - sampling_per_token_logps = None # not used in this case - # Mask everything after the first EOS token - is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) - completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + # Mask everything after the first EOS token + is_eos = completion_ids == self.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] + completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())] + logprobs = None # not used in this case - # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging - completion_lengths = completion_mask.sum(1) - agg_completion_lengths = self.accelerator.gather(completion_lengths) - num_items_in_batch = agg_completion_lengths.sum() # this is required for the DAPO loss + return prompt_ids, completion_ids, logprobs, forward_kwargs - # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask - if self.mask_truncated_completions: - truncated_completions = ~is_eos.any(dim=1) - completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() + def _generate(self, prompts: list[str], images: Optional[list]): + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompt_ids, completion_ids, logprobs, forward_kwargs = self._generate_single_turn(prompts, images) + + # Get completion length per sequence, used for logging + prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) + completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) + agg_prompt_lengths = self.accelerator.gather(prompt_lengths) + agg_completion_lengths = self.accelerator.gather(completion_lengths) + total_prompt_tokens = agg_prompt_lengths.sum() + total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss # Log the metrics if mode == "train": - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() + self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item() self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] # Log completion lengths, mean, min, max @@ -1372,25 +1371,18 @@ def _generate(self, prompts: list[str], images: Optional[list]): self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) # Identify sequences that terminated with EOS and log their lengths - agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) - term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] - clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) - self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) + agg_is_truncated = self.accelerator.gather(is_truncated) + self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) + term_completion_lengths = agg_completion_lengths[~agg_is_truncated] if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found term_completion_lengths = torch.zeros(1, device=device) self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return ( - prompt_ids, - completion_ids, - prompt_mask, - completion_mask, - num_items_in_batch, - sampling_per_token_logps, - forward_kwargs, - ) + return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1408,18 +1400,33 @@ def _generate_and_score_completions( images = None ( - prompt_ids, - completion_ids, - prompt_mask, - completion_mask, + prompt_ids_list, + completion_ids_list, num_items_in_batch, - sampling_per_token_logps, + sampling_per_token_logps_list, forward_kwargs, ) = self._generate(prompts, images) - # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need - # to re-tokenize completions if the reward is computed from tokens. - completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] + # Convert lists of token IDs to padded tensors + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") + completion_mask = pad(completion_mask, padding_value=0, padding_side="right") + if sampling_per_token_logps_list is not None: + sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list] + sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") + else: + sampling_per_token_logps = None + + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index c9abc5787ee..4c9a3623c3d 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -1064,9 +1064,8 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate(self, prompts: list[str], images: Optional[list]): + def _generate_single_turn(self, prompts: list[str], images: Optional[list]): device = self.accelerator.device - mode = "train" if self.model.training else "eval" # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to @@ -1093,21 +1092,19 @@ def _generate(self, prompts: list[str], images: Optional[list]): prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. - # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, - # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation). + # Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special + # tokens are needed for generation. protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id] protected = [token for token in protected if token is not None] - prompt_ids, prompt_mask = truncate_with_protected_tokens( - prompt_ids, prompt_mask, self.max_prompt_length, protected - ) + prompt_ids = [truncate_with_protected_tokens(ids, self.max_prompt_length, protected) for ids in prompt_ids] prompts_text = self.processing_class.batch_decode( prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False ) - prompts_text = [re.sub(rf"^({re.escape(self.pad_token)})+", "", text) for text in prompts_text] # The chat template sometimes inserts a single image token into the prompt text. However, when this text is # later tokenized, the single image token string is expanded into multiple image token IDs, depending on the @@ -1186,13 +1183,13 @@ def _generate(self, prompts: list[str], images: Optional[list]): # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. obj_list = [payload] broadcast_object_list(obj_list, from_process=0) - completion_ids, _ = obj_list[0] + all_completion_ids, _ = obj_list[0] process_slice = slice( self.accelerator.process_index * len(prompts), (self.accelerator.process_index + 1) * len(prompts), ) - completion_ids = completion_ids[process_slice] + completion_ids = all_completion_ids[process_slice] # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts elif self.vllm_mode == "colocate": @@ -1244,24 +1241,20 @@ def _generate(self, prompts: list[str], images: Optional[list]): with profiling_context(self, "vLLM.generate"): all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) - completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] if self.vllm_tensor_parallel_size > 1: # Slice completions for this rank within its TP group. # Each rank generates all outputs — we keep only our share. local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - completion_ids = completion_ids[tp_slice] + completion_ids = all_completion_ids[tp_slice] + else: + completion_ids = all_completion_ids if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=1) - # Pad the completions, and concatenate them with the prompts - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_mask = [torch.ones(len(ids), device=device, dtype=torch.long) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id) - completion_mask = pad(completion_mask, padding_value=0) - elif self.use_transformers_paged: # Re-process inputs for paged generation if needed # Note: images are already validated and preprocessed above @@ -1291,15 +1284,17 @@ def _generate(self, prompts: list[str], images: Optional[list]): ) unwrapped_model.train() # restore training mode, as generate_batch forces eval mode completion_ids = [output.generated_tokens for output in all_outputs.values()] - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") - prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_ids = paged_prompt_inputs.input_ids # Restore the original attention implementation, training mode self.model_wrapped.config._attn_implementation = previous_attn else: # Regular generation path + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + with ( profiling_context(self, "transformers.generate"), unwrap_model_for_generation( @@ -1320,25 +1315,34 @@ def _generate(self, prompts: list[str], images: Optional[list]): prompt_ids = prompt_completion_ids[:, :prompt_length] completion_ids = prompt_completion_ids[:, prompt_length:] - # Mask everything after the first EOS token - is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) - completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + # Mask everything after the first EOS token + is_eos = completion_ids == self.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] + completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())] - # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging - completion_lengths = completion_mask.sum(1) + return prompt_ids, completion_ids, forward_kwargs - # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask - if self.mask_truncated_completions: - truncated_completions = ~is_eos.any(dim=1) - completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() + def _generate(self, prompts: list[str], images: Optional[list]): + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompt_ids, completion_ids, forward_kwargs = self._generate_single_turn(prompts, images) + + # Get completion length per sequence, used for logging + prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) + completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) + agg_prompt_lengths = self.accelerator.gather(prompt_lengths) + agg_completion_lengths = self.accelerator.gather(completion_lengths) + total_prompt_tokens = agg_prompt_lengths.sum() + total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss # Log the metrics if mode == "train": - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() + self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item() self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] # Log completion lengths, mean, min, max @@ -1348,17 +1352,18 @@ def _generate(self, prompts: list[str], images: Optional[list]): self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) # Identify sequences that terminated with EOS and log their lengths - agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) - term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] - clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) - self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) + agg_is_truncated = self.accelerator.gather(is_truncated) + self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) + term_completion_lengths = agg_completion_lengths[~agg_is_truncated] if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found term_completion_lengths = torch.zeros(1, device=device) self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return prompt_ids, completion_ids, prompt_mask, completion_mask, forward_kwargs + return prompt_ids, completion_ids, forward_kwargs def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1375,11 +1380,23 @@ def _generate_and_score_completions( else: images = None - prompt_ids, completion_ids, prompt_mask, completion_mask, forward_kwargs = self._generate(prompts, images) + prompt_ids_list, completion_ids_list, forward_kwargs = self._generate(prompts, images) + + # Convert lists of token IDs to padded tensors + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") + completion_mask = pad(completion_mask, padding_value=0, padding_side="right") - # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need - # to re-tokenize completions if the reward is computed from tokens. - completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index a12fdec7b44..76691bff7fb 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1923,63 +1923,45 @@ def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch return batch -def truncate_with_protected_tokens( - ids: torch.Tensor, mask: torch.Tensor, target_length: int, protected_tokens: list[int] -) -> tuple[torch.Tensor, torch.Tensor]: +def truncate_with_protected_tokens(ids: list[int], target_length: int, protected_tokens: list[int]) -> list[int]: """ - Truncate tensors to target length while preserving protected tokens. + Truncate list to target length while preserving protected tokens. Args: - ids (`torch.Tensor`): - Input tensor of token IDs, shape (batch_size, sequence_length). - mask (`torch.Tensor`): - Input tensor of attention masks, shape (batch_size, sequence_length). + ids (`list[int]`): + Input sequence of token IDs. target_length (`int`): - Desired length of the output sequences. + Desired length of the output sequence. protected_tokens (`list[int]`): List of token IDs that should be preserved in the output. - """ - protected_set = set(protected_tokens) - # Create protected_tokens tensor once to avoid recreating it on every call - protected_tokens_tensor = torch.tensor(list(protected_set), device=ids.device) - - def process_sequence(ids, mask): - # Create boolean masks - is_protected = torch.isin(ids, protected_tokens_tensor) - is_non_protected = ~is_protected - - # Count tokens - num_protected = is_protected.sum().item() - num_non_protected_needed = target_length - num_protected - - if num_non_protected_needed < 0: - raise ValueError( - f"target_length ({target_length}) is too small for the protected tokens ({num_protected} tokens). " - f"Please increase target length to at least {num_protected} or disable truncation." - ) - - # Select which non-protected tokens to keep (rightmost ones) - non_protected_indices = torch.where(is_non_protected)[0] - keep_non_protected = torch.zeros_like(is_non_protected) - if num_non_protected_needed > 0: - keep_indices = non_protected_indices[-num_non_protected_needed:] - keep_non_protected[keep_indices] = True - - # Final mask: protected OR selected non-protected - keep_mask = is_protected | keep_non_protected - return ids[keep_mask], mask[keep_mask] - - # Process each sequence in the batch - truncated_seq = [] - truncated_mask = [] + Returns: + `list[int]`: Truncated sequence. - for i in range(ids.shape[0]): - new_ids, new_mask = process_sequence(ids[i], mask[i]) - truncated_seq.append(new_ids) - truncated_mask.append(new_mask) + Raises: + `ValueError`: If `len(protected_tokens ∩ seq) > target_length`. + """ + protected_set = set(protected_tokens) - return torch.stack(truncated_seq), torch.stack(truncated_mask) + # Count protected tokens + num_protected = sum(1 for t in ids if t in protected_set) + if num_protected > target_length: + raise ValueError( + f"target_length ({target_length}) is too small for the protected tokens ({num_protected} tokens). " + f"Please increase target length to at least {num_protected} or disable truncation." + ) + num_non_protected_needed = target_length - num_protected + result = [] + + # Iterate backward to select all protected tokens and rightmost non-protected tokens + for t in reversed(ids): + if t in protected_set: + result.append(t) + elif num_non_protected_needed > 0: + result.append(t) + num_non_protected_needed -= 1 + # Reverse to restore original order + return result[::-1] TListOrMapping = TypeVar("TListOrMapping", list, Mapping)