fix: reconcile VLM image tokens with pixel_values after prompt truncation#2013
fix: reconcile VLM image tokens with pixel_values after prompt truncation#2013
Conversation
…tion In multi-turn VLM conversations that exceed seq_len, vLLM left-truncates the prompt to fit in max_model_len, removing image_pad tokens from early turns. But pixel_values/image_grid_thw are computed independently by the orchestrator's image processor from all conversation images. This causes a mismatch: input_ids has fewer image tokens than image_grid_thw expects, crashing the trainer with "Image features and image tokens do not match". Fix: after seq_len truncation in prepare_sample, compare the image_pad tokens actually present in input_ids against image_grid_thw. Drop images whose placeholder tokens were truncated and remove any orphaned partial image tokens, so both sides agree on the same set of complete images. No-op for text-only samples and VLM samples where tokens already match. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix prepared a fix for the issue found in the latest run.
- ✅ Fixed: Left-truncation drops images from wrong end of grid
- Fixed _trim_multimodal_to_match to detect left-truncation and correctly pair surviving image tokens with their corresponding pixel values.
Or push these changes by commenting:
@cursor push 7372554878
Preview (7372554878)
diff --git a/src/prime_rl/trainer/batch.py b/src/prime_rl/trainer/batch.py
--- a/src/prime_rl/trainer/batch.py
+++ b/src/prime_rl/trainer/batch.py
@@ -31,16 +31,54 @@
if num_image_tokens == expected_tokens:
return input_ids, None, pixel_values, pixel_values_shape, image_grid_thw
- # Keep only complete images that fit within the available image tokens
- kept_grids = []
- valid_tokens = 0
- for grid in image_grid_thw:
- t = _grid_tokens(grid)
- if valid_tokens + t <= num_image_tokens:
- kept_grids.append(grid)
- valid_tokens += t
- else:
+ # Determine which images to keep based on the truncation pattern
+ # Count tokens per image
+ tokens_per_image = [_grid_tokens(g) for g in image_grid_thw]
+
+ # Detect if we have left-truncation by checking if the first image is incomplete
+ # Count image tokens before the first complete image boundary
+ first_image_tokens = 0
+ for t in input_ids:
+ if t == _IMAGE_PAD_TOKEN_ID:
+ first_image_tokens += 1
+ elif first_image_tokens > 0:
+ # We've seen some image tokens and now hit a non-image token
break
+
+ # Check if first image is partial (indicates left-truncation)
+ skip_images = 0
+ if first_image_tokens > 0 and first_image_tokens < tokens_per_image[0]:
+ # Left-truncation case: skip images from the beginning
+ remaining_tokens = num_image_tokens - first_image_tokens
+ skip_images = 1 # Skip the partial first image
+
+ # Skip additional complete images that don't have tokens
+ for i in range(1, len(tokens_per_image)):
+ if remaining_tokens >= tokens_per_image[i]:
+ break
+ skip_images += 1
+
+ # Keep grids starting after the skipped images
+ kept_grids = []
+ valid_tokens = 0
+ for i in range(skip_images, len(image_grid_thw)):
+ t = tokens_per_image[i]
+ if valid_tokens + t <= remaining_tokens:
+ kept_grids.append(image_grid_thw[i])
+ valid_tokens += t
+ else:
+ break
+ else:
+ # Right-truncation case: keep images from the beginning
+ kept_grids = []
+ valid_tokens = 0
+ for grid in image_grid_thw:
+ t = _grid_tokens(grid)
+ if valid_tokens + t <= num_image_tokens:
+ kept_grids.append(grid)
+ valid_tokens += t
+ else:
+ break
# Build keep mask: True for non-image tokens and complete-image tokens, False for orphans
keep_mask = []
@@ -58,9 +96,18 @@
return input_ids, keep_mask, None, None, None
patch_dim = pixel_values_shape[1] if pixel_values_shape else 0
+
+ # Calculate pixel_values offset for left-truncation
+ if skip_images > 0:
+ # Calculate bytes to skip for the dropped images
+ skipped_patches = sum(g[0] * g[1] * g[2] for g in image_grid_thw[:skip_images])
+ skip_bytes = skipped_patches * 4 * patch_dim
+ else:
+ skip_bytes = 0
+
kept_patches = sum(g[0] * g[1] * g[2] for g in kept_grids)
kept_bytes = kept_patches * 4 * patch_dim
- return input_ids, keep_mask, pixel_values[:kept_bytes], [kept_patches, patch_dim], kept_grids
+ return input_ids, keep_mask, pixel_values[skip_bytes:skip_bytes + kept_bytes], [kept_patches, patch_dim], kept_grids
def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch:This Bugbot Autofix run was free. To enable autofix for future PRs, go to the Cursor dashboard.
| patch_dim = pixel_values_shape[1] if pixel_values_shape else 0 | ||
| kept_patches = sum(g[0] * g[1] * g[2] for g in kept_grids) | ||
| kept_bytes = kept_patches * 4 * patch_dim | ||
| return input_ids, keep_mask, pixel_values[:kept_bytes], [kept_patches, patch_dim], kept_grids |
There was a problem hiding this comment.
Left-truncation drops images from wrong end of grid
High Severity
_trim_multimodal_to_match always iterates image_grid_thw front-to-back and keeps the earliest images, which is correct for seq_len right-truncation but wrong for vLLM left-truncation. When vLLM removes tokens from the beginning, the surviving complete image tokens correspond to later images, but the function keeps pixel values and grids for the first images. This causes the model to pair image features with the wrong placeholder tokens — e.g., image 1's pixel features get mapped onto tokens that actually belong to images 2+. The test only asserts count equality, so it passes despite the semantic mismatch. For images with different grid sizes this can also produce count mismatches and crashes.
Additional Locations (1)
| from prime_rl.transport.types import MicroBatch, TrainingSample | ||
|
|
||
| # Qwen3-VL image placeholder token ID | ||
| _IMAGE_PAD_TOKEN_ID = 151655 |
There was a problem hiding this comment.
need to add this to the model registry, should def not be a magic number



Summary
In multi-turn VLM conversations that exceed
seq_len, vLLM left-truncates the prompt to fit inmax_model_len, removing image_pad tokens from early turns. Butpixel_values/image_grid_thware computed independently by the orchestrator's image processor from all conversation images. This causes a mismatch — the trainer crashes withValueError: Image features and image tokens do not match.The fix runs after seq_len truncation in
prepare_sample: compare image_pad tokens ininput_idsagainstimage_grid_thw, drop images whose placeholder tokens were truncated, and remove any orphaned partial image tokens so both sides agree on the same set of complete images. No-op for text-only samples and VLM samples where tokens already match.Test plan
_trim_multimodal_to_match: noop cases (text-only, already matching), partial image drop, all images droppedprepare_sample: seq_len truncation path, vLLM pre-truncation pathNote
Medium Risk
Changes core sample-prep logic for multimodal training and can alter which images/features are fed to the model when truncation occurs, though the behavior is narrowly scoped and covered by new tests.
Overview
Fixes Qwen3-VL batch preparation when prompts are truncated by ensuring image placeholder tokens in
input_idsstay consistent withpixel_values/image_grid_thw.prepare_samplenow calls a new_trim_multimodal_to_matchhelper that drops orphaned/partial image tokens and trims (or removes) trailing image features/grids accordingly, while leaving text-only or already-matching samples unchanged. Adds unit tests covering noop, partial/all-image drops, and bothseq_lentruncation and vLLM left-truncation scenarios.Written by Cursor Bugbot for commit f4db3bb. This will update automatically on new commits. Configure here.