Skip to content

Commit b0ee5a4

Browse files
committed
Added caching for vision outputs
Signed-off-by: Rishin Raj <[email protected]>
1 parent ca0cc03 commit b0ee5a4

File tree

3 files changed

+212
-60
lines changed

3 files changed

+212
-60
lines changed

QEfficient/generation/vlm_generation.py

Lines changed: 181 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
calculate_latency,
3636
write_io_files,
3737
)
38+
from QEfficient.utils import LRUCache
3839
from QEfficient.utils.logging_utils import logger
3940

4041

@@ -140,6 +141,8 @@ def __init__(
140141
self._vision_qpc_path = vision_qpc_path
141142
self.device_id = device_id # Store device_id for vision components
142143
self.enable_debug_logs = enable_debug_logs # Store for vision components
144+
self._vision_outputs_cache = LRUCache(max_size=100) # LRU cache for vision outputs
145+
self._vision_cache = {} # Cache for vision outputs across batches
143146
self._init_vision_components()
144147

145148
# Now that vision components are initialized, activate the text session
@@ -201,14 +204,20 @@ def _get_vision_config(self) -> Dict[str, Any]:
201204

202205
def _setup_vision_buffer_skipping(self):
203206
"""Skip KV cache and retained state buffers for vision session"""
204-
skip_patterns = [lambda x: x.startswith("past_"), lambda x: x.endswith("_RetainedState")]
205-
206-
buffers_to_skip = [
207+
# Pre-compute skip buffers
208+
self._vision_skip_buffers = [
207209
x
208210
for x in self._vision_session.input_names + self._vision_session.output_names
209-
if any(pattern(x) for pattern in skip_patterns)
211+
if x.startswith("past_") or x.endswith("_RetainedState")
212+
]
213+
self._vision_session.skip_buffers(self._vision_skip_buffers)
214+
215+
# Pre-compute language skip buffers
216+
self._lang_skip_buffers = [
217+
x
218+
for x in self._session.input_names + self._session.output_names
219+
if x.startswith("past_") or x.endswith("_RetainedState")
210220
]
211-
self._vision_session.skip_buffers(buffers_to_skip)
212221

213222
def run_prefill_for_all_inputs(self, prompt_queue, generation_len):
214223
"""
@@ -255,6 +264,70 @@ def update_decode_inputs_qwen2_5_vl(self, outputs, position_ids, generation_len,
255264
self.generation_len[decode_batch_id or slice(None)] = generation_len
256265
return next_token_id
257266

267+
def _execute_chunked_prefill(
268+
self,
269+
lang_inputs: Dict[str, np.ndarray],
270+
num_chunks: int,
271+
decode_batch_id: Optional[np.ndarray] = None,
272+
prefill_logit_bs: int = 1,
273+
) -> Dict[str, np.ndarray]:
274+
"""
275+
Execute chunked prefill with language inputs (Optimization 3: extracted common logic).
276+
277+
Args:
278+
lang_inputs: Pre-processed language inputs with input_ids, position_ids, etc.
279+
num_chunks: Number of chunks to process
280+
decode_batch_id: Batch ID for continuous batching (optional)
281+
prefill_logit_bs: Batch size for prefill logits
282+
283+
Returns:
284+
Final prefill outputs
285+
"""
286+
# Set output buffers
287+
self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1)
288+
289+
# Skip buffers for dual-QPC coordination (Optimization 2: use cached list)
290+
self._session.skip_buffers(self._lang_skip_buffers)
291+
292+
# Run chunked prefill
293+
outputs = None
294+
chunk_image_idx = None
295+
296+
for i in range(num_chunks):
297+
input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len]
298+
position_ids_slice = lang_inputs["position_ids"][
299+
..., i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len
300+
]
301+
302+
chunk_inputs = {
303+
"input_ids": input_ids_slice,
304+
"position_ids": position_ids_slice,
305+
"image_idx": chunk_image_idx if chunk_image_idx is not None else np.array([[0]], dtype=np.int64),
306+
}
307+
308+
if decode_batch_id is not None:
309+
chunk_inputs["batch_index"] = decode_batch_id
310+
311+
if "cross_attention_mask" in lang_inputs:
312+
chunk_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"]
313+
314+
outputs = self._session.run(chunk_inputs)
315+
316+
if "image_idx_output" in outputs:
317+
chunk_image_idx = outputs["image_idx_output"]
318+
319+
if self._write_io_dir is not None:
320+
write_io_files(lang_inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False)
321+
322+
# Prepare decode-time cross_attention_mask
323+
if "cross_attention_mask" in lang_inputs:
324+
bs, _, num_images, img_tiles = lang_inputs["cross_attention_mask"].shape
325+
self._decode_cross_attention_mask = np.ones((bs, 1, num_images, img_tiles), dtype=np.int64)
326+
else:
327+
self._decode_cross_attention_mask = None
328+
329+
return outputs
330+
258331
def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None):
259332
"""
260333
Override base class prefill to handle vision processing
@@ -281,10 +354,21 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
281354
if isinstance(prompt, tuple) and len(prompt) == 2:
282355
image_path, text_prompt = prompt
283356

284-
# Build language inputs with processor-aware vision/text integration
285-
lang_inputs, vision_outputs, num_chunks = self._vision_handler.get_language_inputs_from_vision_processing(
286-
image_url=image_path, query=text_prompt, prefill_seq_len=self._prefill_seq_len
287-
)
357+
# Check cache for vision outputs
358+
cache_key = image_path if isinstance(image_path, str) else str(image_path)
359+
if cache_key in self._vision_cache:
360+
lang_inputs, vision_outputs, num_chunks = self._vision_cache[cache_key]
361+
logger.debug(f"Using cached vision outputs for {cache_key}")
362+
else:
363+
# Build language inputs with processor-aware vision/text integration
364+
lang_inputs, vision_outputs, num_chunks = (
365+
self._vision_handler.get_language_inputs_from_vision_processing(
366+
image_url=image_path, query=text_prompt, prefill_seq_len=self._prefill_seq_len
367+
)
368+
)
369+
# Cache for future use
370+
self._vision_cache[cache_key] = (lang_inputs, vision_outputs, num_chunks)
371+
logger.debug(f"Cached vision outputs for {cache_key}")
288372

289373
# Set vision buffers in language session
290374
self._session.set_buffers(vision_outputs)
@@ -296,58 +380,12 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
296380
max_gen_len = self._ctx_len - np.where(lang_inputs["position_ids"] != -1, 1, 0).sum(1, keepdims=True).max()
297381
generation_len = self._fetch_generation_len(generation_len, max_gen_len)
298382

299-
# Set the prefill output buffers
300-
self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1)
301-
302-
# Run prefill across chunks, updating image_idx as needed
303-
outputs = None
304-
chunk_image_idx = None # track image_idx across chunks
305-
for i in range(num_chunks):
306-
input_ids_slice = lang_inputs["input_ids"][
307-
:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len
308-
]
309-
position_ids_slice = lang_inputs["position_ids"][
310-
..., i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len
311-
]
312-
# Build minimal input set to avoid unintended buffers (e.g., batch_index) during prefill
313-
chunk_inputs = {
314-
"input_ids": input_ids_slice,
315-
"position_ids": position_ids_slice,
316-
"image_idx": chunk_image_idx if chunk_image_idx is not None else np.array([[0]], dtype=np.int64),
317-
}
318-
if decode_batch_id is not None:
319-
chunk_inputs["batch_index"] = decode_batch_id
320-
if "cross_attention_mask" in lang_inputs:
321-
chunk_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"]
322-
323-
outputs = self._session.run(chunk_inputs)
324-
325-
# Update image_idx for next chunk if provided by model
326-
if "image_idx_output" in outputs:
327-
chunk_image_idx = outputs["image_idx_output"]
328-
329-
if self._write_io_dir is not None:
330-
write_io_files(lang_inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False)
383+
# Execute chunked prefill (Optimization 3: use extracted method)
384+
outputs = self._execute_chunked_prefill(lang_inputs, num_chunks, decode_batch_id, prefill_logit_bs)
331385

332386
# Prepare position_ids for decode phase (next position after prefill)
333387
position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1
334388

335-
# Prepare decode-time cross_attention_mask (ones over image tiles) if available
336-
if "cross_attention_mask" in lang_inputs:
337-
bs, _, num_images, img_tiles = lang_inputs["cross_attention_mask"].shape
338-
self._decode_cross_attention_mask = np.ones((bs, 1, num_images, img_tiles), dtype=np.int64)
339-
else:
340-
self._decode_cross_attention_mask = None
341-
342-
# Skip retained_state and past_ buffers before decode for dual-QPC coordination
343-
self._session.skip_buffers(
344-
[
345-
x
346-
for x in self._session.input_names + self._session.output_names
347-
if x.startswith("past_") or x.endswith("_RetainedState")
348-
]
349-
)
350-
351389
return outputs, position_ids_decode, generation_len
352390
else:
353391
# Fall back to base class for text-only
@@ -404,6 +442,9 @@ def generate(
404442
if len(images) != len(prompts):
405443
raise ValueError(f"Number of images ({len(images)}) must match number of prompts ({len(prompts)})")
406444

445+
# Clear vision cache for fresh generation
446+
self._vision_cache.clear()
447+
407448
logger.info(f"Generating for {len(images)} image-prompt pairs")
408449

409450
# Convert to base class format: list of (image, prompt) tuples
@@ -487,8 +528,8 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream,
487528
# Reset vision processing state for new generation
488529
self._vision_processed = False
489530
self._vision_outputs = None
531+
self._vision_outputs_cache = {}
490532

491-
# Use the base class continuous batching logic directly
492533
# Initialize decode inputs
493534
num_prompts = len(vision_prompts)
494535
execution_batch_size = self.full_batch_size
@@ -503,10 +544,39 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream,
503544

504545
start = perf_counter()
505546

547+
# Pre-process ALL vision inputs and cache them
548+
logger.info("Pre-processing all vision inputs...")
549+
for batch_id in range(min(self.full_batch_size, len(vision_prompts))):
550+
img, prompt = vision_prompts[batch_id]
551+
552+
# Process vision for this slot
553+
lang_inputs, vision_outputs, num_chunks = self._vision_handler.get_language_inputs_from_vision_processing(
554+
image_url=img, query=prompt, prefill_seq_len=self._prefill_seq_len
555+
)
556+
557+
# Cache vision outputs for this batch slot (Optimization 4: use LRU cache)
558+
self._vision_outputs_cache[batch_id] = {
559+
"vision_outputs": vision_outputs,
560+
"lang_inputs": lang_inputs,
561+
"num_chunks": num_chunks,
562+
}
563+
564+
logger.debug(f"Cached vision outputs for batch_id {batch_id}")
565+
566+
# Reset prompt queue for prefill
567+
prompt_queue = deque(vision_prompts)
568+
506569
self.batch_index = None
507570

508-
# Run prefill for all inputs first (batch_index should NOT be set during prefill)
509-
self.run_prefill_for_all_inputs(prompt_queue, generation_len)
571+
# Run prefill for all inputs using cached vision
572+
self.run_prefill_for_all_inputs_with_cached_vision(prompt_queue, generation_len)
573+
574+
# Set vision buffers for decode (use first slot's vision for now)
575+
# For identical images, any slot's vision works (Optimization 4: use LRU cache)
576+
cached_slot_0 = self._vision_outputs_cache.get(0)
577+
if cached_slot_0:
578+
self._session.set_buffers(cached_slot_0["vision_outputs"])
579+
logger.debug("Set vision buffers from slot 0 for decode phase")
510580

511581
# Now set batch_index for decode phase
512582
self.batch_index = np.arange(self.full_batch_size).reshape(-1, 1)
@@ -531,6 +601,57 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream,
531601
batch_size=1, generated_texts=generated_texts, generated_ids=self.generated_ids, perf_metrics=perf_metrics
532602
)
533603

604+
def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation_len):
605+
"""
606+
Runs prefill for all inputs using pre-cached vision outputs.
607+
608+
This avoids the vision buffer overwriting issue by using cached vision
609+
outputs instead of processing vision during each prefill iteration.
610+
611+
Args:
612+
prompt_queue (deque): The queue of prompts.
613+
generation_len (int): The generation length.
614+
"""
615+
for decode_batch_id in range(self.full_batch_size):
616+
# Get cached vision outputs for this batch slot (Optimization 4: use LRU cache)
617+
cached = self._vision_outputs_cache.get(decode_batch_id)
618+
if cached:
619+
vision_outputs = cached["vision_outputs"]
620+
lang_inputs = cached["lang_inputs"]
621+
num_chunks = cached["num_chunks"]
622+
623+
# Set vision buffers for THIS prefill
624+
self._session.set_buffers(vision_outputs)
625+
logger.debug(f"Set vision buffers for batch_id {decode_batch_id} prefill")
626+
627+
# Run prefill with cached inputs (Optimization 3: use extracted method)
628+
outputs = self._execute_chunked_prefill(
629+
lang_inputs,
630+
num_chunks,
631+
decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1),
632+
prefill_logit_bs=1,
633+
)
634+
635+
# Calculate position_ids for decode
636+
position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1
637+
638+
# Calculate generation_len
639+
max_gen_len = (
640+
self._ctx_len - np.where(lang_inputs["position_ids"] != -1, 1, 0).sum(1, keepdims=True).max()
641+
)
642+
generation_len_final = self._fetch_generation_len(generation_len, max_gen_len)
643+
644+
# Update decode inputs
645+
if self.is_qwen2_5_vl:
646+
self.update_decode_inputs_qwen2_5_vl(
647+
outputs, position_ids_decode, generation_len_final, decode_batch_id
648+
)
649+
else:
650+
self.update_decode_input(outputs, position_ids_decode, generation_len_final, decode_batch_id)
651+
else:
652+
logger.error(f"No cached vision outputs for batch_id {decode_batch_id}")
653+
raise RuntimeError(f"Vision outputs not cached for batch_id {decode_batch_id}")
654+
534655
def prepare_decode_inputs(self):
535656
"""
536657
Override base class to handle vision-specific decode inputs

QEfficient/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
undo_transformers_quantizers,
1111
)
1212
from QEfficient.utils._utils import ( # noqa: F401
13+
LRUCache,
1314
check_and_assign_cache_dir,
1415
create_json,
1516
create_model_params,

QEfficient/utils/_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,36 @@
3333
from QEfficient.utils.logging_utils import logger
3434

3535

36+
class LRUCache:
37+
"""Simple LRU cache with size limit for vision outputs"""
38+
39+
def __init__(self, max_size=100):
40+
self._cache = {}
41+
self._access_order = []
42+
self._max_size = max_size
43+
44+
def get(self, key):
45+
if key in self._cache:
46+
self._access_order.remove(key)
47+
self._access_order.append(key)
48+
return self._cache[key]
49+
return None
50+
51+
def put(self, key, value):
52+
if key in self._cache:
53+
self._access_order.remove(key)
54+
elif len(self._cache) >= self._max_size:
55+
oldest = self._access_order.pop(0)
56+
del self._cache[oldest]
57+
58+
self._cache[key] = value
59+
self._access_order.append(key)
60+
61+
def clear(self):
62+
self._cache.clear()
63+
self._access_order.clear()
64+
65+
3666
class DownloadRetryLimitExceeded(Exception):
3767
"""
3868
Used for raising error when hf_download fails to download the model after given max_retries.

0 commit comments

Comments
 (0)