From 1ee373eb2b244623f40a6af56f22362bb5af21c8 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Sun, 5 Oct 2025 21:22:20 +0000 Subject: [PATCH 01/24] Enable CB for vlms with multiple images and multiple prompts Signed-off-by: Mamta Singh Signed-off-by: Rishin Raj --- .../generation/text_generation_inference.py | 272 +++++++++++++----- .../models/llama4/modeling_llama4.py | 98 +++++-- .../transformers/models/modeling_auto.py | 72 +++-- examples/llama4_CB_example_vision_lang.py | 65 +++++ 4 files changed, 393 insertions(+), 114 deletions(-) create mode 100644 examples/llama4_CB_example_vision_lang.py diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 6d04cf573..b5079eac6 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -13,6 +13,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np +import torch import transformers from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -313,7 +314,10 @@ def calculate_latency(total_decoded_tokens, loop_start, start, end, decode_pause def cloud_ai_100_exec_kv( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - qpc_path: str, + processor, + lang_qpc_path: str, + vision_qpc_path: str, + images, prompt: Optional[str] = None, prompts_txt_file_path: Optional[str] = None, device_id: Optional[List[int]] = None, @@ -372,7 +376,7 @@ def cloud_ai_100_exec_kv( exec_info = QEfficient.cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc_path=qpc_path, prompt="Hi there!!", device_id=[0]) """ - batch_size, ctx_len, full_batch_size = get_compilation_dims(qpc_path) + batch_size, ctx_len, full_batch_size = get_compilation_dims(lang_qpc_path) prompt: List[str] = get_input_prompts(prompt, prompts_txt_file_path) prompt = fix_prompts(prompt, batch_size, full_batch_size) if prompt_to_lora_id_mapping is not None: @@ -381,7 +385,9 @@ def cloud_ai_100_exec_kv( ) generate_text = TextGeneration( tokenizer=tokenizer, - qpc_path=qpc_path, + processor=processor, + lang_qpc_path=lang_qpc_path, + vision_qpc_path=vision_qpc_path, device_id=device_id, ctx_len=ctx_len, enable_debug_logs=enable_debug_logs, @@ -393,18 +399,19 @@ def cloud_ai_100_exec_kv( sampling_params=sampling_params, ) - for _ in range(0, int(iteration)): - if full_batch_size is None: - exec_info = [ - generate_text.generate(prompt[i : i + batch_size], generation_len, stream, prompt_to_lora_id_mapping) - for i in range(0, len(prompt), batch_size) - ] - prefill_time = np.average([info.perf_metrics.prefill_time for info in exec_info]) - decode_perf = np.average([info.perf_metrics.decode_perf for info in exec_info]) - total_perf = np.average([info.perf_metrics.total_perf for info in exec_info]) - total_time = np.average([info.perf_metrics.total_time for info in exec_info]) - generated_texts = [info.generated_texts for info in exec_info] - generated_ids = [info.generated_ids for info in exec_info] + exec_info = CloudAI100ExecInfo( + batch_size=batch_size, + generated_texts=generated_texts, + generated_ids=generated_ids, + perf_metrics=PerfMetrics(prefill_time, decode_perf, total_perf, total_time), + ) + else: + exec_info = generate_text.generate( + images=images, + prompt=prompt, + generation_len=generation_len, + prompt_to_lora_id_mapping=prompt_to_lora_id_mapping, + ) exec_info = CloudAI100ExecInfo( batch_size=batch_size, @@ -426,8 +433,10 @@ def cloud_ai_100_exec_kv( class QEffTextGenerationBase: def __init__( self, + processor, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - qpc_path: str, + lang_qpc_path: str, + vision_qpc_path: Optional[str] = None, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, device_id: Optional[List[int]] = None, @@ -445,11 +454,15 @@ def __init__( self.sampling_params = sampling_params # Load QPC - self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs) + if not lang_qpc_path: + raise TypeError("Please run compile API for language model first!") + self._lang_session = QAICInferenceSession(lang_qpc_path, device_id, activate=False) + if vision_qpc_path: + self._vision_session = QAICInferenceSession(vision_qpc_path, device_id) # Validate sampler inputs for On-Device Sampling self.include_sampler = validate_sampler_inputs( - session_inputs=set(self._session.input_names), include_sampler=include_sampler + session_inputs=set(self._lang_session.input_names), include_sampler=include_sampler ) # Fetch the variables from the QPC @@ -474,10 +487,23 @@ def __init__( self.generation_len = None self.tokenizer = tokenizer + self.processor = processor self._set_tokenizer_params() # set tokenizer params # Skip inputs/outputs - self._session.skip_buffers( - [x for x in self._session.input_names + self._session.output_names if x.startswith("past_")] + if self._vision_session: + self._vision_session.skip_buffers( + [ + x + for x in self._vision_session.input_names + self._vision_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + ) + self._lang_session.skip_buffers( + [ + x + for x in self._lang_session.input_names + self._lang_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] ) def _set_tokenizer_params(self): @@ -502,13 +528,16 @@ def _fetch_full_batch_size( """ full_batch_size = None - if "batch_index" in self._session.binding_index_map: - if self._session.allowed_shapes: + if "batch_index" in self._lang_session.binding_index_map: + if self._lang_session.allowed_shapes: full_batch_size, _ = [ - x[self._session.binding_index_map["batch_index"]][1][0] for x in self._session.allowed_shapes + x[self._lang_session.binding_index_map["batch_index"]][1][0] + for x in self._lang_session.allowed_shapes ] else: - full_batch_size, _ = self._session.bindings[self._session.binding_index_map["batch_index"]].dims + full_batch_size, _ = self._lang_session.bindings[ + self._lang_session.binding_index_map["batch_index"] + ].dims return full_batch_size def _fetch_batch_size_prefill_seq_len( @@ -521,15 +550,17 @@ def _fetch_batch_size_prefill_seq_len( batch_size: The batch size fetched from the session's bindings or allowed shapes. prefill_seq_len: The prefill sequence length fetched from the session's bindings or allowed shapes. """ - if self._session.allowed_shapes: + if self._lang_session.allowed_shapes: batch_size = max( - [x[self._session.binding_index_map["input_ids"]][1][0] for x in self._session.allowed_shapes] + [x[self._lang_session.binding_index_map["input_ids"]][1][0] for x in self._lang_session.allowed_shapes] ) prefill_seq_len = max( - [x[self._session.binding_index_map["input_ids"]][1][1] for x in self._session.allowed_shapes] + [x[self._lang_session.binding_index_map["input_ids"]][1][1] for x in self._lang_session.allowed_shapes] ) else: - batch_size, prefill_seq_len = self._session.bindings[self._session.binding_index_map["input_ids"]].dims + batch_size, prefill_seq_len = self._lang_session.bindings[ + self._lang_session.binding_index_map["input_ids"] + ].dims return batch_size, prefill_seq_len def _fetch_decode_seq_len( @@ -542,9 +573,9 @@ def _fetch_decode_seq_len( decode_seq_len: The decode sequence length fetched from the session's bindings or allowed shapes. """ decode_seq_len = None - if self._session.allowed_shapes: + if self._lang_session.allowed_shapes: decode_seq_len = min( - [x[self._session.binding_index_map["input_ids"]][1][1] for x in self._session.allowed_shapes] + [x[self._lang_session.binding_index_map["input_ids"]][1][1] for x in self._lang_session.allowed_shapes] ) return decode_seq_len @@ -563,10 +594,10 @@ def _fetch_vocab_size( if self.include_sampler else "logits" ) - if self._session.allowed_shapes: - return [x[self._session.binding_index_map[key]] for x in self._session.allowed_shapes][0][1][2] + if self._lang_session.allowed_shapes: + return [x[self._lang_session.binding_index_map[key]] for x in self._lang_session.allowed_shapes][0][1][2] - return self._session.bindings[self._session.binding_index_map[key]].dims[2] + return self._lang_session.bindings[self._lang_session.binding_index_map[key]].dims[2] def _fetch_generation_len(self, generation_len, max_gen_len): """ @@ -655,7 +686,7 @@ def _fetch_next_token_id(self, outputs): logits = np.expand_dims(logits, 1) return logits.argmax(2) - def initialize_decode_inputs(self, num_prompts, execution_batch_size, max_gen_length): + def initialize_decode_inputs(self, num_images, num_prompts, execution_batch_size, max_gen_length): """ Initialize np arrays for storing the prefill output for all the decode batch size. """ @@ -702,7 +733,7 @@ def update_decode_input(self, outputs, position_ids, generation_len, decode_batc self.generation_len[decode_batch_id or slice(None)] = generation_len return next_token_id - def run_prefill_for_all_inputs(self, prompt_queue, generation_len): + def run_prefill_for_all_inputs(self, image_queue, prompt_queue, generation_len): """ Runs prefill for all inputs in the prompt queue and updates the decode input. @@ -715,10 +746,14 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len): """ for decode_batch_id in range(self.full_batch_size): next_prompt = prompt_queue.popleft() + next_image = image_queue.popleft() # run prefill for num_chunks outputs, position_ids, generation_len = self.run_prefill( - next_prompt, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1) + next_image, + next_prompt, + generation_len, + decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), ) _ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) @@ -733,14 +768,39 @@ def _set_output_buffers(self, batch_size: int = 1, sequence_length: int = 1): if self.include_sampler: if self.return_pdfs: probs_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32) - self._session.set_buffers({"probs": probs_out_placeholder}) + self._lang_session.set_buffers({"probs": probs_out_placeholder}) next_tokens_out_placeholder = np.zeros((batch_size, sequence_length, 1), dtype=np.int64) - self._session.set_buffers({"next_tokens": next_tokens_out_placeholder}) + self._lang_session.set_buffers({"next_tokens": next_tokens_out_placeholder}) else: logits_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32) - self._session.set_buffers({"logits": logits_out_placeholder}) + self._lang_session.set_buffers({"logits": logits_out_placeholder}) + + vision_embeds_out_placeholder = np.zeros((2448, 5120), dtype=np.float16) + self._vision_session.set_buffers({"vision_embeds": vision_embeds_out_placeholder}) + + def prepare_vision_language_inputs(self, prompt, image_url): + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": image_url}, + {"type": "text", "text": prompt}, + ], + }, + ] + inputs = self.processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + # padding="max_length", + # max_length=padded_len, + ) + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + return inputs - def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None): + def run_prefill(self, image, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None): """ Runs prefill for a given prompt and generation length. @@ -758,7 +818,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i generation_len (int): The generation length. """ # Run prefill - inputs = self.tokenizer(prompt, return_tensors="np", padding=True) + inputs = self.prepare_vision_language_inputs(prompt, image) + position_ids = inputs["attention_mask"].sum(1, keepdims=True) padded_len = inputs["input_ids"].shape[1] num_chunks = -(padded_len // -self._prefill_seq_len) # ceil divide without float @@ -772,51 +833,110 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i # Set the prefill output buffers self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1) - inputs = self.tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) - inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) - inputs.pop("token_type_ids", None) + pad_token_id = 1 + input_ids_length = inputs["input_ids"].shape[1] + num_chunks = -(input_ids_length // -self._prefill_seq_len) # ceil divide without float + padded_len = num_chunks * self._prefill_seq_len # Convert to a multiple of prompt_len + + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], + (0, padded_len - input_ids_length), + "constant", + pad_token_id, + ) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 + ) + if "cross_attention_mask" in inputs: + inputs["cross_attention_mask"] = torch.nn.functional.pad( + inputs["cross_attention_mask"], (0, 0, 0, 0, 0, padded_len - input_ids_length) + ) + + for k, v in inputs.items(): + inputs[k] = np.array(v) + + vision_inputs = { + k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} + } + if vision_inputs: + vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") + + vision_outputs = {} + if self._vision_session: + self._vision_session.activate() + # Run vision prefill + if vision_inputs: + vision_outputs = self._vision_session.run(vision_inputs) + + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + lang_inputs["position_ids"] = np.where( + lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 + ) # Need to use -1 as position_ids for invalid tokens + + # not_mllama = hasattr(self.model.config, "model_type") and self.model.config.model_type != "mllama" + # if not_mllama: + lang_inputs["image_idx"] = np.array([[0]]) + + if self._vision_session: + self._vision_session.deactivate() + self._lang_session.activate() + self._lang_session.set_buffers(vision_outputs) if decode_batch_id is not None: - inputs["batch_index"] = decode_batch_id + lang_inputs["batch_index"] = decode_batch_id if self.is_tlm: - inputs["num_logits_to_keep"] = np.zeros((1, 1)) + lang_inputs["num_logits_to_keep"] = np.zeros((1, 1)) if self.include_sampler: - inputs["last_accepted_output_tokens"] = inputs["input_ids"] + lang_inputs["last_accepted_output_tokens"] = lang_inputs["input_ids"] for op in Constants.SAMPLER_OPS: if decode_batch_id is not None: - inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] + lang_inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] else: - inputs[op] = self.sampling_params[op] + lang_inputs[op] = self.sampling_params[op] if self._prompt_to_lora_id_mapping_prefill: if self.full_batch_size: - inputs["lora_ids"] = np.array( + lang_inputs["lora_ids"] = np.array( self._prompt_to_lora_id_mapping_prefill.popleft(), dtype=np.int64 ).reshape(1, 1) else: batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)] - inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) + lang_inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) + # Run language prefill + chunk_inputs = lang_inputs.copy() for i in range(num_chunks): - chunk_inputs = inputs.copy() - chunk_inputs["input_ids"] = inputs["input_ids"][ + chunk_inputs = lang_inputs.copy() + chunk_inputs["input_ids"] = lang_inputs["input_ids"][ :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len ] - chunk_inputs["position_ids"] = inputs["position_ids"][ + chunk_inputs["position_ids"] = lang_inputs["position_ids"][ :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len ] if self.include_sampler: chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] - outputs = self._session.run(chunk_inputs) + outputs = self._lang_session.run(chunk_inputs) + chunk_inputs["image_idx"] = outputs["image_idx_output"] if self._write_io_dir is not None: write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) + + # Skip inputs/outputs again + self._lang_session.skip_buffers( + [ + x + for x in self._lang_session.input_names + self._lang_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + ) + if self._lang_session: + self._lang_session.deactivate() return ( outputs, position_ids, generation_len, ) - def run_continuous_batching_decode(self, prompt_queue, generation_len): + def run_continuous_batching_decode(self, image_queue, prompt_queue, generation_len): """ Runs continuous batching decode for the given prompt queue and generation length. @@ -848,7 +968,8 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): decode_inputs = self.prepare_decode_inputs() while prompt_queue or current_decode_ongoing.any(): - outputs = self._session.run(decode_inputs) + self._lang_session.activate() + outputs = self._lang_session.run(decode_inputs) # Prepare inputs for next iteration next_token_id = self._fetch_next_token_id(outputs) @@ -898,6 +1019,8 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): generated_id_current_index[decode_batch_id] += 1 + self._lang_session.deactivate() + return decode_pause_time def run_decode( @@ -919,13 +1042,13 @@ def run_decode( logits_out_placeholder = np.zeros( (self.batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32 ) - self._session.set_buffers({"logits": logits_out_placeholder}) + self._lang_session.set_buffers({"logits": logits_out_placeholder}) finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id num_token = 0 for num_token in range(1, generation_len): if streamer: streamer.put(decode_inputs["input_ids"][0]) - outputs = self._session.run(decode_inputs) + outputs = self._lang_session.run(decode_inputs) if self._write_io_dir is not None: write_io_files(decode_inputs, outputs, self._write_io_dir, "decode", "aic_batch_io", True, False) @@ -959,7 +1082,7 @@ def generate_decode_stream(self, decode_inputs, generation_len, automation): finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id for num_token in range(1, generation_len): yield decode_inputs["input_ids"] - outputs = self._session.run(decode_inputs) + outputs = self._lang_session.run(decode_inputs) if self._write_io_dir is not None: write_io_files(decode_inputs, outputs, self._write_io_dir, "decode", "aic_batch_io", True, False) @@ -979,8 +1102,10 @@ def generate_decode_stream(self, decode_inputs, generation_len, automation): class TextGeneration: def __init__( self, + processor, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - qpc_path: str, + vision_qpc_path: str, + lang_qpc_path: str, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, device_id: Optional[List[int]] = None, @@ -992,8 +1117,10 @@ def __init__( sampling_params: Optional[Dict[str, Any]] = None, ) -> None: self._qaic_model = QEffTextGenerationBase( + processor, tokenizer=tokenizer, - qpc_path=qpc_path, + lang_qpc_path=lang_qpc_path, + vision_qpc_path=vision_qpc_path, full_batch_size=full_batch_size, ctx_len=ctx_len, device_id=device_id, @@ -1006,9 +1133,11 @@ def __init__( ) self._full_batch_size = self._qaic_model.full_batch_size self._tokenizer = self._qaic_model.tokenizer + self._processor = self._qaic_model.processor self._ctx_len = ctx_len self._perf_metrics = None self._prompt_queue = None + self._image_queue = None self._text_streamer = None @property @@ -1017,6 +1146,7 @@ def perf_metrics(self): def _setup_model_execution_inputs( self, + images, prompt: List[str], generation_len: Optional[int] = None, prompt_to_lora_id_mapping: Optional[List[int]] = None, @@ -1035,13 +1165,15 @@ def _setup_model_execution_inputs( # Create a prompt queue. self._prompt_queue = deque(prompt) + self._image_queue = deque(images) # Initialize np arrays for storing the prefill output for all the decode batch size. num_prompts = len(self._prompt_queue) + num_images = len(self._image_queue) if prompt_to_lora_id_mapping: self._qaic_model.initialize_lora_id_mapping(prompt_to_lora_id_mapping) - self._qaic_model.initialize_decode_inputs(num_prompts, execution_batch_size, max_gen_length) + self._qaic_model.initialize_decode_inputs(num_images, num_prompts, execution_batch_size, max_gen_length) def _regular_model_execution( self, @@ -1089,6 +1221,7 @@ def _regular_model_execution( def _continuous_batching_execution( self, + images, prompt: List[str], generation_len: Optional[int] = None, prompt_to_lora_id_mapping: Optional[List[int]] = None, @@ -1105,13 +1238,17 @@ def _continuous_batching_execution( Returns: :tuple: A tuple containing performance metrics and generated texts. """ - self._setup_model_execution_inputs(prompt, generation_len, prompt_to_lora_id_mapping) + self._setup_model_execution_inputs(images, prompt, generation_len, prompt_to_lora_id_mapping) self._qaic_model.batch_index = np.arange(self._full_batch_size).reshape(-1, 1) start = perf_counter() - self._qaic_model.run_prefill_for_all_inputs(self._prompt_queue, generation_len) + self._qaic_model.run_prefill_for_all_inputs(self._image_queue, self._prompt_queue, generation_len) + + print("\n\n\n\n Prefill for all inputs completed\n\n\n\n") loop_start = perf_counter() # Start decode loop timer - decode_pause_time = self._qaic_model.run_continuous_batching_decode(self._prompt_queue, generation_len) + decode_pause_time = self._qaic_model.run_continuous_batching_decode( + self._image_queue, self._prompt_queue, generation_len + ) end = perf_counter() generated_texts = self._tokenizer.batch_decode(self._qaic_model.generated_ids, skip_special_tokens=True) @@ -1176,6 +1313,7 @@ def generate_stream_tokens( def generate( self, + images, prompt: List[str], generation_len: Optional[int] = None, stream: bool = True, @@ -1197,7 +1335,7 @@ def generate( if self._full_batch_size is not None: logger.warning("Streamer is currently unavailable for continuous batch execution.") perf_metrics, generated_texts = self._continuous_batching_execution( - prompt, generation_len, prompt_to_lora_id_mapping + images, prompt, generation_len, prompt_to_lora_id_mapping ) else: if stream: diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 212fe16ae..b7b951101 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -820,7 +820,7 @@ def forward(self, pixel_values): ) vision_flat = image_features.view(-1, image_features.size(-1)) projected_vision_flat = self.model.multi_modal_projector(vision_flat) - return projected_vision_flat + return projected_vision_flat # , pixel_values # This wrapper utilizes the 'vision_embeds', which contains vision embeddings, and an 'image_idx' index starting at 0. @@ -836,7 +836,15 @@ def __init__(self, model): self.language_model = self.model.language_model self.config = self.model.config - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + batch_index: Optional[torch.LongTensor] = None, + ): inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) selected = input_ids == self.model.config.image_token_index indices1 = selected.to(torch.int64).cumsum(1) - 1 @@ -846,7 +854,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.model.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) @@ -893,6 +905,9 @@ def get_specializations( ctx_len: int, img_size: int, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): max_num_tiles = compiler_options.pop("max_num_tiles", None) @@ -941,28 +956,42 @@ def get_specializations( "img_size": img_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "max_num_tiles": max_num_tiles, - "img_size": img_size, - "vision_size": vision_size, - "chunk_length": prefill_seq_len, - "chunk_ctx_len": chunk_ctx_len, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "max_num_tiles": max_num_tiles, - "img_size": img_size, - "vision_size": vision_size, - "chunk_length": prefill_seq_len, - "chunk_ctx_len": chunk_ctx_len, - }, - ] + + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) specializations = {} @@ -971,18 +1000,22 @@ def get_specializations( specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["vision_embeds"] = {0: "vision_size"} + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} vision_dynamic_axes["pixel_values"] = {0: "max_num_tiles", 2: "img_size", 3: "img_size"} - pkv_dynamic_axes = {0: "batch_size"} + pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size"} for i in range(self.language_model.config.num_hidden_layers): # switch between chunk_ctx_len and ctx_len for RoPE and NoPE layers. if int((i + 1) % 4 != 0): @@ -1011,6 +1044,7 @@ def get_output_names(self, kv_offload: bool = False): output_names = {} if kv_offload: + # vision_output_names.insert(1, "pixel_values_RetainedState") lang_output_names.insert(1, "vision_embeds_RetainedState") lang_output_names.insert(2, "image_idx_output") output_names["vision"] = vision_output_names @@ -1045,7 +1079,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 336) else: @@ -1090,10 +1124,14 @@ def get_dummy_inputs(self, kv_offload: bool = False): .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV past_key_values = self.get_dummy_pkv_cache( config=self.language_model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -1102,6 +1140,8 @@ def get_dummy_inputs(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(past_key_values[0][0].shape, dtype=torch.float32)) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 633a0b29d..90603d991 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -856,6 +856,7 @@ class _QEffAutoModelForImageTextToTextDualQPC: def __init__( self, model: nn.Module, + continuous_batching, **kwargs, ): """ @@ -879,6 +880,7 @@ def __init__( self.config = model.config self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) + self.continuous_batching = continuous_batching self.input_shapes, self.output_names = None, None @property @@ -978,8 +980,8 @@ def export( List[str] A list containing the paths to the generated ONNX graph files for both components. """ - inputs = self.model.get_dummy_inputs(kv_offload=True) - dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True) + inputs = self.model.get_dummy_inputs(kv_offload=True, continuous_batching=self.continuous_batching) + dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True, continuous_batching=self.continuous_batching) output_names = self.model.get_output_names(kv_offload=True) self.vision_model.export( @@ -1068,14 +1070,20 @@ def compile( If `full_batch_size`, `kv_cache_batch_size`, or `num_speculative_tokens` are not None. If both `skip_lang` and `skip_vision` are True. """ - if any(param is not None for param in [full_batch_size, kv_cache_batch_size, num_speculative_tokens]): + if skip_lang and skip_vision: + raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + + if self.continuous_batching and full_batch_size is None: + raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") + + if kv_cache_batch_size and not full_batch_size: raise ValueError( - f"Expected 'full_batch_size', 'kv_cache_batch_size', 'num_speculative_tokens' to be None but got: " - f"full_batch_size={full_batch_size}, kv_cache_batch_size={kv_cache_batch_size}, num_speculative_tokens={num_speculative_tokens}, " + "KV caching requires continuous batching. Please set `full_batch_size` and " + "enable `continuous_batching=True` in `from_pretrained`." ) - if skip_lang and skip_vision: - raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + # Infer kv_cache_batch_size if not provided + kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size output_names = self.model.get_output_names(kv_offload=True) @@ -1085,6 +1093,9 @@ def compile( ctx_len=ctx_len, img_size=img_size, kv_offload=True, + continuous_batching=self.continuous_batching, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, **compiler_options, ) @@ -1156,7 +1167,11 @@ def compile( def generate( self, - inputs: torch.Tensor, + inputs: Optional[torch.Tensor] = None, + tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer] = None, + processor=None, + images: List[str] = None, + prompts: List[str] = None, streamer: Optional[TextStreamer] = None, device_ids: List[int] = None, runtime_ai100: bool = True, @@ -1196,6 +1211,17 @@ def generate( if not runtime_ai100: raise NotImplementedError("PyTorch execution is not supported yet for this model!") + if (processor and images) or (tokenizer and prompts): + return QEfficient.cloud_ai_100_exec_kv( + tokenizer, + processor, + self.lang_model.qpc_path, + self.vision_model.qpc_path, + images=images, + prompt=prompts, + device_id=device_ids, + generation_len=generation_len, + ) return self.kv_offload_generate( inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len ) @@ -1332,9 +1358,7 @@ def kv_offload_generate( lang_session.set_buffers(vision_outputs) - # Prepare inputs for prefill - chunk_inputs = lang_inputs.copy() - prefill_start = perf_counter() + lang_start = perf_counter() # Run prefill chunk_inputs = lang_inputs.copy() @@ -1346,7 +1370,7 @@ def kv_offload_generate( outputs = lang_session.run(chunk_inputs) chunk_inputs["image_idx"] = outputs["image_idx_output"] - prefill_time = perf_counter() - prefill_start + vision_end - vision_start + prefill_time = perf_counter() - lang_start + vision_end - vision_start # Skip inputs/outputs again lang_session.skip_buffers( [ @@ -1930,7 +1954,7 @@ class QEFFAutoModelForImageTextToText: _hf_auto_class = AutoModelForImageTextToText - def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs): + def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, continuous_batching: bool = False, **kwargs): """ Instantiate the appropriate internal class for single or dual QPC mode. @@ -1951,13 +1975,19 @@ def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs) The wrapped model instance, configured for either dual or single QPC. """ if kv_offload: - return _QEffAutoModelForImageTextToTextDualQPC(model, **kwargs) + return _QEffAutoModelForImageTextToTextDualQPC(model, continuous_batching, **kwargs) else: return _QEFFAutoModelForImageTextToTextSingleQPC(model, **kwargs) @classmethod @with_replaced_quantizers - def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, **kwargs): + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + kv_offload: Optional[bool] = None, + continuous_batching: bool = False, + **kwargs, + ): """ Load a QEfficient image-text-to-text model from a pretrained HuggingFace model or local path. @@ -1992,12 +2022,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona if kwargs.get("low_cpu_mem_usage", None): logger.warning("Updating low_cpu_mem_usage=False") - if kwargs.pop("continuous_batching", None): - NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + if continuous_batching and not kv_offload: + NotImplementedError("Continuous batching is not supported for kv_offload = False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) - return cls(model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + return cls( + model, + kv_offload=kv_offload, + continuous_batching=continuous_batching, + pretrained_model_name_or_path=pretrained_model_name_or_path, + **kwargs, + ) MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = { diff --git a/examples/llama4_CB_example_vision_lang.py b/examples/llama4_CB_example_vision_lang.py new file mode 100644 index 000000000..f6cd2bf5c --- /dev/null +++ b/examples/llama4_CB_example_vision_lang.py @@ -0,0 +1,65 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import transformers +from transformers import AutoConfig, AutoProcessor + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config.text_config.num_hidden_layers = 4 +config.vision_config.num_hidden_layers = 2 + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +qeff_model.compile( + prefill_seq_len=128, + ctx_len=3072, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + batch_size=1, + full_batch_size=4, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, +) + +image_urls = [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +output = qeff_model.generate( + images=image_urls, + tokenizer=tokenizer, + processor=processor, + device_ids=[0, 1, 2, 3], + prompts=prompts, + generation_len=100, +) From e9cf657ab77e4dd36f91497b7a6bb0a4dc1c399b Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Tue, 7 Oct 2025 07:57:45 +0000 Subject: [PATCH 02/24] update text_generation_interface Signed-off-by: Mamta Singh Signed-off-by: Rishin Raj --- .../generation/text_generation_inference.py | 191 +++++++++++------- .../transformers/models/modeling_auto.py | 15 +- examples/llama4_CB_example_vision_lang.py | 4 +- 3 files changed, 125 insertions(+), 85 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index b5079eac6..1bcd3ed3d 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -15,7 +15,7 @@ import numpy as np import torch import transformers -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import AutoImageProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.utils import padding_check_and_fix @@ -314,10 +314,10 @@ def calculate_latency(total_decoded_tokens, loop_start, start, end, decode_pause def cloud_ai_100_exec_kv( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - processor, lang_qpc_path: str, - vision_qpc_path: str, - images, + processor: Optional[AutoImageProcessor] = None, + vision_qpc_path: Optional[str] = None, + images: Optional[str] = None, prompt: Optional[str] = None, prompts_txt_file_path: Optional[str] = None, device_id: Optional[List[int]] = None, @@ -398,6 +398,22 @@ def cloud_ai_100_exec_kv( return_pdfs=return_pdfs, sampling_params=sampling_params, ) + if full_batch_size is None: + exec_info = [ + generate_text.generate( + prompt=prompt[i : i + batch_size], + generation_len=generation_len, + stream=stream, + prompt_to_lora_id_mapping=prompt_to_lora_id_mapping, + ) + for i in range(0, len(prompt), batch_size) + ] + prefill_time = np.average([info.perf_metrics.prefill_time for info in exec_info]) + decode_perf = np.average([info.perf_metrics.decode_perf for info in exec_info]) + total_perf = np.average([info.perf_metrics.total_perf for info in exec_info]) + total_time = np.average([info.perf_metrics.total_time for info in exec_info]) + generated_texts = [info.generated_texts for info in exec_info] + generated_ids = [info.generated_ids for info in exec_info] exec_info = CloudAI100ExecInfo( batch_size=batch_size, @@ -407,8 +423,8 @@ def cloud_ai_100_exec_kv( ) else: exec_info = generate_text.generate( - images=images, prompt=prompt, + images=images, generation_len=generation_len, prompt_to_lora_id_mapping=prompt_to_lora_id_mapping, ) @@ -433,9 +449,9 @@ def cloud_ai_100_exec_kv( class QEffTextGenerationBase: def __init__( self, - processor, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], lang_qpc_path: str, + processor: Optional[AutoImageProcessor] = None, vision_qpc_path: Optional[str] = None, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, @@ -454,11 +470,13 @@ def __init__( self.sampling_params = sampling_params # Load QPC + self._lang_session = None + self._vision_session = None if not lang_qpc_path: raise TypeError("Please run compile API for language model first!") self._lang_session = QAICInferenceSession(lang_qpc_path, device_id, activate=False) if vision_qpc_path: - self._vision_session = QAICInferenceSession(vision_qpc_path, device_id) + self._vision_session = QAICInferenceSession(vision_qpc_path, device_id, activate=False) # Validate sampler inputs for On-Device Sampling self.include_sampler = validate_sampler_inputs( @@ -686,7 +704,7 @@ def _fetch_next_token_id(self, outputs): logits = np.expand_dims(logits, 1) return logits.argmax(2) - def initialize_decode_inputs(self, num_images, num_prompts, execution_batch_size, max_gen_length): + def initialize_decode_inputs(self, num_prompts, execution_batch_size, max_gen_length): """ Initialize np arrays for storing the prefill output for all the decode batch size. """ @@ -744,14 +762,18 @@ def run_prefill_for_all_inputs(self, image_queue, prompt_queue, generation_len): generation_len (int): The generation length. """ + next_prompt = None + next_image = None for decode_batch_id in range(self.full_batch_size): - next_prompt = prompt_queue.popleft() - next_image = image_queue.popleft() + if prompt_queue: + next_prompt = prompt_queue.popleft() + if image_queue: + next_image = image_queue.popleft() # run prefill for num_chunks outputs, position_ids, generation_len = self.run_prefill( - next_image, next_prompt, + next_image, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), ) @@ -775,8 +797,9 @@ def _set_output_buffers(self, batch_size: int = 1, sequence_length: int = 1): logits_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32) self._lang_session.set_buffers({"logits": logits_out_placeholder}) - vision_embeds_out_placeholder = np.zeros((2448, 5120), dtype=np.float16) - self._vision_session.set_buffers({"vision_embeds": vision_embeds_out_placeholder}) + if self._vision_session: + vision_embeds_out_placeholder = np.zeros((2448, 5120), dtype=np.float16) + self._vision_session.set_buffers({"vision_embeds": vision_embeds_out_placeholder}) def prepare_vision_language_inputs(self, prompt, image_url): messages = [ @@ -794,13 +817,18 @@ def prepare_vision_language_inputs(self, prompt, image_url): tokenize=True, return_dict=True, return_tensors="pt", - # padding="max_length", - # max_length=padded_len, ) inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) return inputs - def run_prefill(self, image, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None): + def run_prefill( + self, + prompt: str, + image: Optional[str] = None, + generation_len: Optional[int] = None, + prefill_logit_bs=1, + decode_batch_id=None, + ): """ Runs prefill for a given prompt and generation length. @@ -817,8 +845,12 @@ def run_prefill(self, image, prompt, generation_len, prefill_logit_bs=1, decode_ position_ids (array): The position IDs. generation_len (int): The generation length. """ + # Run prefill - inputs = self.prepare_vision_language_inputs(prompt, image) + if image: + inputs = self.prepare_vision_language_inputs(prompt, image) + else: + inputs = self.tokenizer(prompt, return_tensors="np", padding=True) position_ids = inputs["attention_mask"].sum(1, keepdims=True) padded_len = inputs["input_ids"].shape[1] @@ -833,40 +865,45 @@ def run_prefill(self, image, prompt, generation_len, prefill_logit_bs=1, decode_ # Set the prefill output buffers self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1) - pad_token_id = 1 - input_ids_length = inputs["input_ids"].shape[1] - num_chunks = -(input_ids_length // -self._prefill_seq_len) # ceil divide without float - padded_len = num_chunks * self._prefill_seq_len # Convert to a multiple of prompt_len - - inputs["input_ids"] = torch.nn.functional.pad( - inputs["input_ids"], - (0, padded_len - input_ids_length), - "constant", - pad_token_id, - ) - inputs["attention_mask"] = torch.nn.functional.pad( - inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 - ) - if "cross_attention_mask" in inputs: - inputs["cross_attention_mask"] = torch.nn.functional.pad( - inputs["cross_attention_mask"], (0, 0, 0, 0, 0, padded_len - input_ids_length) + vision_inputs = {} + vision_outputs = {} + if image: + pad_token_id = 1 + input_ids_length = inputs["input_ids"].shape[1] + num_chunks = -(input_ids_length // -self._prefill_seq_len) # ceil divide without float + padded_len = num_chunks * self._prefill_seq_len # Convert to a multiple of prompt_len + + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], + (0, padded_len - input_ids_length), + "constant", + pad_token_id, ) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 + ) + if "cross_attention_mask" in inputs: + inputs["cross_attention_mask"] = torch.nn.functional.pad( + inputs["cross_attention_mask"], (0, 0, 0, 0, 0, padded_len - input_ids_length) + ) - for k, v in inputs.items(): - inputs[k] = np.array(v) + for k, v in inputs.items(): + inputs[k] = np.array(v) - vision_inputs = { - k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} - } - if vision_inputs: - vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") + vision_inputs = { + k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} + } + if vision_inputs: + vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") - vision_outputs = {} - if self._vision_session: - self._vision_session.activate() - # Run vision prefill - if vision_inputs: - vision_outputs = self._vision_session.run(vision_inputs) + # Run vision prefill + if vision_inputs: + self._vision_session.activate() + vision_outputs = self._vision_session.run(vision_inputs) + self._vision_session.deactivate() + else: + inputs = self.tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs.pop("token_type_ids", None) lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} lang_inputs["position_ids"] = np.where( @@ -875,10 +912,9 @@ def run_prefill(self, image, prompt, generation_len, prefill_logit_bs=1, decode_ # not_mllama = hasattr(self.model.config, "model_type") and self.model.config.model_type != "mllama" # if not_mllama: - lang_inputs["image_idx"] = np.array([[0]]) + if image: + lang_inputs["image_idx"] = np.array([[0]]) - if self._vision_session: - self._vision_session.deactivate() self._lang_session.activate() self._lang_session.set_buffers(vision_outputs) @@ -904,7 +940,7 @@ def run_prefill(self, image, prompt, generation_len, prefill_logit_bs=1, decode_ lang_inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) # Run language prefill - chunk_inputs = lang_inputs.copy() + for i in range(num_chunks): chunk_inputs = lang_inputs.copy() chunk_inputs["input_ids"] = lang_inputs["input_ids"][ @@ -916,7 +952,8 @@ def run_prefill(self, image, prompt, generation_len, prefill_logit_bs=1, decode_ if self.include_sampler: chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] outputs = self._lang_session.run(chunk_inputs) - chunk_inputs["image_idx"] = outputs["image_idx_output"] + if image: + chunk_inputs["image_idx"] = outputs["image_idx_output"] if self._write_io_dir is not None: write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) @@ -928,15 +965,15 @@ def run_prefill(self, image, prompt, generation_len, prefill_logit_bs=1, decode_ if x.startswith("past_") or x.endswith("_RetainedState") ] ) - if self._lang_session: - self._lang_session.deactivate() + self._lang_session.deactivate() + return ( outputs, position_ids, generation_len, ) - def run_continuous_batching_decode(self, image_queue, prompt_queue, generation_len): + def run_continuous_batching_decode(self, prompt_queue, generation_len): """ Runs continuous batching decode for the given prompt queue and generation length. @@ -983,8 +1020,8 @@ def run_continuous_batching_decode(self, image_queue, prompt_queue, generation_l start = perf_counter() # run prefill for next prompt input. outputs, position_ids, generation_len = self.run_prefill( - prompt_queue.popleft(), - generation_len, + prompt=prompt_queue.popleft(), + generation_len=generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), ) @@ -1045,6 +1082,7 @@ def run_decode( self._lang_session.set_buffers({"logits": logits_out_placeholder}) finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id num_token = 0 + self._lang_session.activate() for num_token in range(1, generation_len): if streamer: streamer.put(decode_inputs["input_ids"][0]) @@ -1064,6 +1102,7 @@ def run_decode( if finished_sequences.all() and not automation: break + self._lang_session.deactivate() return num_token def generate_decode_stream(self, decode_inputs, generation_len, automation): @@ -1080,6 +1119,7 @@ def generate_decode_stream(self, decode_inputs, generation_len, automation): token_id (int): The token generated in the decoding process. """ finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id + self._lang_session.activate() for num_token in range(1, generation_len): yield decode_inputs["input_ids"] outputs = self._lang_session.run(decode_inputs) @@ -1096,16 +1136,17 @@ def generate_decode_stream(self, decode_inputs, generation_len, automation): if finished_sequences.all() and not automation: break + self._lang_session.deactivate() yield decode_inputs["input_ids"] # yield the last token class TextGeneration: def __init__( self, - processor, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - vision_qpc_path: str, lang_qpc_path: str, + processor: Optional[AutoImageProcessor] = None, + vision_qpc_path: Optional[str] = None, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, device_id: Optional[List[int]] = None, @@ -1117,9 +1158,9 @@ def __init__( sampling_params: Optional[Dict[str, Any]] = None, ) -> None: self._qaic_model = QEffTextGenerationBase( - processor, tokenizer=tokenizer, lang_qpc_path=lang_qpc_path, + processor=processor, vision_qpc_path=vision_qpc_path, full_batch_size=full_batch_size, ctx_len=ctx_len, @@ -1146,8 +1187,8 @@ def perf_metrics(self): def _setup_model_execution_inputs( self, - images, prompt: List[str], + images: Optional[List[str]] = None, generation_len: Optional[int] = None, prompt_to_lora_id_mapping: Optional[List[int]] = None, ): @@ -1165,15 +1206,15 @@ def _setup_model_execution_inputs( # Create a prompt queue. self._prompt_queue = deque(prompt) - self._image_queue = deque(images) + if images: + self._image_queue = deque(images) # Initialize np arrays for storing the prefill output for all the decode batch size. num_prompts = len(self._prompt_queue) - num_images = len(self._image_queue) if prompt_to_lora_id_mapping: self._qaic_model.initialize_lora_id_mapping(prompt_to_lora_id_mapping) - self._qaic_model.initialize_decode_inputs(num_images, num_prompts, execution_batch_size, max_gen_length) + self._qaic_model.initialize_decode_inputs(num_prompts, execution_batch_size, max_gen_length) def _regular_model_execution( self, @@ -1196,12 +1237,14 @@ def _regular_model_execution( :tuple: A tuple containing performance metrics and generated texts. """ - self._setup_model_execution_inputs(prompt, generation_len, prompt_to_lora_id_mapping) + self._setup_model_execution_inputs( + prompt=prompt, generation_len=generation_len, prompt_to_lora_id_mapping=prompt_to_lora_id_mapping + ) if stream and self._text_streamer is None: self._text_streamer = transformers.TextStreamer(self._tokenizer) start = perf_counter() outputs, position_ids, generation_len = self._qaic_model.run_prefill( - prompt, generation_len, prefill_logit_bs=self._qaic_model.batch_size + prompt=prompt, generation_len=generation_len, prefill_logit_bs=self._qaic_model.batch_size ) self._qaic_model.update_decode_input(outputs, position_ids, generation_len) @@ -1221,8 +1264,8 @@ def _regular_model_execution( def _continuous_batching_execution( self, - images, prompt: List[str], + images: Optional[List[str]] = None, generation_len: Optional[int] = None, prompt_to_lora_id_mapping: Optional[List[int]] = None, ): @@ -1238,17 +1281,13 @@ def _continuous_batching_execution( Returns: :tuple: A tuple containing performance metrics and generated texts. """ - self._setup_model_execution_inputs(images, prompt, generation_len, prompt_to_lora_id_mapping) + self._setup_model_execution_inputs(prompt, images, generation_len, prompt_to_lora_id_mapping) self._qaic_model.batch_index = np.arange(self._full_batch_size).reshape(-1, 1) start = perf_counter() self._qaic_model.run_prefill_for_all_inputs(self._image_queue, self._prompt_queue, generation_len) - print("\n\n\n\n Prefill for all inputs completed\n\n\n\n") - loop_start = perf_counter() # Start decode loop timer - decode_pause_time = self._qaic_model.run_continuous_batching_decode( - self._image_queue, self._prompt_queue, generation_len - ) + decode_pause_time = self._qaic_model.run_continuous_batching_decode(self._prompt_queue, generation_len) end = perf_counter() generated_texts = self._tokenizer.batch_decode(self._qaic_model.generated_ids, skip_special_tokens=True) @@ -1289,7 +1328,7 @@ def generate_stream_tokens( self._setup_model_execution_inputs(prompt, generation_len, prompt_to_lora_id_mapping) start = perf_counter() outputs, position_ids, generation_len = self._qaic_model.run_prefill( - prompt, generation_len, prefill_logit_bs=self._qaic_model.batch_size + prompt=prompt, generation_len=generation_len, prefill_logit_bs=self._qaic_model.batch_size ) self._qaic_model.update_decode_input(outputs, position_ids, generation_len) @@ -1313,8 +1352,8 @@ def generate_stream_tokens( def generate( self, - images, prompt: List[str], + images: Optional[List[str]] = None, generation_len: Optional[int] = None, stream: bool = True, automation: Optional[bool] = False, @@ -1335,7 +1374,7 @@ def generate( if self._full_batch_size is not None: logger.warning("Streamer is currently unavailable for continuous batch execution.") perf_metrics, generated_texts = self._continuous_batching_execution( - images, prompt, generation_len, prompt_to_lora_id_mapping + prompt, images, generation_len, prompt_to_lora_id_mapping ) else: if stream: diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 90603d991..eed3782db 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -14,6 +14,7 @@ import torch import torch.nn as nn from transformers import ( + AutoImageProcessor, AutoModel, AutoModelForCausalLM, AutoModelForCTC, @@ -1169,7 +1170,7 @@ def generate( self, inputs: Optional[torch.Tensor] = None, tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer] = None, - processor=None, + processor: Optional[AutoImageProcessor] = None, images: List[str] = None, prompts: List[str] = None, streamer: Optional[TextStreamer] = None, @@ -1213,10 +1214,10 @@ def generate( if (processor and images) or (tokenizer and prompts): return QEfficient.cloud_ai_100_exec_kv( - tokenizer, - processor, - self.lang_model.qpc_path, - self.vision_model.qpc_path, + tokenizer=tokenizer, + processor=processor, + lang_qpc_path=self.lang_model.qpc_path, + vision_qpc_path=self.vision_model.qpc_path, images=images, prompt=prompts, device_id=device_ids, @@ -2741,8 +2742,8 @@ def generate( raise TypeError("Please run compile API first!") generation_len = kwargs.pop("generation_len", None) return QEfficient.cloud_ai_100_exec_kv( - tokenizer, - self.qpc_path, + tokenizer=tokenizer, + lang_qpc_path=self.qpc_path, prompt=prompts, device_id=device_id, generation_len=generation_len, diff --git a/examples/llama4_CB_example_vision_lang.py b/examples/llama4_CB_example_vision_lang.py index f6cd2bf5c..ebe65bf82 100644 --- a/examples/llama4_CB_example_vision_lang.py +++ b/examples/llama4_CB_example_vision_lang.py @@ -56,10 +56,10 @@ ] output = qeff_model.generate( - images=image_urls, tokenizer=tokenizer, + prompts=prompts, processor=processor, + images=image_urls, device_ids=[0, 1, 2, 3], - prompts=prompts, generation_len=100, ) From f62f71ea07ba9fb06f2810f4dae4e13a8e1a9516 Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Thu, 23 Oct 2025 08:40:16 +0000 Subject: [PATCH 03/24] Updated text_generation to run CB for VLMs Signed-off-by: Asmita Goswami Signed-off-by: Rishin Raj --- QEfficient/generation/cloud_infer.py | 31 ++++- .../generation/text_generation_inference.py | 108 +++++++++--------- 2 files changed, 78 insertions(+), 61 deletions(-) diff --git a/QEfficient/generation/cloud_infer.py b/QEfficient/generation/cloud_infer.py index 8519d824c..42c8b342e 100644 --- a/QEfficient/generation/cloud_infer.py +++ b/QEfficient/generation/cloud_infer.py @@ -90,8 +90,10 @@ def __init__( self.program = qaicrt.Program(self.context, None, qpc, prog_properties) if self.program.load() != qaicrt.QStatus.QS_SUCCESS: raise RuntimeError("Failed to load program") + self.is_active = False if activate: self.activate() + self.is_active = True # Create input qbuffers and buf_dims self.qbuffers = [qaicrt.QBuffer(bytes(binding.size)) for binding in self.bindings] self.buf_dims = qaicrt.BufferDimensionsVecRef( @@ -108,15 +110,32 @@ def output_names(self) -> List[str]: def activate(self): """Activate qpc""" - - self.program.activate() - self.execObj = qaicrt.ExecObj(self.context, self.program) + if not self.is_active: + self.program.activate() + self.execObj = qaicrt.ExecObj(self.context, self.program) + self.is_active = True def deactivate(self): """Deactivate qpc""" - - del self.execObj - self.program.deactivate() + if self.is_active: + del self.execObj + self.program.deactivate() + self.is_active = False + + def pause(self): + """Pause the session while preserving state""" + if self.is_active: + # Just deactivate the program and set state + self.program.deactivate() + self.is_active = False + + def resume(self): + """Resume a paused session""" + if not self.is_active: + # Reactivate program and create new execObj + self.program.activate() + self.execObj = qaicrt.ExecObj(self.context, self.program) + self.is_active = True def set_buffers(self, buffers: Dict[str, np.ndarray]): """ diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 1bcd3ed3d..f014bf6a0 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -11,6 +11,8 @@ from dataclasses import dataclass from time import perf_counter from typing import Any, Dict, List, Optional, Tuple, Union +import requests +from PIL import Image import numpy as np import torch @@ -398,36 +400,24 @@ def cloud_ai_100_exec_kv( return_pdfs=return_pdfs, sampling_params=sampling_params, ) - if full_batch_size is None: - exec_info = [ - generate_text.generate( - prompt=prompt[i : i + batch_size], - generation_len=generation_len, - stream=stream, - prompt_to_lora_id_mapping=prompt_to_lora_id_mapping, - ) - for i in range(0, len(prompt), batch_size) - ] - prefill_time = np.average([info.perf_metrics.prefill_time for info in exec_info]) - decode_perf = np.average([info.perf_metrics.decode_perf for info in exec_info]) - total_perf = np.average([info.perf_metrics.total_perf for info in exec_info]) - total_time = np.average([info.perf_metrics.total_time for info in exec_info]) - generated_texts = [info.generated_texts for info in exec_info] - generated_ids = [info.generated_ids for info in exec_info] - - exec_info = CloudAI100ExecInfo( - batch_size=batch_size, - generated_texts=generated_texts, - generated_ids=generated_ids, - perf_metrics=PerfMetrics(prefill_time, decode_perf, total_perf, total_time), - ) - else: - exec_info = generate_text.generate( - prompt=prompt, - images=images, - generation_len=generation_len, - prompt_to_lora_id_mapping=prompt_to_lora_id_mapping, - ) + + for _ in range(0, int(iteration)): + if full_batch_size is None: + exec_info = [ + generate_text.generate( + prompt=prompt[i : i + batch_size], + generation_len=generation_len, + stream=stream, + prompt_to_lora_id_mapping=prompt_to_lora_id_mapping, + ) + for i in range(0, len(prompt), batch_size) + ] + prefill_time = np.average([info.perf_metrics.prefill_time for info in exec_info]) + decode_perf = np.average([info.perf_metrics.decode_perf for info in exec_info]) + total_perf = np.average([info.perf_metrics.total_perf for info in exec_info]) + total_time = np.average([info.perf_metrics.total_time for info in exec_info]) + generated_texts = [info.generated_texts for info in exec_info] + generated_ids = [info.generated_ids for info in exec_info] exec_info = CloudAI100ExecInfo( batch_size=batch_size, @@ -437,7 +427,10 @@ def cloud_ai_100_exec_kv( ) else: exec_info = generate_text.generate( - prompt=prompt, generation_len=generation_len, prompt_to_lora_id_mapping=prompt_to_lora_id_mapping + prompt=prompt, + images=images, + generation_len=generation_len, + prompt_to_lora_id_mapping=prompt_to_lora_id_mapping, ) print_latency_stats_kv(prompt, exec_info=exec_info, automation=automation) @@ -751,7 +744,7 @@ def update_decode_input(self, outputs, position_ids, generation_len, decode_batc self.generation_len[decode_batch_id or slice(None)] = generation_len return next_token_id - def run_prefill_for_all_inputs(self, image_queue, prompt_queue, generation_len): + def run_prefill_for_all_inputs(self, image_queue, prompt_queue, processor, generation_len): """ Runs prefill for all inputs in the prompt queue and updates the decode input. @@ -774,6 +767,7 @@ def run_prefill_for_all_inputs(self, image_queue, prompt_queue, generation_len): outputs, position_ids, generation_len = self.run_prefill( next_prompt, next_image, + processor, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), ) @@ -801,30 +795,28 @@ def _set_output_buffers(self, batch_size: int = 1, sequence_length: int = 1): vision_embeds_out_placeholder = np.zeros((2448, 5120), dtype=np.float16) self._vision_session.set_buffers({"vision_embeds": vision_embeds_out_placeholder}) - def prepare_vision_language_inputs(self, prompt, image_url): - messages = [ + def prepare_vision_language_inputs(self, processor, query, image_url): + image = Image.open(requests.get(image_url, stream=True).raw) + conversation = [ { "role": "user", "content": [ - {"type": "image", "url": image_url}, - {"type": "text", "text": prompt}, + {"type": "text", "text": query}, + {"type": "image"}, ], }, ] - inputs = self.processor.apply_chat_template( - messages, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - return_tensors="pt", - ) - inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + inputs = processor(images=image, text=prompt, return_tensors="pt") + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) return inputs def run_prefill( self, prompt: str, image: Optional[str] = None, + processor: Optional[AutoImageProcessor] = None, generation_len: Optional[int] = None, prefill_logit_bs=1, decode_batch_id=None, @@ -848,7 +840,7 @@ def run_prefill( # Run prefill if image: - inputs = self.prepare_vision_language_inputs(prompt, image) + inputs = self.prepare_vision_language_inputs(processor, prompt, image) else: inputs = self.tokenizer(prompt, return_tensors="np", padding=True) @@ -891,16 +883,22 @@ def run_prefill( inputs[k] = np.array(v) vision_inputs = { - k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} + k: v for k, v in inputs.items() if k in {"pixel_values", "image_masks", "image_input_idx", "valid_idx", "aspect_ratio_ids", "aspect_ratio_mask"} } - if vision_inputs: - vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") + # if vision_inputs: + # vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") + vision_inputs_fp16 = {"pixel_values", "image_masks"} + vision_inputs.update({k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs}) + # if not(self._lang_session.is_active): + # self._lang_session.activate() # Run vision prefill if vision_inputs: + # self._lang_session.pause() self._vision_session.activate() vision_outputs = self._vision_session.run(vision_inputs) self._vision_session.deactivate() + # self._lang_session.resume() else: inputs = self.tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) inputs.pop("token_type_ids", None) @@ -910,8 +908,9 @@ def run_prefill( lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 ) # Need to use -1 as position_ids for invalid tokens - # not_mllama = hasattr(self.model.config, "model_type") and self.model.config.model_type != "mllama" - # if not_mllama: + # not_mllama = hasattr(self.model.config, "model_type") and self.model.config.model_type != "mllama" + # if not_mllama: + # lang_inputs["image_idx"] = np.array([[0]]) if image: lang_inputs["image_idx"] = np.array([[0]]) @@ -940,9 +939,8 @@ def run_prefill( lang_inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) # Run language prefill - + chunk_inputs = lang_inputs.copy() for i in range(num_chunks): - chunk_inputs = lang_inputs.copy() chunk_inputs["input_ids"] = lang_inputs["input_ids"][ :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len ] @@ -965,7 +963,7 @@ def run_prefill( if x.startswith("past_") or x.endswith("_RetainedState") ] ) - self._lang_session.deactivate() + # self._lang_session.deactivate() return ( outputs, @@ -1004,8 +1002,8 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): # Prepare decode inputs inputs. decode_inputs = self.prepare_decode_inputs() + # self._lang_session.activate() # Due to activating new session (new exec_obj) run values are changing while prompt_queue or current_decode_ongoing.any(): - self._lang_session.activate() outputs = self._lang_session.run(decode_inputs) # Prepare inputs for next iteration @@ -1284,7 +1282,7 @@ def _continuous_batching_execution( self._setup_model_execution_inputs(prompt, images, generation_len, prompt_to_lora_id_mapping) self._qaic_model.batch_index = np.arange(self._full_batch_size).reshape(-1, 1) start = perf_counter() - self._qaic_model.run_prefill_for_all_inputs(self._image_queue, self._prompt_queue, generation_len) + self._qaic_model.run_prefill_for_all_inputs(self._image_queue, self._prompt_queue, self._processor, generation_len) loop_start = perf_counter() # Start decode loop timer decode_pause_time = self._qaic_model.run_continuous_batching_decode(self._prompt_queue, generation_len) From 39f557447f136a990579967d8196e471269c96eb Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Thu, 23 Oct 2025 08:50:48 +0000 Subject: [PATCH 04/24] Ruff format Signed-off-by: Asmita Goswami Signed-off-by: Rishin Raj --- .../generation/text_generation_inference.py | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index f014bf6a0..190bcf764 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -11,12 +11,12 @@ from dataclasses import dataclass from time import perf_counter from typing import Any, Dict, List, Optional, Tuple, Union -import requests -from PIL import Image import numpy as np +import requests import torch import transformers +from PIL import Image from transformers import AutoImageProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast from QEfficient.generation.cloud_infer import QAICInferenceSession @@ -883,12 +883,24 @@ def run_prefill( inputs[k] = np.array(v) vision_inputs = { - k: v for k, v in inputs.items() if k in {"pixel_values", "image_masks", "image_input_idx", "valid_idx", "aspect_ratio_ids", "aspect_ratio_mask"} + k: v + for k, v in inputs.items() + if k + in { + "pixel_values", + "image_masks", + "image_input_idx", + "valid_idx", + "aspect_ratio_ids", + "aspect_ratio_mask", + } } # if vision_inputs: # vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") vision_inputs_fp16 = {"pixel_values", "image_masks"} - vision_inputs.update({k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs}) + vision_inputs.update( + {k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs} + ) # if not(self._lang_session.is_active): # self._lang_session.activate() @@ -908,9 +920,9 @@ def run_prefill( lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 ) # Need to use -1 as position_ids for invalid tokens - # not_mllama = hasattr(self.model.config, "model_type") and self.model.config.model_type != "mllama" - # if not_mllama: - # lang_inputs["image_idx"] = np.array([[0]]) + # not_mllama = hasattr(self.model.config, "model_type") and self.model.config.model_type != "mllama" + # if not_mllama: + # lang_inputs["image_idx"] = np.array([[0]]) if image: lang_inputs["image_idx"] = np.array([[0]]) @@ -1282,7 +1294,9 @@ def _continuous_batching_execution( self._setup_model_execution_inputs(prompt, images, generation_len, prompt_to_lora_id_mapping) self._qaic_model.batch_index = np.arange(self._full_batch_size).reshape(-1, 1) start = perf_counter() - self._qaic_model.run_prefill_for_all_inputs(self._image_queue, self._prompt_queue, self._processor, generation_len) + self._qaic_model.run_prefill_for_all_inputs( + self._image_queue, self._prompt_queue, self._processor, generation_len + ) loop_start = perf_counter() # Start decode loop timer decode_pause_time = self._qaic_model.run_continuous_batching_decode(self._prompt_queue, generation_len) From ecd59059d3b15fd4a9e4785f27b69af3c9a57355 Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Fri, 24 Oct 2025 06:53:39 +0000 Subject: [PATCH 05/24] Updated qwen2_5 modelling for CB support Signed-off-by: Asmita Goswami Signed-off-by: Rishin Raj --- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 67 +++++++++++++------ 1 file changed, 48 insertions(+), 19 deletions(-) diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index e5e842e6f..a7027f82b 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -680,7 +680,7 @@ def __init__(self, model): self.model = model self.language_model = self.model.model.language_model - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, batch_index: Optional[torch.LongTensor] = None,): inputs_embeds = self.model.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape selected = input_ids == self.model.config.image_token_id @@ -691,7 +691,7 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) outputs = self.model.model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, batch_index=batch_index, use_cache=True ) logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) @@ -709,7 +709,7 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffQwen_2_5_vl_DecoderWrapper(self) - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False, **kwargs): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) @@ -745,10 +745,14 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): .repeat(4, 1, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV kv_cache_shape = get_padding_shape_from_config( config=self.model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -757,6 +761,9 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -776,6 +783,9 @@ def get_specializations( height: int = None, width: int = None, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): if height is None or width is None: @@ -856,20 +866,34 @@ def smart_resize( "grid_w": grid_w, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "vision_size": vision_size, - }, - ] + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + } + + if continuous_batching: lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "vision_size": vision_size, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) specializations = {} @@ -878,9 +902,11 @@ def smart_resize( specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers @@ -895,6 +921,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): "vision_embeds": {0: "batch_size", 1: "vision_size"}, } + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} + for i in range(num_layers): lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} From 3d9cd49a0c792ad0b319c2cdbdbbfd3e3aa3fecd Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Fri, 24 Oct 2025 06:55:28 +0000 Subject: [PATCH 06/24] Updated qwen2_5 modelling for CB support Signed-off-by: Asmita Goswami Signed-off-by: Rishin Raj --- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index a7027f82b..8d316211b 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -680,7 +680,15 @@ def __init__(self, model): self.model = model self.language_model = self.model.model.language_model - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, batch_index: Optional[torch.LongTensor] = None,): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + batch_index: Optional[torch.LongTensor] = None, + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape selected = input_ids == self.model.config.image_token_id @@ -691,7 +699,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) outputs = self.model.model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, batch_index=batch_index, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + use_cache=True, ) logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) @@ -873,7 +885,8 @@ def smart_resize( "vision_size": vision_size, } - if continuous_batching: lang_prefill["full_batch_size"] = kv_cache_batch_size + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size else: lang_prefill["batch_size"] = kv_cache_batch_size if full_batch_size: From 165c8fb6ac0be71434342b2d5b838233b86acb31 Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Fri, 24 Oct 2025 06:58:06 +0000 Subject: [PATCH 07/24] Passed image queue in decode CB Signed-off-by: Asmita Goswami Signed-off-by: Rishin Raj --- .../generation/text_generation_inference.py | 34 ++++++++++++------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 190bcf764..829e1fc14 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -791,10 +791,6 @@ def _set_output_buffers(self, batch_size: int = 1, sequence_length: int = 1): logits_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32) self._lang_session.set_buffers({"logits": logits_out_placeholder}) - if self._vision_session: - vision_embeds_out_placeholder = np.zeros((2448, 5120), dtype=np.float16) - self._vision_session.set_buffers({"vision_embeds": vision_embeds_out_placeholder}) - def prepare_vision_language_inputs(self, processor, query, image_url): image = Image.open(requests.get(image_url, stream=True).raw) conversation = [ @@ -844,7 +840,10 @@ def run_prefill( else: inputs = self.tokenizer(prompt, return_tensors="np", padding=True) - position_ids = inputs["attention_mask"].sum(1, keepdims=True) + if "position_ids" in inputs: + position_ids = inputs["position_ids"] + else: + position_ids = inputs["attention_mask"].sum(1, keepdims=True) padded_len = inputs["input_ids"].shape[1] num_chunks = -(padded_len // -self._prefill_seq_len) # ceil divide without float padded_len = num_chunks * self._prefill_seq_len # Convert to a multiple of prompt_len @@ -916,9 +915,13 @@ def run_prefill( inputs.pop("token_type_ids", None) lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} - lang_inputs["position_ids"] = np.where( - lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 - ) # Need to use -1 as position_ids for invalid tokens + if "position_ids" in inputs: + lang_inputs["position_ids"] = inputs["position_ids"] + lang_inputs.pop("attention_mask") + else: + lang_inputs["position_ids"] = np.where( + lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 + ) # Need to use -1 as position_ids for invalid tokens # not_mllama = hasattr(self.model.config, "model_type") and self.model.config.model_type != "mllama" # if not_mllama: @@ -957,7 +960,7 @@ def run_prefill( :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len ] chunk_inputs["position_ids"] = lang_inputs["position_ids"][ - :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len + ..., i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len ] if self.include_sampler: chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] @@ -983,7 +986,7 @@ def run_prefill( generation_len, ) - def run_continuous_batching_decode(self, prompt_queue, generation_len): + def run_continuous_batching_decode(self, prompt_queue, image_queue, processor, generation_len): """ Runs continuous batching decode for the given prompt queue and generation length. @@ -1013,6 +1016,8 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): decode_pause_time = 0 # Prepare decode inputs inputs. decode_inputs = self.prepare_decode_inputs() + next_prompt = None + next_image = None # self._lang_session.activate() # Due to activating new session (new exec_obj) run values are changing while prompt_queue or current_decode_ongoing.any(): @@ -1026,11 +1031,16 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): next_token_id[decode_batch_id, -1] == self.tokenizer.eos_token_id or generated_id_current_index[decode_batch_id] >= self.generation_len[decode_batch_id] ): + if image_queue: + next_image = image_queue.popleft() if prompt_queue: + next_prompt = prompt_queue.popleft() start = perf_counter() # run prefill for next prompt input. outputs, position_ids, generation_len = self.run_prefill( - prompt=prompt_queue.popleft(), + prompt=next_prompt, + image=next_image, + processor=processor, generation_len=generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), ) @@ -1299,7 +1309,7 @@ def _continuous_batching_execution( ) loop_start = perf_counter() # Start decode loop timer - decode_pause_time = self._qaic_model.run_continuous_batching_decode(self._prompt_queue, generation_len) + decode_pause_time = self._qaic_model.run_continuous_batching_decode(self._prompt_queue, self._image_queue, self._processor, generation_len) end = perf_counter() generated_texts = self._tokenizer.batch_decode(self._qaic_model.generated_ids, skip_special_tokens=True) From cf2d4fba7b841685fce45deb8b71e65613b4e1a5 Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Fri, 24 Oct 2025 07:03:10 +0000 Subject: [PATCH 08/24] Ruff format Signed-off-by: Asmita Goswami Signed-off-by: Rishin Raj --- QEfficient/generation/text_generation_inference.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 829e1fc14..c0b59349a 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -1309,7 +1309,9 @@ def _continuous_batching_execution( ) loop_start = perf_counter() # Start decode loop timer - decode_pause_time = self._qaic_model.run_continuous_batching_decode(self._prompt_queue, self._image_queue, self._processor, generation_len) + decode_pause_time = self._qaic_model.run_continuous_batching_decode( + self._prompt_queue, self._image_queue, self._processor, generation_len + ) end = perf_counter() generated_texts = self._tokenizer.batch_decode(self._qaic_model.generated_ids, skip_special_tokens=True) From d97bda9cedefd9dd3d962d6c37e9a7ba6c55a374 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Mon, 27 Oct 2025 13:54:46 +0000 Subject: [PATCH 09/24] refactored the code Signed-off-by: Rishin Raj --- .../generation/text_generation_inference.py | 354 +++------- QEfficient/generation/vision_handler.py | 395 +++++++++++ QEfficient/generation/vlm_generation.py | 633 ++++++++++++++++++ .../transformers/models/modeling_auto.py | 29 +- examples/llama4_CB_example_vision_lang.py | 2 + 5 files changed, 1150 insertions(+), 263 deletions(-) create mode 100644 QEfficient/generation/vision_handler.py create mode 100644 QEfficient/generation/vlm_generation.py diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index c0b59349a..2a9f5bc3a 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -13,11 +13,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np -import requests -import torch import transformers -from PIL import Image -from transformers import AutoImageProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.utils import padding_check_and_fix @@ -316,10 +313,7 @@ def calculate_latency(total_decoded_tokens, loop_start, start, end, decode_pause def cloud_ai_100_exec_kv( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - lang_qpc_path: str, - processor: Optional[AutoImageProcessor] = None, - vision_qpc_path: Optional[str] = None, - images: Optional[str] = None, + qpc_path: str, prompt: Optional[str] = None, prompts_txt_file_path: Optional[str] = None, device_id: Optional[List[int]] = None, @@ -378,7 +372,7 @@ def cloud_ai_100_exec_kv( exec_info = QEfficient.cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc_path=qpc_path, prompt="Hi there!!", device_id=[0]) """ - batch_size, ctx_len, full_batch_size = get_compilation_dims(lang_qpc_path) + batch_size, ctx_len, full_batch_size = get_compilation_dims(qpc_path) prompt: List[str] = get_input_prompts(prompt, prompts_txt_file_path) prompt = fix_prompts(prompt, batch_size, full_batch_size) if prompt_to_lora_id_mapping is not None: @@ -387,9 +381,7 @@ def cloud_ai_100_exec_kv( ) generate_text = TextGeneration( tokenizer=tokenizer, - processor=processor, - lang_qpc_path=lang_qpc_path, - vision_qpc_path=vision_qpc_path, + qpc_path=qpc_path, device_id=device_id, ctx_len=ctx_len, enable_debug_logs=enable_debug_logs, @@ -404,12 +396,7 @@ def cloud_ai_100_exec_kv( for _ in range(0, int(iteration)): if full_batch_size is None: exec_info = [ - generate_text.generate( - prompt=prompt[i : i + batch_size], - generation_len=generation_len, - stream=stream, - prompt_to_lora_id_mapping=prompt_to_lora_id_mapping, - ) + generate_text.generate(prompt[i : i + batch_size], generation_len, stream, prompt_to_lora_id_mapping) for i in range(0, len(prompt), batch_size) ] prefill_time = np.average([info.perf_metrics.prefill_time for info in exec_info]) @@ -427,10 +414,7 @@ def cloud_ai_100_exec_kv( ) else: exec_info = generate_text.generate( - prompt=prompt, - images=images, - generation_len=generation_len, - prompt_to_lora_id_mapping=prompt_to_lora_id_mapping, + prompt=prompt, generation_len=generation_len, prompt_to_lora_id_mapping=prompt_to_lora_id_mapping ) print_latency_stats_kv(prompt, exec_info=exec_info, automation=automation) @@ -443,9 +427,7 @@ class QEffTextGenerationBase: def __init__( self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - lang_qpc_path: str, - processor: Optional[AutoImageProcessor] = None, - vision_qpc_path: Optional[str] = None, + qpc_path: str, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, device_id: Optional[List[int]] = None, @@ -455,25 +437,21 @@ def __init__( include_sampler: bool = False, return_pdfs: bool = False, sampling_params: Optional[Dict[str, Any]] = None, + activate: bool = True, ) -> None: self._ctx_len = ctx_len self._write_io_dir = write_io_dir self.is_tlm = is_tlm self.return_pdfs = return_pdfs self.sampling_params = sampling_params + self._qpc_path = qpc_path # Store qpc_path for later use # Load QPC - self._lang_session = None - self._vision_session = None - if not lang_qpc_path: - raise TypeError("Please run compile API for language model first!") - self._lang_session = QAICInferenceSession(lang_qpc_path, device_id, activate=False) - if vision_qpc_path: - self._vision_session = QAICInferenceSession(vision_qpc_path, device_id, activate=False) + self._session = QAICInferenceSession(qpc_path, device_id, activate=activate, enable_debug_logs=enable_debug_logs) # Validate sampler inputs for On-Device Sampling self.include_sampler = validate_sampler_inputs( - session_inputs=set(self._lang_session.input_names), include_sampler=include_sampler + session_inputs=set(self._session.input_names), include_sampler=include_sampler ) # Fetch the variables from the QPC @@ -498,23 +476,10 @@ def __init__( self.generation_len = None self.tokenizer = tokenizer - self.processor = processor self._set_tokenizer_params() # set tokenizer params # Skip inputs/outputs - if self._vision_session: - self._vision_session.skip_buffers( - [ - x - for x in self._vision_session.input_names + self._vision_session.output_names - if x.startswith("past_") or x.endswith("_RetainedState") - ] - ) - self._lang_session.skip_buffers( - [ - x - for x in self._lang_session.input_names + self._lang_session.output_names - if x.startswith("past_") or x.endswith("_RetainedState") - ] + self._session.skip_buffers( + [x for x in self._session.input_names + self._session.output_names if x.startswith("past_")] ) def _set_tokenizer_params(self): @@ -539,16 +504,13 @@ def _fetch_full_batch_size( """ full_batch_size = None - if "batch_index" in self._lang_session.binding_index_map: - if self._lang_session.allowed_shapes: + if "batch_index" in self._session.binding_index_map: + if self._session.allowed_shapes: full_batch_size, _ = [ - x[self._lang_session.binding_index_map["batch_index"]][1][0] - for x in self._lang_session.allowed_shapes + x[self._session.binding_index_map["batch_index"]][1][0] for x in self._session.allowed_shapes ] else: - full_batch_size, _ = self._lang_session.bindings[ - self._lang_session.binding_index_map["batch_index"] - ].dims + full_batch_size, _ = self._session.bindings[self._session.binding_index_map["batch_index"]].dims return full_batch_size def _fetch_batch_size_prefill_seq_len( @@ -561,17 +523,15 @@ def _fetch_batch_size_prefill_seq_len( batch_size: The batch size fetched from the session's bindings or allowed shapes. prefill_seq_len: The prefill sequence length fetched from the session's bindings or allowed shapes. """ - if self._lang_session.allowed_shapes: + if self._session.allowed_shapes: batch_size = max( - [x[self._lang_session.binding_index_map["input_ids"]][1][0] for x in self._lang_session.allowed_shapes] + [x[self._session.binding_index_map["input_ids"]][1][0] for x in self._session.allowed_shapes] ) prefill_seq_len = max( - [x[self._lang_session.binding_index_map["input_ids"]][1][1] for x in self._lang_session.allowed_shapes] + [x[self._session.binding_index_map["input_ids"]][1][1] for x in self._session.allowed_shapes] ) else: - batch_size, prefill_seq_len = self._lang_session.bindings[ - self._lang_session.binding_index_map["input_ids"] - ].dims + batch_size, prefill_seq_len = self._session.bindings[self._session.binding_index_map["input_ids"]].dims return batch_size, prefill_seq_len def _fetch_decode_seq_len( @@ -584,9 +544,9 @@ def _fetch_decode_seq_len( decode_seq_len: The decode sequence length fetched from the session's bindings or allowed shapes. """ decode_seq_len = None - if self._lang_session.allowed_shapes: + if self._session.allowed_shapes: decode_seq_len = min( - [x[self._lang_session.binding_index_map["input_ids"]][1][1] for x in self._lang_session.allowed_shapes] + [x[self._session.binding_index_map["input_ids"]][1][1] for x in self._session.allowed_shapes] ) return decode_seq_len @@ -605,10 +565,10 @@ def _fetch_vocab_size( if self.include_sampler else "logits" ) - if self._lang_session.allowed_shapes: - return [x[self._lang_session.binding_index_map[key]] for x in self._lang_session.allowed_shapes][0][1][2] + if self._session.allowed_shapes: + return [x[self._session.binding_index_map[key]] for x in self._session.allowed_shapes][0][1][2] - return self._lang_session.bindings[self._lang_session.binding_index_map[key]].dims[2] + return self._session.bindings[self._session.binding_index_map[key]].dims[2] def _fetch_generation_len(self, generation_len, max_gen_len): """ @@ -744,7 +704,7 @@ def update_decode_input(self, outputs, position_ids, generation_len, decode_batc self.generation_len[decode_batch_id or slice(None)] = generation_len return next_token_id - def run_prefill_for_all_inputs(self, image_queue, prompt_queue, processor, generation_len): + def run_prefill_for_all_inputs(self, prompt_queue, generation_len): """ Runs prefill for all inputs in the prompt queue and updates the decode input. @@ -755,21 +715,14 @@ def run_prefill_for_all_inputs(self, image_queue, prompt_queue, processor, gener generation_len (int): The generation length. """ - next_prompt = None - next_image = None for decode_batch_id in range(self.full_batch_size): - if prompt_queue: - next_prompt = prompt_queue.popleft() - if image_queue: - next_image = image_queue.popleft() + next_prompt = prompt_queue.popleft() # run prefill for num_chunks + # NOTE: We pass decode_batch_id=None during prefill to ensure batch_index is not added + # The decode_batch_id is only used for updating decode inputs, not for prefill execution outputs, position_ids, generation_len = self.run_prefill( - next_prompt, - next_image, - processor, - generation_len, - decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), + next_prompt, generation_len, decode_batch_id=None ) _ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) @@ -784,39 +737,14 @@ def _set_output_buffers(self, batch_size: int = 1, sequence_length: int = 1): if self.include_sampler: if self.return_pdfs: probs_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32) - self._lang_session.set_buffers({"probs": probs_out_placeholder}) + self._session.set_buffers({"probs": probs_out_placeholder}) next_tokens_out_placeholder = np.zeros((batch_size, sequence_length, 1), dtype=np.int64) - self._lang_session.set_buffers({"next_tokens": next_tokens_out_placeholder}) + self._session.set_buffers({"next_tokens": next_tokens_out_placeholder}) else: logits_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32) - self._lang_session.set_buffers({"logits": logits_out_placeholder}) - - def prepare_vision_language_inputs(self, processor, query, image_url): - image = Image.open(requests.get(image_url, stream=True).raw) - conversation = [ - { - "role": "user", - "content": [ - {"type": "text", "text": query}, - {"type": "image"}, - ], - }, - ] - prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) - inputs = processor(images=image, text=prompt, return_tensors="pt") - if "pixel_values" in inputs: - inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) - return inputs + self._session.set_buffers({"logits": logits_out_placeholder}) - def run_prefill( - self, - prompt: str, - image: Optional[str] = None, - processor: Optional[AutoImageProcessor] = None, - generation_len: Optional[int] = None, - prefill_logit_bs=1, - decode_batch_id=None, - ): + def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None): """ Runs prefill for a given prompt and generation length. @@ -833,17 +761,9 @@ def run_prefill( position_ids (array): The position IDs. generation_len (int): The generation length. """ - # Run prefill - if image: - inputs = self.prepare_vision_language_inputs(processor, prompt, image) - else: - inputs = self.tokenizer(prompt, return_tensors="np", padding=True) - - if "position_ids" in inputs: - position_ids = inputs["position_ids"] - else: - position_ids = inputs["attention_mask"].sum(1, keepdims=True) + inputs = self.tokenizer(prompt, return_tensors="np", padding=True) + position_ids = inputs["attention_mask"].sum(1, keepdims=True) padded_len = inputs["input_ids"].shape[1] num_chunks = -(padded_len // -self._prefill_seq_len) # ceil divide without float padded_len = num_chunks * self._prefill_seq_len # Convert to a multiple of prompt_len @@ -856,130 +776,60 @@ def run_prefill( # Set the prefill output buffers self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1) - vision_inputs = {} - vision_outputs = {} - if image: - pad_token_id = 1 - input_ids_length = inputs["input_ids"].shape[1] - num_chunks = -(input_ids_length // -self._prefill_seq_len) # ceil divide without float - padded_len = num_chunks * self._prefill_seq_len # Convert to a multiple of prompt_len - - inputs["input_ids"] = torch.nn.functional.pad( - inputs["input_ids"], - (0, padded_len - input_ids_length), - "constant", - pad_token_id, - ) - inputs["attention_mask"] = torch.nn.functional.pad( - inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 - ) - if "cross_attention_mask" in inputs: - inputs["cross_attention_mask"] = torch.nn.functional.pad( - inputs["cross_attention_mask"], (0, 0, 0, 0, 0, padded_len - input_ids_length) - ) - - for k, v in inputs.items(): - inputs[k] = np.array(v) - - vision_inputs = { - k: v - for k, v in inputs.items() - if k - in { - "pixel_values", - "image_masks", - "image_input_idx", - "valid_idx", - "aspect_ratio_ids", - "aspect_ratio_mask", - } - } - # if vision_inputs: - # vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") - vision_inputs_fp16 = {"pixel_values", "image_masks"} - vision_inputs.update( - {k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs} - ) - - # if not(self._lang_session.is_active): - # self._lang_session.activate() - # Run vision prefill - if vision_inputs: - # self._lang_session.pause() - self._vision_session.activate() - vision_outputs = self._vision_session.run(vision_inputs) - self._vision_session.deactivate() - # self._lang_session.resume() - else: - inputs = self.tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) - inputs.pop("token_type_ids", None) - - lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} - if "position_ids" in inputs: - lang_inputs["position_ids"] = inputs["position_ids"] - lang_inputs.pop("attention_mask") - else: - lang_inputs["position_ids"] = np.where( - lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 - ) # Need to use -1 as position_ids for invalid tokens - - # not_mllama = hasattr(self.model.config, "model_type") and self.model.config.model_type != "mllama" - # if not_mllama: - # lang_inputs["image_idx"] = np.array([[0]]) - if image: - lang_inputs["image_idx"] = np.array([[0]]) - - self._lang_session.activate() - self._lang_session.set_buffers(vision_outputs) + inputs = self.tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) - if decode_batch_id is not None: - lang_inputs["batch_index"] = decode_batch_id + # Note: batch_index is only used during decode, not prefill + # During prefill in CB mode, we use the prefill specialization (batch_size=1, seq_len=128) + # which doesn't have batch_index parameter + # However, image_idx may be needed for VLM models during prefill if self.is_tlm: - lang_inputs["num_logits_to_keep"] = np.zeros((1, 1)) + inputs["num_logits_to_keep"] = np.zeros((1, 1)) if self.include_sampler: - lang_inputs["last_accepted_output_tokens"] = lang_inputs["input_ids"] + inputs["last_accepted_output_tokens"] = inputs["input_ids"] for op in Constants.SAMPLER_OPS: if decode_batch_id is not None: - lang_inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] + inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] else: - lang_inputs[op] = self.sampling_params[op] + inputs[op] = self.sampling_params[op] if self._prompt_to_lora_id_mapping_prefill: if self.full_batch_size: - lang_inputs["lora_ids"] = np.array( + inputs["lora_ids"] = np.array( self._prompt_to_lora_id_mapping_prefill.popleft(), dtype=np.int64 ).reshape(1, 1) else: batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)] - lang_inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) + inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) + + # Check if image_idx is expected by the session (for VLM models) + if "image_idx" in self._session.input_names or "image_idx" in getattr(self._session, 'binding_index_map', {}): + try: + binding_idx = self._session.binding_index_map.get("image_idx") + dims = self._session.bindings[binding_idx].dims if binding_idx is not None else (1, 1) + inputs["image_idx"] = np.zeros(tuple(dims), dtype=np.int64) + except Exception: + inputs["image_idx"] = np.array([[0]], dtype=np.int64) - # Run language prefill - chunk_inputs = lang_inputs.copy() for i in range(num_chunks): - chunk_inputs["input_ids"] = lang_inputs["input_ids"][ + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][ :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len ] - chunk_inputs["position_ids"] = lang_inputs["position_ids"][ - ..., i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len + chunk_inputs["position_ids"] = inputs["position_ids"][ + :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len ] if self.include_sampler: chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] - outputs = self._lang_session.run(chunk_inputs) - if image: - chunk_inputs["image_idx"] = outputs["image_idx_output"] + outputs = self._session.run(chunk_inputs) + + # Update image_idx for next chunk if VLM model provides it + if "image_idx_output" in outputs: + inputs["image_idx"] = outputs["image_idx_output"] + if self._write_io_dir is not None: write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) - - # Skip inputs/outputs again - self._lang_session.skip_buffers( - [ - x - for x in self._lang_session.input_names + self._lang_session.output_names - if x.startswith("past_") or x.endswith("_RetainedState") - ] - ) - # self._lang_session.deactivate() - return ( outputs, position_ids, @@ -998,6 +848,15 @@ def run_continuous_batching_decode(self, prompt_queue, image_queue, processor, g """ + # Skip vision buffers during decode for vision-language models in CB mode + # Vision buffers don't have batch dimension and should only be used during prefill + if hasattr(self, '_vision_outputs'): + vision_buffer_names = [name for name in self._session.input_names + self._session.output_names + if 'vision' in name.lower() or 'pixel_values' in name.lower()] + if vision_buffer_names: + logger.debug(f"Skipping vision buffers during decode: {vision_buffer_names}") + self._session.skip_buffers(vision_buffer_names) + # Set output placeholders for decode self._set_output_buffers( batch_size=self.full_batch_size, @@ -1019,9 +878,8 @@ def run_continuous_batching_decode(self, prompt_queue, image_queue, processor, g next_prompt = None next_image = None - # self._lang_session.activate() # Due to activating new session (new exec_obj) run values are changing while prompt_queue or current_decode_ongoing.any(): - outputs = self._lang_session.run(decode_inputs) + outputs = self._session.run(decode_inputs) # Prepare inputs for next iteration next_token_id = self._fetch_next_token_id(outputs) @@ -1038,10 +896,8 @@ def run_continuous_batching_decode(self, prompt_queue, image_queue, processor, g start = perf_counter() # run prefill for next prompt input. outputs, position_ids, generation_len = self.run_prefill( - prompt=next_prompt, - image=next_image, - processor=processor, - generation_len=generation_len, + prompt_queue.popleft(), + generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), ) @@ -1076,8 +932,6 @@ def run_continuous_batching_decode(self, prompt_queue, image_queue, processor, g generated_id_current_index[decode_batch_id] += 1 - self._lang_session.deactivate() - return decode_pause_time def run_decode( @@ -1099,14 +953,13 @@ def run_decode( logits_out_placeholder = np.zeros( (self.batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32 ) - self._lang_session.set_buffers({"logits": logits_out_placeholder}) + self._session.set_buffers({"logits": logits_out_placeholder}) finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id num_token = 0 - self._lang_session.activate() for num_token in range(1, generation_len): if streamer: streamer.put(decode_inputs["input_ids"][0]) - outputs = self._lang_session.run(decode_inputs) + outputs = self._session.run(decode_inputs) if self._write_io_dir is not None: write_io_files(decode_inputs, outputs, self._write_io_dir, "decode", "aic_batch_io", True, False) @@ -1122,7 +975,6 @@ def run_decode( if finished_sequences.all() and not automation: break - self._lang_session.deactivate() return num_token def generate_decode_stream(self, decode_inputs, generation_len, automation): @@ -1139,10 +991,9 @@ def generate_decode_stream(self, decode_inputs, generation_len, automation): token_id (int): The token generated in the decoding process. """ finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id - self._lang_session.activate() for num_token in range(1, generation_len): yield decode_inputs["input_ids"] - outputs = self._lang_session.run(decode_inputs) + outputs = self._session.run(decode_inputs) if self._write_io_dir is not None: write_io_files(decode_inputs, outputs, self._write_io_dir, "decode", "aic_batch_io", True, False) @@ -1156,7 +1007,6 @@ def generate_decode_stream(self, decode_inputs, generation_len, automation): if finished_sequences.all() and not automation: break - self._lang_session.deactivate() yield decode_inputs["input_ids"] # yield the last token @@ -1164,9 +1014,7 @@ class TextGeneration: def __init__( self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - lang_qpc_path: str, - processor: Optional[AutoImageProcessor] = None, - vision_qpc_path: Optional[str] = None, + qpc_path: str, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, device_id: Optional[List[int]] = None, @@ -1179,9 +1027,7 @@ def __init__( ) -> None: self._qaic_model = QEffTextGenerationBase( tokenizer=tokenizer, - lang_qpc_path=lang_qpc_path, - processor=processor, - vision_qpc_path=vision_qpc_path, + qpc_path=qpc_path, full_batch_size=full_batch_size, ctx_len=ctx_len, device_id=device_id, @@ -1194,11 +1040,9 @@ def __init__( ) self._full_batch_size = self._qaic_model.full_batch_size self._tokenizer = self._qaic_model.tokenizer - self._processor = self._qaic_model.processor self._ctx_len = ctx_len self._perf_metrics = None self._prompt_queue = None - self._image_queue = None self._text_streamer = None @property @@ -1208,7 +1052,6 @@ def perf_metrics(self): def _setup_model_execution_inputs( self, prompt: List[str], - images: Optional[List[str]] = None, generation_len: Optional[int] = None, prompt_to_lora_id_mapping: Optional[List[int]] = None, ): @@ -1226,8 +1069,6 @@ def _setup_model_execution_inputs( # Create a prompt queue. self._prompt_queue = deque(prompt) - if images: - self._image_queue = deque(images) # Initialize np arrays for storing the prefill output for all the decode batch size. num_prompts = len(self._prompt_queue) @@ -1257,14 +1098,12 @@ def _regular_model_execution( :tuple: A tuple containing performance metrics and generated texts. """ - self._setup_model_execution_inputs( - prompt=prompt, generation_len=generation_len, prompt_to_lora_id_mapping=prompt_to_lora_id_mapping - ) + self._setup_model_execution_inputs(prompt, generation_len, prompt_to_lora_id_mapping) if stream and self._text_streamer is None: self._text_streamer = transformers.TextStreamer(self._tokenizer) start = perf_counter() outputs, position_ids, generation_len = self._qaic_model.run_prefill( - prompt=prompt, generation_len=generation_len, prefill_logit_bs=self._qaic_model.batch_size + prompt, generation_len, prefill_logit_bs=self._qaic_model.batch_size ) self._qaic_model.update_decode_input(outputs, position_ids, generation_len) @@ -1285,7 +1124,6 @@ def _regular_model_execution( def _continuous_batching_execution( self, prompt: List[str], - images: Optional[List[str]] = None, generation_len: Optional[int] = None, prompt_to_lora_id_mapping: Optional[List[int]] = None, ): @@ -1301,12 +1139,13 @@ def _continuous_batching_execution( Returns: :tuple: A tuple containing performance metrics and generated texts. """ - self._setup_model_execution_inputs(prompt, images, generation_len, prompt_to_lora_id_mapping) - self._qaic_model.batch_index = np.arange(self._full_batch_size).reshape(-1, 1) + self._setup_model_execution_inputs(prompt, generation_len, prompt_to_lora_id_mapping) start = perf_counter() - self._qaic_model.run_prefill_for_all_inputs( - self._image_queue, self._prompt_queue, self._processor, generation_len - ) + # Run prefill for all inputs first (batch_index should NOT be set during prefill) + self._qaic_model.run_prefill_for_all_inputs(self._prompt_queue, generation_len) + + # Now set batch_index for decode phase + self._qaic_model.batch_index = np.arange(self._full_batch_size).reshape(-1, 1) loop_start = perf_counter() # Start decode loop timer decode_pause_time = self._qaic_model.run_continuous_batching_decode( @@ -1352,7 +1191,7 @@ def generate_stream_tokens( self._setup_model_execution_inputs(prompt, generation_len, prompt_to_lora_id_mapping) start = perf_counter() outputs, position_ids, generation_len = self._qaic_model.run_prefill( - prompt=prompt, generation_len=generation_len, prefill_logit_bs=self._qaic_model.batch_size + prompt, generation_len, prefill_logit_bs=self._qaic_model.batch_size ) self._qaic_model.update_decode_input(outputs, position_ids, generation_len) @@ -1377,7 +1216,6 @@ def generate_stream_tokens( def generate( self, prompt: List[str], - images: Optional[List[str]] = None, generation_len: Optional[int] = None, stream: bool = True, automation: Optional[bool] = False, @@ -1398,7 +1236,7 @@ def generate( if self._full_batch_size is not None: logger.warning("Streamer is currently unavailable for continuous batch execution.") perf_metrics, generated_texts = self._continuous_batching_execution( - prompt, images, generation_len, prompt_to_lora_id_mapping + prompt, generation_len, prompt_to_lora_id_mapping ) else: if stream: diff --git a/QEfficient/generation/vision_handler.py b/QEfficient/generation/vision_handler.py new file mode 100644 index 000000000..5a67d237d --- /dev/null +++ b/QEfficient/generation/vision_handler.py @@ -0,0 +1,395 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Vision Handler for Vision-Language Models + +This module provides the VisionHandler class that encapsulates all vision model +operations, separating them from the main text generation logic. +""" + +import requests +import torch +import numpy as np +from typing import Any, Dict, List, Optional, Tuple, Union +from PIL import Image +from transformers import AutoImageProcessor + +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils.logging_utils import logger + + +class VisionHandler: + """ + Handles all vision model operations for vision-language models. + + This class encapsulates vision preprocessing, inference, and output handling, + providing a clean separation between vision and language processing. + """ + + def __init__( + self, + vision_session: Optional[QAICInferenceSession], + processor: Optional[AutoImageProcessor], + config: Optional[Dict[str, Any]] = None, + lang_session: Optional[QAICInferenceSession] = None + ): + """ + Initialize vision handler + + Args: + vision_session: QAICInferenceSession for vision model + processor: AutoImageProcessor for image preprocessing + config: Configuration dictionary with vision model parameters + lang_session: Optional language session for coordination (to avoid resource conflicts) + """ + self._vision_session = vision_session + self._processor = processor + self._config = config or {} + self._lang_session = lang_session # Store language session for coordination + + # Cache for vision output shapes + self._vision_output_shapes = None + + if self._vision_session and not self._processor: + logger.warning("Vision session provided but no processor. Vision functionality may be limited.") + + def is_available(self) -> bool: + """ + Check if vision processing is available + + Returns: + True if both vision session and processor are available + """ + return self._vision_session is not None and self._processor is not None + + def prepare_vision_inputs(self, image_url: str, query: str) -> Dict[str, np.ndarray]: + """ + Download and preprocess image into model inputs + + Args: + image_url: URL or path to image + query: Text query to process with image + + Returns: + Dictionary of vision model inputs + + Raises: + ValueError: If vision handler is not properly initialized + RuntimeError: If image processing fails + """ + if not self.is_available(): + raise ValueError("Vision handler not properly initialized. Need both vision_session and processor.") + + try: + # Download image + if image_url.startswith(('http://', 'https://')): + image = Image.open(requests.get(image_url, stream=True).raw) + else: + image = Image.open(image_url) + + # Prepare conversation format + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": query}, + {"type": "image"}, + ], + }, + ] + + # Apply chat template + prompt = self._processor.apply_chat_template(conversation, add_generation_prompt=True) + + # Process image and text + inputs = self._processor(images=image, text=prompt, return_tensors="pt") + + # Convert to float32 if needed + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + + # Convert to numpy arrays + vision_inputs = {} + for k, v in inputs.items(): + if k in { + "pixel_values", + "image_masks", + "image_input_idx", + "valid_idx", + "aspect_ratio_ids", + "aspect_ratio_mask" + }: + vision_inputs[k] = np.array(v) + + # Convert specific inputs to float16 + vision_inputs_fp16 = {"pixel_values", "image_masks"} + for k in vision_inputs_fp16: + if k in vision_inputs: + vision_inputs[k] = vision_inputs[k].astype("float16") + + return vision_inputs + + except Exception as e: + raise RuntimeError(f"Failed to process image {image_url}: {str(e)}") + + def run_vision_inference(self, vision_inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Execute vision model inference with session coordination + + Args: + vision_inputs: Preprocessed vision inputs + + Returns: + Vision embeddings and metadata + + Raises: + ValueError: If vision session is not available + RuntimeError: If inference fails + """ + if not self._vision_session: + raise ValueError("Vision session not available") + + lang_was_active = False + try: + # Coordinate with language session to avoid resource conflicts + if self._lang_session and self._lang_session.is_active: + logger.debug("Deactivating language session before vision inference") + self._lang_session.deactivate() + lang_was_active = True + + # Activate vision session + logger.debug("Activating vision session for inference") + self._vision_session.activate() + + # Run inference + vision_outputs = self._vision_session.run(vision_inputs) + + # Deactivate vision session + logger.debug("Deactivating vision session after inference") + self._vision_session.deactivate() + + # Reactivate language session if it was active before + if lang_was_active and self._lang_session: + logger.debug("Reactivating language session after vision inference") + self._lang_session.activate() + + return vision_outputs + + except Exception as e: + # Ensure proper cleanup on error + if self._vision_session: + try: + self._vision_session.deactivate() + except: + pass + + # Restore language session if needed + if lang_was_active and self._lang_session: + try: + self._lang_session.activate() + except: + pass + + raise RuntimeError(f"Vision inference failed: {str(e)}") + + def get_vision_output_shapes(self) -> Dict[str, Tuple[int, ...]]: + """ + Get vision output dimensions from config or session + + Returns: + Dictionary mapping output names to shapes + """ + if self._vision_output_shapes is not None: + return self._vision_output_shapes + + # Try to get from config first + if self._config and "vision_output_shapes" in self._config: + self._vision_output_shapes = self._config["vision_output_shapes"] + return self._vision_output_shapes + + # Try to derive from vision session + if self._vision_session: + try: + shapes = {} + for output_name in self._vision_session.output_names: + if hasattr(self._vision_session, 'bindings') and output_name in self._vision_session.binding_index_map: + binding_idx = self._vision_session.binding_index_map[output_name] + if hasattr(self._vision_session.bindings[binding_idx], 'dims'): + shapes[output_name] = tuple(self._vision_session.bindings[binding_idx].dims) + + if shapes: + self._vision_output_shapes = shapes + return shapes + except Exception as e: + logger.warning(f"Could not derive vision output shapes from session: {e}") + + # Fallback to default shapes (these were hard-coded in original implementation) + default_shapes = { + "vision_embeds": (2448, 5120) # This should be derived from model config + } + + logger.warning("Using default vision output shapes. Consider providing shapes in config.") + self._vision_output_shapes = default_shapes + return default_shapes + + def setup_vision_buffers(self): + """ + Configure vision model output buffers + + Raises: + ValueError: If vision session is not available + """ + if not self._vision_session: + raise ValueError("Vision session not available") + + try: + shapes = self.get_vision_output_shapes() + + # Set up output buffers + buffers = {} + for output_name, shape in shapes.items(): + # Create placeholder with appropriate dtype + if "vision_embeds" in output_name: + buffers[output_name] = np.zeros(shape, dtype=np.float16) + else: + buffers[output_name] = np.zeros(shape, dtype=np.float32) + + self._vision_session.set_buffers(buffers) + + except Exception as e: + raise RuntimeError(f"Failed to setup vision buffers: {str(e)}") + + def prepare_complete_vision_language_inputs( + self, + image_url: str, + query: str + ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: + """ + Complete pipeline: prepare inputs and run vision inference + + Args: + image_url: URL or path to image + query: Text query + + Returns: + Tuple of (vision_inputs, vision_outputs) + """ + # Prepare vision inputs + vision_inputs = self.prepare_vision_inputs(image_url, query) + + # Setup buffers + self.setup_vision_buffers() + + # Run vision inference + vision_outputs = self.run_vision_inference(vision_inputs) + + return vision_inputs, vision_outputs + + def get_language_inputs_from_vision_processing( + self, + image_url: str, + query: str, + padded_len: int + ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: + """ + Process vision inputs and prepare language model inputs + + Args: + image_url: URL or path to image + query: Text query + padded_len: Padded sequence length for language model + + Returns: + Tuple of (language_inputs, vision_outputs) + """ + if not self.is_available(): + raise ValueError("Vision handler not properly initialized") + + try: + # Download and process image + if image_url.startswith(('http://', 'https://')): + image = Image.open(requests.get(image_url, stream=True).raw) + else: + image = Image.open(image_url) + + # Prepare conversation + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": query}, + {"type": "image"}, + ], + }, + ] + + prompt = self._processor.apply_chat_template(conversation, add_generation_prompt=True) + inputs = self._processor(images=image, text=prompt, return_tensors="pt") + + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + + # Handle padding for language model + pad_token_id = 1 + input_ids_length = inputs["input_ids"].shape[1] + + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], + (0, padded_len - input_ids_length), + "constant", + pad_token_id, + ) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 + ) + + if "cross_attention_mask" in inputs: + inputs["cross_attention_mask"] = torch.nn.functional.pad( + inputs["cross_attention_mask"], (0, 0, 0, 0, 0, padded_len - input_ids_length) + ) + + # Convert to numpy + for k, v in inputs.items(): + inputs[k] = np.array(v) + + # Separate vision and language inputs + vision_inputs = { + k: v for k, v in inputs.items() + if k in { + "pixel_values", + "image_masks", + "image_input_idx", + "valid_idx", + "aspect_ratio_ids", + "aspect_ratio_mask" + } + } + + # Convert vision inputs to appropriate dtypes + vision_inputs_fp16 = {"pixel_values", "image_masks"} + for k in vision_inputs_fp16: + if k in vision_inputs: + vision_inputs[k] = vision_inputs[k].astype("float16") + + # Run vision inference if we have vision inputs + vision_outputs = {} + if vision_inputs: + self.setup_vision_buffers() + vision_outputs = self.run_vision_inference(vision_inputs) + + # Prepare language inputs + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + lang_inputs["position_ids"] = np.where( + lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 + ) + lang_inputs["image_idx"] = np.array([[0]]) + + return lang_inputs, vision_outputs + + except Exception as e: + raise RuntimeError(f"Failed to process vision-language inputs: {str(e)}") diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py new file mode 100644 index 000000000..33b918b70 --- /dev/null +++ b/QEfficient/generation/vlm_generation.py @@ -0,0 +1,633 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +This module provides the VisionLanguageGeneration class that inherits from +QEffTextGenerationBase, enabling all advanced text generation features while +maintaining full API compatibility with the original VisionLanguageGeneration.bn hv$$ Z&F + +Key enhancements: +- Continuous batching support for vision models +- Advanced streaming capabilities +- On-device sampling support +- LoRA adapter support +- Better performance metrics +""" + +from time import perf_counter +from typing import Any, Dict, List, Optional, Tuple, Union +from collections import deque + +import numpy as np +import transformers +from transformers import AutoImageProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast + +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.generation.vision_handler import VisionHandler +from QEfficient.generation.text_generation_inference import ( + QEffTextGenerationBase, + TextGeneration, + CloudAI100ExecInfo, + PerfMetrics, + calculate_latency, + write_io_files, +) +from QEfficient.utils.constants import Constants +from QEfficient.utils.logging_utils import logger + + +class VisionLanguageGeneration(QEffTextGenerationBase): + """ + Enhanced vision-language generation class inheriting from QEffTextGenerationBase. + + This class maintains full API compatibility with VisionLanguageGeneration while + adding advanced features like continuous batching, streaming, and sampling. + + Example: + >>> # Drop-in replacement for VisionLanguageGeneration + >>> vlm = VisionLanguageGeneration( + ... tokenizer=tokenizer, + ... processor=processor, + ... lang_qpc_path="path/to/lang.qpc", + ... vision_qpc_path="path/to/vision.qpc", + ... device_id=[0] + ... ) + >>> result = vlm.generate( + ... images=["image1.jpg"], + ... prompts=["Describe this image"], + ... generation_len=512 + ... ) + + >>> # Enhanced usage with new features + >>> vlm_enhanced = VisionLanguageGeneration( + ... tokenizer=tokenizer, + ... processor=processor, + ... lang_qpc_path="path/to/lang.qpc", + ... vision_qpc_path="path/to/vision.qpc", + ... device_id=[0], + ... full_batch_size=8, # Enable continuous batching + ... include_sampler=True, # Enable on-device sampling + ... sampling_params=sampling_config + ... ) + """ + + def __init__( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + processor: AutoImageProcessor, + lang_qpc_path: str, + vision_qpc_path: str, + device_id: Optional[List[int]] = None, + ctx_len: Optional[int] = None, + enable_debug_logs: bool = False, + write_io_dir: Optional[str] = None, + full_batch_size: Optional[int] = None, + is_tlm: bool = False, + include_sampler: bool = False, + return_pdfs: bool = False, + sampling_params: Optional[Dict[str, Any]] = None, + ): + """ + Initialize vision-language generation with enhanced capabilities + + Args: + tokenizer: Text tokenizer + processor: Image processor + lang_qpc_path: Path to language model QPC + vision_qpc_path: Path to vision encoder QPC + device_id: Device IDs for execution (default: [0]) + ctx_len: Context length + enable_debug_logs: Enable debug logging + write_io_dir: Directory for I/O file writing + full_batch_size: Enable continuous batching (new feature) + is_tlm: Target language model flag + include_sampler: Enable on-device sampling (new feature) + return_pdfs: Return probability distributions + sampling_params: Sampling parameters for on-device sampling + """ + # Validate required parameters + if not lang_qpc_path: + raise TypeError("lang_qpc_path is required") + if not vision_qpc_path: + raise TypeError("vision_qpc_path is required") + + # Initialize base class with language QPC + # Pass activate=False to prevent premature activation before vision components are ready + super().__init__( + tokenizer=tokenizer, + qpc_path=lang_qpc_path, + full_batch_size=full_batch_size, + ctx_len=ctx_len, + device_id=device_id, + enable_debug_logs=enable_debug_logs, + write_io_dir=write_io_dir, + is_tlm=is_tlm, + include_sampler=include_sampler, + return_pdfs=return_pdfs, + sampling_params=sampling_params, + activate=False, #vision components need to be initialized first + ) + + # Vision-specific initialization + self.processor = processor + self._vision_qpc_path = vision_qpc_path + self.device_id = device_id # Store device_id for vision components + self.enable_debug_logs = enable_debug_logs # Store for vision components + self._init_vision_components() + + # Now that vision components are initialized, activate the text session + self._session.activate() + + logger.info(f"VisionLanguageGeneration initialized: batch_size={self.batch_size}, " + f"prefill_seq_len={self._prefill_seq_len}, ctx_len={ctx_len}, " + f"continuous_batching={'enabled' if full_batch_size else 'disabled'}, " + f"sampling={'enabled' if include_sampler else 'disabled'}") + + def _init_vision_components(self): + """Initialize vision-specific components""" + # Vision session (separate from base class language session) + self._vision_session = QAICInferenceSession( + self._vision_qpc_path, + self.device_id, + activate=False, + enable_debug_logs=self.enable_debug_logs + ) + + # Vision handler with language session coordination + vision_config = self._get_vision_config() + self._vision_handler = VisionHandler( + vision_session=self._vision_session, + processor=self.processor, + config=vision_config, + lang_session=self._session # Pass language session for coordination + ) + + # Setup vision buffer skipping + self._setup_vision_buffer_skipping() + + def _get_vision_config(self) -> Dict[str, Any]: + """ + Derive vision config from session + + Returns: + Dictionary with vision configuration + """ + config = {} + if self._vision_session: + try: + shapes = {} + for output_name in self._vision_session.output_names: + if (hasattr(self._vision_session, 'bindings') and + output_name in self._vision_session.binding_index_map): + binding_idx = self._vision_session.binding_index_map[output_name] + if hasattr(self._vision_session.bindings[binding_idx], 'dims'): + shapes[output_name] = tuple( + self._vision_session.bindings[binding_idx].dims + ) + + if shapes: + config["vision_output_shapes"] = shapes + except Exception as e: + logger.warning(f"Could not derive vision config from session: {e}") + + return config + + def _setup_vision_buffer_skipping(self): + """Skip KV cache and retained state buffers for vision session""" + skip_patterns = [ + lambda x: x.startswith("past_"), + lambda x: x.endswith("_RetainedState") + ] + + buffers_to_skip = [ + x for x in self._vision_session.input_names + self._vision_session.output_names + if any(pattern(x) for pattern in skip_patterns) + ] + self._vision_session.skip_buffers(buffers_to_skip) + + def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None): + """ + Override base class prefill to handle vision processing + + Args: + prompt: Can be string or tuple (image_path, text_prompt) + generation_len: Generation length + prefill_logit_bs: Prefill batch size + decode_batch_id: Batch ID for continuous batching + + Returns: + Same as base class: (outputs, position_ids, generation_len) + """ + # Normalize prompt: TextGeneration passes a list even for batch_size=1 + if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], tuple) and len(prompt[0]) == 2: + # Unwrap single (image_path, text_prompt) tuple + if len(prompt) == 1: + prompt = prompt[0] + else: + raise NotImplementedError( + "VisionLanguageGeneration.run_prefill currently supports a single (image, text) pair per call." + ) + # Check if this is a vision-language prompt + if isinstance(prompt, tuple) and len(prompt) == 2: + image_path, text_prompt = prompt + + # Process vision inputs if this is the first time or if vision hasn't been processed yet + if not hasattr(self, '_vision_processed') or not self._vision_processed: + logger.debug("Processing vision inputs for the first time") + vision_inputs = self._vision_handler.prepare_vision_inputs(image_path, text_prompt) + self._vision_handler.setup_vision_buffers() + vision_outputs = self._vision_handler.run_vision_inference(vision_inputs) + + # Set vision buffers in language session (shared across all batch indices) + self._session.set_buffers(vision_outputs) + logger.debug(f"Vision buffers set: {list(vision_outputs.keys())}") + + # Mark vision as processed + self._vision_processed = True + self._vision_outputs = vision_outputs + + # Prepare the text prompt with vision context + processed_prompt = self._prepare_vision_language_prompt(text_prompt, image_path) + + # Compute padded_len aligned to prefill_seq_len using tokenizer on processed prompt + tmp_tokens = self.tokenizer(processed_prompt, return_tensors="np", padding=True) + input_len = tmp_tokens["input_ids"].shape[1] + num_chunks = -(input_len // -self._prefill_seq_len) + padded_len = num_chunks * self._prefill_seq_len + + # Build language inputs with processor-aware vision/text integration + lang_inputs, vision_outputs = self._vision_handler.get_language_inputs_from_vision_processing( + image_url=image_path, + query=text_prompt, + padded_len=padded_len + ) + + # Set vision buffers in language session + self._session.set_buffers(vision_outputs) + + # Calculate generation_len consistent with ctx_len + max_gen_len = self._ctx_len - np.where(lang_inputs["position_ids"] != -1, 1, 0).sum(1, keepdims=True).max() + generation_len = self._fetch_generation_len(generation_len, max_gen_len) + + # Set the prefill output buffers + self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1) + + # Run prefill across chunks, updating image_idx as needed + outputs = None + chunk_image_idx = None # track image_idx across chunks + for i in range(num_chunks): + input_ids_slice = lang_inputs["input_ids"][ + :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len + ] + position_ids_slice = lang_inputs["position_ids"][ + :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len + ] + # Build minimal input set to avoid unintended buffers (e.g., batch_index) during prefill + chunk_inputs = { + "input_ids": input_ids_slice, + "position_ids": position_ids_slice, + "image_idx": chunk_image_idx if chunk_image_idx is not None else np.array([[0]], dtype=np.int64), + } + if "cross_attention_mask" in lang_inputs: + chunk_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"] + + outputs = self._session.run(chunk_inputs) + + # Update image_idx for next chunk if provided by model + if "image_idx_output" in outputs: + chunk_image_idx = outputs["image_idx_output"] + + if self._write_io_dir is not None: + write_io_files(lang_inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) + + # Prepare position_ids for decode phase (next position after prefill) + position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 + + # Prepare decode-time cross_attention_mask (ones over image tiles) if available + if "cross_attention_mask" in lang_inputs: + bs, _, num_images, img_tiles = lang_inputs["cross_attention_mask"].shape + self._decode_cross_attention_mask = np.ones((bs, 1, num_images, img_tiles), dtype=np.int64) + else: + self._decode_cross_attention_mask = None + + # Skip retained_state and past_ buffers before decode for dual-QPC coordination + self._session.skip_buffers( + [ + x + for x in self._session.input_names + self._session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + ) + + return outputs, position_ids_decode, generation_len + else: + # Fall back to base class for text-only + return super().run_prefill(prompt, generation_len, prefill_logit_bs, decode_batch_id) + + def _prepare_vision_language_prompt(self, text_prompt, image_path): + """ + Prepare text prompt with vision context + + This method handles the integration of vision and text inputs + according to the specific model's requirements. + """ + # For most vision-language models, we need to apply the chat template + # that includes both image and text components + try: + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": text_prompt}, + {"type": "image"}, + ], + }, + ] + + # Apply chat template + processed_prompt = self.processor.apply_chat_template( + conversation, + add_generation_prompt=True + ) + + return processed_prompt + + except Exception as e: + logger.warning(f"Failed to apply chat template: {e}. Using original prompt.") + return text_prompt + + def generate( + self, + images: List[str], + prompts: List[str], + generation_len: Optional[int] = None, + stream: bool = True, + **kwargs + ) -> CloudAI100ExecInfo: + """ + Main generation method maintaining API compatibility with VisionLanguageGeneration + + Args: + images: List of image URLs/paths + prompts: List of text prompts + generation_len: Max generation length + stream: Enable streaming output + **kwargs: Additional arguments passed to base class + + Returns: + CloudAI100ExecInfo with results and metrics + + Raises: + ValueError: If images and prompts lengths don't match + """ + if len(images) != len(prompts): + raise ValueError( + f"Number of images ({len(images)}) must match number of prompts ({len(prompts)})" + ) + + logger.info(f"Generating for {len(images)} image-prompt pairs") + + # Convert to base class format: list of (image, prompt) tuples + vision_prompts = [(img, prompt) for img, prompt in zip(images, prompts)] + + # Use base class generate method with vision prompts + if self.full_batch_size is not None: + # Continuous batching mode (new capability) + return self._generate_continuous_batching(vision_prompts, generation_len, stream, **kwargs) + else: + # Regular batching mode + return self._generate_regular_batching(vision_prompts, generation_len, stream, **kwargs) + + def _generate_regular_batching(self, vision_prompts, generation_len, stream, **kwargs): + """Handle regular batching for vision-language generation without creating a second language session""" + batch_results = [] + for i in range(0, len(vision_prompts), self.batch_size): + batch = vision_prompts[i : i + self.batch_size] + + if stream: + print(f"\nProcessing batch {i // self.batch_size + 1}/{(len(vision_prompts) - 1) // self.batch_size + 1}") + for j, (img, prompt) in enumerate(batch): + print(f"Image: {img}") + print(f"Prompt: {prompt}") + print("Completion:", flush=True, end="") + + # Setup decode storage arrays for this batch (use ctx_len or generation_len whichever is larger) + exec_batch_size = self.batch_size + max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len) + self.initialize_decode_inputs(num_prompts=len(batch), execution_batch_size=exec_batch_size, max_gen_length=max_gen_length) + + # Prefill using VLM-aware run_prefill (batch is a list of (image, text)) + start = perf_counter() + outputs, position_ids, generation_len_final = self.run_prefill(batch, generation_len, prefill_logit_bs=self.batch_size) + self.update_decode_input(outputs, position_ids, generation_len_final) + + # Prepare decode + decode_inputs = self.prepare_decode_inputs() + + # Decode loop + loop_start = perf_counter() + num_token = self.run_decode(decode_inputs, generation_len_final, automation=False, streamer=None) + end = perf_counter() + + # Decode generated texts + generated_texts = self.tokenizer.batch_decode(self.generated_ids, skip_special_tokens=True) + + # Latency metrics + total_decode_tokens = num_token + prefill_time, decode_perf, total_perf, total_time = calculate_latency(total_decode_tokens, loop_start, start, end) + perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time) + + # Package result for this batch + batch_results.append( + CloudAI100ExecInfo( + batch_size=self.batch_size, + generated_texts=generated_texts, + generated_ids=self.generated_ids, + perf_metrics=perf_metrics, + ) + ) + + # Aggregate results across batches + return self._aggregate_batch_results(batch_results) + + def _generate_continuous_batching(self, vision_prompts, generation_len, stream, **kwargs): + """Enable continuous batching for vision-language models (new capability)""" + logger.info("Using continuous batching for vision-language generation") + + if stream: + logger.warning("Streaming output not fully supported with continuous batching") + + # Reset vision processing state for new generation + self._vision_processed = False + self._vision_outputs = None + + # Use the base class continuous batching logic directly + # Initialize decode inputs + num_prompts = len(vision_prompts) + execution_batch_size = self.full_batch_size + max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len) + + self.initialize_decode_inputs(num_prompts, execution_batch_size, max_gen_length) + + # Create prompt queue + prompt_queue = deque(vision_prompts) + + start = perf_counter() + + # IMPORTANT: Ensure batch_index is None during prefill phase + # Store the current batch_index and set it to None for prefill + saved_batch_index = self.batch_index + self.batch_index = None + + # Run prefill for all inputs first (batch_index should NOT be set during prefill) + self.run_prefill_for_all_inputs(prompt_queue, generation_len) + + # Now set batch_index for decode phase + self.batch_index = np.arange(self.full_batch_size).reshape(-1, 1) + + loop_start = perf_counter() + decode_pause_time = self.run_continuous_batching_decode(prompt_queue, generation_len) + end = perf_counter() + + generated_texts = self.tokenizer.batch_decode(self.generated_ids, skip_special_tokens=True) + + total_decode_tokens = sum( + np.sum(self.generated_ids[i] != self.tokenizer.pad_token_id) - 1 for i in range(len(vision_prompts)) + ) + prefill_time, decode_perf, total_perf, total_time = calculate_latency( + total_decode_tokens, loop_start, start, end, decode_pause_time + ) + prefill_time /= len(vision_prompts) # Average prefill time for continuous batching + + perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time) + + return CloudAI100ExecInfo( + batch_size=1, + generated_texts=generated_texts, + generated_ids=self.generated_ids, + perf_metrics=perf_metrics + ) + + def prepare_decode_inputs(self): + """ + Override base class to handle vision-specific decode inputs + """ + decode_inputs = super().prepare_decode_inputs() + + # Add image_idx for vision-language models in CB mode during decode only + if self.batch_index is not None and hasattr(self, '_vision_outputs'): + # image_idx should point to vision embedding (0 for all batch positions) + # since vision embeddings are shared across all batch indices + decode_inputs["image_idx"] = np.zeros_like(self.batch_index) + + # Include cross_attention_mask during decode if present/required + if hasattr(self, '_decode_cross_attention_mask') and self._decode_cross_attention_mask is not None: + decode_inputs["cross_attention_mask"] = self._decode_cross_attention_mask + + return decode_inputs + + def _aggregate_batch_results(self, batch_results): + """Aggregate results from multiple batches""" + if not batch_results: + raise ValueError("No batch results to aggregate") + + if len(batch_results) == 1: + return batch_results[0] + + # Aggregate multiple batch results + all_generated_texts = [] + all_generated_ids = [] + all_metrics = [] + + for result in batch_results: + if isinstance(result.generated_texts[0], list): + # Flatten nested lists + all_generated_texts.extend([text for batch in result.generated_texts for text in batch]) + else: + all_generated_texts.extend(result.generated_texts) + + if isinstance(result.generated_ids, list): + all_generated_ids.extend(result.generated_ids) + else: + all_generated_ids.append(result.generated_ids) + + all_metrics.append(result.perf_metrics) + + # Average metrics + avg_metrics = PerfMetrics( + prefill_time=np.mean([m.prefill_time for m in all_metrics]), + decode_perf=np.mean([m.decode_perf for m in all_metrics]), + total_perf=np.mean([m.total_perf for m in all_metrics]), + total_time=np.mean([m.total_time for m in all_metrics]) + ) + + return CloudAI100ExecInfo( + batch_size=batch_results[0].batch_size, + generated_texts=all_generated_texts, + generated_ids=all_generated_ids, + perf_metrics=avg_metrics + ) + + def generate_stream_tokens( + self, + images: List[str], + prompts: List[str], + generation_len: Optional[int] = None, + **kwargs + ): + """ + Enable token-by-token streaming for vision models (new capability) + + Args: + images: List of image URLs/paths + prompts: List of text prompts + generation_len: Max generation length + **kwargs: Additional arguments + + Yields: + List of decoded tokens for each batch position + + Raises: + NotImplementedError: If continuous batching is enabled + """ + if self.full_batch_size is not None: + raise NotImplementedError("Token streaming not supported with continuous batching for VLM") + + if len(images) != len(prompts): + raise ValueError( + f"Number of images ({len(images)}) must match number of prompts ({len(prompts)})" + ) + + logger.info(f"Starting token streaming for {len(images)} image-prompt pairs") + + vision_prompts = [(img, prompt) for img, prompt in zip(images, prompts)] + + text_gen = TextGeneration( + tokenizer=self.tokenizer, + qpc_path=self._qpc_path, + ctx_len=self._ctx_len, + device_id=self.device_id, + enable_debug_logs=self.enable_debug_logs, + is_tlm=self.is_tlm, + include_sampler=self.include_sampler, + return_pdfs=self.return_pdfs, + sampling_params=self.sampling_params, + ) + + text_gen._qaic_model = self + + # Yield tokens as they're generated + for tokens in text_gen.generate_stream_tokens(vision_prompts, generation_len, **kwargs): + yield tokens + + def __repr__(self): + """String representation of the class""" + return (f"VisionLanguageGeneration(" + f"batch_size={self.batch_size}, " + f"ctx_len={self._ctx_len}, " + f"continuous_batching={'enabled' if self.full_batch_size else 'disabled'}, " + f"sampling={'enabled' if self.include_sampler else 'disabled'})") diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index eed3782db..24a3de9bd 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -36,6 +36,7 @@ calculate_latency, get_compilation_dims, ) +from QEfficient.generation.vlm_generation import VisionLanguageGeneration from QEfficient.transformers.modeling_utils import DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH from QEfficient.transformers.models.pytorch_transforms import ( CustomOpsTransform, @@ -1188,6 +1189,14 @@ def generate( inputs : Dict[str, Union[torch.Tensor, np.ndarray]] Inputs to run the execution, typically includes `pixel_values`, `input_ids`, `attention_mask`, etc. + tokenizer : PreTrainedTokenizer or PreTrainedTokenizerFast, optional + Tokenizer for the model. Used when images and prompts are provided. + processor : AutoImageProcessor, optional + Processor for the model. Used when images and prompts are provided. + images : List[str], optional + List of image paths or PIL images to process. + prompts : List[str], optional + List of text prompts corresponding to the images. streamer : TextStreamer, optional A streamer object to display generated tokens in real-time. Default is None. device_ids : List[int], optional @@ -1212,17 +1221,27 @@ def generate( if not runtime_ai100: raise NotImplementedError("PyTorch execution is not supported yet for this model!") + # Use VisionLanguageGeneration for image-prompt pairs if (processor and images) or (tokenizer and prompts): - return QEfficient.cloud_ai_100_exec_kv( - tokenizer=tokenizer, - processor=processor, + # Create VisionLanguageGeneration instance + vlm_gen = VisionLanguageGeneration( lang_qpc_path=self.lang_model.qpc_path, vision_qpc_path=self.vision_model.qpc_path, + tokenizer=tokenizer, + processor=processor, + device_id=device_ids, # if device_ids is not None else [0], + ctx_len=3072 # TODO need to get it from the QPC + ) + + # Call generate method + return vlm_gen.generate( images=images, - prompt=prompts, - device_id=device_ids, + prompts=prompts, generation_len=generation_len, + stream=streamer is not None, ) + + # Fallback to kv_offload_generate for direct inputs (backward compatibility) return self.kv_offload_generate( inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len ) diff --git a/examples/llama4_CB_example_vision_lang.py b/examples/llama4_CB_example_vision_lang.py index ebe65bf82..8711e0f40 100644 --- a/examples/llama4_CB_example_vision_lang.py +++ b/examples/llama4_CB_example_vision_lang.py @@ -63,3 +63,5 @@ device_ids=[0, 1, 2, 3], generation_len=100, ) + +print(output) From e663cd6f316edc983374a2d4fbb0924a04dabd7b Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Tue, 28 Oct 2025 04:52:10 +0000 Subject: [PATCH 10/24] Qwen2.5vl CB Update Signed-off-by: Mohit Soni Signed-off-by: Rishin Raj --- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 8d316211b..8df72d0c9 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -763,7 +763,7 @@ def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = # Add data for KV kv_cache_shape = get_padding_shape_from_config( - config=self.model.config, + config=self.model.config.text_config, batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -883,6 +883,7 @@ def smart_resize( "seq_len": prefill_seq_len, "ctx_len": ctx_len, "vision_size": vision_size, + "vision_batch_size": batch_size, } if continuous_batching: @@ -897,6 +898,7 @@ def smart_resize( "seq_len": 1, "ctx_len": ctx_len, "vision_size": vision_size, + "vision_batch_size": batch_size, } if continuous_batching: @@ -931,16 +933,22 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: b lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {1: "batch_size", 2: "seq_len"}, - "vision_embeds": {0: "batch_size", 1: "vision_size"}, + "vision_embeds": {0: "vision_batch_size", 1: "vision_size"}, } + for i in range(num_layers): + lang_dynamic_axes[f"past_key.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + lang_dynamic_axes[f"past_value.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + if continuous_batching: lang_dynamic_axes["batch_index"] = {0: "batch_size"} - for i in range(num_layers): - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} - dynamic_axes = {} if kv_offload: From 799af59b4c761b257efc78e1d329d84bde7c1338 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Tue, 28 Oct 2025 05:44:18 +0000 Subject: [PATCH 11/24] Lint fix and code cleaning Signed-off-by: Rishin Raj --- .../generation/text_generation_inference.py | 45 +-- QEfficient/generation/vision_handler.py | 193 ++++++----- QEfficient/generation/vlm_generation.py | 307 ++++++++---------- 3 files changed, 252 insertions(+), 293 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 2a9f5bc3a..81eb711c0 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -447,7 +447,9 @@ def __init__( self._qpc_path = qpc_path # Store qpc_path for later use # Load QPC - self._session = QAICInferenceSession(qpc_path, device_id, activate=activate, enable_debug_logs=enable_debug_logs) + self._session = QAICInferenceSession( + qpc_path, device_id, activate=activate, enable_debug_logs=enable_debug_logs + ) # Validate sampler inputs for On-Device Sampling self.include_sampler = validate_sampler_inputs( @@ -719,10 +721,8 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len): next_prompt = prompt_queue.popleft() # run prefill for num_chunks - # NOTE: We pass decode_batch_id=None during prefill to ensure batch_index is not added - # The decode_batch_id is only used for updating decode inputs, not for prefill execution outputs, position_ids, generation_len = self.run_prefill( - next_prompt, generation_len, decode_batch_id=None + next_prompt, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1) ) _ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) @@ -780,10 +780,9 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) inputs.pop("token_type_ids", None) - # Note: batch_index is only used during decode, not prefill - # During prefill in CB mode, we use the prefill specialization (batch_size=1, seq_len=128) - # which doesn't have batch_index parameter - # However, image_idx may be needed for VLM models during prefill + if decode_batch_id is not None: + inputs["batch_index"] = decode_batch_id + if self.is_tlm: inputs["num_logits_to_keep"] = np.zeros((1, 1)) if self.include_sampler: @@ -804,7 +803,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) # Check if image_idx is expected by the session (for VLM models) - if "image_idx" in self._session.input_names or "image_idx" in getattr(self._session, 'binding_index_map', {}): + if "image_idx" in self._session.input_names or "image_idx" in getattr(self._session, "binding_index_map", {}): try: binding_idx = self._session.binding_index_map.get("image_idx") dims = self._session.bindings[binding_idx].dims if binding_idx is not None else (1, 1) @@ -823,11 +822,11 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i if self.include_sampler: chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] outputs = self._session.run(chunk_inputs) - + # Update image_idx for next chunk if VLM model provides it if "image_idx_output" in outputs: inputs["image_idx"] = outputs["image_idx_output"] - + if self._write_io_dir is not None: write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) return ( @@ -836,7 +835,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i generation_len, ) - def run_continuous_batching_decode(self, prompt_queue, image_queue, processor, generation_len): + def run_continuous_batching_decode(self, prompt_queue, generation_len): """ Runs continuous batching decode for the given prompt queue and generation length. @@ -848,15 +847,6 @@ def run_continuous_batching_decode(self, prompt_queue, image_queue, processor, g """ - # Skip vision buffers during decode for vision-language models in CB mode - # Vision buffers don't have batch dimension and should only be used during prefill - if hasattr(self, '_vision_outputs'): - vision_buffer_names = [name for name in self._session.input_names + self._session.output_names - if 'vision' in name.lower() or 'pixel_values' in name.lower()] - if vision_buffer_names: - logger.debug(f"Skipping vision buffers during decode: {vision_buffer_names}") - self._session.skip_buffers(vision_buffer_names) - # Set output placeholders for decode self._set_output_buffers( batch_size=self.full_batch_size, @@ -875,8 +865,6 @@ def run_continuous_batching_decode(self, prompt_queue, image_queue, processor, g decode_pause_time = 0 # Prepare decode inputs inputs. decode_inputs = self.prepare_decode_inputs() - next_prompt = None - next_image = None while prompt_queue or current_decode_ongoing.any(): outputs = self._session.run(decode_inputs) @@ -889,10 +877,7 @@ def run_continuous_batching_decode(self, prompt_queue, image_queue, processor, g next_token_id[decode_batch_id, -1] == self.tokenizer.eos_token_id or generated_id_current_index[decode_batch_id] >= self.generation_len[decode_batch_id] ): - if image_queue: - next_image = image_queue.popleft() if prompt_queue: - next_prompt = prompt_queue.popleft() start = perf_counter() # run prefill for next prompt input. outputs, position_ids, generation_len = self.run_prefill( @@ -1141,16 +1126,10 @@ def _continuous_batching_execution( """ self._setup_model_execution_inputs(prompt, generation_len, prompt_to_lora_id_mapping) start = perf_counter() - # Run prefill for all inputs first (batch_index should NOT be set during prefill) self._qaic_model.run_prefill_for_all_inputs(self._prompt_queue, generation_len) - - # Now set batch_index for decode phase - self._qaic_model.batch_index = np.arange(self._full_batch_size).reshape(-1, 1) loop_start = perf_counter() # Start decode loop timer - decode_pause_time = self._qaic_model.run_continuous_batching_decode( - self._prompt_queue, self._image_queue, self._processor, generation_len - ) + decode_pause_time = self._qaic_model.run_continuous_batching_decode(self._prompt_queue, generation_len) end = perf_counter() generated_texts = self._tokenizer.batch_decode(self._qaic_model.generated_ids, skip_special_tokens=True) diff --git a/QEfficient/generation/vision_handler.py b/QEfficient/generation/vision_handler.py index 5a67d237d..558aa5eb6 100644 --- a/QEfficient/generation/vision_handler.py +++ b/QEfficient/generation/vision_handler.py @@ -12,10 +12,11 @@ operations, separating them from the main text generation logic. """ +from typing import Any, Dict, Optional, Tuple + +import numpy as np import requests import torch -import numpy as np -from typing import Any, Dict, List, Optional, Tuple, Union from PIL import Image from transformers import AutoImageProcessor @@ -26,21 +27,21 @@ class VisionHandler: """ Handles all vision model operations for vision-language models. - + This class encapsulates vision preprocessing, inference, and output handling, providing a clean separation between vision and language processing. """ - + def __init__( self, vision_session: Optional[QAICInferenceSession], processor: Optional[AutoImageProcessor], config: Optional[Dict[str, Any]] = None, - lang_session: Optional[QAICInferenceSession] = None + lang_session: Optional[QAICInferenceSession] = None, ): """ Initialize vision handler - + Args: vision_session: QAICInferenceSession for vision model processor: AutoImageProcessor for image preprocessing @@ -51,47 +52,47 @@ def __init__( self._processor = processor self._config = config or {} self._lang_session = lang_session # Store language session for coordination - + # Cache for vision output shapes self._vision_output_shapes = None - + if self._vision_session and not self._processor: logger.warning("Vision session provided but no processor. Vision functionality may be limited.") - + def is_available(self) -> bool: """ Check if vision processing is available - + Returns: True if both vision session and processor are available """ return self._vision_session is not None and self._processor is not None - + def prepare_vision_inputs(self, image_url: str, query: str) -> Dict[str, np.ndarray]: """ Download and preprocess image into model inputs - + Args: image_url: URL or path to image query: Text query to process with image - + Returns: Dictionary of vision model inputs - + Raises: ValueError: If vision handler is not properly initialized RuntimeError: If image processing fails """ if not self.is_available(): raise ValueError("Vision handler not properly initialized. Need both vision_session and processor.") - + try: # Download image - if image_url.startswith(('http://', 'https://')): + if image_url.startswith(("http://", "https://")): image = Image.open(requests.get(image_url, stream=True).raw) else: image = Image.open(image_url) - + # Prepare conversation format conversation = [ { @@ -102,58 +103,58 @@ def prepare_vision_inputs(self, image_url: str, query: str) -> Dict[str, np.ndar ], }, ] - + # Apply chat template prompt = self._processor.apply_chat_template(conversation, add_generation_prompt=True) - + # Process image and text inputs = self._processor(images=image, text=prompt, return_tensors="pt") - + # Convert to float32 if needed if "pixel_values" in inputs: inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) - + # Convert to numpy arrays vision_inputs = {} for k, v in inputs.items(): if k in { "pixel_values", - "image_masks", + "image_masks", "image_input_idx", "valid_idx", "aspect_ratio_ids", - "aspect_ratio_mask" + "aspect_ratio_mask", }: vision_inputs[k] = np.array(v) - + # Convert specific inputs to float16 vision_inputs_fp16 = {"pixel_values", "image_masks"} for k in vision_inputs_fp16: if k in vision_inputs: vision_inputs[k] = vision_inputs[k].astype("float16") - + return vision_inputs - + except Exception as e: raise RuntimeError(f"Failed to process image {image_url}: {str(e)}") - + def run_vision_inference(self, vision_inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: """ Execute vision model inference with session coordination - + Args: vision_inputs: Preprocessed vision inputs - + Returns: Vision embeddings and metadata - + Raises: ValueError: If vision session is not available RuntimeError: If inference fails """ if not self._vision_session: raise ValueError("Vision session not available") - + lang_was_active = False try: # Coordinate with language session to avoid resource conflicts @@ -161,95 +162,98 @@ def run_vision_inference(self, vision_inputs: Dict[str, np.ndarray]) -> Dict[str logger.debug("Deactivating language session before vision inference") self._lang_session.deactivate() lang_was_active = True - + # Activate vision session logger.debug("Activating vision session for inference") self._vision_session.activate() - + # Run inference vision_outputs = self._vision_session.run(vision_inputs) - + # Deactivate vision session logger.debug("Deactivating vision session after inference") self._vision_session.deactivate() - + # Reactivate language session if it was active before if lang_was_active and self._lang_session: logger.debug("Reactivating language session after vision inference") self._lang_session.activate() - + return vision_outputs - + except Exception as e: # Ensure proper cleanup on error if self._vision_session: try: self._vision_session.deactivate() - except: - pass - + except Exception: + logger.warning("Deactivating vision session failed") + # Restore language session if needed if lang_was_active and self._lang_session: try: self._lang_session.activate() - except: - pass - + except Exception: + logger.warning("Deactivating language session failed") + raise RuntimeError(f"Vision inference failed: {str(e)}") - + def get_vision_output_shapes(self) -> Dict[str, Tuple[int, ...]]: """ Get vision output dimensions from config or session - + Returns: Dictionary mapping output names to shapes """ if self._vision_output_shapes is not None: return self._vision_output_shapes - + # Try to get from config first if self._config and "vision_output_shapes" in self._config: self._vision_output_shapes = self._config["vision_output_shapes"] return self._vision_output_shapes - + # Try to derive from vision session if self._vision_session: try: shapes = {} for output_name in self._vision_session.output_names: - if hasattr(self._vision_session, 'bindings') and output_name in self._vision_session.binding_index_map: + if ( + hasattr(self._vision_session, "bindings") + and output_name in self._vision_session.binding_index_map + ): binding_idx = self._vision_session.binding_index_map[output_name] - if hasattr(self._vision_session.bindings[binding_idx], 'dims'): + if hasattr(self._vision_session.bindings[binding_idx], "dims"): shapes[output_name] = tuple(self._vision_session.bindings[binding_idx].dims) - + if shapes: self._vision_output_shapes = shapes return shapes except Exception as e: logger.warning(f"Could not derive vision output shapes from session: {e}") - + # Fallback to default shapes (these were hard-coded in original implementation) default_shapes = { "vision_embeds": (2448, 5120) # This should be derived from model config } - + logger.warning("Using default vision output shapes. Consider providing shapes in config.") self._vision_output_shapes = default_shapes return default_shapes - + def setup_vision_buffers(self): """ Configure vision model output buffers - + Raises: ValueError: If vision session is not available """ if not self._vision_session: raise ValueError("Vision session not available") - + try: shapes = self.get_vision_output_shapes() - + # Set up output buffers buffers = {} for output_name, shape in shapes.items(): @@ -258,86 +262,81 @@ def setup_vision_buffers(self): buffers[output_name] = np.zeros(shape, dtype=np.float16) else: buffers[output_name] = np.zeros(shape, dtype=np.float32) - + self._vision_session.set_buffers(buffers) - + except Exception as e: raise RuntimeError(f"Failed to setup vision buffers: {str(e)}") - + def prepare_complete_vision_language_inputs( - self, - image_url: str, - query: str + self, image_url: str, query: str ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: """ Complete pipeline: prepare inputs and run vision inference - + Args: image_url: URL or path to image query: Text query - + Returns: Tuple of (vision_inputs, vision_outputs) """ # Prepare vision inputs vision_inputs = self.prepare_vision_inputs(image_url, query) - + # Setup buffers self.setup_vision_buffers() - + # Run vision inference vision_outputs = self.run_vision_inference(vision_inputs) - + return vision_inputs, vision_outputs - + def get_language_inputs_from_vision_processing( - self, - image_url: str, - query: str, - padded_len: int + self, image_url: str, query: str, padded_len: int ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: """ Process vision inputs and prepare language model inputs - + Args: image_url: URL or path to image query: Text query padded_len: Padded sequence length for language model - + Returns: Tuple of (language_inputs, vision_outputs) """ if not self.is_available(): raise ValueError("Vision handler not properly initialized") - + try: # Download and process image - if image_url.startswith(('http://', 'https://')): + if image_url.startswith(("http://", "https://")): image = Image.open(requests.get(image_url, stream=True).raw) else: image = Image.open(image_url) - + # Prepare conversation conversation = [ { - "role": "user", + "role": "user", "content": [ {"type": "text", "text": query}, {"type": "image"}, ], }, ] - + prompt = self._processor.apply_chat_template(conversation, add_generation_prompt=True) inputs = self._processor(images=image, text=prompt, return_tensors="pt") - + if "pixel_values" in inputs: inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) - + # Handle padding for language model pad_token_id = 1 input_ids_length = inputs["input_ids"].shape[1] - + inputs["input_ids"] = torch.nn.functional.pad( inputs["input_ids"], (0, padded_len - input_ids_length), @@ -347,49 +346,49 @@ def get_language_inputs_from_vision_processing( inputs["attention_mask"] = torch.nn.functional.pad( inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 ) - + if "cross_attention_mask" in inputs: inputs["cross_attention_mask"] = torch.nn.functional.pad( inputs["cross_attention_mask"], (0, 0, 0, 0, 0, padded_len - input_ids_length) ) - + # Convert to numpy for k, v in inputs.items(): inputs[k] = np.array(v) - + # Separate vision and language inputs vision_inputs = { - k: v for k, v in inputs.items() - if k in { + k: v + for k, v in inputs.items() + if k + in { "pixel_values", "image_masks", - "image_input_idx", + "image_input_idx", "valid_idx", "aspect_ratio_ids", - "aspect_ratio_mask" + "aspect_ratio_mask", } } - + # Convert vision inputs to appropriate dtypes vision_inputs_fp16 = {"pixel_values", "image_masks"} for k in vision_inputs_fp16: if k in vision_inputs: vision_inputs[k] = vision_inputs[k].astype("float16") - + # Run vision inference if we have vision inputs vision_outputs = {} if vision_inputs: self.setup_vision_buffers() vision_outputs = self.run_vision_inference(vision_inputs) - + # Prepare language inputs lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} - lang_inputs["position_ids"] = np.where( - lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 - ) + lang_inputs["position_ids"] = np.where(lang_inputs.pop("attention_mask"), np.arange(padded_len), -1) lang_inputs["image_idx"] = np.array([[0]]) - + return lang_inputs, vision_outputs - + except Exception as e: raise RuntimeError(f"Failed to process vision-language inputs: {str(e)}") diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 33b918b70..edcef7364 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -8,45 +8,43 @@ """ This module provides the VisionLanguageGeneration class that inherits from QEffTextGenerationBase, enabling all advanced text generation features while -maintaining full API compatibility with the original VisionLanguageGeneration.bn hv$$ Z&F +maintaining full API compatibility with the original VisionLanguageGeneration.bn hv$$ Z&F Key enhancements: - Continuous batching support for vision models -- Advanced streaming capabilities +- Advanced streaming capabilities - On-device sampling support - LoRA adapter support - Better performance metrics """ -from time import perf_counter -from typing import Any, Dict, List, Optional, Tuple, Union from collections import deque +from time import perf_counter +from typing import Any, Dict, List, Optional, Union import numpy as np -import transformers from transformers import AutoImageProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast from QEfficient.generation.cloud_infer import QAICInferenceSession -from QEfficient.generation.vision_handler import VisionHandler from QEfficient.generation.text_generation_inference import ( - QEffTextGenerationBase, - TextGeneration, CloudAI100ExecInfo, PerfMetrics, + QEffTextGenerationBase, + TextGeneration, calculate_latency, write_io_files, ) -from QEfficient.utils.constants import Constants +from QEfficient.generation.vision_handler import VisionHandler from QEfficient.utils.logging_utils import logger class VisionLanguageGeneration(QEffTextGenerationBase): """ Enhanced vision-language generation class inheriting from QEffTextGenerationBase. - + This class maintains full API compatibility with VisionLanguageGeneration while adding advanced features like continuous batching, streaming, and sampling. - + Example: >>> # Drop-in replacement for VisionLanguageGeneration >>> vlm = VisionLanguageGeneration( @@ -61,7 +59,7 @@ class VisionLanguageGeneration(QEffTextGenerationBase): ... prompts=["Describe this image"], ... generation_len=512 ... ) - + >>> # Enhanced usage with new features >>> vlm_enhanced = VisionLanguageGeneration( ... tokenizer=tokenizer, @@ -74,7 +72,7 @@ class VisionLanguageGeneration(QEffTextGenerationBase): ... sampling_params=sampling_config ... ) """ - + def __init__( self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], @@ -93,7 +91,7 @@ def __init__( ): """ Initialize vision-language generation with enhanced capabilities - + Args: tokenizer: Text tokenizer processor: Image processor @@ -114,12 +112,12 @@ def __init__( raise TypeError("lang_qpc_path is required") if not vision_qpc_path: raise TypeError("vision_qpc_path is required") - + # Initialize base class with language QPC # Pass activate=False to prevent premature activation before vision components are ready super().__init__( tokenizer=tokenizer, - qpc_path=lang_qpc_path, + qpc_path=lang_qpc_path, full_batch_size=full_batch_size, ctx_len=ctx_len, device_id=device_id, @@ -129,50 +127,49 @@ def __init__( include_sampler=include_sampler, return_pdfs=return_pdfs, sampling_params=sampling_params, - activate=False, #vision components need to be initialized first + activate=False, # vision components need to be initialized first ) - + # Vision-specific initialization self.processor = processor self._vision_qpc_path = vision_qpc_path self.device_id = device_id # Store device_id for vision components self.enable_debug_logs = enable_debug_logs # Store for vision components self._init_vision_components() - + # Now that vision components are initialized, activate the text session self._session.activate() - - logger.info(f"VisionLanguageGeneration initialized: batch_size={self.batch_size}, " - f"prefill_seq_len={self._prefill_seq_len}, ctx_len={ctx_len}, " - f"continuous_batching={'enabled' if full_batch_size else 'disabled'}, " - f"sampling={'enabled' if include_sampler else 'disabled'}") - + + logger.info( + f"VisionLanguageGeneration initialized: batch_size={self.batch_size}, " + f"prefill_seq_len={self._prefill_seq_len}, ctx_len={ctx_len}, " + f"continuous_batching={'enabled' if full_batch_size else 'disabled'}, " + f"sampling={'enabled' if include_sampler else 'disabled'}" + ) + def _init_vision_components(self): """Initialize vision-specific components""" # Vision session (separate from base class language session) self._vision_session = QAICInferenceSession( - self._vision_qpc_path, - self.device_id, - activate=False, - enable_debug_logs=self.enable_debug_logs + self._vision_qpc_path, self.device_id, activate=False, enable_debug_logs=self.enable_debug_logs ) - + # Vision handler with language session coordination vision_config = self._get_vision_config() self._vision_handler = VisionHandler( vision_session=self._vision_session, processor=self.processor, config=vision_config, - lang_session=self._session # Pass language session for coordination + lang_session=self._session, # Pass language session for coordination ) - + # Setup vision buffer skipping self._setup_vision_buffer_skipping() - + def _get_vision_config(self) -> Dict[str, Any]: """ Derive vision config from session - + Returns: Dictionary with vision configuration """ @@ -181,44 +178,42 @@ def _get_vision_config(self) -> Dict[str, Any]: try: shapes = {} for output_name in self._vision_session.output_names: - if (hasattr(self._vision_session, 'bindings') and - output_name in self._vision_session.binding_index_map): + if ( + hasattr(self._vision_session, "bindings") + and output_name in self._vision_session.binding_index_map + ): binding_idx = self._vision_session.binding_index_map[output_name] - if hasattr(self._vision_session.bindings[binding_idx], 'dims'): - shapes[output_name] = tuple( - self._vision_session.bindings[binding_idx].dims - ) - + if hasattr(self._vision_session.bindings[binding_idx], "dims"): + shapes[output_name] = tuple(self._vision_session.bindings[binding_idx].dims) + if shapes: config["vision_output_shapes"] = shapes except Exception as e: logger.warning(f"Could not derive vision config from session: {e}") - + return config - + def _setup_vision_buffer_skipping(self): """Skip KV cache and retained state buffers for vision session""" - skip_patterns = [ - lambda x: x.startswith("past_"), - lambda x: x.endswith("_RetainedState") - ] - + skip_patterns = [lambda x: x.startswith("past_"), lambda x: x.endswith("_RetainedState")] + buffers_to_skip = [ - x for x in self._vision_session.input_names + self._vision_session.output_names + x + for x in self._vision_session.input_names + self._vision_session.output_names if any(pattern(x) for pattern in skip_patterns) ] self._vision_session.skip_buffers(buffers_to_skip) - + def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None): """ Override base class prefill to handle vision processing - + Args: prompt: Can be string or tuple (image_path, text_prompt) generation_len: Generation length prefill_logit_bs: Prefill batch size decode_batch_id: Batch ID for continuous batching - + Returns: Same as base class: (outputs, position_ids, generation_len) """ @@ -234,48 +229,46 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i # Check if this is a vision-language prompt if isinstance(prompt, tuple) and len(prompt) == 2: image_path, text_prompt = prompt - + # Process vision inputs if this is the first time or if vision hasn't been processed yet - if not hasattr(self, '_vision_processed') or not self._vision_processed: + if not hasattr(self, "_vision_processed") or not self._vision_processed: logger.debug("Processing vision inputs for the first time") vision_inputs = self._vision_handler.prepare_vision_inputs(image_path, text_prompt) self._vision_handler.setup_vision_buffers() vision_outputs = self._vision_handler.run_vision_inference(vision_inputs) - + # Set vision buffers in language session (shared across all batch indices) self._session.set_buffers(vision_outputs) logger.debug(f"Vision buffers set: {list(vision_outputs.keys())}") - + # Mark vision as processed self._vision_processed = True self._vision_outputs = vision_outputs - + # Prepare the text prompt with vision context processed_prompt = self._prepare_vision_language_prompt(text_prompt, image_path) - + # Compute padded_len aligned to prefill_seq_len using tokenizer on processed prompt tmp_tokens = self.tokenizer(processed_prompt, return_tensors="np", padding=True) input_len = tmp_tokens["input_ids"].shape[1] num_chunks = -(input_len // -self._prefill_seq_len) padded_len = num_chunks * self._prefill_seq_len - + # Build language inputs with processor-aware vision/text integration lang_inputs, vision_outputs = self._vision_handler.get_language_inputs_from_vision_processing( - image_url=image_path, - query=text_prompt, - padded_len=padded_len + image_url=image_path, query=text_prompt, padded_len=padded_len ) - + # Set vision buffers in language session self._session.set_buffers(vision_outputs) - + # Calculate generation_len consistent with ctx_len max_gen_len = self._ctx_len - np.where(lang_inputs["position_ids"] != -1, 1, 0).sum(1, keepdims=True).max() generation_len = self._fetch_generation_len(generation_len, max_gen_len) - + # Set the prefill output buffers self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1) - + # Run prefill across chunks, updating image_idx as needed outputs = None chunk_image_idx = None # track image_idx across chunks @@ -294,16 +287,16 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i } if "cross_attention_mask" in lang_inputs: chunk_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"] - + outputs = self._session.run(chunk_inputs) - + # Update image_idx for next chunk if provided by model if "image_idx_output" in outputs: chunk_image_idx = outputs["image_idx_output"] - + if self._write_io_dir is not None: write_io_files(lang_inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) - + # Prepare position_ids for decode phase (next position after prefill) position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 @@ -322,16 +315,16 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i if x.startswith("past_") or x.endswith("_RetainedState") ] ) - + return outputs, position_ids_decode, generation_len else: # Fall back to base class for text-only return super().run_prefill(prompt, generation_len, prefill_logit_bs, decode_batch_id) - + def _prepare_vision_language_prompt(self, text_prompt, image_path): """ Prepare text prompt with vision context - + This method handles the integration of vision and text inputs according to the specific model's requirements. """ @@ -347,53 +340,43 @@ def _prepare_vision_language_prompt(self, text_prompt, image_path): ], }, ] - + # Apply chat template - processed_prompt = self.processor.apply_chat_template( - conversation, - add_generation_prompt=True - ) - + processed_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True) + return processed_prompt - + except Exception as e: logger.warning(f"Failed to apply chat template: {e}. Using original prompt.") return text_prompt - + def generate( - self, - images: List[str], - prompts: List[str], - generation_len: Optional[int] = None, - stream: bool = True, - **kwargs + self, images: List[str], prompts: List[str], generation_len: Optional[int] = None, stream: bool = True, **kwargs ) -> CloudAI100ExecInfo: """ Main generation method maintaining API compatibility with VisionLanguageGeneration - + Args: images: List of image URLs/paths prompts: List of text prompts generation_len: Max generation length stream: Enable streaming output **kwargs: Additional arguments passed to base class - + Returns: CloudAI100ExecInfo with results and metrics - + Raises: ValueError: If images and prompts lengths don't match """ if len(images) != len(prompts): - raise ValueError( - f"Number of images ({len(images)}) must match number of prompts ({len(prompts)})" - ) - + raise ValueError(f"Number of images ({len(images)}) must match number of prompts ({len(prompts)})") + logger.info(f"Generating for {len(images)} image-prompt pairs") - + # Convert to base class format: list of (image, prompt) tuples vision_prompts = [(img, prompt) for img, prompt in zip(images, prompts)] - + # Use base class generate method with vision prompts if self.full_batch_size is not None: # Continuous batching mode (new capability) @@ -401,7 +384,7 @@ def generate( else: # Regular batching mode return self._generate_regular_batching(vision_prompts, generation_len, stream, **kwargs) - + def _generate_regular_batching(self, vision_prompts, generation_len, stream, **kwargs): """Handle regular batching for vision-language generation without creating a second language session""" batch_results = [] @@ -409,7 +392,9 @@ def _generate_regular_batching(self, vision_prompts, generation_len, stream, **k batch = vision_prompts[i : i + self.batch_size] if stream: - print(f"\nProcessing batch {i // self.batch_size + 1}/{(len(vision_prompts) - 1) // self.batch_size + 1}") + print( + f"\nProcessing batch {i // self.batch_size + 1}/{(len(vision_prompts) - 1) // self.batch_size + 1}" + ) for j, (img, prompt) in enumerate(batch): print(f"Image: {img}") print(f"Prompt: {prompt}") @@ -418,11 +403,15 @@ def _generate_regular_batching(self, vision_prompts, generation_len, stream, **k # Setup decode storage arrays for this batch (use ctx_len or generation_len whichever is larger) exec_batch_size = self.batch_size max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len) - self.initialize_decode_inputs(num_prompts=len(batch), execution_batch_size=exec_batch_size, max_gen_length=max_gen_length) + self.initialize_decode_inputs( + num_prompts=len(batch), execution_batch_size=exec_batch_size, max_gen_length=max_gen_length + ) # Prefill using VLM-aware run_prefill (batch is a list of (image, text)) start = perf_counter() - outputs, position_ids, generation_len_final = self.run_prefill(batch, generation_len, prefill_logit_bs=self.batch_size) + outputs, position_ids, generation_len_final = self.run_prefill( + batch, generation_len, prefill_logit_bs=self.batch_size + ) self.update_decode_input(outputs, position_ids, generation_len_final) # Prepare decode @@ -438,7 +427,9 @@ def _generate_regular_batching(self, vision_prompts, generation_len, stream, **k # Latency metrics total_decode_tokens = num_token - prefill_time, decode_perf, total_perf, total_time = calculate_latency(total_decode_tokens, loop_start, start, end) + prefill_time, decode_perf, total_perf, total_time = calculate_latency( + total_decode_tokens, loop_start, start, end + ) perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time) # Package result for this batch @@ -453,48 +444,45 @@ def _generate_regular_batching(self, vision_prompts, generation_len, stream, **k # Aggregate results across batches return self._aggregate_batch_results(batch_results) - + def _generate_continuous_batching(self, vision_prompts, generation_len, stream, **kwargs): """Enable continuous batching for vision-language models (new capability)""" logger.info("Using continuous batching for vision-language generation") - + if stream: logger.warning("Streaming output not fully supported with continuous batching") - + # Reset vision processing state for new generation self._vision_processed = False self._vision_outputs = None - + # Use the base class continuous batching logic directly # Initialize decode inputs num_prompts = len(vision_prompts) execution_batch_size = self.full_batch_size max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len) - + self.initialize_decode_inputs(num_prompts, execution_batch_size, max_gen_length) - + # Create prompt queue prompt_queue = deque(vision_prompts) - + start = perf_counter() - - # IMPORTANT: Ensure batch_index is None during prefill phase - # Store the current batch_index and set it to None for prefill - saved_batch_index = self.batch_index + self.batch_index = None - + # Run prefill for all inputs first (batch_index should NOT be set during prefill) self.run_prefill_for_all_inputs(prompt_queue, generation_len) - + # Now set batch_index for decode phase self.batch_index = np.arange(self.full_batch_size).reshape(-1, 1) - + loop_start = perf_counter() decode_pause_time = self.run_continuous_batching_decode(prompt_queue, generation_len) end = perf_counter() - + generated_texts = self.tokenizer.batch_decode(self.generated_ids, skip_special_tokens=True) - + total_decode_tokens = sum( np.sum(self.generated_ids[i] != self.tokenizer.pad_token_id) - 1 for i in range(len(vision_prompts)) ) @@ -502,110 +490,101 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream, total_decode_tokens, loop_start, start, end, decode_pause_time ) prefill_time /= len(vision_prompts) # Average prefill time for continuous batching - + perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time) - + return CloudAI100ExecInfo( - batch_size=1, - generated_texts=generated_texts, - generated_ids=self.generated_ids, - perf_metrics=perf_metrics + batch_size=1, generated_texts=generated_texts, generated_ids=self.generated_ids, perf_metrics=perf_metrics ) - + def prepare_decode_inputs(self): """ Override base class to handle vision-specific decode inputs """ decode_inputs = super().prepare_decode_inputs() - + # Add image_idx for vision-language models in CB mode during decode only - if self.batch_index is not None and hasattr(self, '_vision_outputs'): + if self.batch_index is not None and hasattr(self, "_vision_outputs"): # image_idx should point to vision embedding (0 for all batch positions) # since vision embeddings are shared across all batch indices decode_inputs["image_idx"] = np.zeros_like(self.batch_index) - + # Include cross_attention_mask during decode if present/required - if hasattr(self, '_decode_cross_attention_mask') and self._decode_cross_attention_mask is not None: + if hasattr(self, "_decode_cross_attention_mask") and self._decode_cross_attention_mask is not None: decode_inputs["cross_attention_mask"] = self._decode_cross_attention_mask - + return decode_inputs - + def _aggregate_batch_results(self, batch_results): """Aggregate results from multiple batches""" if not batch_results: raise ValueError("No batch results to aggregate") - + if len(batch_results) == 1: return batch_results[0] - + # Aggregate multiple batch results all_generated_texts = [] all_generated_ids = [] all_metrics = [] - + for result in batch_results: if isinstance(result.generated_texts[0], list): # Flatten nested lists all_generated_texts.extend([text for batch in result.generated_texts for text in batch]) else: all_generated_texts.extend(result.generated_texts) - + if isinstance(result.generated_ids, list): all_generated_ids.extend(result.generated_ids) else: all_generated_ids.append(result.generated_ids) - + all_metrics.append(result.perf_metrics) - + # Average metrics avg_metrics = PerfMetrics( prefill_time=np.mean([m.prefill_time for m in all_metrics]), decode_perf=np.mean([m.decode_perf for m in all_metrics]), total_perf=np.mean([m.total_perf for m in all_metrics]), - total_time=np.mean([m.total_time for m in all_metrics]) + total_time=np.mean([m.total_time for m in all_metrics]), ) - + return CloudAI100ExecInfo( batch_size=batch_results[0].batch_size, generated_texts=all_generated_texts, generated_ids=all_generated_ids, - perf_metrics=avg_metrics + perf_metrics=avg_metrics, ) - + def generate_stream_tokens( - self, - images: List[str], - prompts: List[str], - generation_len: Optional[int] = None, - **kwargs + self, images: List[str], prompts: List[str], generation_len: Optional[int] = None, **kwargs ): """ Enable token-by-token streaming for vision models (new capability) - + Args: images: List of image URLs/paths prompts: List of text prompts generation_len: Max generation length **kwargs: Additional arguments - + Yields: List of decoded tokens for each batch position - + Raises: NotImplementedError: If continuous batching is enabled """ if self.full_batch_size is not None: raise NotImplementedError("Token streaming not supported with continuous batching for VLM") - + if len(images) != len(prompts): - raise ValueError( - f"Number of images ({len(images)}) must match number of prompts ({len(prompts)})" - ) - + raise ValueError(f"Number of images ({len(images)}) must match number of prompts ({len(prompts)})") + logger.info(f"Starting token streaming for {len(images)} image-prompt pairs") - + vision_prompts = [(img, prompt) for img, prompt in zip(images, prompts)] - + text_gen = TextGeneration( tokenizer=self.tokenizer, qpc_path=self._qpc_path, @@ -617,17 +596,19 @@ def generate_stream_tokens( return_pdfs=self.return_pdfs, sampling_params=self.sampling_params, ) - + text_gen._qaic_model = self - + # Yield tokens as they're generated for tokens in text_gen.generate_stream_tokens(vision_prompts, generation_len, **kwargs): yield tokens - + def __repr__(self): """String representation of the class""" - return (f"VisionLanguageGeneration(" - f"batch_size={self.batch_size}, " - f"ctx_len={self._ctx_len}, " - f"continuous_batching={'enabled' if self.full_batch_size else 'disabled'}, " - f"sampling={'enabled' if self.include_sampler else 'disabled'})") + return ( + f"VisionLanguageGeneration(" + f"batch_size={self.batch_size}, " + f"ctx_len={self._ctx_len}, " + f"continuous_batching={'enabled' if self.full_batch_size else 'disabled'}, " + f"sampling={'enabled' if self.include_sampler else 'disabled'})" + ) From 621e3a8bb1c7a0033ea2556388636fd2c85683d6 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Tue, 28 Oct 2025 15:06:41 +0000 Subject: [PATCH 12/24] fix for fbs >1 Signed-off-by: Rishin Raj --- .../generation/text_generation_inference.py | 24 +++++-------------- QEfficient/generation/vlm_generation.py | 23 +++++++++++++----- .../transformers/models/modeling_auto.py | 10 ++++---- 3 files changed, 29 insertions(+), 28 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 81eb711c0..3487f145c 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -508,9 +508,10 @@ def _fetch_full_batch_size( full_batch_size = None if "batch_index" in self._session.binding_index_map: if self._session.allowed_shapes: - full_batch_size, _ = [ - x[self._session.binding_index_map["batch_index"]][1][0] for x in self._session.allowed_shapes - ] + # Take the maximum batch_index dimension across specializations (prefill vs decode) + full_batch_size = max( + [x[self._session.binding_index_map["batch_index"]][1][0] for x in self._session.allowed_shapes] + ) else: full_batch_size, _ = self._session.bindings[self._session.binding_index_map["batch_index"]].dims return full_batch_size @@ -802,15 +803,6 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)] inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) - # Check if image_idx is expected by the session (for VLM models) - if "image_idx" in self._session.input_names or "image_idx" in getattr(self._session, "binding_index_map", {}): - try: - binding_idx = self._session.binding_index_map.get("image_idx") - dims = self._session.bindings[binding_idx].dims if binding_idx is not None else (1, 1) - inputs["image_idx"] = np.zeros(tuple(dims), dtype=np.int64) - except Exception: - inputs["image_idx"] = np.array([[0]], dtype=np.int64) - for i in range(num_chunks): chunk_inputs = inputs.copy() chunk_inputs["input_ids"] = inputs["input_ids"][ @@ -823,10 +815,6 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] outputs = self._session.run(chunk_inputs) - # Update image_idx for next chunk if VLM model provides it - if "image_idx_output" in outputs: - inputs["image_idx"] = outputs["image_idx_output"] - if self._write_io_dir is not None: write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) return ( @@ -850,7 +838,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): # Set output placeholders for decode self._set_output_buffers( batch_size=self.full_batch_size, - sequence_length=self._decode_seq_len, + sequence_length=1, ) # Generate flag for tracking progress for each batch ID @@ -894,7 +882,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): self._set_output_buffers( batch_size=self.full_batch_size, - sequence_length=self._decode_seq_len, + sequence_length=1, ) decode_pause_time += perf_counter() - start diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index edcef7364..58044d8bd 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -230,9 +230,9 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i if isinstance(prompt, tuple) and len(prompt) == 2: image_path, text_prompt = prompt - # Process vision inputs if this is the first time or if vision hasn't been processed yet - if not hasattr(self, "_vision_processed") or not self._vision_processed: - logger.debug("Processing vision inputs for the first time") + # Process vision inputs. In CB, process for every prefill since each slot can have different image. + if self.full_batch_size is not None or not hasattr(self, "_vision_processed") or not self._vision_processed: + logger.debug("Processing vision inputs for this request") vision_inputs = self._vision_handler.prepare_vision_inputs(image_path, text_prompt) self._vision_handler.setup_vision_buffers() vision_outputs = self._vision_handler.run_vision_inference(vision_inputs) @@ -285,6 +285,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i "position_ids": position_ids_slice, "image_idx": chunk_image_idx if chunk_image_idx is not None else np.array([[0]], dtype=np.int64), } + if decode_batch_id is not None: + chunk_inputs["batch_index"] = decode_batch_id if "cross_attention_mask" in lang_inputs: chunk_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"] @@ -505,12 +507,21 @@ def prepare_decode_inputs(self): # Add image_idx for vision-language models in CB mode during decode only if self.batch_index is not None and hasattr(self, "_vision_outputs"): - # image_idx should point to vision embedding (0 for all batch positions) - # since vision embeddings are shared across all batch indices - decode_inputs["image_idx"] = np.zeros_like(self.batch_index) + # image_idx should be a single slot selector; decoder expects shape (1,1) + # Query binding dims if available to be robust + try: + if "image_idx" in getattr(self._session, "binding_index_map", {}): + idx = self._session.binding_index_map["image_idx"] + dims = tuple(self._session.bindings[idx].dims) + decode_inputs["image_idx"] = np.zeros(dims, dtype=np.int64) + else: + decode_inputs["image_idx"] = np.array([[0]], dtype=np.int64) + except Exception: + decode_inputs["image_idx"] = np.array([[0]], dtype=np.int64) # Include cross_attention_mask during decode if present/required if hasattr(self, "_decode_cross_attention_mask") and self._decode_cross_attention_mask is not None: + # Decoder specialization expects a single mask (batch dim = 1) decode_inputs["cross_attention_mask"] = self._decode_cross_attention_mask return decode_inputs diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 24a3de9bd..9b182f11c 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1224,15 +1224,17 @@ def generate( # Use VisionLanguageGeneration for image-prompt pairs if (processor and images) or (tokenizer and prompts): # Create VisionLanguageGeneration instance + batch_size_comp, ctx_len_comp, fbs = get_compilation_dims(self.lang_model.qpc_path) vlm_gen = VisionLanguageGeneration( lang_qpc_path=self.lang_model.qpc_path, vision_qpc_path=self.vision_model.qpc_path, tokenizer=tokenizer, processor=processor, - device_id=device_ids, # if device_ids is not None else [0], - ctx_len=3072 # TODO need to get it from the QPC + device_id=device_ids, # if device_ids is not None else [0], + ctx_len=ctx_len_comp, + full_batch_size=fbs, ) - + # Call generate method return vlm_gen.generate( images=images, @@ -1240,7 +1242,7 @@ def generate( generation_len=generation_len, stream=streamer is not None, ) - + # Fallback to kv_offload_generate for direct inputs (backward compatibility) return self.kv_offload_generate( inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len From 59dff468e4cced359d36ec2772cc762a135a7d08 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Wed, 29 Oct 2025 10:10:03 +0000 Subject: [PATCH 13/24] Updated cloud_ai_100_exec_kv call in modelling_auto Signed-off-by: Rishin Raj --- QEfficient/transformers/models/modeling_auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 9b182f11c..6d340cd0b 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2764,7 +2764,7 @@ def generate( generation_len = kwargs.pop("generation_len", None) return QEfficient.cloud_ai_100_exec_kv( tokenizer=tokenizer, - lang_qpc_path=self.qpc_path, + qpc_path=self.qpc_path, prompt=prompts, device_id=device_id, generation_len=generation_len, From fc69e3a807e0a644f7d9edfa3334462c664e08a7 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Thu, 30 Oct 2025 05:54:06 +0000 Subject: [PATCH 14/24] Removed redundant vision execution and refactoring Signed-off-by: Rishin Raj --- ...vision_handler.py => embedding_handler.py} | 0 QEfficient/generation/vlm_generation.py | 21 ++---- examples/llama4_CB_example_vision_lang.py | 73 +++++++++++++------ 3 files changed, 55 insertions(+), 39 deletions(-) rename QEfficient/generation/{vision_handler.py => embedding_handler.py} (100%) diff --git a/QEfficient/generation/vision_handler.py b/QEfficient/generation/embedding_handler.py similarity index 100% rename from QEfficient/generation/vision_handler.py rename to QEfficient/generation/embedding_handler.py diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 58044d8bd..fd2e2ab66 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -34,7 +34,7 @@ calculate_latency, write_io_files, ) -from QEfficient.generation.vision_handler import VisionHandler +from QEfficient.generation.embedding_handler import VisionHandler from QEfficient.utils.logging_utils import logger @@ -230,21 +230,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i if isinstance(prompt, tuple) and len(prompt) == 2: image_path, text_prompt = prompt - # Process vision inputs. In CB, process for every prefill since each slot can have different image. - if self.full_batch_size is not None or not hasattr(self, "_vision_processed") or not self._vision_processed: - logger.debug("Processing vision inputs for this request") - vision_inputs = self._vision_handler.prepare_vision_inputs(image_path, text_prompt) - self._vision_handler.setup_vision_buffers() - vision_outputs = self._vision_handler.run_vision_inference(vision_inputs) - - # Set vision buffers in language session (shared across all batch indices) - self._session.set_buffers(vision_outputs) - logger.debug(f"Vision buffers set: {list(vision_outputs.keys())}") - - # Mark vision as processed - self._vision_processed = True - self._vision_outputs = vision_outputs - + # Process vision inputs. In CB, process for every prefill since each slot can have different image. # Prepare the text prompt with vision context processed_prompt = self._prepare_vision_language_prompt(text_prompt, image_path) @@ -261,6 +247,9 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i # Set vision buffers in language session self._session.set_buffers(vision_outputs) + logger.debug(f"Vision buffers set: {list(vision_outputs.keys())}") + self._vision_processed = True + self._vision_outputs = vision_outputs # Calculate generation_len consistent with ctx_len max_gen_len = self._ctx_len - np.where(lang_inputs["position_ids"] != -1, 1, 0).sum(1, keepdims=True).max() diff --git a/examples/llama4_CB_example_vision_lang.py b/examples/llama4_CB_example_vision_lang.py index 8711e0f40..492e39203 100644 --- a/examples/llama4_CB_example_vision_lang.py +++ b/examples/llama4_CB_example_vision_lang.py @@ -16,30 +16,55 @@ config.text_config.num_hidden_layers = 4 config.vision_config.num_hidden_layers = 2 -qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( - model_id, - attn_implementation="eager", - kv_offload=True, - config=config, - continuous_batching=True, -) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id) -qeff_model.compile( - prefill_seq_len=128, - ctx_len=3072, - img_size=336, - num_cores=16, - num_devices=4, - max_num_tiles=17, - batch_size=1, - full_batch_size=4, - mxfp6_matmul=True, - mxint8_kv_cache=True, - aic_enable_depth_first=True, - mos=1, -) +continious_batching = True +if continious_batching: + + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, + ) + + qeff_model.compile( + prefill_seq_len=128, + ctx_len=3072, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + batch_size=1, + full_batch_size=4, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) +else: + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + ) + + qeff_model.compile( + prefill_seq_len=128, + ctx_len=3072, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + batch_size=1, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) image_urls = [ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", @@ -55,7 +80,7 @@ "What colors are predominant in the image?", ] -output = qeff_model.generate( +exec_info = qeff_model.generate( tokenizer=tokenizer, prompts=prompts, processor=processor, @@ -64,4 +89,6 @@ generation_len=100, ) -print(output) +# print("Generated texts:", exec_info.generated_texts) +print("Generated IDs:", exec_info.generated_ids) +print(exec_info) From 275d4cdb7a3db7742479f6e558360fa334b9413e Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Thu, 30 Oct 2025 08:12:04 +0000 Subject: [PATCH 15/24] nit: update QBlocking to LM Attention in Qwen2.5VL Signed-off-by: vbaddi Signed-off-by: Rishin Raj --- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 102 ++++++++++++++++-- 1 file changed, 94 insertions(+), 8 deletions(-) diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 8df72d0c9..b77f5075c 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- import math +import os from typing import Callable, List, Optional, Tuple, Union import torch @@ -360,7 +361,7 @@ def forward(self, x, seq_len=None): ) -def eager_attention_forward( +def eager_attention_forward_q_blocked( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -368,22 +369,107 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], **kwargs, ): + """ + Q-blocked attention for Qwen2.5-VL. + Blocks only the query SL dimension. + + Args: + query: (BS, NH, Q_LEN, DH) + key: (BS, NH_KV, KV_LEN, DH) + value: (BS, NH_KV, KV_LEN, DH) + attention_mask: (BS, NH, Q_LEN, KV_LEN) or broadcastable + """ + BS, NH, Q_LEN, DH = query.shape + _, _, KV_LEN, _ = key.shape + key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) - if attention_mask is not None: - attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights - ) + target_blocks_q = int(os.environ.get("num_q_blocks", Q_LEN)) + q_block_positions = [(i * Q_LEN) // target_blocks_q for i in range(target_blocks_q)] + scaling = 1.0 / math.sqrt(module.head_dim) + + q_output_blocks = [] + q_attn_weights_blocks = [] + + # Process each Q block + for q_block_idx in range(target_blocks_q): + qi = q_block_positions[q_block_idx] - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_output = torch.matmul(attn_weights, value_states) + # Calculate Q block size + if q_block_idx == target_blocks_q - 1: + real_q_len = Q_LEN - qi + else: + real_q_len = q_block_positions[q_block_idx + 1] - qi + + # Extract Q block + q_block = query[:, :, qi : qi + real_q_len, :] + attn_mask_block = None + if attention_mask is not None: + attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :] + + # Compute attention scores for this Q block + attn_weights = torch.matmul(q_block, key_states.transpose(2, 3)) * scaling + if attn_mask_block is not None: + attn_weights = torch.where( + attn_mask_block, + torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32, device=attn_weights.device), + attn_weights, + ) + + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # Compute output for this Q block + output_block = torch.matmul(attn_weights, value_states) + + q_output_blocks.append(output_block) + q_attn_weights_blocks.append(attn_weights) + + attn_output = torch.cat(q_output_blocks, dim=2) attn_output = attn_output.transpose(1, 2).contiguous() + # Concatenate attention weights + attn_weights = torch.cat(q_attn_weights_blocks, dim=2) + return attn_output, attn_weights +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + **kwargs, +): + """ + Wrapper that routes to blocked or default attention based on environment variable. + """ + blocking_mode = os.environ.get("ATTENTION_BLOCKING_MODE", "default").lower() + + if blocking_mode == "q": + return eager_attention_forward_q_blocked(module, query, key, value, attention_mask, **kwargs) + elif blocking_mode == "default": + # Original implementation + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) + + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + else: + raise ValueError(f"Invalid ATTENTION_BLOCKING_MODE: {blocking_mode}. Must be 'q' or 'default'") + + class QEffQwen2_5_VLAttention(Qwen2_5_VLAttention): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer From e8253ff267eabac1f6d5b7e52f6866162d755872 Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Thu, 30 Oct 2025 08:16:34 +0000 Subject: [PATCH 16/24] nit: update readme for qblocking in example file and lint/format checks Signed-off-by: vbaddi Signed-off-by: Rishin Raj --- QEfficient/generation/vlm_generation.py | 4 ++-- examples/qwen2_5_vl_example.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index fd2e2ab66..8ec17c388 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -26,6 +26,7 @@ from transformers import AutoImageProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.generation.embedding_handler import VisionHandler from QEfficient.generation.text_generation_inference import ( CloudAI100ExecInfo, PerfMetrics, @@ -34,7 +35,6 @@ calculate_latency, write_io_files, ) -from QEfficient.generation.embedding_handler import VisionHandler from QEfficient.utils.logging_utils import logger @@ -230,7 +230,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i if isinstance(prompt, tuple) and len(prompt) == 2: image_path, text_prompt = prompt - # Process vision inputs. In CB, process for every prefill since each slot can have different image. + # Process vision inputs. In CB, process for every prefill since each slot can have different image. # Prepare the text prompt with vision context processed_prompt = self._prepare_vision_language_prompt(text_prompt, image_path) diff --git a/examples/qwen2_5_vl_example.py b/examples/qwen2_5_vl_example.py index 374f70ad2..d5d943c9c 100644 --- a/examples/qwen2_5_vl_example.py +++ b/examples/qwen2_5_vl_example.py @@ -5,6 +5,9 @@ # # ----------------------------------------------------------------------------- +# If we want to enable QBlocking Run below command:, default is without blocking +# ATTENTION_BLOCKING_MODE=q num_q_blocks=2 python -W ignore qwen2_5_vl_example.py + import requests import transformers from PIL import Image From 4d9afe2186ef95efef4d83621abc4b057afbe119 Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Thu, 30 Oct 2025 08:18:43 +0000 Subject: [PATCH 17/24] nit: lint/format checks Signed-off-by: vbaddi Signed-off-by: Rishin Raj --- examples/llama4_CB_example_vision_lang.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/llama4_CB_example_vision_lang.py b/examples/llama4_CB_example_vision_lang.py index 492e39203..a9fc3a8aa 100644 --- a/examples/llama4_CB_example_vision_lang.py +++ b/examples/llama4_CB_example_vision_lang.py @@ -21,7 +21,6 @@ continious_batching = True if continious_batching: - qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, attn_implementation="eager", From 6b079fe00d435687e1878e16d16c9c5d00ae2bea Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Thu, 30 Oct 2025 08:58:51 +0000 Subject: [PATCH 18/24] CI failure fix Signed-off-by: Rishin Raj --- QEfficient/generation/cloud_infer.py | 15 --------------- .../generation/text_generation_inference.py | 12 ++++++------ examples/llama4_CB_example_vision_lang.py | 2 +- 3 files changed, 7 insertions(+), 22 deletions(-) diff --git a/QEfficient/generation/cloud_infer.py b/QEfficient/generation/cloud_infer.py index 42c8b342e..5068c174e 100644 --- a/QEfficient/generation/cloud_infer.py +++ b/QEfficient/generation/cloud_infer.py @@ -122,21 +122,6 @@ def deactivate(self): self.program.deactivate() self.is_active = False - def pause(self): - """Pause the session while preserving state""" - if self.is_active: - # Just deactivate the program and set state - self.program.deactivate() - self.is_active = False - - def resume(self): - """Resume a paused session""" - if not self.is_active: - # Reactivate program and create new execObj - self.program.activate() - self.execObj = qaicrt.ExecObj(self.context, self.program) - self.is_active = True - def set_buffers(self, buffers: Dict[str, np.ndarray]): """ Provide buffer mapping for input and output diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 3487f145c..e96908824 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -508,10 +508,9 @@ def _fetch_full_batch_size( full_batch_size = None if "batch_index" in self._session.binding_index_map: if self._session.allowed_shapes: - # Take the maximum batch_index dimension across specializations (prefill vs decode) - full_batch_size = max( - [x[self._session.binding_index_map["batch_index"]][1][0] for x in self._session.allowed_shapes] - ) + full_batch_size, _ = [ + x[self._session.binding_index_map["batch_index"]][1][0] for x in self._session.allowed_shapes + ] else: full_batch_size, _ = self._session.bindings[self._session.binding_index_map["batch_index"]].dims return full_batch_size @@ -838,7 +837,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): # Set output placeholders for decode self._set_output_buffers( batch_size=self.full_batch_size, - sequence_length=1, + sequence_length=self._decode_seq_len, ) # Generate flag for tracking progress for each batch ID @@ -882,7 +881,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): self._set_output_buffers( batch_size=self.full_batch_size, - sequence_length=1, + sequence_length=self._decode_seq_len, ) decode_pause_time += perf_counter() - start @@ -1113,6 +1112,7 @@ def _continuous_batching_execution( :tuple: A tuple containing performance metrics and generated texts. """ self._setup_model_execution_inputs(prompt, generation_len, prompt_to_lora_id_mapping) + self._qaic_model.batch_index = np.arange(self._full_batch_size).reshape(-1, 1) start = perf_counter() self._qaic_model.run_prefill_for_all_inputs(self._prompt_queue, generation_len) diff --git a/examples/llama4_CB_example_vision_lang.py b/examples/llama4_CB_example_vision_lang.py index a9fc3a8aa..f285ea278 100644 --- a/examples/llama4_CB_example_vision_lang.py +++ b/examples/llama4_CB_example_vision_lang.py @@ -19,7 +19,7 @@ tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id) -continious_batching = True +continious_batching = False if continious_batching: qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, From 64dffdd70448536370c7b2afc0cdf56df212eded Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Thu, 30 Oct 2025 16:15:47 +0000 Subject: [PATCH 19/24] qwen2_5_vl inference changes Signed-off-by: Mohit Soni Signed-off-by: Rishin Raj --- QEfficient/generation/embedding_handler.py | 23 +++++- QEfficient/generation/vlm_generation.py | 69 ++++++++++++++---- .../transformers/models/modeling_auto.py | 1 + examples/qwen2_5_vl_CB.py | 72 +++++++++++++++++++ 4 files changed, 149 insertions(+), 16 deletions(-) create mode 100644 examples/qwen2_5_vl_CB.py diff --git a/QEfficient/generation/embedding_handler.py b/QEfficient/generation/embedding_handler.py index 558aa5eb6..fc1695756 100644 --- a/QEfficient/generation/embedding_handler.py +++ b/QEfficient/generation/embedding_handler.py @@ -34,6 +34,7 @@ class VisionHandler: def __init__( self, + qeff_model: Optional[QAICInferenceSession], vision_session: Optional[QAICInferenceSession], processor: Optional[AutoImageProcessor], config: Optional[Dict[str, Any]] = None, @@ -48,6 +49,7 @@ def __init__( config: Configuration dictionary with vision model parameters lang_session: Optional language session for coordination (to avoid resource conflicts) """ + self._qeff_model = qeff_model self._vision_session = vision_session self._processor = processor self._config = config or {} @@ -293,7 +295,7 @@ def prepare_complete_vision_language_inputs( return vision_inputs, vision_outputs def get_language_inputs_from_vision_processing( - self, image_url: str, query: str, padded_len: int + self, image_url: str, query: str, prefill_seq_len: int ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: """ Process vision inputs and prepare language model inputs @@ -330,12 +332,22 @@ def get_language_inputs_from_vision_processing( prompt = self._processor.apply_chat_template(conversation, add_generation_prompt=True) inputs = self._processor(images=image, text=prompt, return_tensors="pt") + if ( + hasattr(self._qeff_model.model.config, "model_type") + and self._qeff_model.model.config.model_type == "qwen2_5_vl" + ): + inputs = self._qeff_model.model.prepare_inputs_for_generation( + inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0] + ) + if "pixel_values" in inputs: inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) # Handle padding for language model pad_token_id = 1 input_ids_length = inputs["input_ids"].shape[1] + num_chunks = -(input_ids_length // -prefill_seq_len) + padded_len = num_chunks * prefill_seq_len inputs["input_ids"] = torch.nn.functional.pad( inputs["input_ids"], @@ -385,10 +397,15 @@ def get_language_inputs_from_vision_processing( # Prepare language inputs lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} - lang_inputs["position_ids"] = np.where(lang_inputs.pop("attention_mask"), np.arange(padded_len), -1) + + if "position_ids" in inputs: + lang_inputs["position_ids"] = inputs["position_ids"] + lang_inputs.pop("attention_mask") + else: + lang_inputs["position_ids"] = np.where(lang_inputs.pop("attention_mask"), np.arange(padded_len), -1) lang_inputs["image_idx"] = np.array([[0]]) - return lang_inputs, vision_outputs + return lang_inputs, vision_outputs, num_chunks except Exception as e: raise RuntimeError(f"Failed to process vision-language inputs: {str(e)}") diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 8ec17c388..97664cf4a 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -75,6 +75,7 @@ class VisionLanguageGeneration(QEffTextGenerationBase): def __init__( self, + qeff_model, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], processor: AutoImageProcessor, lang_qpc_path: str, @@ -131,6 +132,10 @@ def __init__( ) # Vision-specific initialization + self.is_qwen2_5_vl = ( + hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen2_5_vl" + ) + self.qeff_model = qeff_model self.processor = processor self._vision_qpc_path = vision_qpc_path self.device_id = device_id # Store device_id for vision components @@ -157,6 +162,7 @@ def _init_vision_components(self): # Vision handler with language session coordination vision_config = self._get_vision_config() self._vision_handler = VisionHandler( + qeff_model=self.qeff_model, vision_session=self._vision_session, processor=self.processor, config=vision_config, @@ -204,6 +210,51 @@ def _setup_vision_buffer_skipping(self): ] self._vision_session.skip_buffers(buffers_to_skip) + def run_prefill_for_all_inputs(self, prompt_queue, generation_len): + """ + Runs prefill for all inputs in the prompt queue and updates the decode input. + + Method iterates over the full batch size and for each decode batch ID, it pops the next prompt from the queue. It then runs prefill for the next prompt and updates the decode input with the outputs. + + Args: + prompt_queue (deque): The queue of prompts. + generation_len (int): The generation length. + + """ + for decode_batch_id in range(self.full_batch_size): + next_prompt = prompt_queue.popleft() + + # run prefill for num_chunks + outputs, position_ids, generation_len = self.run_prefill( + next_prompt, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1) + ) + + if self.is_qwen2_5_vl: + _ = self.update_decode_inputs_qwen2_5_vl(outputs, position_ids, generation_len, decode_batch_id) + else: + _ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) + + def update_decode_inputs_qwen2_5_vl(self, outputs, position_ids, generation_len, decode_batch_id=None): + """ + Updates the decode input with the generated values. + Args: + outputs (dict): The outputs of the model. + position_ids (array): The position IDs. + generation_len (int): The generation length. + decode_batch_id (int, optional): The decode batch ID. If None, all values are updated. Defaults to None. + + Returns: + next_token_id (array): The next token ID. + """ + next_token_id = self._fetch_next_token_id(outputs) + + # Store the generated values. + self.decode_input_ids[decode_batch_id or slice(None)] = next_token_id + self.decode_pos_ids[:, decode_batch_id] = position_ids.squeeze(1) + self.generated_ids[decode_batch_id or slice(None), 0] = next_token_id.squeeze(1) + self.generation_len[decode_batch_id or slice(None)] = generation_len + return next_token_id + def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None): """ Override base class prefill to handle vision processing @@ -230,19 +281,9 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i if isinstance(prompt, tuple) and len(prompt) == 2: image_path, text_prompt = prompt - # Process vision inputs. In CB, process for every prefill since each slot can have different image. - # Prepare the text prompt with vision context - processed_prompt = self._prepare_vision_language_prompt(text_prompt, image_path) - - # Compute padded_len aligned to prefill_seq_len using tokenizer on processed prompt - tmp_tokens = self.tokenizer(processed_prompt, return_tensors="np", padding=True) - input_len = tmp_tokens["input_ids"].shape[1] - num_chunks = -(input_len // -self._prefill_seq_len) - padded_len = num_chunks * self._prefill_seq_len - # Build language inputs with processor-aware vision/text integration - lang_inputs, vision_outputs = self._vision_handler.get_language_inputs_from_vision_processing( - image_url=image_path, query=text_prompt, padded_len=padded_len + lang_inputs, vision_outputs, num_chunks = self._vision_handler.get_language_inputs_from_vision_processing( + image_url=image_path, query=text_prompt, prefill_seq_len=self._prefill_seq_len ) # Set vision buffers in language session @@ -266,7 +307,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len ] position_ids_slice = lang_inputs["position_ids"][ - :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len + ..., i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len ] # Build minimal input set to avoid unintended buffers (e.g., batch_index) during prefill chunk_inputs = { @@ -454,6 +495,8 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream, max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len) self.initialize_decode_inputs(num_prompts, execution_batch_size, max_gen_length) + if self.is_qwen2_5_vl: + self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64) # Create prompt queue prompt_queue = deque(vision_prompts) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 6d340cd0b..c7c08044e 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1226,6 +1226,7 @@ def generate( # Create VisionLanguageGeneration instance batch_size_comp, ctx_len_comp, fbs = get_compilation_dims(self.lang_model.qpc_path) vlm_gen = VisionLanguageGeneration( + qeff_model=self, lang_qpc_path=self.lang_model.qpc_path, vision_qpc_path=self.vision_model.qpc_path, tokenizer=tokenizer, diff --git a/examples/qwen2_5_vl_CB.py b/examples/qwen2_5_vl_CB.py new file mode 100644 index 000000000..96ef4898a --- /dev/null +++ b/examples/qwen2_5_vl_CB.py @@ -0,0 +1,72 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +# If we want to enable QBlocking Run below command:, default is without blocking +# ATTENTION_BLOCKING_MODE=q num_q_blocks=2 python -W ignore qwen2_5_vl_example.py + +import transformers +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +## For AWQ model update pytorch version to 2.8.* +model_id = "Qwen/Qwen2.5-VL-32B-Instruct" +config = AutoConfig.from_pretrained(model_id) +config.text_config.num_hidden_layers = 2 + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +batch_size = 1 +## Vision + Text ## +qeff_model.compile( + batch_size=batch_size, + full_batch_size=4, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, +) + +image_urls = [ + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +streamer = TextStreamer(tokenizer) +output = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + generation_len=100, +) +print(output.generated_ids) +print(tokenizer.batch_decode(output.generated_ids)) +print(output) From ca0cc0333f8b832b0d42904dbab85a170b0d3ffd Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Fri, 31 Oct 2025 08:07:33 +0000 Subject: [PATCH 20/24] Removed CB regard kw args for functionin non CB models Signed-off-by: Rishin Raj --- .../transformers/models/modeling_auto.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index c7c08044e..eea61b357 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -858,7 +858,7 @@ class _QEffAutoModelForImageTextToTextDualQPC: def __init__( self, model: nn.Module, - continuous_batching, + continuous_batching: bool = False, **kwargs, ): """ @@ -982,8 +982,15 @@ def export( List[str] A list containing the paths to the generated ONNX graph files for both components. """ - inputs = self.model.get_dummy_inputs(kv_offload=True, continuous_batching=self.continuous_batching) - dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True, continuous_batching=self.continuous_batching) + # TODO This is a temporary change as continous batching is enabled only for few models. Once support is added for all the models this exception handing can be removed. + try: + inputs = self.model.get_dummy_inputs(kv_offload=True, continuous_batching=self.continuous_batching) + dynamic_axes = self.model.get_onnx_dynamic_axes( + kv_offload=True, continuous_batching=self.continuous_batching + ) + except TypeError: + inputs = self.model.get_dummy_inputs(kv_offload=True) + dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True) output_names = self.model.get_output_names(kv_offload=True) self.vision_model.export( @@ -1124,6 +1131,11 @@ def compile( ): self.export() + # TODO this hould be removed once the continous batching is supported for all the models. + compiler_options.pop("continuous_batching", None) + compiler_options.pop("kv_cache_batch_size", None) + compiler_options.pop("full_batch_size", None) + if not skip_vision: self.vision_model._compile( compile_dir=compile_dir, From b0ee5a4c961d1c00a350321dd3d19f3856fc237e Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Fri, 31 Oct 2025 17:58:39 +0000 Subject: [PATCH 21/24] Added caching for vision outputs Signed-off-by: Rishin Raj --- QEfficient/generation/vlm_generation.py | 241 ++++++++++++++++++------ QEfficient/utils/__init__.py | 1 + QEfficient/utils/_utils.py | 30 +++ 3 files changed, 212 insertions(+), 60 deletions(-) diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 97664cf4a..624af91b3 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -35,6 +35,7 @@ calculate_latency, write_io_files, ) +from QEfficient.utils import LRUCache from QEfficient.utils.logging_utils import logger @@ -140,6 +141,8 @@ def __init__( self._vision_qpc_path = vision_qpc_path self.device_id = device_id # Store device_id for vision components self.enable_debug_logs = enable_debug_logs # Store for vision components + self._vision_outputs_cache = LRUCache(max_size=100) # LRU cache for vision outputs + self._vision_cache = {} # Cache for vision outputs across batches self._init_vision_components() # Now that vision components are initialized, activate the text session @@ -201,14 +204,20 @@ def _get_vision_config(self) -> Dict[str, Any]: def _setup_vision_buffer_skipping(self): """Skip KV cache and retained state buffers for vision session""" - skip_patterns = [lambda x: x.startswith("past_"), lambda x: x.endswith("_RetainedState")] - - buffers_to_skip = [ + # Pre-compute skip buffers + self._vision_skip_buffers = [ x for x in self._vision_session.input_names + self._vision_session.output_names - if any(pattern(x) for pattern in skip_patterns) + if x.startswith("past_") or x.endswith("_RetainedState") + ] + self._vision_session.skip_buffers(self._vision_skip_buffers) + + # Pre-compute language skip buffers + self._lang_skip_buffers = [ + x + for x in self._session.input_names + self._session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") ] - self._vision_session.skip_buffers(buffers_to_skip) def run_prefill_for_all_inputs(self, prompt_queue, generation_len): """ @@ -255,6 +264,70 @@ def update_decode_inputs_qwen2_5_vl(self, outputs, position_ids, generation_len, self.generation_len[decode_batch_id or slice(None)] = generation_len return next_token_id + def _execute_chunked_prefill( + self, + lang_inputs: Dict[str, np.ndarray], + num_chunks: int, + decode_batch_id: Optional[np.ndarray] = None, + prefill_logit_bs: int = 1, + ) -> Dict[str, np.ndarray]: + """ + Execute chunked prefill with language inputs (Optimization 3: extracted common logic). + + Args: + lang_inputs: Pre-processed language inputs with input_ids, position_ids, etc. + num_chunks: Number of chunks to process + decode_batch_id: Batch ID for continuous batching (optional) + prefill_logit_bs: Batch size for prefill logits + + Returns: + Final prefill outputs + """ + # Set output buffers + self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1) + + # Skip buffers for dual-QPC coordination (Optimization 2: use cached list) + self._session.skip_buffers(self._lang_skip_buffers) + + # Run chunked prefill + outputs = None + chunk_image_idx = None + + for i in range(num_chunks): + input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len] + position_ids_slice = lang_inputs["position_ids"][ + ..., i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len + ] + + chunk_inputs = { + "input_ids": input_ids_slice, + "position_ids": position_ids_slice, + "image_idx": chunk_image_idx if chunk_image_idx is not None else np.array([[0]], dtype=np.int64), + } + + if decode_batch_id is not None: + chunk_inputs["batch_index"] = decode_batch_id + + if "cross_attention_mask" in lang_inputs: + chunk_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"] + + outputs = self._session.run(chunk_inputs) + + if "image_idx_output" in outputs: + chunk_image_idx = outputs["image_idx_output"] + + if self._write_io_dir is not None: + write_io_files(lang_inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) + + # Prepare decode-time cross_attention_mask + if "cross_attention_mask" in lang_inputs: + bs, _, num_images, img_tiles = lang_inputs["cross_attention_mask"].shape + self._decode_cross_attention_mask = np.ones((bs, 1, num_images, img_tiles), dtype=np.int64) + else: + self._decode_cross_attention_mask = None + + return outputs + def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None): """ Override base class prefill to handle vision processing @@ -281,10 +354,21 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i if isinstance(prompt, tuple) and len(prompt) == 2: image_path, text_prompt = prompt - # Build language inputs with processor-aware vision/text integration - lang_inputs, vision_outputs, num_chunks = self._vision_handler.get_language_inputs_from_vision_processing( - image_url=image_path, query=text_prompt, prefill_seq_len=self._prefill_seq_len - ) + # Check cache for vision outputs + cache_key = image_path if isinstance(image_path, str) else str(image_path) + if cache_key in self._vision_cache: + lang_inputs, vision_outputs, num_chunks = self._vision_cache[cache_key] + logger.debug(f"Using cached vision outputs for {cache_key}") + else: + # Build language inputs with processor-aware vision/text integration + lang_inputs, vision_outputs, num_chunks = ( + self._vision_handler.get_language_inputs_from_vision_processing( + image_url=image_path, query=text_prompt, prefill_seq_len=self._prefill_seq_len + ) + ) + # Cache for future use + self._vision_cache[cache_key] = (lang_inputs, vision_outputs, num_chunks) + logger.debug(f"Cached vision outputs for {cache_key}") # Set vision buffers in language session self._session.set_buffers(vision_outputs) @@ -296,58 +380,12 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i max_gen_len = self._ctx_len - np.where(lang_inputs["position_ids"] != -1, 1, 0).sum(1, keepdims=True).max() generation_len = self._fetch_generation_len(generation_len, max_gen_len) - # Set the prefill output buffers - self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1) - - # Run prefill across chunks, updating image_idx as needed - outputs = None - chunk_image_idx = None # track image_idx across chunks - for i in range(num_chunks): - input_ids_slice = lang_inputs["input_ids"][ - :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len - ] - position_ids_slice = lang_inputs["position_ids"][ - ..., i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len - ] - # Build minimal input set to avoid unintended buffers (e.g., batch_index) during prefill - chunk_inputs = { - "input_ids": input_ids_slice, - "position_ids": position_ids_slice, - "image_idx": chunk_image_idx if chunk_image_idx is not None else np.array([[0]], dtype=np.int64), - } - if decode_batch_id is not None: - chunk_inputs["batch_index"] = decode_batch_id - if "cross_attention_mask" in lang_inputs: - chunk_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"] - - outputs = self._session.run(chunk_inputs) - - # Update image_idx for next chunk if provided by model - if "image_idx_output" in outputs: - chunk_image_idx = outputs["image_idx_output"] - - if self._write_io_dir is not None: - write_io_files(lang_inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) + # Execute chunked prefill (Optimization 3: use extracted method) + outputs = self._execute_chunked_prefill(lang_inputs, num_chunks, decode_batch_id, prefill_logit_bs) # Prepare position_ids for decode phase (next position after prefill) position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 - # Prepare decode-time cross_attention_mask (ones over image tiles) if available - if "cross_attention_mask" in lang_inputs: - bs, _, num_images, img_tiles = lang_inputs["cross_attention_mask"].shape - self._decode_cross_attention_mask = np.ones((bs, 1, num_images, img_tiles), dtype=np.int64) - else: - self._decode_cross_attention_mask = None - - # Skip retained_state and past_ buffers before decode for dual-QPC coordination - self._session.skip_buffers( - [ - x - for x in self._session.input_names + self._session.output_names - if x.startswith("past_") or x.endswith("_RetainedState") - ] - ) - return outputs, position_ids_decode, generation_len else: # Fall back to base class for text-only @@ -404,6 +442,9 @@ def generate( if len(images) != len(prompts): raise ValueError(f"Number of images ({len(images)}) must match number of prompts ({len(prompts)})") + # Clear vision cache for fresh generation + self._vision_cache.clear() + logger.info(f"Generating for {len(images)} image-prompt pairs") # Convert to base class format: list of (image, prompt) tuples @@ -487,8 +528,8 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream, # Reset vision processing state for new generation self._vision_processed = False self._vision_outputs = None + self._vision_outputs_cache = {} - # Use the base class continuous batching logic directly # Initialize decode inputs num_prompts = len(vision_prompts) execution_batch_size = self.full_batch_size @@ -503,10 +544,39 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream, start = perf_counter() + # Pre-process ALL vision inputs and cache them + logger.info("Pre-processing all vision inputs...") + for batch_id in range(min(self.full_batch_size, len(vision_prompts))): + img, prompt = vision_prompts[batch_id] + + # Process vision for this slot + lang_inputs, vision_outputs, num_chunks = self._vision_handler.get_language_inputs_from_vision_processing( + image_url=img, query=prompt, prefill_seq_len=self._prefill_seq_len + ) + + # Cache vision outputs for this batch slot (Optimization 4: use LRU cache) + self._vision_outputs_cache[batch_id] = { + "vision_outputs": vision_outputs, + "lang_inputs": lang_inputs, + "num_chunks": num_chunks, + } + + logger.debug(f"Cached vision outputs for batch_id {batch_id}") + + # Reset prompt queue for prefill + prompt_queue = deque(vision_prompts) + self.batch_index = None - # Run prefill for all inputs first (batch_index should NOT be set during prefill) - self.run_prefill_for_all_inputs(prompt_queue, generation_len) + # Run prefill for all inputs using cached vision + self.run_prefill_for_all_inputs_with_cached_vision(prompt_queue, generation_len) + + # Set vision buffers for decode (use first slot's vision for now) + # For identical images, any slot's vision works (Optimization 4: use LRU cache) + cached_slot_0 = self._vision_outputs_cache.get(0) + if cached_slot_0: + self._session.set_buffers(cached_slot_0["vision_outputs"]) + logger.debug("Set vision buffers from slot 0 for decode phase") # Now set batch_index for decode phase self.batch_index = np.arange(self.full_batch_size).reshape(-1, 1) @@ -531,6 +601,57 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream, batch_size=1, generated_texts=generated_texts, generated_ids=self.generated_ids, perf_metrics=perf_metrics ) + def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation_len): + """ + Runs prefill for all inputs using pre-cached vision outputs. + + This avoids the vision buffer overwriting issue by using cached vision + outputs instead of processing vision during each prefill iteration. + + Args: + prompt_queue (deque): The queue of prompts. + generation_len (int): The generation length. + """ + for decode_batch_id in range(self.full_batch_size): + # Get cached vision outputs for this batch slot (Optimization 4: use LRU cache) + cached = self._vision_outputs_cache.get(decode_batch_id) + if cached: + vision_outputs = cached["vision_outputs"] + lang_inputs = cached["lang_inputs"] + num_chunks = cached["num_chunks"] + + # Set vision buffers for THIS prefill + self._session.set_buffers(vision_outputs) + logger.debug(f"Set vision buffers for batch_id {decode_batch_id} prefill") + + # Run prefill with cached inputs (Optimization 3: use extracted method) + outputs = self._execute_chunked_prefill( + lang_inputs, + num_chunks, + decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), + prefill_logit_bs=1, + ) + + # Calculate position_ids for decode + position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 + + # Calculate generation_len + max_gen_len = ( + self._ctx_len - np.where(lang_inputs["position_ids"] != -1, 1, 0).sum(1, keepdims=True).max() + ) + generation_len_final = self._fetch_generation_len(generation_len, max_gen_len) + + # Update decode inputs + if self.is_qwen2_5_vl: + self.update_decode_inputs_qwen2_5_vl( + outputs, position_ids_decode, generation_len_final, decode_batch_id + ) + else: + self.update_decode_input(outputs, position_ids_decode, generation_len_final, decode_batch_id) + else: + logger.error(f"No cached vision outputs for batch_id {decode_batch_id}") + raise RuntimeError(f"Vision outputs not cached for batch_id {decode_batch_id}") + def prepare_decode_inputs(self): """ Override base class to handle vision-specific decode inputs diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index e487d4af4..49f0ad30b 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -10,6 +10,7 @@ undo_transformers_quantizers, ) from QEfficient.utils._utils import ( # noqa: F401 + LRUCache, check_and_assign_cache_dir, create_json, create_model_params, diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index abe383556..d58f54952 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -33,6 +33,36 @@ from QEfficient.utils.logging_utils import logger +class LRUCache: + """Simple LRU cache with size limit for vision outputs""" + + def __init__(self, max_size=100): + self._cache = {} + self._access_order = [] + self._max_size = max_size + + def get(self, key): + if key in self._cache: + self._access_order.remove(key) + self._access_order.append(key) + return self._cache[key] + return None + + def put(self, key, value): + if key in self._cache: + self._access_order.remove(key) + elif len(self._cache) >= self._max_size: + oldest = self._access_order.pop(0) + del self._cache[oldest] + + self._cache[key] = value + self._access_order.append(key) + + def clear(self): + self._cache.clear() + self._access_order.clear() + + class DownloadRetryLimitExceeded(Exception): """ Used for raising error when hf_download fails to download the model after given max_retries. From a0a80c06615d047f9788cac190a55293decf3b38 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Mon, 3 Nov 2025 06:25:03 +0000 Subject: [PATCH 22/24] Added change for poping prompt while running continuous batching Signed-off-by: Rishin Raj --- QEfficient/generation/vlm_generation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 624af91b3..bb58b3e1d 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -613,6 +613,9 @@ def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation generation_len (int): The generation length. """ for decode_batch_id in range(self.full_batch_size): + # Pop the promt as we are processing + _ = prompt_queue.popleft() + # Get cached vision outputs for this batch slot (Optimization 4: use LRU cache) cached = self._vision_outputs_cache.get(decode_batch_id) if cached: From 089248044a7ffda7ffa98d9eb349b457bfc5a81a Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Mon, 3 Nov 2025 08:50:51 +0000 Subject: [PATCH 23/24] Adding multi_frame modeling changes and some fix Signed-off-by: Mohit Soni --- QEfficient/generation/embedding_handler.py | 98 +++++-------------- QEfficient/generation/vlm_generation.py | 8 +- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 1 + 3 files changed, 31 insertions(+), 76 deletions(-) diff --git a/QEfficient/generation/embedding_handler.py b/QEfficient/generation/embedding_handler.py index fc1695756..76da7afc2 100644 --- a/QEfficient/generation/embedding_handler.py +++ b/QEfficient/generation/embedding_handler.py @@ -70,7 +70,7 @@ def is_available(self) -> bool: """ return self._vision_session is not None and self._processor is not None - def prepare_vision_inputs(self, image_url: str, query: str) -> Dict[str, np.ndarray]: + def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -> Dict[str, np.ndarray]: """ Download and preprocess image into model inputs @@ -112,6 +112,14 @@ def prepare_vision_inputs(self, image_url: str, query: str) -> Dict[str, np.ndar # Process image and text inputs = self._processor(images=image, text=prompt, return_tensors="pt") + if ( + hasattr(self._qeff_model.model.config, "model_type") + and self._qeff_model.model.config.model_type == "qwen2_5_vl" + ): + inputs = self._qeff_model.model.prepare_inputs_for_generation( + inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0] + ) + # Convert to float32 if needed if "pixel_values" in inputs: inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) @@ -135,7 +143,9 @@ def prepare_vision_inputs(self, image_url: str, query: str) -> Dict[str, np.ndar if k in vision_inputs: vision_inputs[k] = vision_inputs[k].astype("float16") - return vision_inputs + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + + return vision_inputs, lang_inputs except Exception as e: raise RuntimeError(f"Failed to process image {image_url}: {str(e)}") @@ -294,7 +304,7 @@ def prepare_complete_vision_language_inputs( return vision_inputs, vision_outputs - def get_language_inputs_from_vision_processing( + def get_processed_inputs( self, image_url: str, query: str, prefill_seq_len: int ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: """ @@ -312,97 +322,43 @@ def get_language_inputs_from_vision_processing( raise ValueError("Vision handler not properly initialized") try: - # Download and process image - if image_url.startswith(("http://", "https://")): - image = Image.open(requests.get(image_url, stream=True).raw) - else: - image = Image.open(image_url) - - # Prepare conversation - conversation = [ - { - "role": "user", - "content": [ - {"type": "text", "text": query}, - {"type": "image"}, - ], - }, - ] - - prompt = self._processor.apply_chat_template(conversation, add_generation_prompt=True) - inputs = self._processor(images=image, text=prompt, return_tensors="pt") - - if ( - hasattr(self._qeff_model.model.config, "model_type") - and self._qeff_model.model.config.model_type == "qwen2_5_vl" - ): - inputs = self._qeff_model.model.prepare_inputs_for_generation( - inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0] - ) - - if "pixel_values" in inputs: - inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + ## Get vlm inputs ## + vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len) # Handle padding for language model pad_token_id = 1 - input_ids_length = inputs["input_ids"].shape[1] + input_ids_length = lang_inputs["input_ids"].shape[1] num_chunks = -(input_ids_length // -prefill_seq_len) padded_len = num_chunks * prefill_seq_len - inputs["input_ids"] = torch.nn.functional.pad( - inputs["input_ids"], + lang_inputs["input_ids"] = torch.nn.functional.pad( + lang_inputs["input_ids"], (0, padded_len - input_ids_length), "constant", pad_token_id, ) - inputs["attention_mask"] = torch.nn.functional.pad( - inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 + lang_inputs["attention_mask"] = torch.nn.functional.pad( + lang_inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 ) - if "cross_attention_mask" in inputs: - inputs["cross_attention_mask"] = torch.nn.functional.pad( - inputs["cross_attention_mask"], (0, 0, 0, 0, 0, padded_len - input_ids_length) + if "cross_attention_mask" in lang_inputs: + lang_inputs["cross_attention_mask"] = torch.nn.functional.pad( + lang_inputs["cross_attention_mask"], (0, 0, 0, 0, 0, padded_len - input_ids_length) ) - # Convert to numpy - for k, v in inputs.items(): - inputs[k] = np.array(v) - - # Separate vision and language inputs - vision_inputs = { - k: v - for k, v in inputs.items() - if k - in { - "pixel_values", - "image_masks", - "image_input_idx", - "valid_idx", - "aspect_ratio_ids", - "aspect_ratio_mask", - } - } - - # Convert vision inputs to appropriate dtypes - vision_inputs_fp16 = {"pixel_values", "image_masks"} - for k in vision_inputs_fp16: - if k in vision_inputs: - vision_inputs[k] = vision_inputs[k].astype("float16") + for k, v in lang_inputs.items(): + lang_inputs[k] = np.array(v) - # Run vision inference if we have vision inputs vision_outputs = {} if vision_inputs: self.setup_vision_buffers() vision_outputs = self.run_vision_inference(vision_inputs) - # Prepare language inputs - lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} - - if "position_ids" in inputs: - lang_inputs["position_ids"] = inputs["position_ids"] + if "position_ids" in lang_inputs: lang_inputs.pop("attention_mask") else: lang_inputs["position_ids"] = np.where(lang_inputs.pop("attention_mask"), np.arange(padded_len), -1) + lang_inputs["image_idx"] = np.array([[0]]) return lang_inputs, vision_outputs, num_chunks diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index bb58b3e1d..3e5878840 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -361,10 +361,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i logger.debug(f"Using cached vision outputs for {cache_key}") else: # Build language inputs with processor-aware vision/text integration - lang_inputs, vision_outputs, num_chunks = ( - self._vision_handler.get_language_inputs_from_vision_processing( - image_url=image_path, query=text_prompt, prefill_seq_len=self._prefill_seq_len - ) + lang_inputs, vision_outputs, num_chunks = self._vision_handler.get_processed_inputs( + image_url=image_path, query=text_prompt, prefill_seq_len=self._prefill_seq_len ) # Cache for future use self._vision_cache[cache_key] = (lang_inputs, vision_outputs, num_chunks) @@ -550,7 +548,7 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream, img, prompt = vision_prompts[batch_id] # Process vision for this slot - lang_inputs, vision_outputs, num_chunks = self._vision_handler.get_language_inputs_from_vision_processing( + lang_inputs, vision_outputs, num_chunks = self._vision_handler.get_processed_inputs( image_url=img, query=prompt, prefill_seq_len=self._prefill_seq_len ) diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index b77f5075c..0f6630210 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -880,6 +880,7 @@ def get_specializations( img_size: None, height: int = None, width: int = None, + num_frames: int = 1, kv_offload: bool = False, continuous_batching: bool = False, kv_cache_batch_size: Optional[int] = None, From 875c4d2968f928c21076c22869f3fa3a4379cbbb Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Mon, 3 Nov 2025 15:04:14 +0000 Subject: [PATCH 24/24] Adressed review comments Signed-off-by: Rishin Raj --- QEfficient/generation/vlm_generation.py | 21 ++++++++++++------- .../transformers/models/modeling_auto.py | 7 +++---- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 3e5878840..2e8f04f2b 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -8,7 +8,7 @@ """ This module provides the VisionLanguageGeneration class that inherits from QEffTextGenerationBase, enabling all advanced text generation features while -maintaining full API compatibility with the original VisionLanguageGeneration.bn hv$$ Z&F +maintaining full API compatibility with the original VisionLanguageGeneration. Key enhancements: - Continuous batching support for vision models @@ -95,6 +95,7 @@ def __init__( Initialize vision-language generation with enhanced capabilities Args: + qeff_model: QEff model instance tokenizer: Text tokenizer processor: Image processor lang_qpc_path: Path to language model QPC @@ -272,7 +273,7 @@ def _execute_chunked_prefill( prefill_logit_bs: int = 1, ) -> Dict[str, np.ndarray]: """ - Execute chunked prefill with language inputs (Optimization 3: extracted common logic). + Execute chunked prefill with language inputs Args: lang_inputs: Pre-processed language inputs with input_ids, position_ids, etc. @@ -286,7 +287,7 @@ def _execute_chunked_prefill( # Set output buffers self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1) - # Skip buffers for dual-QPC coordination (Optimization 2: use cached list) + # Skip buffers for dual-QPC coordination self._session.skip_buffers(self._lang_skip_buffers) # Run chunked prefill @@ -378,9 +379,11 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i max_gen_len = self._ctx_len - np.where(lang_inputs["position_ids"] != -1, 1, 0).sum(1, keepdims=True).max() generation_len = self._fetch_generation_len(generation_len, max_gen_len) - # Execute chunked prefill (Optimization 3: use extracted method) + # Execute chunked prefill outputs = self._execute_chunked_prefill(lang_inputs, num_chunks, decode_batch_id, prefill_logit_bs) + self._session.skip_buffers(vision_outputs) + # Prepare position_ids for decode phase (next position after prefill) position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 @@ -552,7 +555,7 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream, image_url=img, query=prompt, prefill_seq_len=self._prefill_seq_len ) - # Cache vision outputs for this batch slot (Optimization 4: use LRU cache) + # Cache vision outputs for this batch slot self._vision_outputs_cache[batch_id] = { "vision_outputs": vision_outputs, "lang_inputs": lang_inputs, @@ -570,7 +573,7 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream, self.run_prefill_for_all_inputs_with_cached_vision(prompt_queue, generation_len) # Set vision buffers for decode (use first slot's vision for now) - # For identical images, any slot's vision works (Optimization 4: use LRU cache) + # For identical images, any slot's vision works cached_slot_0 = self._vision_outputs_cache.get(0) if cached_slot_0: self._session.set_buffers(cached_slot_0["vision_outputs"]) @@ -614,7 +617,7 @@ def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation # Pop the promt as we are processing _ = prompt_queue.popleft() - # Get cached vision outputs for this batch slot (Optimization 4: use LRU cache) + # Get cached vision outputs for this batch slot cached = self._vision_outputs_cache.get(decode_batch_id) if cached: vision_outputs = cached["vision_outputs"] @@ -625,7 +628,7 @@ def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation self._session.set_buffers(vision_outputs) logger.debug(f"Set vision buffers for batch_id {decode_batch_id} prefill") - # Run prefill with cached inputs (Optimization 3: use extracted method) + # Run prefill with cached inputs outputs = self._execute_chunked_prefill( lang_inputs, num_chunks, @@ -633,6 +636,8 @@ def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation prefill_logit_bs=1, ) + self._session.skip_buffers(vision_outputs.keys()) + # Calculate position_ids for decode position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index eea61b357..aeb72d858 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1022,7 +1022,6 @@ def compile( num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, - num_speculative_tokens: Optional[int] = None, skip_vision: Optional[bool] = False, skip_lang: Optional[bool] = False, **compiler_options, @@ -2051,15 +2050,15 @@ def from_pretrained( If `continuous_batching` is provided as True. """ # TODO: add a check to see if kv_offload is allowed for given model by loading the config and checking architecture or type of config here. + if continuous_batching and not kv_offload: + NotImplementedError("Continuous batching is not supported for kv_offload = False") + if kwargs.get("attn_implementation", None) not in {None, "eager"}: logger.warning('Updating attn_implementation="eager"') if kwargs.get("low_cpu_mem_usage", None): logger.warning("Updating low_cpu_mem_usage=False") - if continuous_batching and not kv_offload: - NotImplementedError("Continuous batching is not supported for kv_offload = False") - kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls(