Skip to content

fix: reconcile VLM image tokens with pixel_values after prompt truncation#2013

Open
cdreetz wants to merge 3 commits intomainfrom
fix/vlm-image-token-mismatch
Open

fix: reconcile VLM image tokens with pixel_values after prompt truncation#2013
cdreetz wants to merge 3 commits intomainfrom
fix/vlm-image-token-mismatch

Conversation

@cdreetz
Copy link
Contributor

@cdreetz cdreetz commented Mar 10, 2026

Summary

In multi-turn VLM conversations that exceed seq_len, vLLM left-truncates the prompt to fit in max_model_len, removing image_pad tokens from early turns. But pixel_values/image_grid_thw are computed independently by the orchestrator's image processor from all conversation images. This causes a mismatch — the trainer crashes with ValueError: Image features and image tokens do not match.

The fix runs after seq_len truncation in prepare_sample: compare image_pad tokens in input_ids against image_grid_thw, drop images whose placeholder tokens were truncated, and remove any orphaned partial image tokens so both sides agree on the same set of complete images. No-op for text-only samples and VLM samples where tokens already match.

Test plan

  • Unit tests for _trim_multimodal_to_match: noop cases (text-only, already matching), partial image drop, all images dropped
  • Integration tests for prepare_sample: seq_len truncation path, vLLM pre-truncation path
  • End-to-end: 5-step VLM tic-tac-toe training run with Qwen3-VL-4B-Instruct completes without errors

Note

Medium Risk
Changes core sample-prep logic for multimodal training and can alter which images/features are fed to the model when truncation occurs, though the behavior is narrowly scoped and covered by new tests.

Overview
Fixes Qwen3-VL batch preparation when prompts are truncated by ensuring image placeholder tokens in input_ids stay consistent with pixel_values/image_grid_thw.

prepare_sample now calls a new _trim_multimodal_to_match helper that drops orphaned/partial image tokens and trims (or removes) trailing image features/grids accordingly, while leaving text-only or already-matching samples unchanged. Adds unit tests covering noop, partial/all-image drops, and both seq_len truncation and vLLM left-truncation scenarios.

Written by Cursor Bugbot for commit f4db3bb. This will update automatically on new commits. Configure here.

root and others added 3 commits March 10, 2026 19:20
…tion

In multi-turn VLM conversations that exceed seq_len, vLLM left-truncates
the prompt to fit in max_model_len, removing image_pad tokens from early
turns. But pixel_values/image_grid_thw are computed independently by the
orchestrator's image processor from all conversation images.

This causes a mismatch: input_ids has fewer image tokens than
image_grid_thw expects, crashing the trainer with "Image features and
image tokens do not match".

Fix: after seq_len truncation in prepare_sample, compare the image_pad
tokens actually present in input_ids against image_grid_thw. Drop images
whose placeholder tokens were truncated and remove any orphaned partial
image tokens, so both sides agree on the same set of complete images.

No-op for text-only samples and VLM samples where tokens already match.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix prepared a fix for the issue found in the latest run.

  • ✅ Fixed: Left-truncation drops images from wrong end of grid
    • Fixed _trim_multimodal_to_match to detect left-truncation and correctly pair surviving image tokens with their corresponding pixel values.

Create PR

Or push these changes by commenting:

@cursor push 7372554878
Preview (7372554878)
diff --git a/src/prime_rl/trainer/batch.py b/src/prime_rl/trainer/batch.py
--- a/src/prime_rl/trainer/batch.py
+++ b/src/prime_rl/trainer/batch.py
@@ -31,16 +31,54 @@
     if num_image_tokens == expected_tokens:
         return input_ids, None, pixel_values, pixel_values_shape, image_grid_thw
 
-    # Keep only complete images that fit within the available image tokens
-    kept_grids = []
-    valid_tokens = 0
-    for grid in image_grid_thw:
-        t = _grid_tokens(grid)
-        if valid_tokens + t <= num_image_tokens:
-            kept_grids.append(grid)
-            valid_tokens += t
-        else:
+    # Determine which images to keep based on the truncation pattern
+    # Count tokens per image
+    tokens_per_image = [_grid_tokens(g) for g in image_grid_thw]
+    
+    # Detect if we have left-truncation by checking if the first image is incomplete
+    # Count image tokens before the first complete image boundary
+    first_image_tokens = 0
+    for t in input_ids:
+        if t == _IMAGE_PAD_TOKEN_ID:
+            first_image_tokens += 1
+        elif first_image_tokens > 0:
+            # We've seen some image tokens and now hit a non-image token
             break
+    
+    # Check if first image is partial (indicates left-truncation)
+    skip_images = 0
+    if first_image_tokens > 0 and first_image_tokens < tokens_per_image[0]:
+        # Left-truncation case: skip images from the beginning
+        remaining_tokens = num_image_tokens - first_image_tokens
+        skip_images = 1  # Skip the partial first image
+        
+        # Skip additional complete images that don't have tokens
+        for i in range(1, len(tokens_per_image)):
+            if remaining_tokens >= tokens_per_image[i]:
+                break
+            skip_images += 1
+        
+        # Keep grids starting after the skipped images
+        kept_grids = []
+        valid_tokens = 0
+        for i in range(skip_images, len(image_grid_thw)):
+            t = tokens_per_image[i]
+            if valid_tokens + t <= remaining_tokens:
+                kept_grids.append(image_grid_thw[i])
+                valid_tokens += t
+            else:
+                break
+    else:
+        # Right-truncation case: keep images from the beginning
+        kept_grids = []
+        valid_tokens = 0
+        for grid in image_grid_thw:
+            t = _grid_tokens(grid)
+            if valid_tokens + t <= num_image_tokens:
+                kept_grids.append(grid)
+                valid_tokens += t
+            else:
+                break
 
     # Build keep mask: True for non-image tokens and complete-image tokens, False for orphans
     keep_mask = []
@@ -58,9 +96,18 @@
         return input_ids, keep_mask, None, None, None
 
     patch_dim = pixel_values_shape[1] if pixel_values_shape else 0
+    
+    # Calculate pixel_values offset for left-truncation
+    if skip_images > 0:
+        # Calculate bytes to skip for the dropped images
+        skipped_patches = sum(g[0] * g[1] * g[2] for g in image_grid_thw[:skip_images])
+        skip_bytes = skipped_patches * 4 * patch_dim
+    else:
+        skip_bytes = 0
+    
     kept_patches = sum(g[0] * g[1] * g[2] for g in kept_grids)
     kept_bytes = kept_patches * 4 * patch_dim
-    return input_ids, keep_mask, pixel_values[:kept_bytes], [kept_patches, patch_dim], kept_grids
+    return input_ids, keep_mask, pixel_values[skip_bytes:skip_bytes + kept_bytes], [kept_patches, patch_dim], kept_grids
 
 
 def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch:

This Bugbot Autofix run was free. To enable autofix for future PRs, go to the Cursor dashboard.

patch_dim = pixel_values_shape[1] if pixel_values_shape else 0
kept_patches = sum(g[0] * g[1] * g[2] for g in kept_grids)
kept_bytes = kept_patches * 4 * patch_dim
return input_ids, keep_mask, pixel_values[:kept_bytes], [kept_patches, patch_dim], kept_grids
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left-truncation drops images from wrong end of grid

High Severity

_trim_multimodal_to_match always iterates image_grid_thw front-to-back and keeps the earliest images, which is correct for seq_len right-truncation but wrong for vLLM left-truncation. When vLLM removes tokens from the beginning, the surviving complete image tokens correspond to later images, but the function keeps pixel values and grids for the first images. This causes the model to pair image features with the wrong placeholder tokens — e.g., image 1's pixel features get mapped onto tokens that actually belong to images 2+. The test only asserts count equality, so it passes despite the semantic mismatch. For images with different grid sizes this can also produce count mismatches and crashes.

Additional Locations (1)
Fix in Cursor Fix in Web

@hallerite hallerite self-requested a review March 11, 2026 00:42
from prime_rl.transport.types import MicroBatch, TrainingSample

# Qwen3-VL image placeholder token ID
_IMAGE_PAD_TOKEN_ID = 151655
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to add this to the model registry, should def not be a magic number

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants