diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index fd7ef03ff..16bcf9e79 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -444,7 +444,11 @@ def __init__( self._set_tokenizer_params() # set tokenizer params # Skip inputs/outputs self._session.skip_buffers( - [x for x in self._session.input_names + self._session.output_names if x.startswith("past_")] + [ + x + for x in self._session.input_names + self._session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] ) def _set_tokenizer_params(self): @@ -822,6 +826,166 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): return decode_pause_time + def run_vision_language_continuous_batching_decode( + self, prompt_queue, generation_len, shared_vision_embeddings=None + ): + """ + Runs continuous batching decode for vision language models with shared vision embeddings. + + Method sets up the initial conditions for decoding and preparing the decode inputs. Then enters a loop that continues as long as there are prompts in the queue or any decoding is ongoing. In each iteration of the loop, it runs the session with the current decode inputs, prepares the inputs for the next iteration and updates the decode inputs. If a prompt has been fully decoded, it runs prefill for the next prompt in the queue if available. + + Args: + prompt_queue (deque): The queue of prompts to be decoded. + generation_len (int): The generation length. + shared_vision_embeddings (np.array, optional): Shared vision embeddings for vision-language models. Defaults to None. + + """ + # Set logits placeholder for decode + logits_out_placeholder = np.zeros( + (self.full_batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32 + ) + self._session.set_buffers({"logits": logits_out_placeholder}) + + # Set shared vision embeddings if provided + if shared_vision_embeddings is not None: + self._session.set_buffers(shared_vision_embeddings) + + # Generate flag for tracking progress for each batch ID + current_decode_ongoing = np.full((self.full_batch_size, 1), True) + + # Generate an array for maintaining the tokens generated in each batch ID + generated_id_current_index = np.ones((self.full_batch_size, 1), np.int64) + + # Generate a batch ID map for mapping the batch ID if input > full_batch_size. + # This ID map will be used for storing all generated tokens + batch_id_map = {i: i for i in range(self.full_batch_size)} + decode_pause_time = 0 + + # Prepare decode inputs. + decode_inputs = self.prepare_decode_inputs() + + while prompt_queue or current_decode_ongoing.any(): + outputs = self._session.run(decode_inputs) + + # Prepare inputs for next iteration + logits = outputs["logits"] + if len(logits.shape) == 2: + logits = np.expand_dims(logits, 1) + next_token_id = logits.argmax(2) + + for decode_batch_id in range(self.full_batch_size): + if ( + next_token_id[decode_batch_id, -1] == self.tokenizer.eos_token_id + or generated_id_current_index[decode_batch_id] >= self.generation_len[decode_batch_id] + ): + if prompt_queue: + start = perf_counter() + # run prefill for next prompt input. + outputs, position_ids, generation_len = self.run_vision_language_prefill( + prompt_queue.popleft(), + generation_len, + decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), + shared_vision_embeddings=shared_vision_embeddings, + ) + + new_token_id = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) + + batch_id_map[decode_batch_id] = max(batch_id_map.values()) + 1 + self.generated_ids[batch_id_map[decode_batch_id], 0] = new_token_id.squeeze(1) + generated_id_current_index[decode_batch_id] = 1 + + self._session.set_buffers({"logits": logits_out_placeholder}) + + # Re-set shared vision embeddings for consistency + if shared_vision_embeddings: + self._session.set_buffers(shared_vision_embeddings) + + decode_pause_time += perf_counter() - start + + if self._prompt_to_lora_id_mapping_decode: + decode_inputs["lora_ids"][decode_batch_id] = self._prompt_to_lora_id_mapping_decode[ + batch_id_map[decode_batch_id] + ] + + else: + current_decode_ongoing[decode_batch_id] = False + else: + # If the generated sequence is valid and within generation len prepare for next decode + decode_inputs["input_ids"][decode_batch_id, -1] = next_token_id[decode_batch_id, -1] + decode_inputs["position_ids"][decode_batch_id, -1] += 1 + self.generated_ids[batch_id_map[decode_batch_id], generated_id_current_index[decode_batch_id]] = ( + next_token_id[decode_batch_id, -1] + ) + + generated_id_current_index[decode_batch_id] += 1 + + return decode_pause_time + + def run_vision_language_prefill( + self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None, shared_vision_embeddings=None + ): + """ + Default method for running decode. Executes the decoding process for a given set of inputs and a specified generation length. + Args: + prompt (str): The prompt for which to run prefill. + generation_len (int): Max allowed length for generating tokens. The decoding process will be terminated when generation length is reached. + decode_batch_id (np.ndarray, optional): The decode batch ID for continuous batching. Defaults to None. + """ + # Run prefill + inputs = self.tokenizer(prompt, return_tensors="np", padding=True) + position_ids = inputs["attention_mask"].sum(1, keepdims=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -self._prefill_seq_len) # ceil divide without float + padded_len = num_chunks * self._prefill_seq_len # Convert to a multiple of prompt_len + + # Initialize variables specific to request + # Calculate the max generation length. + max_gen_len = self._ctx_len - position_ids.max() + generation_len = self._fetch_generation_len(generation_len, max_gen_len) + + # Set the prefill logic buffer + logits_out_placeholder = np.zeros((prefill_logit_bs, 1, self._vocab_size), dtype=np.float32) + self._session.set_buffers({"logits": logits_out_placeholder}) + + # Set shared vision embeddings if provided + if shared_vision_embeddings is not None: + self._session.set_buffers(shared_vision_embeddings) + + inputs = self.tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + + if decode_batch_id is not None: + inputs["batch_index"] = decode_batch_id + if self.is_tlm: + inputs["num_logits_to_keep"] = np.zeros((1, 1)) + + if self._prompt_to_lora_id_mapping_prefill: + if self.full_batch_size: + inputs["lora_ids"] = np.array( + self._prompt_to_lora_id_mapping_prefill.popleft(), dtype=np.int64 + ).reshape(1, 1) + else: + batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)] + inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) + + for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][ + :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len + ] + chunk_inputs["position_ids"] = inputs["position_ids"][ + :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len + ] + outputs = self._session.run(chunk_inputs) + if self._write_io_dir is not None: + write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) + return ( + outputs, + position_ids, + generation_len, + ) + def run_decode(self, decode_inputs, generation_len, streamer: Optional[transformers.TextStreamer] = None): """ Default method for running decode. Executes the decoding process for a given set of inputs and a specified generation length. diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 16767fbe2..f87e008ee 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -443,6 +443,7 @@ def update( else: position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs is_sliding_layer = torch.tensor(bool(self.is_sliding[layer_idx])) # Update the position_ids to handle the sliding window @@ -460,10 +461,22 @@ def update( valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( - self.value_cache[layer_idx], kv_position_ids, value_states - ) + if batch_index is not None: + invalid_scatter_index = torch.iinfo(torch.int32).max + scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) + + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ) + + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ) + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], position_ids, value_states + ) k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] # Original Gather @@ -483,8 +496,12 @@ def update( final_indices = torch.where( (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices ) - k_out = CtxGatherFunc.apply(k_out, final_indices) - v_out = CtxGatherFunc.apply(v_out, final_indices) + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, final_indices) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, final_indices) + else: + k_out = CtxGatherFunc.apply(k_out, final_indices) + v_out = CtxGatherFunc.apply(v_out, final_indices) ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) return k_out, v_out diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 4b957ebec..ae9cd3a01 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -831,20 +831,69 @@ def __init__(self, model): self.language_model = self.model.language_model self.config = self.model.config - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + batch_index: Optional[torch.LongTensor] = None, + ): inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) + batch_size = None + + # Handle CB case with multiple prompts sharing same image + if batch_index is not None and batch_index.numel() > 1: + # For CB with multiple prompts sharing same image, reuse vision embeds accross batches + batch_size = input_ids.shape[0] + + # Expanfd vision_embeds to match batch size if needed + if vision_embeds.shape[0] == 1 and batch_size > 1: + vision_embeds = vision_embeds.expand(batch_size, -1, -1) + selected = input_ids == self.model.config.image_token_index indices1 = selected.to(torch.int64).cumsum(1) - 1 - indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + # indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + # indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + # image_features_expanded = vision_embeds.unsqueeze(0)[indices0, indices1] + + # Handle batch aware image indexing for CB + if batch_size is not None: + # For CB, use per-batch image indices + batch_image_idx = image_idx.expand_as(selected[:, :1]) + indices1 = torch.where(indices1 != -1, indices1 + batch_image_idx, indices1) + else: + # For non-CB, use global image indices + indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) - image_features_expanded = vision_embeds.unsqueeze(0)[indices0, indices1] + + # Handle vision embeddings indexing for batch processing + if vision_embeds.dim() == 3 and vision_embeds.shape[0] == input_ids.shape[0]: + # Batch wise vision embeddings + image_features_expanded = vision_embeds[indices0, indices1] + else: + # Single vision embeddings for all batches/ single image shared accross all batches + image_features_expanded = vision_embeds.unsqueeze(0)[indices0, indices1] + image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.model.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + use_cache=True, ) - next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) - image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) + # Update image_idx to point to the next available vision_embeds index - handle batch case + if batch_index is not None and indices1.numel() > 0: + # For CB, update image_idx per batch + next_idx = (indices1.max(dim=1, keepdim=True)[0] + 1).unsqueeze(1) + image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) + else: + next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) + return outputs.logits, vision_embeds, image_idx, outputs.past_key_values @@ -888,6 +937,9 @@ def get_specializations( ctx_len: int, img_size: int, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): max_num_tiles = compiler_options.pop("max_num_tiles", None) @@ -931,33 +983,49 @@ def get_specializations( vision = [ { - "batch_size": batch_size, + "batch_size": 1, # To process image only once for all batch_sizes(prompts) in continuous batching "max_num_tiles": max_num_tiles, "img_size": img_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "max_num_tiles": max_num_tiles, - "img_size": img_size, - "vision_size": vision_size, - "chunk_length": prefill_seq_len, - "chunk_ctx_len": chunk_ctx_len, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "max_num_tiles": max_num_tiles, - "img_size": img_size, - "vision_size": vision_size, - "chunk_length": prefill_seq_len, - "chunk_ctx_len": chunk_ctx_len, - }, - ] + + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + if continuous_batching: + lang_prefill["full_batch_size"] = full_batch_size or kv_cache_batch_size + # Enable multi-prompt support with shared vision embeddings + lang_prefill["shared_vision"] = 1 + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + if continuous_batching: + lang_decode["full_batch_size"] = full_batch_size or kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) specializations = {} @@ -966,18 +1034,22 @@ def get_specializations( specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["vision_embeds"] = {0: "vision_size"} + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} vision_dynamic_axes["pixel_values"] = {0: "max_num_tiles", 2: "img_size", 3: "img_size"} - pkv_dynamic_axes = {0: "batch_size"} + pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size"} for i in range(self.language_model.config.num_hidden_layers): # switch between chunk_ctx_len and ctx_len for RoPE and NoPE layers. if int((i + 1) % 4 != 0): @@ -1040,7 +1112,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 336) else: @@ -1085,10 +1157,14 @@ def get_dummy_inputs(self, kv_offload: bool = False): .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV past_key_values = self.get_dummy_pkv_cache( config=self.language_model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -1097,6 +1173,8 @@ def get_dummy_inputs(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(past_key_values[0][0].shape, dtype=torch.float32)) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 42898381d..99e7e07cf 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -31,6 +31,7 @@ from QEfficient.generation.text_generation_inference import ( CloudAI100ExecInfoNew, PerfMetrics, + TextGeneration, calculate_latency, get_compilation_dims, ) @@ -545,6 +546,7 @@ class _QEffAutoModelForImageTextToTextDualQPC: def __init__( self, model: nn.Module, + continuous_batching, **kwargs, ): if kwargs.pop("full_batch_size", None): @@ -553,6 +555,7 @@ def __init__( self.config = model.config self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) + self.continuous_batching = continuous_batching self.input_shapes, self.output_names = None, None @property @@ -592,8 +595,8 @@ def export( export_dir: Optional[str] = None, **kwargs, ) -> str: - inputs = self.model.get_dummy_inputs(kv_offload=True) - dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True) + inputs = self.model.get_dummy_inputs(kv_offload=True, continuous_batching=self.continuous_batching) + dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True, continuous_batching=self.continuous_batching) output_names = self.model.get_output_names(kv_offload=True) self.vision_model.export( @@ -630,14 +633,20 @@ def compile( skip_lang: Optional[bool] = False, **compiler_options, ) -> str: - if any(param is not None for param in [full_batch_size, kv_cache_batch_size, num_speculative_tokens]): + if skip_lang and skip_vision: + raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + + if self.continuous_batching and full_batch_size is None: + raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") + + if kv_cache_batch_size and not full_batch_size: raise ValueError( - f"Expected 'full_batch_size', 'kv_cache_batch_size', 'num_speculative_tokens' to be None but got: " - f"full_batch_size={full_batch_size}, kv_cache_batch_size={kv_cache_batch_size}, num_speculative_tokens={num_speculative_tokens}, " + "KV caching requires continuous batching. Please set `full_batch_size` and " + "enable `continuous_batching=True` in `from_pretrained`." ) - if skip_lang and skip_vision: - raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + # Infer kv_cache_batch_size if not provided + kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size output_names = self.model.get_output_names(kv_offload=True) @@ -647,6 +656,9 @@ def compile( ctx_len=ctx_len, img_size=img_size, kv_offload=True, + continuous_batching=self.continuous_batching, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, **compiler_options, ) @@ -715,6 +727,8 @@ def compile( def generate( self, inputs: torch.Tensor, + tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer] = None, + prompts: List[str] = None, streamer: Optional[TextStreamer] = None, device_ids: List[int] = None, runtime_ai100: bool = True, @@ -733,10 +747,116 @@ def generate( if not runtime_ai100: raise NotImplementedError("PyTorch execution is not supported yet for this model!") + # Handle CB for multiple prompts with same image + if self.continuous_batching and tokenizer and prompts and len(prompts) > 1: + return self.continuous_batching_multi_prompt_generate( + inputs=inputs, + tokenizer=tokenizer, + prompts=prompts, + device_ids=device_ids, + generation_len=generation_len, + streamer=streamer, + ) + if tokenizer and prompts: + return QEfficient.cloud_ai_100_exec_kv( + tokenizer, + self.lang_model.qpc_path, + prompt=prompts, + device_id=device_ids, + generation_len=generation_len, + ) + return self.kv_offload_generate( inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len ) + def continuous_batching_multi_prompt_generate( + self, + inputs: torch.Tensor, + tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer], + prompts: List[str], + device_ids: List[int] = None, + generation_len: Optional[int] = None, + streamer: Optional[TextStreamer] = None, + ): + """ + Optimized continuous batching generate function for multiple prompts with same image. + This method processes a single image with multiple text prompts in a continuous batching manner, by: + 1. Running the vision encoder once for the shared image. + 2. Using continuous batching for multiplt prompts in the language decoder. + 3. Sharing vision embeddings across all prompts to save memory and computation. + """ + if not self.lang_model.qpc_path: + raise TypeError("Please run compile API for language model first!") + + vision_session = None + lang_session = None + if self.vision_model.qpc_path: + vision_session = QAICInferenceSession(self.vision_model.qpc_path, device_ids) + + lang_session = QAICInferenceSession(self.lang_model.qpc_path, device_ids, activate=False) + + # Get compilation dimensions + batch_size, ctx_len, fbs = get_compilation_dims(self.lang_model.qpc_path) + + # Skip inputs/outputs + lang_session.skip_buffers( + [ + x + for x in lang_session.input_names + lang_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + ) + + # Process vision inputs once for all prompts + vision_inputs = { + k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} + } + if vision_inputs: + vision_inputs["pixel_values"] = vision_inputs["pixel_values"].to(torch.float16).cpu().numpy() + vision_start = perf_counter() + vision_outputs = {} + if vision_inputs: + vision_outputs = vision_session.run(vision_inputs) + vision_end = perf_counter() + + # Deactivate vision session after use + if self.vision_model.qpc_path: + vision_session.deactivate() + + # Text generation instance for continuous batching + text_generator = TextGeneration( + tokenizer=tokenizer, + qpc_path=self.lang_model.qpc_path, + device_id=device_ids, + ctx_len=ctx_len, + enable_debug_logs=False, + full_batch_size=fbs, + ) + + # Prepare prompts for CB + # Each prompt processed with same vision embeddings + tokenized_prompts = [] + for prompt in prompts: + tokenized_prompt = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) + tokenized_prompts.append(tokenized_prompt) + + # Run CB with shared vision embeddings + lang_session.activate() + + if vision_outputs: + lang_session.set_buffers(vision_outputs) + + # Execute continuous batching generate + exec_info = text_generator.generate( + prompt=prompts, + generation_len=generation_len, + streamer=streamer is not None, + ) + + print("Vision encoding time (s): ", vision_end - vision_start) + return exec_info + def kv_offload_generate( self, inputs: List[str] = None, @@ -808,7 +928,7 @@ def kv_offload_generate( } if vision_inputs: - vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") + vision_inputs["pixel_values"] = vision_inputs["pixel_values"].to(torch.float16).cpu().numpy() vision_start = perf_counter() vision_outputs = {} @@ -1259,15 +1379,21 @@ class QEFFAutoModelForImageTextToText: _hf_auto_class = AutoModelForImageTextToText - def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs): + def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, continuous_batching: bool = False, **kwargs): if kv_offload: - return _QEffAutoModelForImageTextToTextDualQPC(model, **kwargs) + return _QEffAutoModelForImageTextToTextDualQPC(model, continuous_batching, **kwargs) else: return _QEFFAutoModelForImageTextToTextSingleQPC(model, **kwargs) @classmethod @with_replaced_quantizers - def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, **kwargs): + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + kv_offload: Optional[bool] = None, + continuous_batching: bool = False, + **kwargs, + ): """Used to load models supported by transformers.AutoModelForImageTextToText for Cloud AI 100. Args: @@ -1284,12 +1410,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona if kwargs.get("low_cpu_mem_usage", None): logger.warning("Updating low_cpu_mem_usage=False") - if kwargs.pop("continuous_batching", None): - NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + if continuous_batching and not kv_offload: + NotImplementedError("Continuous batching is not supported for kv_offload = False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) - return cls(model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + return cls( + model, + kv_offload=kv_offload, + continuous_batching=continuous_batching, + pretrained_model_name_or_path=pretrained_model_name_or_path, + **kwargs, + ) MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = {"InternVLChatModel": QEFFAutoModelForImageTextToText} diff --git a/examples/llama4_CB_example.py b/examples/llama4_CB_example.py new file mode 100644 index 000000000..578581a3b --- /dev/null +++ b/examples/llama4_CB_example.py @@ -0,0 +1,99 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import torch +import transformers +from transformers import AutoConfig, AutoProcessor + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config.text_config.num_hidden_layers = 4 +config.vision_config.num_hidden_layers = 2 + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +qeff_model.compile( + prefill_seq_len=128, + ctx_len=3072, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + batch_size=1, + full_batch_size=4, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, +) + +image_url = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" +) + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +all_inputs = [] +for prompt in prompts: + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": image_url}, + {"type": "text", "text": prompt}, + ], + }, + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + all_inputs.append(inputs) + + +output = qeff_model.generate( + inputs=all_inputs[0], tokenizer=tokenizer, device_ids=[0, 1, 2, 3], prompts=prompts, generation_len=100 +) + +if hasattr(output, "generated_texts"): + for i, (prompt, response) in enumerate(zip(prompts, output.generated_texts)): + print(f"Prompt {i + 1}: {prompt}") + print(f"Response {i + 1}: {response}") + print("-" * 30) +else: + print("Generated IDs:", output.generated_ids) + decoded_responses = tokenizer.batch_decode(output.generated_ids, skip_special_tokens=True) + for i, (prompt, response) in enumerate(zip(prompts, decoded_responses)): + print(f"Prompt {i + 1}: {prompt}") + print(f"Response {i + 1}: {response}") + print("-" * 30) + +# print(output.generated_ids) +# print(tokenizer.batch_decode(output.generated_ids)) +print(output) +print()