-
Notifications
You must be signed in to change notification settings - Fork 261
fix: skip multimodal samples that exceed seq_len instead of truncating #2064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0db6199
43ae501
f38ce72
a403a7f
5051bc2
ccb7d5e
448d636
929c818
df63cbf
3045142
a3e2799
3a3597b
bd93215
0e77ad0
ec320ee
f72b471
cefe93b
e7f3da3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -143,11 +143,23 @@ async def orchestrate(config: OrchestratorConfig): | |
| tokenizer = AutoTokenizer.from_pretrained(config.model.name, trust_remote_code=config.model.trust_remote_code) | ||
|
|
||
| processor = None | ||
| image_token_id = None | ||
| if is_vlm: | ||
| logger.info(f"Loading VLM processor for {config.model.name}") | ||
| processor = AutoProcessor.from_pretrained( | ||
| config.model.name, trust_remote_code=config.model.trust_remote_code, use_fast=True | ||
| ) | ||
| if config.model.vlm.image_token_id is not None: | ||
| image_token_id = config.model.vlm.image_token_id | ||
| else: | ||
| from transformers import AutoConfig | ||
|
|
||
| from prime_rl.utils.vlm import get_image_token_id | ||
|
|
||
| model_config = AutoConfig.from_pretrained( | ||
| config.model.name, trust_remote_code=config.model.trust_remote_code | ||
| ) | ||
| image_token_id = get_image_token_id(model_config) | ||
|
|
||
| # Setup monitor | ||
| logger.info(f"Initializing monitor (wandb={config.wandb}, prime_monitor={config.prime_monitor})") | ||
|
|
@@ -578,12 +590,29 @@ def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[Trainin | |
| rollout_samples_per_rollout: list[int] = [] | ||
| num_prefill_tokens = 0 | ||
| num_decode_tokens = 0 | ||
| num_total_samples = 0 | ||
| num_skipped_vlm_truncation = 0 | ||
| for rollout, advantage, samples in zip(train_rollouts, advantages, results): | ||
| rollout_prefill_tokens = 0 | ||
| rollout_decode_tokens = 0 | ||
| if samples is not None: | ||
| num_total_samples += len(samples) | ||
| rollout_samples_per_rollout.append(len(samples)) | ||
| for sample in samples: | ||
| # Multimodal samples where image tokens were truncated cannot be trained: | ||
| # the pixel_values expect a specific number of image_pad tokens and any | ||
| # mismatch crashes the vision encoder. This happens when vLLM truncates | ||
| # the prompt to fit max_model_len. Skip these before sending to the trainer. | ||
| if sample.pixel_values is not None and sample.image_grid_thw and image_token_id is not None: | ||
| expected_image_tokens = sum( | ||
| t * ((h + 1) // 2) * ((w + 1) // 2) for t, h, w in sample.image_grid_thw | ||
| ) | ||
| all_ids = sample.prompt_ids + sample.completion_ids | ||
| actual_image_tokens = sum(1 for tid in all_ids if tid == image_token_id) | ||
| if actual_image_tokens != expected_image_tokens: | ||
| num_skipped_vlm_truncation += 1 | ||
| continue | ||
|
cursor[bot] marked this conversation as resolved.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Multimodal overflow not skipped before trainerHigh Severity The new filter only skips multimodal samples when image-token counts mismatch, but it does not skip samples where Additional Locations (1) |
||
|
|
||
| sample.advantage = advantage | ||
| sample.reward = rollout["reward"] | ||
| sample_decode_tokens = sum(sample.completion_mask) | ||
|
|
@@ -598,6 +627,20 @@ def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[Trainin | |
| num_prefill_tokens += rollout_prefill_tokens | ||
| num_decode_tokens += rollout_decode_tokens | ||
|
|
||
| num_skipped = num_total_samples - len(train_examples) | ||
| reasons = [] | ||
| if num_skipped_vlm_truncation > 0: | ||
| reasons.append(f"{num_skipped_vlm_truncation} multimodal with truncated image tokens") | ||
| if num_skipped > 0: | ||
| logger.warning(f"Skipped {num_skipped}/{num_total_samples} samples ({', '.join(reasons)})") | ||
|
|
||
| if not train_examples: | ||
| detail = f" ({', '.join(reasons)})" if reasons else "" | ||
| raise ValueError( | ||
| f"All {num_total_samples} training samples were skipped{detail}. " | ||
| f"Increase seq_len (currently {config.seq_len}) to fit multimodal samples." | ||
| ) | ||
|
cursor[bot] marked this conversation as resolved.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validation task left running on new early errorLow Severity The new early Additional Locations (1) |
||
|
|
||
| parallel_preprocess_time = time.perf_counter() - parallel_preprocess_start | ||
| logger.debug( | ||
| f"Converted {len(train_rollouts)} rollouts ({num_unique_examples} unique examples) " | ||
|
|
||


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing changelog for config schema change
Low Severity
VLMConfigadds a new config field,image_token_id, insrc/prime_rl/configs/shared.py, but this PR does not include aCHANGELOG.mdupdate documenting the config structure change.Triggered by project rule: BugBot Instructions