diff --git a/tests/test_utils.py b/tests/test_utils.py index 6f6ba1579ef..50d54311272 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -942,7 +942,7 @@ def test_basic_example(self): protected_tokens = [2, 3, 6] target_length = 3 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids, new_mask, _ = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) expected_ids = torch.tensor([[2, 3, 5], [6, 9, 10]]) expected_mask = torch.ones_like(expected_ids) @@ -957,7 +957,7 @@ def test_no_truncation_needed(self): protected_tokens = [2] target_length = 3 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids, new_mask, _ = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) self.assertTrue(torch.equal(new_ids, prompt_ids)) self.assertTrue(torch.equal(new_mask, prompt_mask)) @@ -969,7 +969,7 @@ def test_no_protected_tokens(self): protected_tokens = [] target_length = 3 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids, _, _ = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) expected_ids = torch.tensor([[3, 4, 5]]) # Last 3 tokens self.assertTrue(torch.equal(new_ids, expected_ids)) @@ -981,7 +981,7 @@ def test_all_tokens_protected(self): protected_tokens = [3, 4, 5] target_length = 3 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids, _, _ = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) expected_ids = torch.tensor([[3, 4, 5]]) self.assertTrue(torch.equal(new_ids, expected_ids)) @@ -1003,7 +1003,7 @@ def test_single_batch_single_token(self): protected_tokens = [5] target_length = 1 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids, _, _ = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) self.assertTrue(torch.equal(new_ids, prompt_ids)) @@ -1014,7 +1014,7 @@ def test_mask_preservation(self): protected_tokens = [2, 4] target_length = 3 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids, new_mask, _ = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) expected_ids = torch.tensor([[2, 4, 5]]) expected_mask = torch.tensor([[0, 0, 1]]) # Corresponding mask values @@ -1029,7 +1029,7 @@ def test_multiple_batches_different_protected(self): protected_tokens = [2] target_length = 3 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids, _, _ = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) expected_ids = torch.tensor( [ @@ -1048,7 +1048,7 @@ def test_order_preservation(self): protected_tokens = [2, 3] target_length = 4 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids, _, _ = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) # Should keep protected tokens 2,3 and last 2 non-protected tokens 30,40 # Order should be: 2, 3, 30, 40 (maintaining original relative positions) @@ -1063,10 +1063,209 @@ def test_empty_protected_tokens_list(self): protected_tokens = [] target_length = 2 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids, _, metadata = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) expected_ids = torch.tensor([[4, 5]]) # Last 2 tokens self.assertTrue(torch.equal(new_ids, expected_ids)) + self.assertEqual(metadata["total_images_removed"], 0) + + def test_image_token_block_detection_internvl(self): + """Test InternVL image token blocks (256 tokens per image).""" + # Simulate InternVL with 768 image tokens (3 x 256) + image_token_id = 151667 + prompt_ids = torch.tensor([[1, 2] + [image_token_id] * 768 + [3, 4, 5]]) + prompt_mask = torch.ones_like(prompt_ids) + protected_tokens = [] + target_length = 400 # Force reduction + + new_ids, _, metadata = truncate_with_protected_tokens( + prompt_ids, + prompt_mask, + target_length, + protected_tokens, + image_token_id=image_token_id, + image_seq_length=256, + ) + + # Should reduce from 3 image blocks to 1 block (768 -> 256 tokens) + image_tokens_after = (new_ids == image_token_id).sum().item() + self.assertEqual(image_tokens_after, 256) + self.assertEqual(metadata["block_size"], 256) + self.assertEqual(metadata["images_removed_per_sequence"], [2]) # Removed 2 blocks, kept 1 + self.assertEqual(metadata["total_images_removed"], 2) + + def test_image_token_block_multiple_sequences_simple(self): + """Test image block reduction with multiple sequences (simplified).""" + image_token_id = 151667 + # Simple case: both sequences have exactly the same pattern + seq1 = [1, 2] + [image_token_id] * 768 + [3, 4] # 3 blocks + seq2 = [1, 2] + [image_token_id] * 768 + [3, 4] # 3 blocks (same) + + prompt_ids = torch.tensor([seq1, seq2]) + prompt_mask = torch.ones_like(prompt_ids) + protected_tokens = [] + target_length = 400 # Force reduction + + _, _, metadata = truncate_with_protected_tokens( + prompt_ids, + prompt_mask, + target_length, + protected_tokens, + image_token_id=image_token_id, + image_seq_length=256, + ) + + # Both sequences should be reduced from 3 blocks to 1 block + self.assertEqual(metadata["images_removed_per_sequence"], [2, 2]) # Both reduced 3->1 + self.assertEqual(metadata["total_images_removed"], 4) + self.assertEqual(metadata["block_size"], 256) + + def test_image_token_block_no_reduction_needed(self): + """Test when image tokens don't need reduction.""" + image_token_id = 151667 + prompt_ids = torch.tensor([[1, 2] + [image_token_id] * 256 + [3, 4]]) # Only 1 block + prompt_mask = torch.ones_like(prompt_ids) + protected_tokens = [] + target_length = 300 # No reduction needed + + new_ids, _, metadata = truncate_with_protected_tokens( + prompt_ids, prompt_mask, target_length, protected_tokens, image_token_id=image_token_id + ) + + # Should not reduce anything + image_tokens_after = (new_ids == image_token_id).sum().item() + self.assertEqual(image_tokens_after, 256) + self.assertEqual(metadata["images_removed_per_sequence"], [0]) + self.assertEqual(metadata["total_images_removed"], 0) + + def test_image_token_block_with_protected_tokens(self): + """Test image block reduction combined with protected token preservation.""" + image_token_id = 151667 + vision_start_token = 151666 + prompt_ids = torch.tensor([[vision_start_token, 2] + [image_token_id] * 768 + [3, 4, 5]]) + prompt_mask = torch.ones_like(prompt_ids) + protected_tokens = [vision_start_token, image_token_id] + target_length = 400 + + new_ids, _, metadata = truncate_with_protected_tokens( + prompt_ids, + prompt_mask, + target_length, + protected_tokens, + image_token_id=image_token_id, + image_seq_length=256, + ) + + # Should reduce image blocks first, then apply protected token logic + image_tokens_after = (new_ids == image_token_id).sum().item() + self.assertEqual(image_tokens_after, 256) # Reduced to 1 block + self.assertTrue((new_ids == vision_start_token).any()) # Protected token preserved + self.assertEqual(metadata["images_removed_per_sequence"], [2]) + + def test_image_token_block_edge_case_single_block(self): + """Test that single image blocks are preserved when there's only one block.""" + image_token_id = 151667 + prompt_ids = torch.tensor([[1, 2] + [image_token_id] * 256 + [3, 4, 5]]) + prompt_mask = torch.ones_like(prompt_ids) + protected_tokens = [] + target_length = 260 # Larger than image block but smaller than total + + new_ids, _, metadata = truncate_with_protected_tokens( + prompt_ids, + prompt_mask, + target_length, + protected_tokens, + image_token_id=image_token_id, + image_seq_length=256, + ) + + # Should not remove the single image block since there's only one + image_tokens_after = (new_ids == image_token_id).sum().item() + self.assertEqual(image_tokens_after, 256) + self.assertEqual(metadata["images_removed_per_sequence"], [0]) + + def test_image_token_block_non_consecutive_tokens(self): + """Test that non-consecutive image tokens don't trigger block detection.""" + image_token_id = 151667 + # Scatter 256 image tokens throughout the sequence (not consecutive) + prompt_ids = torch.tensor([[1, image_token_id, 2, image_token_id, 3, image_token_id] * 100]) + prompt_mask = torch.ones_like(prompt_ids) + protected_tokens = [] + target_length = 200 + + _, _, metadata = truncate_with_protected_tokens( + prompt_ids, + prompt_mask, + target_length, + protected_tokens, + image_token_id=image_token_id, + image_seq_length=256, + ) + + # Should not detect blocks since tokens are not consecutive + self.assertIsNone(metadata["block_size"]) + self.assertEqual(metadata["images_removed_per_sequence"], [0]) + + def test_image_token_block_no_detection_non_standard(self): + """Test that non-standard block sizes fall back to regular truncation.""" + image_token_id = 12345 + # Test with non-standard block size that won't be detected + prompt_ids = torch.tensor([[1, 2] + [image_token_id] * 300 + [3, 4]]) # Non-standard size + prompt_mask = torch.ones_like(prompt_ids) + protected_tokens = [] + target_length = 200 # Force reduction + + new_ids, _, metadata = truncate_with_protected_tokens( + prompt_ids, + prompt_mask, + target_length, + protected_tokens, + image_token_id=image_token_id, + image_seq_length=256, + ) + + # Should not detect blocks and fall back to regular truncation + self.assertIsNone(metadata["block_size"]) + self.assertEqual(metadata["images_removed_per_sequence"], [0]) + self.assertEqual(new_ids.shape[1], 200) # Regular truncation to target length + + def test_backward_compatibility_without_image_tokens(self): + """Test that the enhanced function maintains backward compatibility.""" + prompt_ids = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) + prompt_mask = torch.ones_like(prompt_ids) + protected_tokens = [2, 3, 6] + target_length = 3 + + # Call without image_token_id (backward compatibility) + new_ids, new_mask, metadata = truncate_with_protected_tokens( + prompt_ids, prompt_mask, target_length, protected_tokens + ) + + expected_ids = torch.tensor([[2, 3, 5], [6, 9, 10]]) + expected_mask = torch.ones_like(expected_ids) + + self.assertTrue(torch.equal(new_ids, expected_ids)) + self.assertTrue(torch.equal(new_mask, expected_mask)) + self.assertEqual(metadata["total_images_removed"], 0) + self.assertIsNone(metadata["block_size"]) + + def test_metadata_structure(self): + """Test that metadata contains all expected fields.""" + prompt_ids = torch.tensor([[1, 2, 3, 4, 5]]) + prompt_mask = torch.ones_like(prompt_ids) + protected_tokens = [] + target_length = 3 + + _, _, metadata = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + + # Check metadata structure + self.assertIn("images_removed_per_sequence", metadata) + self.assertIn("block_size", metadata) + self.assertIn("total_images_removed", metadata) + self.assertIsInstance(metadata["images_removed_per_sequence"], list) + self.assertEqual(len(metadata["images_removed_per_sequence"]), 1) # One sequence + self.assertEqual(metadata["total_images_removed"], 0) + self.assertIsNone(metadata["block_size"]) class UnsplitPixelValuesByGridTester(TrlTestCase): diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 9088ea39a83..39dd1976b3a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1106,10 +1106,56 @@ def _generate_and_score_completions( # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation). protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id] protected = [token for token in protected if token is not None] - prompt_ids, prompt_mask = truncate_with_protected_tokens( - prompt_ids, prompt_mask, self.max_prompt_length, protected + + # Extract image_seq_length from model config for vision models + image_seq_length = None + if self.image_token_id is not None: + # Try to get image_seq_length from model config + config = getattr(self.model, "config", None) + if config is not None: + image_seq_length = getattr(config, "image_seq_length", None) + if image_seq_length is None: + # For some models, it might be in a nested config + vision_config = getattr(config, "vision_config", None) + if vision_config is not None: + image_seq_length = getattr(vision_config, "image_seq_length", None) + + # Use the enhanced truncate_with_protected_tokens which handles image blocks and protected tokens intelligently + prompt_ids, prompt_mask, truncation_metadata = truncate_with_protected_tokens( + prompt_ids, + prompt_mask, + self.max_prompt_length, + protected, + image_token_id=self.image_token_id, + image_seq_length=image_seq_length, ) + # Fix pixel_values alignment if images were removed during truncation + if ( + hasattr(prompt_inputs, "pixel_values") + and prompt_inputs.pixel_values is not None + and truncation_metadata["total_images_removed"] > 0 + ): + original_pixel_values = prompt_inputs.pixel_values + + # Calculate how many images to keep per sequence + images_removed_per_seq = truncation_metadata["images_removed_per_sequence"] + batch_size = len(images_removed_per_seq) + original_images_per_seq = original_pixel_values.shape[0] // batch_size + + # Create new pixel_values with reduced images + new_pixel_values_list = [] + for seq_idx, images_removed in enumerate(images_removed_per_seq): + images_to_keep = original_images_per_seq - images_removed + seq_start = seq_idx * original_images_per_seq + seq_end = seq_start + images_to_keep + if seq_end <= original_pixel_values.shape[0]: + new_pixel_values_list.append(original_pixel_values[seq_start:seq_end]) + + if new_pixel_values_list: + new_pixel_values = torch.cat(new_pixel_values_list, dim=0) + prompt_inputs["pixel_values"] = new_pixel_values + prompts_text = self.processing_class.batch_decode( prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False ) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index b9f97020ed2..b2b904bc285 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1862,10 +1862,15 @@ def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch def truncate_with_protected_tokens( - ids: torch.Tensor, mask: torch.Tensor, target_length: int, protected_tokens: list[int] -) -> tuple[torch.Tensor, torch.Tensor]: + ids: torch.Tensor, + mask: torch.Tensor, + target_length: int, + protected_tokens: list[int], + image_token_id: int = None, + image_seq_length: int = None, +) -> tuple[torch.Tensor, torch.Tensor, dict]: """ - Truncate tensors to target length while preserving protected tokens. + Truncate tensors to target length while preserving protected tokens and handling image token blocks. Args: ids (`torch.Tensor`): @@ -1876,48 +1881,178 @@ def truncate_with_protected_tokens( Desired length of the output sequences. protected_tokens (`list[int]`): List of token IDs that should be preserved in the output. + image_token_id (`int`, *optional*): + Token ID for image tokens. If provided, will preserve image token blocks. + image_seq_length (`int`, *optional*): + Number of image tokens per image block. If provided, will use this for block detection and reduction. + + Returns: + tuple[torch.Tensor, torch.Tensor, dict]: + - Truncated token IDs + - Truncated attention mask + - Metadata dict with image reduction info (images_removed_per_sequence, block_size) """ protected_set = set(protected_tokens) # Create protected_tokens tensor once to avoid recreating it on every call protected_tokens_tensor = torch.tensor(list(protected_set), device=ids.device) + def detect_image_block_size(ids, image_token_id, image_seq_length): + """ + Detect image token blocks using the provided image_seq_length. + + Args: + ids: Token IDs tensor + image_token_id: ID of image tokens + image_seq_length: Expected number of image tokens per image block + + Returns: + image_seq_length if valid blocks are found, None otherwise + """ + if image_token_id is None or image_seq_length is None: + return None + + # Find all positions of image tokens + image_positions = torch.where(ids == image_token_id)[0] + total_image_tokens = len(image_positions) + + if total_image_tokens == 0: + return None + + # Check if we have complete blocks of the expected size + if total_image_tokens >= image_seq_length and total_image_tokens % image_seq_length == 0: + return image_seq_length + + return None + + def reduce_image_blocks_if_needed(ids, mask, image_token_id, target_reduction, image_seq_length): + """ + Reduce image token blocks to fit within target length while keeping complete blocks. Returns (ids, mask, + images_removed) where images_removed is the number of complete image blocks removed. + """ + if image_token_id is None: + return ids, mask, 0 + + block_size = detect_image_block_size(ids, image_token_id, image_seq_length) + if block_size is None: + return ids, mask, 0 + + # Count current image tokens + n_image_tokens = (ids == image_token_id).sum().item() + if n_image_tokens == 0 or n_image_tokens % block_size != 0: + return ids, mask, 0 + + n_blocks = n_image_tokens // block_size + + # Calculate how many blocks we need to remove + if target_reduction > 0 and n_blocks > 1: + # For image blocks, prioritize keeping complete blocks even if it means over-reducing + # If we need to reduce any amount and have multiple blocks, remove excess blocks + # to keep only one image block (common pattern for GRPO with vision models) + blocks_to_keep = 1 # Always keep exactly one image for vision models + tokens_to_keep = blocks_to_keep * block_size + blocks_removed = n_blocks - blocks_to_keep + + if tokens_to_keep < n_image_tokens: + # Find image token positions + image_positions = torch.where(ids == image_token_id)[0] + + # Create mask to keep only the first tokens_to_keep image tokens + keep_mask = torch.ones_like(ids, dtype=torch.bool) + positions_to_remove = image_positions[tokens_to_keep:] + keep_mask[positions_to_remove] = False + + # Apply the mask + ids = ids[keep_mask] + mask = mask[keep_mask] + + return ids, mask, blocks_removed + + return ids, mask, 0 + def process_sequence(ids, mask): - # Create boolean masks - is_protected = torch.isin(ids, protected_tokens_tensor) - is_non_protected = ~is_protected - - # Count tokens - num_protected = is_protected.sum().item() - num_non_protected_needed = target_length - num_protected - - if num_non_protected_needed < 0: - raise ValueError( - f"target_length ({target_length}) is too small for the protected tokens ({num_protected} tokens). " - f"Please increase target length to at least {num_protected} or disable truncation." + # First, try to reduce image blocks if needed + current_length = ids.shape[0] + images_removed = 0 + block_size = None + + if current_length > target_length and image_token_id is not None: + target_reduction = current_length - target_length + ids, mask, images_removed = reduce_image_blocks_if_needed( + ids, mask, image_token_id, target_reduction, image_seq_length ) + if images_removed > 0: + block_size = detect_image_block_size(ids, image_token_id, image_seq_length) + + # If still too long, fall back to the original protected token logic + if ids.shape[0] > target_length: + # Create boolean masks + is_protected = torch.isin(ids, protected_tokens_tensor) + is_non_protected = ~is_protected + + # Count tokens + num_protected = is_protected.sum().item() + num_non_protected_needed = target_length - num_protected - # Select which non-protected tokens to keep (rightmost ones) - non_protected_indices = torch.where(is_non_protected)[0] - keep_non_protected = torch.zeros_like(is_non_protected) - if num_non_protected_needed > 0: - keep_indices = non_protected_indices[-num_non_protected_needed:] - keep_non_protected[keep_indices] = True + if num_non_protected_needed < 0: + raise ValueError( + f"target_length ({target_length}) is too small for the protected tokens ({num_protected} tokens). " + f"Please increase target length to at least {num_protected} or disable truncation." + ) - # Final mask: protected OR selected non-protected - keep_mask = is_protected | keep_non_protected + # Select which non-protected tokens to keep (rightmost ones) + non_protected_indices = torch.where(is_non_protected)[0] + keep_non_protected = torch.zeros_like(is_non_protected) + if num_non_protected_needed > 0: + keep_indices = non_protected_indices[-num_non_protected_needed:] + keep_non_protected[keep_indices] = True - return ids[keep_mask], mask[keep_mask] + # Final mask: protected OR selected non-protected + keep_mask = is_protected | keep_non_protected + ids = ids[keep_mask] + mask = mask[keep_mask] + + return ids, mask, images_removed, block_size # Process each sequence in the batch truncated_seq = [] truncated_mask = [] + images_removed_per_sequence = [] + detected_block_size = None for i in range(ids.shape[0]): - new_ids, new_mask = process_sequence(ids[i], mask[i]) + new_ids, new_mask, images_removed, block_size = process_sequence(ids[i], mask[i]) truncated_seq.append(new_ids) truncated_mask.append(new_mask) - - return torch.stack(truncated_seq), torch.stack(truncated_mask) + images_removed_per_sequence.append(images_removed) + if block_size is not None: + detected_block_size = block_size + + # Pad sequences to same length for stacking + if truncated_seq: + max_length = max(seq.shape[0] for seq in truncated_seq) + padded_seq = [] + padded_mask = [] + + for seq, mask_seq in zip(truncated_seq, truncated_mask): + if seq.shape[0] < max_length: + pad_length = max_length - seq.shape[0] + # Use pad token 0 for padding sequence + seq = torch.cat([seq, torch.zeros(pad_length, dtype=seq.dtype, device=seq.device)]) + mask_seq = torch.cat([mask_seq, torch.zeros(pad_length, dtype=mask_seq.dtype, device=mask_seq.device)]) + padded_seq.append(seq) + padded_mask.append(mask_seq) + + truncated_seq = padded_seq + truncated_mask = padded_mask + + # Create metadata dict + metadata = { + "images_removed_per_sequence": images_removed_per_sequence, + "block_size": detected_block_size, + "total_images_removed": sum(images_removed_per_sequence), + } + + return torch.stack(truncated_seq), torch.stack(truncated_mask), metadata def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel: