3535 calculate_latency ,
3636 write_io_files ,
3737)
38+ from QEfficient .utils import LRUCache
3839from QEfficient .utils .logging_utils import logger
3940
4041
@@ -140,6 +141,8 @@ def __init__(
140141 self ._vision_qpc_path = vision_qpc_path
141142 self .device_id = device_id # Store device_id for vision components
142143 self .enable_debug_logs = enable_debug_logs # Store for vision components
144+ self ._vision_outputs_cache = LRUCache (max_size = 100 ) # LRU cache for vision outputs
145+ self ._vision_cache = {} # Cache for vision outputs across batches
143146 self ._init_vision_components ()
144147
145148 # Now that vision components are initialized, activate the text session
@@ -201,14 +204,20 @@ def _get_vision_config(self) -> Dict[str, Any]:
201204
202205 def _setup_vision_buffer_skipping (self ):
203206 """Skip KV cache and retained state buffers for vision session"""
204- skip_patterns = [lambda x : x .startswith ("past_" ), lambda x : x .endswith ("_RetainedState" )]
205-
206- buffers_to_skip = [
207+ # Pre-compute skip buffers
208+ self ._vision_skip_buffers = [
207209 x
208210 for x in self ._vision_session .input_names + self ._vision_session .output_names
209- if any (pattern (x ) for pattern in skip_patterns )
211+ if x .startswith ("past_" ) or x .endswith ("_RetainedState" )
212+ ]
213+ self ._vision_session .skip_buffers (self ._vision_skip_buffers )
214+
215+ # Pre-compute language skip buffers
216+ self ._lang_skip_buffers = [
217+ x
218+ for x in self ._session .input_names + self ._session .output_names
219+ if x .startswith ("past_" ) or x .endswith ("_RetainedState" )
210220 ]
211- self ._vision_session .skip_buffers (buffers_to_skip )
212221
213222 def run_prefill_for_all_inputs (self , prompt_queue , generation_len ):
214223 """
@@ -255,6 +264,70 @@ def update_decode_inputs_qwen2_5_vl(self, outputs, position_ids, generation_len,
255264 self .generation_len [decode_batch_id or slice (None )] = generation_len
256265 return next_token_id
257266
267+ def _execute_chunked_prefill (
268+ self ,
269+ lang_inputs : Dict [str , np .ndarray ],
270+ num_chunks : int ,
271+ decode_batch_id : Optional [np .ndarray ] = None ,
272+ prefill_logit_bs : int = 1 ,
273+ ) -> Dict [str , np .ndarray ]:
274+ """
275+ Execute chunked prefill with language inputs (Optimization 3: extracted common logic).
276+
277+ Args:
278+ lang_inputs: Pre-processed language inputs with input_ids, position_ids, etc.
279+ num_chunks: Number of chunks to process
280+ decode_batch_id: Batch ID for continuous batching (optional)
281+ prefill_logit_bs: Batch size for prefill logits
282+
283+ Returns:
284+ Final prefill outputs
285+ """
286+ # Set output buffers
287+ self ._set_output_buffers (batch_size = prefill_logit_bs , sequence_length = 1 )
288+
289+ # Skip buffers for dual-QPC coordination (Optimization 2: use cached list)
290+ self ._session .skip_buffers (self ._lang_skip_buffers )
291+
292+ # Run chunked prefill
293+ outputs = None
294+ chunk_image_idx = None
295+
296+ for i in range (num_chunks ):
297+ input_ids_slice = lang_inputs ["input_ids" ][:, i * self ._prefill_seq_len : (i + 1 ) * self ._prefill_seq_len ]
298+ position_ids_slice = lang_inputs ["position_ids" ][
299+ ..., i * self ._prefill_seq_len : (i + 1 ) * self ._prefill_seq_len
300+ ]
301+
302+ chunk_inputs = {
303+ "input_ids" : input_ids_slice ,
304+ "position_ids" : position_ids_slice ,
305+ "image_idx" : chunk_image_idx if chunk_image_idx is not None else np .array ([[0 ]], dtype = np .int64 ),
306+ }
307+
308+ if decode_batch_id is not None :
309+ chunk_inputs ["batch_index" ] = decode_batch_id
310+
311+ if "cross_attention_mask" in lang_inputs :
312+ chunk_inputs ["cross_attention_mask" ] = lang_inputs ["cross_attention_mask" ]
313+
314+ outputs = self ._session .run (chunk_inputs )
315+
316+ if "image_idx_output" in outputs :
317+ chunk_image_idx = outputs ["image_idx_output" ]
318+
319+ if self ._write_io_dir is not None :
320+ write_io_files (lang_inputs , outputs , self ._write_io_dir , "prefill" , "aic_batch_io" , True , False )
321+
322+ # Prepare decode-time cross_attention_mask
323+ if "cross_attention_mask" in lang_inputs :
324+ bs , _ , num_images , img_tiles = lang_inputs ["cross_attention_mask" ].shape
325+ self ._decode_cross_attention_mask = np .ones ((bs , 1 , num_images , img_tiles ), dtype = np .int64 )
326+ else :
327+ self ._decode_cross_attention_mask = None
328+
329+ return outputs
330+
258331 def run_prefill (self , prompt , generation_len , prefill_logit_bs = 1 , decode_batch_id = None ):
259332 """
260333 Override base class prefill to handle vision processing
@@ -281,10 +354,21 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
281354 if isinstance (prompt , tuple ) and len (prompt ) == 2 :
282355 image_path , text_prompt = prompt
283356
284- # Build language inputs with processor-aware vision/text integration
285- lang_inputs , vision_outputs , num_chunks = self ._vision_handler .get_language_inputs_from_vision_processing (
286- image_url = image_path , query = text_prompt , prefill_seq_len = self ._prefill_seq_len
287- )
357+ # Check cache for vision outputs
358+ cache_key = image_path if isinstance (image_path , str ) else str (image_path )
359+ if cache_key in self ._vision_cache :
360+ lang_inputs , vision_outputs , num_chunks = self ._vision_cache [cache_key ]
361+ logger .debug (f"Using cached vision outputs for { cache_key } " )
362+ else :
363+ # Build language inputs with processor-aware vision/text integration
364+ lang_inputs , vision_outputs , num_chunks = (
365+ self ._vision_handler .get_language_inputs_from_vision_processing (
366+ image_url = image_path , query = text_prompt , prefill_seq_len = self ._prefill_seq_len
367+ )
368+ )
369+ # Cache for future use
370+ self ._vision_cache [cache_key ] = (lang_inputs , vision_outputs , num_chunks )
371+ logger .debug (f"Cached vision outputs for { cache_key } " )
288372
289373 # Set vision buffers in language session
290374 self ._session .set_buffers (vision_outputs )
@@ -296,58 +380,12 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
296380 max_gen_len = self ._ctx_len - np .where (lang_inputs ["position_ids" ] != - 1 , 1 , 0 ).sum (1 , keepdims = True ).max ()
297381 generation_len = self ._fetch_generation_len (generation_len , max_gen_len )
298382
299- # Set the prefill output buffers
300- self ._set_output_buffers (batch_size = prefill_logit_bs , sequence_length = 1 )
301-
302- # Run prefill across chunks, updating image_idx as needed
303- outputs = None
304- chunk_image_idx = None # track image_idx across chunks
305- for i in range (num_chunks ):
306- input_ids_slice = lang_inputs ["input_ids" ][
307- :, i * self ._prefill_seq_len : (i + 1 ) * self ._prefill_seq_len
308- ]
309- position_ids_slice = lang_inputs ["position_ids" ][
310- ..., i * self ._prefill_seq_len : (i + 1 ) * self ._prefill_seq_len
311- ]
312- # Build minimal input set to avoid unintended buffers (e.g., batch_index) during prefill
313- chunk_inputs = {
314- "input_ids" : input_ids_slice ,
315- "position_ids" : position_ids_slice ,
316- "image_idx" : chunk_image_idx if chunk_image_idx is not None else np .array ([[0 ]], dtype = np .int64 ),
317- }
318- if decode_batch_id is not None :
319- chunk_inputs ["batch_index" ] = decode_batch_id
320- if "cross_attention_mask" in lang_inputs :
321- chunk_inputs ["cross_attention_mask" ] = lang_inputs ["cross_attention_mask" ]
322-
323- outputs = self ._session .run (chunk_inputs )
324-
325- # Update image_idx for next chunk if provided by model
326- if "image_idx_output" in outputs :
327- chunk_image_idx = outputs ["image_idx_output" ]
328-
329- if self ._write_io_dir is not None :
330- write_io_files (lang_inputs , outputs , self ._write_io_dir , "prefill" , "aic_batch_io" , True , False )
383+ # Execute chunked prefill (Optimization 3: use extracted method)
384+ outputs = self ._execute_chunked_prefill (lang_inputs , num_chunks , decode_batch_id , prefill_logit_bs )
331385
332386 # Prepare position_ids for decode phase (next position after prefill)
333387 position_ids_decode = np .max (lang_inputs ["position_ids" ], axis = - 1 , keepdims = True ) + 1
334388
335- # Prepare decode-time cross_attention_mask (ones over image tiles) if available
336- if "cross_attention_mask" in lang_inputs :
337- bs , _ , num_images , img_tiles = lang_inputs ["cross_attention_mask" ].shape
338- self ._decode_cross_attention_mask = np .ones ((bs , 1 , num_images , img_tiles ), dtype = np .int64 )
339- else :
340- self ._decode_cross_attention_mask = None
341-
342- # Skip retained_state and past_ buffers before decode for dual-QPC coordination
343- self ._session .skip_buffers (
344- [
345- x
346- for x in self ._session .input_names + self ._session .output_names
347- if x .startswith ("past_" ) or x .endswith ("_RetainedState" )
348- ]
349- )
350-
351389 return outputs , position_ids_decode , generation_len
352390 else :
353391 # Fall back to base class for text-only
@@ -404,6 +442,9 @@ def generate(
404442 if len (images ) != len (prompts ):
405443 raise ValueError (f"Number of images ({ len (images )} ) must match number of prompts ({ len (prompts )} )" )
406444
445+ # Clear vision cache for fresh generation
446+ self ._vision_cache .clear ()
447+
407448 logger .info (f"Generating for { len (images )} image-prompt pairs" )
408449
409450 # Convert to base class format: list of (image, prompt) tuples
@@ -487,8 +528,8 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream,
487528 # Reset vision processing state for new generation
488529 self ._vision_processed = False
489530 self ._vision_outputs = None
531+ self ._vision_outputs_cache = {}
490532
491- # Use the base class continuous batching logic directly
492533 # Initialize decode inputs
493534 num_prompts = len (vision_prompts )
494535 execution_batch_size = self .full_batch_size
@@ -503,10 +544,39 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream,
503544
504545 start = perf_counter ()
505546
547+ # Pre-process ALL vision inputs and cache them
548+ logger .info ("Pre-processing all vision inputs..." )
549+ for batch_id in range (min (self .full_batch_size , len (vision_prompts ))):
550+ img , prompt = vision_prompts [batch_id ]
551+
552+ # Process vision for this slot
553+ lang_inputs , vision_outputs , num_chunks = self ._vision_handler .get_language_inputs_from_vision_processing (
554+ image_url = img , query = prompt , prefill_seq_len = self ._prefill_seq_len
555+ )
556+
557+ # Cache vision outputs for this batch slot (Optimization 4: use LRU cache)
558+ self ._vision_outputs_cache [batch_id ] = {
559+ "vision_outputs" : vision_outputs ,
560+ "lang_inputs" : lang_inputs ,
561+ "num_chunks" : num_chunks ,
562+ }
563+
564+ logger .debug (f"Cached vision outputs for batch_id { batch_id } " )
565+
566+ # Reset prompt queue for prefill
567+ prompt_queue = deque (vision_prompts )
568+
506569 self .batch_index = None
507570
508- # Run prefill for all inputs first (batch_index should NOT be set during prefill)
509- self .run_prefill_for_all_inputs (prompt_queue , generation_len )
571+ # Run prefill for all inputs using cached vision
572+ self .run_prefill_for_all_inputs_with_cached_vision (prompt_queue , generation_len )
573+
574+ # Set vision buffers for decode (use first slot's vision for now)
575+ # For identical images, any slot's vision works (Optimization 4: use LRU cache)
576+ cached_slot_0 = self ._vision_outputs_cache .get (0 )
577+ if cached_slot_0 :
578+ self ._session .set_buffers (cached_slot_0 ["vision_outputs" ])
579+ logger .debug ("Set vision buffers from slot 0 for decode phase" )
510580
511581 # Now set batch_index for decode phase
512582 self .batch_index = np .arange (self .full_batch_size ).reshape (- 1 , 1 )
@@ -531,6 +601,57 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream,
531601 batch_size = 1 , generated_texts = generated_texts , generated_ids = self .generated_ids , perf_metrics = perf_metrics
532602 )
533603
604+ def run_prefill_for_all_inputs_with_cached_vision (self , prompt_queue , generation_len ):
605+ """
606+ Runs prefill for all inputs using pre-cached vision outputs.
607+
608+ This avoids the vision buffer overwriting issue by using cached vision
609+ outputs instead of processing vision during each prefill iteration.
610+
611+ Args:
612+ prompt_queue (deque): The queue of prompts.
613+ generation_len (int): The generation length.
614+ """
615+ for decode_batch_id in range (self .full_batch_size ):
616+ # Get cached vision outputs for this batch slot (Optimization 4: use LRU cache)
617+ cached = self ._vision_outputs_cache .get (decode_batch_id )
618+ if cached :
619+ vision_outputs = cached ["vision_outputs" ]
620+ lang_inputs = cached ["lang_inputs" ]
621+ num_chunks = cached ["num_chunks" ]
622+
623+ # Set vision buffers for THIS prefill
624+ self ._session .set_buffers (vision_outputs )
625+ logger .debug (f"Set vision buffers for batch_id { decode_batch_id } prefill" )
626+
627+ # Run prefill with cached inputs (Optimization 3: use extracted method)
628+ outputs = self ._execute_chunked_prefill (
629+ lang_inputs ,
630+ num_chunks ,
631+ decode_batch_id = np .array (decode_batch_id , dtype = np .int64 ).reshape (1 , 1 ),
632+ prefill_logit_bs = 1 ,
633+ )
634+
635+ # Calculate position_ids for decode
636+ position_ids_decode = np .max (lang_inputs ["position_ids" ], axis = - 1 , keepdims = True ) + 1
637+
638+ # Calculate generation_len
639+ max_gen_len = (
640+ self ._ctx_len - np .where (lang_inputs ["position_ids" ] != - 1 , 1 , 0 ).sum (1 , keepdims = True ).max ()
641+ )
642+ generation_len_final = self ._fetch_generation_len (generation_len , max_gen_len )
643+
644+ # Update decode inputs
645+ if self .is_qwen2_5_vl :
646+ self .update_decode_inputs_qwen2_5_vl (
647+ outputs , position_ids_decode , generation_len_final , decode_batch_id
648+ )
649+ else :
650+ self .update_decode_input (outputs , position_ids_decode , generation_len_final , decode_batch_id )
651+ else :
652+ logger .error (f"No cached vision outputs for batch_id { decode_batch_id } " )
653+ raise RuntimeError (f"Vision outputs not cached for batch_id { decode_batch_id } " )
654+
534655 def prepare_decode_inputs (self ):
535656 """
536657 Override base class to handle vision-specific decode inputs
0 commit comments