Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
0db6199
fix: skip multimodal samples that exceed seq_len instead of truncating
hallerite Mar 22, 2026
43ae501
fix: handle empty batch when all multimodal samples are skipped
hallerite Mar 24, 2026
f38ce72
fix: validate image token count, catch vLLM left-truncation
hallerite Mar 24, 2026
a403a7f
make warning
hallerite Mar 24, 2026
5051bc2
Merge remote-tracking branch 'origin/main' into fix/vlm-truncation-sa…
hallerite Mar 24, 2026
ccb7d5e
test: fix packer test passing tokenizer=None
hallerite Mar 24, 2026
448d636
fix: use VLM registry for image token ID instead of hardcoded string
hallerite Mar 25, 2026
929c818
fix: skip multimodal samples that exceed seq_len, document
hallerite Mar 25, 2026
df63cbf
refactor: move multimodal seq_len filtering to orchestrator
hallerite Mar 31, 2026
3045142
test: update batch tests for orchestrator-side VLM filtering
hallerite Mar 31, 2026
a3e2799
refactor: unify skip logging with per-reason breakdown
hallerite Mar 31, 2026
3a3597b
fix: validate image token count instead of checking seq_len overflow
hallerite Mar 31, 2026
bd93215
fix: move reasons list outside conditional to avoid NameError
hallerite Mar 31, 2026
0e77ad0
fix: raise on multimodal truncation in prepare_sample as defensive guard
hallerite Mar 31, 2026
ec320ee
refactor: add image_token to VLMConfig instead of hardcoding
hallerite Mar 31, 2026
f72b471
refactor: store image_token_id in VLM registry, not config
hallerite Mar 31, 2026
cefe93b
refactor: allow user override of image_token_id in VLMConfig
hallerite Mar 31, 2026
e7f3da3
refactor: image_token_id from config or registry, no tokenizer lookup
hallerite Mar 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ To add permanent support for a new model family, add an entry to `VLM_REGISTRY`

- **Vision encoder is frozen**: The vision encoder is automatically frozen during training. Only the language model is trained.

- **No multimodal-safe truncation**: Token sequences are truncated to `seq_len`, but `pixel_values` and `image_grid_thw` are passed through unchanged. If a multimodal sample exceeds `seq_len`, image tokens can be dropped while image tensors still describe the full set of images. Ensure `seq_len` covers your longest VLM samples.
- **Multimodal samples that exceed `seq_len` are skipped**: Truncating a multimodal sample would break the alignment between image tokens and `pixel_values`. Instead of producing corrupt training data, such samples are dropped with a warning. Ensure `seq_len` covers your longest VLM samples or reduce rollout length.

- **Keep `max_model_len` large for VLMs**: vLLM's tokenizer left-truncates prompts that exceed `max_model_len - max_tokens`, which can silently chop image placeholder tokens from early images while `pixel_values` remain intact. This causes a fatal mismatch at training time. With the model's default context length (e.g. 32768) this never happens, but if you reduce `max_model_len` for memory reasons, make sure it's large enough to fit your longest expanded VLM prompt.

- **Optimization dtype must be bfloat16**: Set `optimization_dtype = "bfloat16"` and `reduce_dtype = "bfloat16"` in your trainer config.

Expand Down
8 changes: 8 additions & 0 deletions src/prime_rl/configs/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ class VLMConfig(BaseConfig):
Field(description="Dotted attribute path to the language model module (e.g. 'model.language_model')."),
]

image_token_id: Annotated[
int | None,
Field(
description="Token ID of the image placeholder in the vocabulary. "
"If None, resolved automatically from the VLM registry."
),
] = None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing changelog for config schema change

Low Severity

VLMConfig adds a new config field, image_token_id, in src/prime_rl/configs/shared.py, but this PR does not include a CHANGELOG.md update documenting the config structure change.

Fix in Cursor Fix in Web

Triggered by project rule: BugBot Instructions


class BaseModelConfig(BaseConfig):
"""Configures the model."""
Expand Down
43 changes: 43 additions & 0 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,23 @@ async def orchestrate(config: OrchestratorConfig):
tokenizer = AutoTokenizer.from_pretrained(config.model.name, trust_remote_code=config.model.trust_remote_code)

processor = None
image_token_id = None
if is_vlm:
logger.info(f"Loading VLM processor for {config.model.name}")
processor = AutoProcessor.from_pretrained(
config.model.name, trust_remote_code=config.model.trust_remote_code, use_fast=True
)
if config.model.vlm.image_token_id is not None:
image_token_id = config.model.vlm.image_token_id
else:
from transformers import AutoConfig

from prime_rl.utils.vlm import get_image_token_id

model_config = AutoConfig.from_pretrained(
config.model.name, trust_remote_code=config.model.trust_remote_code
)
image_token_id = get_image_token_id(model_config)

# Setup monitor
logger.info(f"Initializing monitor (wandb={config.wandb}, prime_monitor={config.prime_monitor})")
Expand Down Expand Up @@ -578,12 +590,29 @@ def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[Trainin
rollout_samples_per_rollout: list[int] = []
num_prefill_tokens = 0
num_decode_tokens = 0
num_total_samples = 0
num_skipped_vlm_truncation = 0
for rollout, advantage, samples in zip(train_rollouts, advantages, results):
rollout_prefill_tokens = 0
rollout_decode_tokens = 0
if samples is not None:
num_total_samples += len(samples)
rollout_samples_per_rollout.append(len(samples))
for sample in samples:
# Multimodal samples where image tokens were truncated cannot be trained:
# the pixel_values expect a specific number of image_pad tokens and any
# mismatch crashes the vision encoder. This happens when vLLM truncates
# the prompt to fit max_model_len. Skip these before sending to the trainer.
if sample.pixel_values is not None and sample.image_grid_thw and image_token_id is not None:
expected_image_tokens = sum(
t * ((h + 1) // 2) * ((w + 1) // 2) for t, h, w in sample.image_grid_thw
)
all_ids = sample.prompt_ids + sample.completion_ids
actual_image_tokens = sum(1 for tid in all_ids if tid == image_token_id)
if actual_image_tokens != expected_image_tokens:
num_skipped_vlm_truncation += 1
continue
Comment thread
cursor[bot] marked this conversation as resolved.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multimodal overflow not skipped before trainer

High Severity

The new filter only skips multimodal samples when image-token counts mismatch, but it does not skip samples where len(prompt_ids + completion_ids) > seq_len and counts still match. Those samples now reach prepare_sample, which raises ValueError in src/prime_rl/trainer/batch.py, causing training to fail instead of skipping.

Additional Locations (1)
Fix in Cursor Fix in Web


sample.advantage = advantage
sample.reward = rollout["reward"]
sample_decode_tokens = sum(sample.completion_mask)
Expand All @@ -598,6 +627,20 @@ def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[Trainin
num_prefill_tokens += rollout_prefill_tokens
num_decode_tokens += rollout_decode_tokens

num_skipped = num_total_samples - len(train_examples)
reasons = []
if num_skipped_vlm_truncation > 0:
reasons.append(f"{num_skipped_vlm_truncation} multimodal with truncated image tokens")
if num_skipped > 0:
logger.warning(f"Skipped {num_skipped}/{num_total_samples} samples ({', '.join(reasons)})")

if not train_examples:
detail = f" ({', '.join(reasons)})" if reasons else ""
raise ValueError(
f"All {num_total_samples} training samples were skipped{detail}. "
f"Increase seq_len (currently {config.seq_len}) to fit multimodal samples."
)
Comment thread
cursor[bot] marked this conversation as resolved.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty-step batches now fail with wrong diagnosis

Medium Severity

The new if not train_examples path raises a ValueError that always attributes failure to seq_len, even when no samples were produced for unrelated reasons. This turns non-truncation cases into hard failures with misleading guidance.

Fix in Cursor Fix in Web

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validation task left running on new early error

Low Severity

The new early ValueError path can exit before val_task is awaited or cancelled. When validation is enabled, this leaves an in-flight async validation job running during shutdown, causing unnecessary outstanding inference work and unclean task termination.

Additional Locations (1)
Fix in Cursor Fix in Web


parallel_preprocess_time = time.perf_counter() - parallel_preprocess_start
logger.debug(
f"Converted {len(train_rollouts)} rollouts ({num_unique_examples} unique examples) "
Expand Down
6 changes: 6 additions & 0 deletions src/prime_rl/trainer/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch
routed_experts = training_example.routed_experts

if len(input_ids) > seq_len:
if training_example.pixel_values is not None:
raise ValueError(
f"Multimodal sample exceeds seq_len ({len(input_ids)} > {seq_len}) "
f"but was not filtered by the orchestrator. This would corrupt "
f"pixel_values/token alignment. Increase seq_len or check orchestrator filtering."
)
input_ids = input_ids[:seq_len]
loss_mask = loss_mask[:seq_len]
inference_logprobs = inference_logprobs[:seq_len]
Expand Down
19 changes: 16 additions & 3 deletions src/prime_rl/utils/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,20 @@ class VLMModelInfo:

vision_encoder_attr: str
language_model_attr: str
image_token_id: int


# Central registry: model_type -> architecture info.
VLM_REGISTRY: dict[str, VLMModelInfo] = {
"qwen3_vl": VLMModelInfo(vision_encoder_attr="model.visual", language_model_attr="model.language_model"),
"qwen3_5": VLMModelInfo(vision_encoder_attr="model.visual", language_model_attr="model.language_model"),
"qwen3_5_moe": VLMModelInfo(vision_encoder_attr="model.visual", language_model_attr="model.language_model"),
"qwen3_vl": VLMModelInfo(
vision_encoder_attr="model.visual", language_model_attr="model.language_model", image_token_id=151655
),
"qwen3_5": VLMModelInfo(
vision_encoder_attr="model.visual", language_model_attr="model.language_model", image_token_id=151655
),
"qwen3_5_moe": VLMModelInfo(
vision_encoder_attr="model.visual", language_model_attr="model.language_model", image_token_id=151655
),
}

# Text-only default
Expand Down Expand Up @@ -80,6 +87,12 @@ def get_language_model(model: nn.Module, override: str | None = None) -> nn.Modu
return model.model


def get_image_token_id(model_config: PretrainedConfig) -> int | None:
"""Return the image token ID from the VLM registry, or None for text-only models."""
info = _get_model_info_from_config(model_config)
return info.image_token_id if info is not None else None
Comment thread
cursor[bot] marked this conversation as resolved.


def get_layer_prefix(model_config: PretrainedConfig, override: str | None = None) -> str:
"""Return the weight key prefix for language model layers.

Expand Down
55 changes: 55 additions & 0 deletions tests/unit/orchestrator/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,58 @@ def test_prepare_sample_none_routed_experts():

micro_batch = prepare_sample(sample, seq_len=8)
assert micro_batch.routed_experts is None


def test_prepare_sample_raises_on_multimodal_exceeding_seq_len():
"""Multimodal samples that exceed seq_len raise ValueError (should be filtered by orchestrator)."""
sample = TrainingSample(
prompt_ids=[1, 2, 3],
prompt_mask=[False, False, False],
completion_ids=[4, 5],
completion_mask=[True, True],
completion_logprobs=[-0.1, -0.2],
completion_temperatures=[1.0, 1.0],
advantage=1.0,
pixel_values=b"\x00" * 16,
pixel_values_shape=[1, 16],
image_grid_thw=[[1, 1, 1]],
)
with pytest.raises(ValueError, match="Multimodal sample exceeds seq_len"):
prepare_sample(sample, seq_len=3)


def test_prepare_sample_keeps_multimodal_within_seq_len():
"""Multimodal samples within seq_len are kept with pixel_values intact."""
sample = TrainingSample(
prompt_ids=[1, 2],
prompt_mask=[False, False],
completion_ids=[3, 4],
completion_mask=[True, True],
completion_logprobs=[-0.1, -0.2],
completion_temperatures=[1.0, 1.0],
advantage=1.0,
pixel_values=b"\x00" * 16,
pixel_values_shape=[1, 16],
image_grid_thw=[[1, 1, 1]],
)
result = prepare_sample(sample, seq_len=8)
assert result is not None
assert result.pixel_values == b"\x00" * 16
assert len(result.input_ids) == 4


def test_prepare_sample_still_truncates_text_only():
"""Text-only samples are still truncated normally."""
sample = TrainingSample(
prompt_ids=[1, 2, 3],
prompt_mask=[False, False, False],
completion_ids=[4, 5],
completion_mask=[True, True],
completion_logprobs=[-0.1, -0.2],
completion_temperatures=[1.0, 1.0],
advantage=1.0,
)
result = prepare_sample(sample, seq_len=3)
assert result is not None
assert len(result.input_ids) == 3
assert result.pixel_values is None