55
66import numpy as np
77import torch
8+ # yapf: disable
89from torch import nn
910from transformers import AutoModel , BatchFeature
1011from transformers .models .gemma3n import (Gemma3nAudioConfig ,
3031 MultiModalKwargsItems )
3132from vllm .multimodal .parse import (ImageProcessorItems , MultiModalDataItems ,
3233 MultiModalDataParser )
33- # yapf: disable
3434from vllm .multimodal .processing import (BaseMultiModalProcessor ,
3535 BaseProcessingInfo ,
3636 MultiModalPromptUpdates ,
@@ -62,7 +62,8 @@ class Gemma3nImagePixelInputs(TypedDict):
6262
6363
6464class Gemma3nAudioInputs (TypedDict ):
65- input_features : torch .Tensor
65+ input_features : Union [torch .Tensor , list [torch .Tensor ]]
66+ input_features_padded : torch .Tensor
6667 """Shape: `(batch_size * num_audio, seq_length, num_features)`"""
6768 input_features_mask : torch .Tensor
6869 """Shape: `(batch_size * num_audio, seq_length)`"""
@@ -188,8 +189,13 @@ def _call_hf_processor(
188189 mm_kwargs ,
189190 tok_kwargs ,
190191 )
192+
191193 if 'input_features' in processed_outputs :
192- # Avoid padding since we need the output of each item to be
194+ # Padding enables audio_tower to run in batched mode
195+ processed_outputs ["input_features_padded" ] = \
196+ processed_outputs ["input_features" ]
197+
198+ # Unpad features here since we need the output of each item to be
193199 # independent of other items for the cache to work correctly
194200 unpadded_features = [
195201 f [mask ] for f , mask in zip (
@@ -206,9 +212,11 @@ def _get_mm_fields_config(
206212 hf_processor_mm_kwargs : Mapping [str , object ],
207213 ) -> Mapping [str , MultiModalFieldConfig ]:
208214
209- return dict (pixel_values = MultiModalFieldConfig .batched ("image" ),
210- input_features = MultiModalFieldConfig .batched ("audio" ),
211- input_features_mask = MultiModalFieldConfig .batched ("audio" ))
215+ return dict (
216+ pixel_values = MultiModalFieldConfig .batched ("image" ),
217+ input_features = MultiModalFieldConfig .batched ("audio" ),
218+ input_features_padded = MultiModalFieldConfig .batched ("audio" ),
219+ input_features_mask = MultiModalFieldConfig .batched ("audio" ))
212220
213221 def _get_prompt_updates (
214222 self ,
@@ -516,9 +524,14 @@ def _parse_and_validate_audio_input(
516524 if input_features_mask is None :
517525 return None
518526
527+ input_features_padded = kwargs .pop ("input_features_padded" , None )
528+ if input_features_padded is None :
529+ return None
530+
519531 return Gemma3nAudioInputs (
520532 input_features = input_features ,
521533 input_features_mask = input_features_mask ,
534+ input_features_padded = input_features_padded ,
522535 )
523536
524537 def _parse_and_validate_multimodal_inputs (self , ** kwargs : object ) -> dict :
@@ -564,7 +577,8 @@ def _process_audio_input(
564577 audio_input : Gemma3nAudioInputs ,
565578 ) -> list [torch .Tensor ]:
566579 assert self .audio_tower is not None
567- input_features = audio_input ["input_features" ].squeeze (1 )
580+ # Run on padded features to enable batching
581+ input_features = audio_input ["input_features_padded" ].squeeze (1 )
568582 input_features_mask = audio_input ["input_features_mask" ].squeeze (1 )
569583 audio_outputs , audio_mask = self .audio_tower (input_features ,
570584 ~ input_features_mask )
0 commit comments