Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 208 additions & 9 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,7 @@ def test_basic_example(self):
protected_tokens = [2, 3, 6]
target_length = 3

new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
new_ids, new_mask, _ = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)

expected_ids = torch.tensor([[2, 3, 5], [6, 9, 10]])
expected_mask = torch.ones_like(expected_ids)
Expand All @@ -957,7 +957,7 @@ def test_no_truncation_needed(self):
protected_tokens = [2]
target_length = 3

new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
new_ids, new_mask, _ = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)

self.assertTrue(torch.equal(new_ids, prompt_ids))
self.assertTrue(torch.equal(new_mask, prompt_mask))
Expand All @@ -969,7 +969,7 @@ def test_no_protected_tokens(self):
protected_tokens = []
target_length = 3

new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
new_ids, _, _ = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)

expected_ids = torch.tensor([[3, 4, 5]]) # Last 3 tokens
self.assertTrue(torch.equal(new_ids, expected_ids))
Expand All @@ -981,7 +981,7 @@ def test_all_tokens_protected(self):
protected_tokens = [3, 4, 5]
target_length = 3

new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
new_ids, _, _ = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)

expected_ids = torch.tensor([[3, 4, 5]])
self.assertTrue(torch.equal(new_ids, expected_ids))
Expand All @@ -1003,7 +1003,7 @@ def test_single_batch_single_token(self):
protected_tokens = [5]
target_length = 1

new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
new_ids, _, _ = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)

self.assertTrue(torch.equal(new_ids, prompt_ids))

Expand All @@ -1014,7 +1014,7 @@ def test_mask_preservation(self):
protected_tokens = [2, 4]
target_length = 3

new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
new_ids, new_mask, _ = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)

expected_ids = torch.tensor([[2, 4, 5]])
expected_mask = torch.tensor([[0, 0, 1]]) # Corresponding mask values
Expand All @@ -1029,7 +1029,7 @@ def test_multiple_batches_different_protected(self):
protected_tokens = [2]
target_length = 3

new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
new_ids, _, _ = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)

expected_ids = torch.tensor(
[
Expand All @@ -1048,7 +1048,7 @@ def test_order_preservation(self):
protected_tokens = [2, 3]
target_length = 4

new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
new_ids, _, _ = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)

# Should keep protected tokens 2,3 and last 2 non-protected tokens 30,40
# Order should be: 2, 3, 30, 40 (maintaining original relative positions)
Expand All @@ -1063,10 +1063,209 @@ def test_empty_protected_tokens_list(self):
protected_tokens = []
target_length = 2

new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)
new_ids, _, metadata = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)

expected_ids = torch.tensor([[4, 5]]) # Last 2 tokens
self.assertTrue(torch.equal(new_ids, expected_ids))
self.assertEqual(metadata["total_images_removed"], 0)

def test_image_token_block_detection_internvl(self):
"""Test InternVL image token blocks (256 tokens per image)."""
# Simulate InternVL with 768 image tokens (3 x 256)
image_token_id = 151667
prompt_ids = torch.tensor([[1, 2] + [image_token_id] * 768 + [3, 4, 5]])
prompt_mask = torch.ones_like(prompt_ids)
protected_tokens = []
target_length = 400 # Force reduction

new_ids, _, metadata = truncate_with_protected_tokens(
prompt_ids,
prompt_mask,
target_length,
protected_tokens,
image_token_id=image_token_id,
image_seq_length=256,
)

# Should reduce from 3 image blocks to 1 block (768 -> 256 tokens)
image_tokens_after = (new_ids == image_token_id).sum().item()
self.assertEqual(image_tokens_after, 256)
self.assertEqual(metadata["block_size"], 256)
self.assertEqual(metadata["images_removed_per_sequence"], [2]) # Removed 2 blocks, kept 1
self.assertEqual(metadata["total_images_removed"], 2)

def test_image_token_block_multiple_sequences_simple(self):
"""Test image block reduction with multiple sequences (simplified)."""
image_token_id = 151667
# Simple case: both sequences have exactly the same pattern
seq1 = [1, 2] + [image_token_id] * 768 + [3, 4] # 3 blocks
seq2 = [1, 2] + [image_token_id] * 768 + [3, 4] # 3 blocks (same)

prompt_ids = torch.tensor([seq1, seq2])
prompt_mask = torch.ones_like(prompt_ids)
protected_tokens = []
target_length = 400 # Force reduction

_, _, metadata = truncate_with_protected_tokens(
prompt_ids,
prompt_mask,
target_length,
protected_tokens,
image_token_id=image_token_id,
image_seq_length=256,
)

# Both sequences should be reduced from 3 blocks to 1 block
self.assertEqual(metadata["images_removed_per_sequence"], [2, 2]) # Both reduced 3->1
self.assertEqual(metadata["total_images_removed"], 4)
self.assertEqual(metadata["block_size"], 256)

def test_image_token_block_no_reduction_needed(self):
"""Test when image tokens don't need reduction."""
image_token_id = 151667
prompt_ids = torch.tensor([[1, 2] + [image_token_id] * 256 + [3, 4]]) # Only 1 block
prompt_mask = torch.ones_like(prompt_ids)
protected_tokens = []
target_length = 300 # No reduction needed

new_ids, _, metadata = truncate_with_protected_tokens(
prompt_ids, prompt_mask, target_length, protected_tokens, image_token_id=image_token_id
)

# Should not reduce anything
image_tokens_after = (new_ids == image_token_id).sum().item()
self.assertEqual(image_tokens_after, 256)
self.assertEqual(metadata["images_removed_per_sequence"], [0])
self.assertEqual(metadata["total_images_removed"], 0)

def test_image_token_block_with_protected_tokens(self):
"""Test image block reduction combined with protected token preservation."""
image_token_id = 151667
vision_start_token = 151666
prompt_ids = torch.tensor([[vision_start_token, 2] + [image_token_id] * 768 + [3, 4, 5]])
prompt_mask = torch.ones_like(prompt_ids)
protected_tokens = [vision_start_token, image_token_id]
target_length = 400

new_ids, _, metadata = truncate_with_protected_tokens(
prompt_ids,
prompt_mask,
target_length,
protected_tokens,
image_token_id=image_token_id,
image_seq_length=256,
)

# Should reduce image blocks first, then apply protected token logic
image_tokens_after = (new_ids == image_token_id).sum().item()
self.assertEqual(image_tokens_after, 256) # Reduced to 1 block
self.assertTrue((new_ids == vision_start_token).any()) # Protected token preserved
self.assertEqual(metadata["images_removed_per_sequence"], [2])

def test_image_token_block_edge_case_single_block(self):
"""Test that single image blocks are preserved when there's only one block."""
image_token_id = 151667
prompt_ids = torch.tensor([[1, 2] + [image_token_id] * 256 + [3, 4, 5]])
prompt_mask = torch.ones_like(prompt_ids)
protected_tokens = []
target_length = 260 # Larger than image block but smaller than total

new_ids, _, metadata = truncate_with_protected_tokens(
prompt_ids,
prompt_mask,
target_length,
protected_tokens,
image_token_id=image_token_id,
image_seq_length=256,
)

# Should not remove the single image block since there's only one
image_tokens_after = (new_ids == image_token_id).sum().item()
self.assertEqual(image_tokens_after, 256)
self.assertEqual(metadata["images_removed_per_sequence"], [0])

def test_image_token_block_non_consecutive_tokens(self):
"""Test that non-consecutive image tokens don't trigger block detection."""
image_token_id = 151667
# Scatter 256 image tokens throughout the sequence (not consecutive)
prompt_ids = torch.tensor([[1, image_token_id, 2, image_token_id, 3, image_token_id] * 100])
prompt_mask = torch.ones_like(prompt_ids)
protected_tokens = []
target_length = 200

_, _, metadata = truncate_with_protected_tokens(
prompt_ids,
prompt_mask,
target_length,
protected_tokens,
image_token_id=image_token_id,
image_seq_length=256,
)

# Should not detect blocks since tokens are not consecutive
self.assertIsNone(metadata["block_size"])
self.assertEqual(metadata["images_removed_per_sequence"], [0])

def test_image_token_block_no_detection_non_standard(self):
"""Test that non-standard block sizes fall back to regular truncation."""
image_token_id = 12345
# Test with non-standard block size that won't be detected
prompt_ids = torch.tensor([[1, 2] + [image_token_id] * 300 + [3, 4]]) # Non-standard size
prompt_mask = torch.ones_like(prompt_ids)
protected_tokens = []
target_length = 200 # Force reduction

new_ids, _, metadata = truncate_with_protected_tokens(
prompt_ids,
prompt_mask,
target_length,
protected_tokens,
image_token_id=image_token_id,
image_seq_length=256,
)

# Should not detect blocks and fall back to regular truncation
self.assertIsNone(metadata["block_size"])
self.assertEqual(metadata["images_removed_per_sequence"], [0])
self.assertEqual(new_ids.shape[1], 200) # Regular truncation to target length

def test_backward_compatibility_without_image_tokens(self):
"""Test that the enhanced function maintains backward compatibility."""
prompt_ids = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
prompt_mask = torch.ones_like(prompt_ids)
protected_tokens = [2, 3, 6]
target_length = 3

# Call without image_token_id (backward compatibility)
new_ids, new_mask, metadata = truncate_with_protected_tokens(
prompt_ids, prompt_mask, target_length, protected_tokens
)

expected_ids = torch.tensor([[2, 3, 5], [6, 9, 10]])
expected_mask = torch.ones_like(expected_ids)

self.assertTrue(torch.equal(new_ids, expected_ids))
self.assertTrue(torch.equal(new_mask, expected_mask))
self.assertEqual(metadata["total_images_removed"], 0)
self.assertIsNone(metadata["block_size"])

def test_metadata_structure(self):
"""Test that metadata contains all expected fields."""
prompt_ids = torch.tensor([[1, 2, 3, 4, 5]])
prompt_mask = torch.ones_like(prompt_ids)
protected_tokens = []
target_length = 3

_, _, metadata = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens)

# Check metadata structure
self.assertIn("images_removed_per_sequence", metadata)
self.assertIn("block_size", metadata)
self.assertIn("total_images_removed", metadata)
self.assertIsInstance(metadata["images_removed_per_sequence"], list)
self.assertEqual(len(metadata["images_removed_per_sequence"]), 1) # One sequence
self.assertEqual(metadata["total_images_removed"], 0)
self.assertIsNone(metadata["block_size"])


class UnsplitPixelValuesByGridTester(TrlTestCase):
Expand Down
50 changes: 48 additions & 2 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,10 +1106,56 @@ def _generate_and_score_completions(
# because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).
protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
protected = [token for token in protected if token is not None]
prompt_ids, prompt_mask = truncate_with_protected_tokens(
prompt_ids, prompt_mask, self.max_prompt_length, protected

# Extract image_seq_length from model config for vision models
image_seq_length = None
if self.image_token_id is not None:
# Try to get image_seq_length from model config
config = getattr(self.model, "config", None)
if config is not None:
image_seq_length = getattr(config, "image_seq_length", None)
if image_seq_length is None:
# For some models, it might be in a nested config
vision_config = getattr(config, "vision_config", None)
if vision_config is not None:
image_seq_length = getattr(vision_config, "image_seq_length", None)

# Use the enhanced truncate_with_protected_tokens which handles image blocks and protected tokens intelligently
prompt_ids, prompt_mask, truncation_metadata = truncate_with_protected_tokens(
prompt_ids,
prompt_mask,
self.max_prompt_length,
protected,
image_token_id=self.image_token_id,
image_seq_length=image_seq_length,
)

# Fix pixel_values alignment if images were removed during truncation
if (
hasattr(prompt_inputs, "pixel_values")
and prompt_inputs.pixel_values is not None
and truncation_metadata["total_images_removed"] > 0
):
original_pixel_values = prompt_inputs.pixel_values

# Calculate how many images to keep per sequence
images_removed_per_seq = truncation_metadata["images_removed_per_sequence"]
batch_size = len(images_removed_per_seq)
original_images_per_seq = original_pixel_values.shape[0] // batch_size

# Create new pixel_values with reduced images
new_pixel_values_list = []
for seq_idx, images_removed in enumerate(images_removed_per_seq):
images_to_keep = original_images_per_seq - images_removed
seq_start = seq_idx * original_images_per_seq
seq_end = seq_start + images_to_keep
if seq_end <= original_pixel_values.shape[0]:
new_pixel_values_list.append(original_pixel_values[seq_start:seq_end])

if new_pixel_values_list:
new_pixel_values = torch.cat(new_pixel_values_list, dim=0)
prompt_inputs["pixel_values"] = new_pixel_values

prompts_text = self.processing_class.batch_decode(
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
Expand Down
Loading
Loading