diff --git a/QEfficient/generation/cloud_infer.py b/QEfficient/generation/cloud_infer.py index 8519d824c..5068c174e 100644 --- a/QEfficient/generation/cloud_infer.py +++ b/QEfficient/generation/cloud_infer.py @@ -90,8 +90,10 @@ def __init__( self.program = qaicrt.Program(self.context, None, qpc, prog_properties) if self.program.load() != qaicrt.QStatus.QS_SUCCESS: raise RuntimeError("Failed to load program") + self.is_active = False if activate: self.activate() + self.is_active = True # Create input qbuffers and buf_dims self.qbuffers = [qaicrt.QBuffer(bytes(binding.size)) for binding in self.bindings] self.buf_dims = qaicrt.BufferDimensionsVecRef( @@ -108,15 +110,17 @@ def output_names(self) -> List[str]: def activate(self): """Activate qpc""" - - self.program.activate() - self.execObj = qaicrt.ExecObj(self.context, self.program) + if not self.is_active: + self.program.activate() + self.execObj = qaicrt.ExecObj(self.context, self.program) + self.is_active = True def deactivate(self): """Deactivate qpc""" - - del self.execObj - self.program.deactivate() + if self.is_active: + del self.execObj + self.program.deactivate() + self.is_active = False def set_buffers(self, buffers: Dict[str, np.ndarray]): """ diff --git a/QEfficient/generation/embedding_handler.py b/QEfficient/generation/embedding_handler.py new file mode 100644 index 000000000..76da7afc2 --- /dev/null +++ b/QEfficient/generation/embedding_handler.py @@ -0,0 +1,367 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Vision Handler for Vision-Language Models + +This module provides the VisionHandler class that encapsulates all vision model +operations, separating them from the main text generation logic. +""" + +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import requests +import torch +from PIL import Image +from transformers import AutoImageProcessor + +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils.logging_utils import logger + + +class VisionHandler: + """ + Handles all vision model operations for vision-language models. + + This class encapsulates vision preprocessing, inference, and output handling, + providing a clean separation between vision and language processing. + """ + + def __init__( + self, + qeff_model: Optional[QAICInferenceSession], + vision_session: Optional[QAICInferenceSession], + processor: Optional[AutoImageProcessor], + config: Optional[Dict[str, Any]] = None, + lang_session: Optional[QAICInferenceSession] = None, + ): + """ + Initialize vision handler + + Args: + vision_session: QAICInferenceSession for vision model + processor: AutoImageProcessor for image preprocessing + config: Configuration dictionary with vision model parameters + lang_session: Optional language session for coordination (to avoid resource conflicts) + """ + self._qeff_model = qeff_model + self._vision_session = vision_session + self._processor = processor + self._config = config or {} + self._lang_session = lang_session # Store language session for coordination + + # Cache for vision output shapes + self._vision_output_shapes = None + + if self._vision_session and not self._processor: + logger.warning("Vision session provided but no processor. Vision functionality may be limited.") + + def is_available(self) -> bool: + """ + Check if vision processing is available + + Returns: + True if both vision session and processor are available + """ + return self._vision_session is not None and self._processor is not None + + def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -> Dict[str, np.ndarray]: + """ + Download and preprocess image into model inputs + + Args: + image_url: URL or path to image + query: Text query to process with image + + Returns: + Dictionary of vision model inputs + + Raises: + ValueError: If vision handler is not properly initialized + RuntimeError: If image processing fails + """ + if not self.is_available(): + raise ValueError("Vision handler not properly initialized. Need both vision_session and processor.") + + try: + # Download image + if image_url.startswith(("http://", "https://")): + image = Image.open(requests.get(image_url, stream=True).raw) + else: + image = Image.open(image_url) + + # Prepare conversation format + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": query}, + {"type": "image"}, + ], + }, + ] + + # Apply chat template + prompt = self._processor.apply_chat_template(conversation, add_generation_prompt=True) + + # Process image and text + inputs = self._processor(images=image, text=prompt, return_tensors="pt") + + if ( + hasattr(self._qeff_model.model.config, "model_type") + and self._qeff_model.model.config.model_type == "qwen2_5_vl" + ): + inputs = self._qeff_model.model.prepare_inputs_for_generation( + inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0] + ) + + # Convert to float32 if needed + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + + # Convert to numpy arrays + vision_inputs = {} + for k, v in inputs.items(): + if k in { + "pixel_values", + "image_masks", + "image_input_idx", + "valid_idx", + "aspect_ratio_ids", + "aspect_ratio_mask", + }: + vision_inputs[k] = np.array(v) + + # Convert specific inputs to float16 + vision_inputs_fp16 = {"pixel_values", "image_masks"} + for k in vision_inputs_fp16: + if k in vision_inputs: + vision_inputs[k] = vision_inputs[k].astype("float16") + + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + + return vision_inputs, lang_inputs + + except Exception as e: + raise RuntimeError(f"Failed to process image {image_url}: {str(e)}") + + def run_vision_inference(self, vision_inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Execute vision model inference with session coordination + + Args: + vision_inputs: Preprocessed vision inputs + + Returns: + Vision embeddings and metadata + + Raises: + ValueError: If vision session is not available + RuntimeError: If inference fails + """ + if not self._vision_session: + raise ValueError("Vision session not available") + + lang_was_active = False + try: + # Coordinate with language session to avoid resource conflicts + if self._lang_session and self._lang_session.is_active: + logger.debug("Deactivating language session before vision inference") + self._lang_session.deactivate() + lang_was_active = True + + # Activate vision session + logger.debug("Activating vision session for inference") + self._vision_session.activate() + + # Run inference + vision_outputs = self._vision_session.run(vision_inputs) + + # Deactivate vision session + logger.debug("Deactivating vision session after inference") + self._vision_session.deactivate() + + # Reactivate language session if it was active before + if lang_was_active and self._lang_session: + logger.debug("Reactivating language session after vision inference") + self._lang_session.activate() + + return vision_outputs + + except Exception as e: + # Ensure proper cleanup on error + if self._vision_session: + try: + self._vision_session.deactivate() + except Exception: + logger.warning("Deactivating vision session failed") + + # Restore language session if needed + if lang_was_active and self._lang_session: + try: + self._lang_session.activate() + except Exception: + logger.warning("Deactivating language session failed") + + raise RuntimeError(f"Vision inference failed: {str(e)}") + + def get_vision_output_shapes(self) -> Dict[str, Tuple[int, ...]]: + """ + Get vision output dimensions from config or session + + Returns: + Dictionary mapping output names to shapes + """ + if self._vision_output_shapes is not None: + return self._vision_output_shapes + + # Try to get from config first + if self._config and "vision_output_shapes" in self._config: + self._vision_output_shapes = self._config["vision_output_shapes"] + return self._vision_output_shapes + + # Try to derive from vision session + if self._vision_session: + try: + shapes = {} + for output_name in self._vision_session.output_names: + if ( + hasattr(self._vision_session, "bindings") + and output_name in self._vision_session.binding_index_map + ): + binding_idx = self._vision_session.binding_index_map[output_name] + if hasattr(self._vision_session.bindings[binding_idx], "dims"): + shapes[output_name] = tuple(self._vision_session.bindings[binding_idx].dims) + + if shapes: + self._vision_output_shapes = shapes + return shapes + except Exception as e: + logger.warning(f"Could not derive vision output shapes from session: {e}") + + # Fallback to default shapes (these were hard-coded in original implementation) + default_shapes = { + "vision_embeds": (2448, 5120) # This should be derived from model config + } + + logger.warning("Using default vision output shapes. Consider providing shapes in config.") + self._vision_output_shapes = default_shapes + return default_shapes + + def setup_vision_buffers(self): + """ + Configure vision model output buffers + + Raises: + ValueError: If vision session is not available + """ + if not self._vision_session: + raise ValueError("Vision session not available") + + try: + shapes = self.get_vision_output_shapes() + + # Set up output buffers + buffers = {} + for output_name, shape in shapes.items(): + # Create placeholder with appropriate dtype + if "vision_embeds" in output_name: + buffers[output_name] = np.zeros(shape, dtype=np.float16) + else: + buffers[output_name] = np.zeros(shape, dtype=np.float32) + + self._vision_session.set_buffers(buffers) + + except Exception as e: + raise RuntimeError(f"Failed to setup vision buffers: {str(e)}") + + def prepare_complete_vision_language_inputs( + self, image_url: str, query: str + ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: + """ + Complete pipeline: prepare inputs and run vision inference + + Args: + image_url: URL or path to image + query: Text query + + Returns: + Tuple of (vision_inputs, vision_outputs) + """ + # Prepare vision inputs + vision_inputs = self.prepare_vision_inputs(image_url, query) + + # Setup buffers + self.setup_vision_buffers() + + # Run vision inference + vision_outputs = self.run_vision_inference(vision_inputs) + + return vision_inputs, vision_outputs + + def get_processed_inputs( + self, image_url: str, query: str, prefill_seq_len: int + ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: + """ + Process vision inputs and prepare language model inputs + + Args: + image_url: URL or path to image + query: Text query + padded_len: Padded sequence length for language model + + Returns: + Tuple of (language_inputs, vision_outputs) + """ + if not self.is_available(): + raise ValueError("Vision handler not properly initialized") + + try: + ## Get vlm inputs ## + vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len) + + # Handle padding for language model + pad_token_id = 1 + input_ids_length = lang_inputs["input_ids"].shape[1] + num_chunks = -(input_ids_length // -prefill_seq_len) + padded_len = num_chunks * prefill_seq_len + + lang_inputs["input_ids"] = torch.nn.functional.pad( + lang_inputs["input_ids"], + (0, padded_len - input_ids_length), + "constant", + pad_token_id, + ) + lang_inputs["attention_mask"] = torch.nn.functional.pad( + lang_inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 + ) + + if "cross_attention_mask" in lang_inputs: + lang_inputs["cross_attention_mask"] = torch.nn.functional.pad( + lang_inputs["cross_attention_mask"], (0, 0, 0, 0, 0, padded_len - input_ids_length) + ) + + for k, v in lang_inputs.items(): + lang_inputs[k] = np.array(v) + + vision_outputs = {} + if vision_inputs: + self.setup_vision_buffers() + vision_outputs = self.run_vision_inference(vision_inputs) + + if "position_ids" in lang_inputs: + lang_inputs.pop("attention_mask") + else: + lang_inputs["position_ids"] = np.where(lang_inputs.pop("attention_mask"), np.arange(padded_len), -1) + + lang_inputs["image_idx"] = np.array([[0]]) + + return lang_inputs, vision_outputs, num_chunks + + except Exception as e: + raise RuntimeError(f"Failed to process vision-language inputs: {str(e)}") diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 6d04cf573..e96908824 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -437,15 +437,19 @@ def __init__( include_sampler: bool = False, return_pdfs: bool = False, sampling_params: Optional[Dict[str, Any]] = None, + activate: bool = True, ) -> None: self._ctx_len = ctx_len self._write_io_dir = write_io_dir self.is_tlm = is_tlm self.return_pdfs = return_pdfs self.sampling_params = sampling_params + self._qpc_path = qpc_path # Store qpc_path for later use # Load QPC - self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs) + self._session = QAICInferenceSession( + qpc_path, device_id, activate=activate, enable_debug_logs=enable_debug_logs + ) # Validate sampler inputs for On-Device Sampling self.include_sampler = validate_sampler_inputs( @@ -778,6 +782,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i if decode_batch_id is not None: inputs["batch_index"] = decode_batch_id + if self.is_tlm: inputs["num_logits_to_keep"] = np.zeros((1, 1)) if self.include_sampler: @@ -808,6 +813,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i if self.include_sampler: chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] outputs = self._session.run(chunk_inputs) + if self._write_io_dir is not None: write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) return ( diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py new file mode 100644 index 000000000..2e8f04f2b --- /dev/null +++ b/QEfficient/generation/vlm_generation.py @@ -0,0 +1,784 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +This module provides the VisionLanguageGeneration class that inherits from +QEffTextGenerationBase, enabling all advanced text generation features while +maintaining full API compatibility with the original VisionLanguageGeneration. + +Key enhancements: +- Continuous batching support for vision models +- Advanced streaming capabilities +- On-device sampling support +- LoRA adapter support +- Better performance metrics +""" + +from collections import deque +from time import perf_counter +from typing import Any, Dict, List, Optional, Union + +import numpy as np +from transformers import AutoImageProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast + +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.generation.embedding_handler import VisionHandler +from QEfficient.generation.text_generation_inference import ( + CloudAI100ExecInfo, + PerfMetrics, + QEffTextGenerationBase, + TextGeneration, + calculate_latency, + write_io_files, +) +from QEfficient.utils import LRUCache +from QEfficient.utils.logging_utils import logger + + +class VisionLanguageGeneration(QEffTextGenerationBase): + """ + Enhanced vision-language generation class inheriting from QEffTextGenerationBase. + + This class maintains full API compatibility with VisionLanguageGeneration while + adding advanced features like continuous batching, streaming, and sampling. + + Example: + >>> # Drop-in replacement for VisionLanguageGeneration + >>> vlm = VisionLanguageGeneration( + ... tokenizer=tokenizer, + ... processor=processor, + ... lang_qpc_path="path/to/lang.qpc", + ... vision_qpc_path="path/to/vision.qpc", + ... device_id=[0] + ... ) + >>> result = vlm.generate( + ... images=["image1.jpg"], + ... prompts=["Describe this image"], + ... generation_len=512 + ... ) + + >>> # Enhanced usage with new features + >>> vlm_enhanced = VisionLanguageGeneration( + ... tokenizer=tokenizer, + ... processor=processor, + ... lang_qpc_path="path/to/lang.qpc", + ... vision_qpc_path="path/to/vision.qpc", + ... device_id=[0], + ... full_batch_size=8, # Enable continuous batching + ... include_sampler=True, # Enable on-device sampling + ... sampling_params=sampling_config + ... ) + """ + + def __init__( + self, + qeff_model, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + processor: AutoImageProcessor, + lang_qpc_path: str, + vision_qpc_path: str, + device_id: Optional[List[int]] = None, + ctx_len: Optional[int] = None, + enable_debug_logs: bool = False, + write_io_dir: Optional[str] = None, + full_batch_size: Optional[int] = None, + is_tlm: bool = False, + include_sampler: bool = False, + return_pdfs: bool = False, + sampling_params: Optional[Dict[str, Any]] = None, + ): + """ + Initialize vision-language generation with enhanced capabilities + + Args: + qeff_model: QEff model instance + tokenizer: Text tokenizer + processor: Image processor + lang_qpc_path: Path to language model QPC + vision_qpc_path: Path to vision encoder QPC + device_id: Device IDs for execution (default: [0]) + ctx_len: Context length + enable_debug_logs: Enable debug logging + write_io_dir: Directory for I/O file writing + full_batch_size: Enable continuous batching (new feature) + is_tlm: Target language model flag + include_sampler: Enable on-device sampling (new feature) + return_pdfs: Return probability distributions + sampling_params: Sampling parameters for on-device sampling + """ + # Validate required parameters + if not lang_qpc_path: + raise TypeError("lang_qpc_path is required") + if not vision_qpc_path: + raise TypeError("vision_qpc_path is required") + + # Initialize base class with language QPC + # Pass activate=False to prevent premature activation before vision components are ready + super().__init__( + tokenizer=tokenizer, + qpc_path=lang_qpc_path, + full_batch_size=full_batch_size, + ctx_len=ctx_len, + device_id=device_id, + enable_debug_logs=enable_debug_logs, + write_io_dir=write_io_dir, + is_tlm=is_tlm, + include_sampler=include_sampler, + return_pdfs=return_pdfs, + sampling_params=sampling_params, + activate=False, # vision components need to be initialized first + ) + + # Vision-specific initialization + self.is_qwen2_5_vl = ( + hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen2_5_vl" + ) + self.qeff_model = qeff_model + self.processor = processor + self._vision_qpc_path = vision_qpc_path + self.device_id = device_id # Store device_id for vision components + self.enable_debug_logs = enable_debug_logs # Store for vision components + self._vision_outputs_cache = LRUCache(max_size=100) # LRU cache for vision outputs + self._vision_cache = {} # Cache for vision outputs across batches + self._init_vision_components() + + # Now that vision components are initialized, activate the text session + self._session.activate() + + logger.info( + f"VisionLanguageGeneration initialized: batch_size={self.batch_size}, " + f"prefill_seq_len={self._prefill_seq_len}, ctx_len={ctx_len}, " + f"continuous_batching={'enabled' if full_batch_size else 'disabled'}, " + f"sampling={'enabled' if include_sampler else 'disabled'}" + ) + + def _init_vision_components(self): + """Initialize vision-specific components""" + # Vision session (separate from base class language session) + self._vision_session = QAICInferenceSession( + self._vision_qpc_path, self.device_id, activate=False, enable_debug_logs=self.enable_debug_logs + ) + + # Vision handler with language session coordination + vision_config = self._get_vision_config() + self._vision_handler = VisionHandler( + qeff_model=self.qeff_model, + vision_session=self._vision_session, + processor=self.processor, + config=vision_config, + lang_session=self._session, # Pass language session for coordination + ) + + # Setup vision buffer skipping + self._setup_vision_buffer_skipping() + + def _get_vision_config(self) -> Dict[str, Any]: + """ + Derive vision config from session + + Returns: + Dictionary with vision configuration + """ + config = {} + if self._vision_session: + try: + shapes = {} + for output_name in self._vision_session.output_names: + if ( + hasattr(self._vision_session, "bindings") + and output_name in self._vision_session.binding_index_map + ): + binding_idx = self._vision_session.binding_index_map[output_name] + if hasattr(self._vision_session.bindings[binding_idx], "dims"): + shapes[output_name] = tuple(self._vision_session.bindings[binding_idx].dims) + + if shapes: + config["vision_output_shapes"] = shapes + except Exception as e: + logger.warning(f"Could not derive vision config from session: {e}") + + return config + + def _setup_vision_buffer_skipping(self): + """Skip KV cache and retained state buffers for vision session""" + # Pre-compute skip buffers + self._vision_skip_buffers = [ + x + for x in self._vision_session.input_names + self._vision_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + self._vision_session.skip_buffers(self._vision_skip_buffers) + + # Pre-compute language skip buffers + self._lang_skip_buffers = [ + x + for x in self._session.input_names + self._session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + + def run_prefill_for_all_inputs(self, prompt_queue, generation_len): + """ + Runs prefill for all inputs in the prompt queue and updates the decode input. + + Method iterates over the full batch size and for each decode batch ID, it pops the next prompt from the queue. It then runs prefill for the next prompt and updates the decode input with the outputs. + + Args: + prompt_queue (deque): The queue of prompts. + generation_len (int): The generation length. + + """ + for decode_batch_id in range(self.full_batch_size): + next_prompt = prompt_queue.popleft() + + # run prefill for num_chunks + outputs, position_ids, generation_len = self.run_prefill( + next_prompt, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1) + ) + + if self.is_qwen2_5_vl: + _ = self.update_decode_inputs_qwen2_5_vl(outputs, position_ids, generation_len, decode_batch_id) + else: + _ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) + + def update_decode_inputs_qwen2_5_vl(self, outputs, position_ids, generation_len, decode_batch_id=None): + """ + Updates the decode input with the generated values. + Args: + outputs (dict): The outputs of the model. + position_ids (array): The position IDs. + generation_len (int): The generation length. + decode_batch_id (int, optional): The decode batch ID. If None, all values are updated. Defaults to None. + + Returns: + next_token_id (array): The next token ID. + """ + next_token_id = self._fetch_next_token_id(outputs) + + # Store the generated values. + self.decode_input_ids[decode_batch_id or slice(None)] = next_token_id + self.decode_pos_ids[:, decode_batch_id] = position_ids.squeeze(1) + self.generated_ids[decode_batch_id or slice(None), 0] = next_token_id.squeeze(1) + self.generation_len[decode_batch_id or slice(None)] = generation_len + return next_token_id + + def _execute_chunked_prefill( + self, + lang_inputs: Dict[str, np.ndarray], + num_chunks: int, + decode_batch_id: Optional[np.ndarray] = None, + prefill_logit_bs: int = 1, + ) -> Dict[str, np.ndarray]: + """ + Execute chunked prefill with language inputs + + Args: + lang_inputs: Pre-processed language inputs with input_ids, position_ids, etc. + num_chunks: Number of chunks to process + decode_batch_id: Batch ID for continuous batching (optional) + prefill_logit_bs: Batch size for prefill logits + + Returns: + Final prefill outputs + """ + # Set output buffers + self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1) + + # Skip buffers for dual-QPC coordination + self._session.skip_buffers(self._lang_skip_buffers) + + # Run chunked prefill + outputs = None + chunk_image_idx = None + + for i in range(num_chunks): + input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len] + position_ids_slice = lang_inputs["position_ids"][ + ..., i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len + ] + + chunk_inputs = { + "input_ids": input_ids_slice, + "position_ids": position_ids_slice, + "image_idx": chunk_image_idx if chunk_image_idx is not None else np.array([[0]], dtype=np.int64), + } + + if decode_batch_id is not None: + chunk_inputs["batch_index"] = decode_batch_id + + if "cross_attention_mask" in lang_inputs: + chunk_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"] + + outputs = self._session.run(chunk_inputs) + + if "image_idx_output" in outputs: + chunk_image_idx = outputs["image_idx_output"] + + if self._write_io_dir is not None: + write_io_files(lang_inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) + + # Prepare decode-time cross_attention_mask + if "cross_attention_mask" in lang_inputs: + bs, _, num_images, img_tiles = lang_inputs["cross_attention_mask"].shape + self._decode_cross_attention_mask = np.ones((bs, 1, num_images, img_tiles), dtype=np.int64) + else: + self._decode_cross_attention_mask = None + + return outputs + + def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None): + """ + Override base class prefill to handle vision processing + + Args: + prompt: Can be string or tuple (image_path, text_prompt) + generation_len: Generation length + prefill_logit_bs: Prefill batch size + decode_batch_id: Batch ID for continuous batching + + Returns: + Same as base class: (outputs, position_ids, generation_len) + """ + # Normalize prompt: TextGeneration passes a list even for batch_size=1 + if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], tuple) and len(prompt[0]) == 2: + # Unwrap single (image_path, text_prompt) tuple + if len(prompt) == 1: + prompt = prompt[0] + else: + raise NotImplementedError( + "VisionLanguageGeneration.run_prefill currently supports a single (image, text) pair per call." + ) + # Check if this is a vision-language prompt + if isinstance(prompt, tuple) and len(prompt) == 2: + image_path, text_prompt = prompt + + # Check cache for vision outputs + cache_key = image_path if isinstance(image_path, str) else str(image_path) + if cache_key in self._vision_cache: + lang_inputs, vision_outputs, num_chunks = self._vision_cache[cache_key] + logger.debug(f"Using cached vision outputs for {cache_key}") + else: + # Build language inputs with processor-aware vision/text integration + lang_inputs, vision_outputs, num_chunks = self._vision_handler.get_processed_inputs( + image_url=image_path, query=text_prompt, prefill_seq_len=self._prefill_seq_len + ) + # Cache for future use + self._vision_cache[cache_key] = (lang_inputs, vision_outputs, num_chunks) + logger.debug(f"Cached vision outputs for {cache_key}") + + # Set vision buffers in language session + self._session.set_buffers(vision_outputs) + logger.debug(f"Vision buffers set: {list(vision_outputs.keys())}") + self._vision_processed = True + self._vision_outputs = vision_outputs + + # Calculate generation_len consistent with ctx_len + max_gen_len = self._ctx_len - np.where(lang_inputs["position_ids"] != -1, 1, 0).sum(1, keepdims=True).max() + generation_len = self._fetch_generation_len(generation_len, max_gen_len) + + # Execute chunked prefill + outputs = self._execute_chunked_prefill(lang_inputs, num_chunks, decode_batch_id, prefill_logit_bs) + + self._session.skip_buffers(vision_outputs) + + # Prepare position_ids for decode phase (next position after prefill) + position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 + + return outputs, position_ids_decode, generation_len + else: + # Fall back to base class for text-only + return super().run_prefill(prompt, generation_len, prefill_logit_bs, decode_batch_id) + + def _prepare_vision_language_prompt(self, text_prompt, image_path): + """ + Prepare text prompt with vision context + + This method handles the integration of vision and text inputs + according to the specific model's requirements. + """ + # For most vision-language models, we need to apply the chat template + # that includes both image and text components + try: + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": text_prompt}, + {"type": "image"}, + ], + }, + ] + + # Apply chat template + processed_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True) + + return processed_prompt + + except Exception as e: + logger.warning(f"Failed to apply chat template: {e}. Using original prompt.") + return text_prompt + + def generate( + self, images: List[str], prompts: List[str], generation_len: Optional[int] = None, stream: bool = True, **kwargs + ) -> CloudAI100ExecInfo: + """ + Main generation method maintaining API compatibility with VisionLanguageGeneration + + Args: + images: List of image URLs/paths + prompts: List of text prompts + generation_len: Max generation length + stream: Enable streaming output + **kwargs: Additional arguments passed to base class + + Returns: + CloudAI100ExecInfo with results and metrics + + Raises: + ValueError: If images and prompts lengths don't match + """ + if len(images) != len(prompts): + raise ValueError(f"Number of images ({len(images)}) must match number of prompts ({len(prompts)})") + + # Clear vision cache for fresh generation + self._vision_cache.clear() + + logger.info(f"Generating for {len(images)} image-prompt pairs") + + # Convert to base class format: list of (image, prompt) tuples + vision_prompts = [(img, prompt) for img, prompt in zip(images, prompts)] + + # Use base class generate method with vision prompts + if self.full_batch_size is not None: + # Continuous batching mode (new capability) + return self._generate_continuous_batching(vision_prompts, generation_len, stream, **kwargs) + else: + # Regular batching mode + return self._generate_regular_batching(vision_prompts, generation_len, stream, **kwargs) + + def _generate_regular_batching(self, vision_prompts, generation_len, stream, **kwargs): + """Handle regular batching for vision-language generation without creating a second language session""" + batch_results = [] + for i in range(0, len(vision_prompts), self.batch_size): + batch = vision_prompts[i : i + self.batch_size] + + if stream: + print( + f"\nProcessing batch {i // self.batch_size + 1}/{(len(vision_prompts) - 1) // self.batch_size + 1}" + ) + for j, (img, prompt) in enumerate(batch): + print(f"Image: {img}") + print(f"Prompt: {prompt}") + print("Completion:", flush=True, end="") + + # Setup decode storage arrays for this batch (use ctx_len or generation_len whichever is larger) + exec_batch_size = self.batch_size + max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len) + self.initialize_decode_inputs( + num_prompts=len(batch), execution_batch_size=exec_batch_size, max_gen_length=max_gen_length + ) + + # Prefill using VLM-aware run_prefill (batch is a list of (image, text)) + start = perf_counter() + outputs, position_ids, generation_len_final = self.run_prefill( + batch, generation_len, prefill_logit_bs=self.batch_size + ) + self.update_decode_input(outputs, position_ids, generation_len_final) + + # Prepare decode + decode_inputs = self.prepare_decode_inputs() + + # Decode loop + loop_start = perf_counter() + num_token = self.run_decode(decode_inputs, generation_len_final, automation=False, streamer=None) + end = perf_counter() + + # Decode generated texts + generated_texts = self.tokenizer.batch_decode(self.generated_ids, skip_special_tokens=True) + + # Latency metrics + total_decode_tokens = num_token + prefill_time, decode_perf, total_perf, total_time = calculate_latency( + total_decode_tokens, loop_start, start, end + ) + perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time) + + # Package result for this batch + batch_results.append( + CloudAI100ExecInfo( + batch_size=self.batch_size, + generated_texts=generated_texts, + generated_ids=self.generated_ids, + perf_metrics=perf_metrics, + ) + ) + + # Aggregate results across batches + return self._aggregate_batch_results(batch_results) + + def _generate_continuous_batching(self, vision_prompts, generation_len, stream, **kwargs): + """Enable continuous batching for vision-language models (new capability)""" + logger.info("Using continuous batching for vision-language generation") + + if stream: + logger.warning("Streaming output not fully supported with continuous batching") + + # Reset vision processing state for new generation + self._vision_processed = False + self._vision_outputs = None + self._vision_outputs_cache = {} + + # Initialize decode inputs + num_prompts = len(vision_prompts) + execution_batch_size = self.full_batch_size + max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len) + + self.initialize_decode_inputs(num_prompts, execution_batch_size, max_gen_length) + if self.is_qwen2_5_vl: + self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64) + + # Create prompt queue + prompt_queue = deque(vision_prompts) + + start = perf_counter() + + # Pre-process ALL vision inputs and cache them + logger.info("Pre-processing all vision inputs...") + for batch_id in range(min(self.full_batch_size, len(vision_prompts))): + img, prompt = vision_prompts[batch_id] + + # Process vision for this slot + lang_inputs, vision_outputs, num_chunks = self._vision_handler.get_processed_inputs( + image_url=img, query=prompt, prefill_seq_len=self._prefill_seq_len + ) + + # Cache vision outputs for this batch slot + self._vision_outputs_cache[batch_id] = { + "vision_outputs": vision_outputs, + "lang_inputs": lang_inputs, + "num_chunks": num_chunks, + } + + logger.debug(f"Cached vision outputs for batch_id {batch_id}") + + # Reset prompt queue for prefill + prompt_queue = deque(vision_prompts) + + self.batch_index = None + + # Run prefill for all inputs using cached vision + self.run_prefill_for_all_inputs_with_cached_vision(prompt_queue, generation_len) + + # Set vision buffers for decode (use first slot's vision for now) + # For identical images, any slot's vision works + cached_slot_0 = self._vision_outputs_cache.get(0) + if cached_slot_0: + self._session.set_buffers(cached_slot_0["vision_outputs"]) + logger.debug("Set vision buffers from slot 0 for decode phase") + + # Now set batch_index for decode phase + self.batch_index = np.arange(self.full_batch_size).reshape(-1, 1) + + loop_start = perf_counter() + decode_pause_time = self.run_continuous_batching_decode(prompt_queue, generation_len) + end = perf_counter() + + generated_texts = self.tokenizer.batch_decode(self.generated_ids, skip_special_tokens=True) + + total_decode_tokens = sum( + np.sum(self.generated_ids[i] != self.tokenizer.pad_token_id) - 1 for i in range(len(vision_prompts)) + ) + prefill_time, decode_perf, total_perf, total_time = calculate_latency( + total_decode_tokens, loop_start, start, end, decode_pause_time + ) + prefill_time /= len(vision_prompts) # Average prefill time for continuous batching + + perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time) + + return CloudAI100ExecInfo( + batch_size=1, generated_texts=generated_texts, generated_ids=self.generated_ids, perf_metrics=perf_metrics + ) + + def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation_len): + """ + Runs prefill for all inputs using pre-cached vision outputs. + + This avoids the vision buffer overwriting issue by using cached vision + outputs instead of processing vision during each prefill iteration. + + Args: + prompt_queue (deque): The queue of prompts. + generation_len (int): The generation length. + """ + for decode_batch_id in range(self.full_batch_size): + # Pop the promt as we are processing + _ = prompt_queue.popleft() + + # Get cached vision outputs for this batch slot + cached = self._vision_outputs_cache.get(decode_batch_id) + if cached: + vision_outputs = cached["vision_outputs"] + lang_inputs = cached["lang_inputs"] + num_chunks = cached["num_chunks"] + + # Set vision buffers for THIS prefill + self._session.set_buffers(vision_outputs) + logger.debug(f"Set vision buffers for batch_id {decode_batch_id} prefill") + + # Run prefill with cached inputs + outputs = self._execute_chunked_prefill( + lang_inputs, + num_chunks, + decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), + prefill_logit_bs=1, + ) + + self._session.skip_buffers(vision_outputs.keys()) + + # Calculate position_ids for decode + position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 + + # Calculate generation_len + max_gen_len = ( + self._ctx_len - np.where(lang_inputs["position_ids"] != -1, 1, 0).sum(1, keepdims=True).max() + ) + generation_len_final = self._fetch_generation_len(generation_len, max_gen_len) + + # Update decode inputs + if self.is_qwen2_5_vl: + self.update_decode_inputs_qwen2_5_vl( + outputs, position_ids_decode, generation_len_final, decode_batch_id + ) + else: + self.update_decode_input(outputs, position_ids_decode, generation_len_final, decode_batch_id) + else: + logger.error(f"No cached vision outputs for batch_id {decode_batch_id}") + raise RuntimeError(f"Vision outputs not cached for batch_id {decode_batch_id}") + + def prepare_decode_inputs(self): + """ + Override base class to handle vision-specific decode inputs + """ + decode_inputs = super().prepare_decode_inputs() + + # Add image_idx for vision-language models in CB mode during decode only + if self.batch_index is not None and hasattr(self, "_vision_outputs"): + # image_idx should be a single slot selector; decoder expects shape (1,1) + # Query binding dims if available to be robust + try: + if "image_idx" in getattr(self._session, "binding_index_map", {}): + idx = self._session.binding_index_map["image_idx"] + dims = tuple(self._session.bindings[idx].dims) + decode_inputs["image_idx"] = np.zeros(dims, dtype=np.int64) + else: + decode_inputs["image_idx"] = np.array([[0]], dtype=np.int64) + except Exception: + decode_inputs["image_idx"] = np.array([[0]], dtype=np.int64) + + # Include cross_attention_mask during decode if present/required + if hasattr(self, "_decode_cross_attention_mask") and self._decode_cross_attention_mask is not None: + # Decoder specialization expects a single mask (batch dim = 1) + decode_inputs["cross_attention_mask"] = self._decode_cross_attention_mask + + return decode_inputs + + def _aggregate_batch_results(self, batch_results): + """Aggregate results from multiple batches""" + if not batch_results: + raise ValueError("No batch results to aggregate") + + if len(batch_results) == 1: + return batch_results[0] + + # Aggregate multiple batch results + all_generated_texts = [] + all_generated_ids = [] + all_metrics = [] + + for result in batch_results: + if isinstance(result.generated_texts[0], list): + # Flatten nested lists + all_generated_texts.extend([text for batch in result.generated_texts for text in batch]) + else: + all_generated_texts.extend(result.generated_texts) + + if isinstance(result.generated_ids, list): + all_generated_ids.extend(result.generated_ids) + else: + all_generated_ids.append(result.generated_ids) + + all_metrics.append(result.perf_metrics) + + # Average metrics + avg_metrics = PerfMetrics( + prefill_time=np.mean([m.prefill_time for m in all_metrics]), + decode_perf=np.mean([m.decode_perf for m in all_metrics]), + total_perf=np.mean([m.total_perf for m in all_metrics]), + total_time=np.mean([m.total_time for m in all_metrics]), + ) + + return CloudAI100ExecInfo( + batch_size=batch_results[0].batch_size, + generated_texts=all_generated_texts, + generated_ids=all_generated_ids, + perf_metrics=avg_metrics, + ) + + def generate_stream_tokens( + self, images: List[str], prompts: List[str], generation_len: Optional[int] = None, **kwargs + ): + """ + Enable token-by-token streaming for vision models (new capability) + + Args: + images: List of image URLs/paths + prompts: List of text prompts + generation_len: Max generation length + **kwargs: Additional arguments + + Yields: + List of decoded tokens for each batch position + + Raises: + NotImplementedError: If continuous batching is enabled + """ + if self.full_batch_size is not None: + raise NotImplementedError("Token streaming not supported with continuous batching for VLM") + + if len(images) != len(prompts): + raise ValueError(f"Number of images ({len(images)}) must match number of prompts ({len(prompts)})") + + logger.info(f"Starting token streaming for {len(images)} image-prompt pairs") + + vision_prompts = [(img, prompt) for img, prompt in zip(images, prompts)] + + text_gen = TextGeneration( + tokenizer=self.tokenizer, + qpc_path=self._qpc_path, + ctx_len=self._ctx_len, + device_id=self.device_id, + enable_debug_logs=self.enable_debug_logs, + is_tlm=self.is_tlm, + include_sampler=self.include_sampler, + return_pdfs=self.return_pdfs, + sampling_params=self.sampling_params, + ) + + text_gen._qaic_model = self + + # Yield tokens as they're generated + for tokens in text_gen.generate_stream_tokens(vision_prompts, generation_len, **kwargs): + yield tokens + + def __repr__(self): + """String representation of the class""" + return ( + f"VisionLanguageGeneration(" + f"batch_size={self.batch_size}, " + f"ctx_len={self._ctx_len}, " + f"continuous_batching={'enabled' if self.full_batch_size else 'disabled'}, " + f"sampling={'enabled' if self.include_sampler else 'disabled'})" + ) diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 212fe16ae..b7b951101 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -820,7 +820,7 @@ def forward(self, pixel_values): ) vision_flat = image_features.view(-1, image_features.size(-1)) projected_vision_flat = self.model.multi_modal_projector(vision_flat) - return projected_vision_flat + return projected_vision_flat # , pixel_values # This wrapper utilizes the 'vision_embeds', which contains vision embeddings, and an 'image_idx' index starting at 0. @@ -836,7 +836,15 @@ def __init__(self, model): self.language_model = self.model.language_model self.config = self.model.config - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + batch_index: Optional[torch.LongTensor] = None, + ): inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) selected = input_ids == self.model.config.image_token_index indices1 = selected.to(torch.int64).cumsum(1) - 1 @@ -846,7 +854,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.model.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) @@ -893,6 +905,9 @@ def get_specializations( ctx_len: int, img_size: int, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): max_num_tiles = compiler_options.pop("max_num_tiles", None) @@ -941,28 +956,42 @@ def get_specializations( "img_size": img_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "max_num_tiles": max_num_tiles, - "img_size": img_size, - "vision_size": vision_size, - "chunk_length": prefill_seq_len, - "chunk_ctx_len": chunk_ctx_len, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "max_num_tiles": max_num_tiles, - "img_size": img_size, - "vision_size": vision_size, - "chunk_length": prefill_seq_len, - "chunk_ctx_len": chunk_ctx_len, - }, - ] + + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) specializations = {} @@ -971,18 +1000,22 @@ def get_specializations( specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["vision_embeds"] = {0: "vision_size"} + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} vision_dynamic_axes["pixel_values"] = {0: "max_num_tiles", 2: "img_size", 3: "img_size"} - pkv_dynamic_axes = {0: "batch_size"} + pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size"} for i in range(self.language_model.config.num_hidden_layers): # switch between chunk_ctx_len and ctx_len for RoPE and NoPE layers. if int((i + 1) % 4 != 0): @@ -1011,6 +1044,7 @@ def get_output_names(self, kv_offload: bool = False): output_names = {} if kv_offload: + # vision_output_names.insert(1, "pixel_values_RetainedState") lang_output_names.insert(1, "vision_embeds_RetainedState") lang_output_names.insert(2, "image_idx_output") output_names["vision"] = vision_output_names @@ -1045,7 +1079,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 336) else: @@ -1090,10 +1124,14 @@ def get_dummy_inputs(self, kv_offload: bool = False): .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV past_key_values = self.get_dummy_pkv_cache( config=self.language_model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -1102,6 +1140,8 @@ def get_dummy_inputs(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(past_key_values[0][0].shape, dtype=torch.float32)) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 633a0b29d..aeb72d858 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -14,6 +14,7 @@ import torch import torch.nn as nn from transformers import ( + AutoImageProcessor, AutoModel, AutoModelForCausalLM, AutoModelForCTC, @@ -35,6 +36,7 @@ calculate_latency, get_compilation_dims, ) +from QEfficient.generation.vlm_generation import VisionLanguageGeneration from QEfficient.transformers.modeling_utils import DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH from QEfficient.transformers.models.pytorch_transforms import ( CustomOpsTransform, @@ -856,6 +858,7 @@ class _QEffAutoModelForImageTextToTextDualQPC: def __init__( self, model: nn.Module, + continuous_batching: bool = False, **kwargs, ): """ @@ -879,6 +882,7 @@ def __init__( self.config = model.config self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) + self.continuous_batching = continuous_batching self.input_shapes, self.output_names = None, None @property @@ -978,8 +982,15 @@ def export( List[str] A list containing the paths to the generated ONNX graph files for both components. """ - inputs = self.model.get_dummy_inputs(kv_offload=True) - dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True) + # TODO This is a temporary change as continous batching is enabled only for few models. Once support is added for all the models this exception handing can be removed. + try: + inputs = self.model.get_dummy_inputs(kv_offload=True, continuous_batching=self.continuous_batching) + dynamic_axes = self.model.get_onnx_dynamic_axes( + kv_offload=True, continuous_batching=self.continuous_batching + ) + except TypeError: + inputs = self.model.get_dummy_inputs(kv_offload=True) + dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True) output_names = self.model.get_output_names(kv_offload=True) self.vision_model.export( @@ -1011,7 +1022,6 @@ def compile( num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, - num_speculative_tokens: Optional[int] = None, skip_vision: Optional[bool] = False, skip_lang: Optional[bool] = False, **compiler_options, @@ -1068,14 +1078,20 @@ def compile( If `full_batch_size`, `kv_cache_batch_size`, or `num_speculative_tokens` are not None. If both `skip_lang` and `skip_vision` are True. """ - if any(param is not None for param in [full_batch_size, kv_cache_batch_size, num_speculative_tokens]): + if skip_lang and skip_vision: + raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + + if self.continuous_batching and full_batch_size is None: + raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") + + if kv_cache_batch_size and not full_batch_size: raise ValueError( - f"Expected 'full_batch_size', 'kv_cache_batch_size', 'num_speculative_tokens' to be None but got: " - f"full_batch_size={full_batch_size}, kv_cache_batch_size={kv_cache_batch_size}, num_speculative_tokens={num_speculative_tokens}, " + "KV caching requires continuous batching. Please set `full_batch_size` and " + "enable `continuous_batching=True` in `from_pretrained`." ) - if skip_lang and skip_vision: - raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + # Infer kv_cache_batch_size if not provided + kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size output_names = self.model.get_output_names(kv_offload=True) @@ -1085,6 +1101,9 @@ def compile( ctx_len=ctx_len, img_size=img_size, kv_offload=True, + continuous_batching=self.continuous_batching, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, **compiler_options, ) @@ -1111,6 +1130,11 @@ def compile( ): self.export() + # TODO this hould be removed once the continous batching is supported for all the models. + compiler_options.pop("continuous_batching", None) + compiler_options.pop("kv_cache_batch_size", None) + compiler_options.pop("full_batch_size", None) + if not skip_vision: self.vision_model._compile( compile_dir=compile_dir, @@ -1156,7 +1180,11 @@ def compile( def generate( self, - inputs: torch.Tensor, + inputs: Optional[torch.Tensor] = None, + tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer] = None, + processor: Optional[AutoImageProcessor] = None, + images: List[str] = None, + prompts: List[str] = None, streamer: Optional[TextStreamer] = None, device_ids: List[int] = None, runtime_ai100: bool = True, @@ -1172,6 +1200,14 @@ def generate( inputs : Dict[str, Union[torch.Tensor, np.ndarray]] Inputs to run the execution, typically includes `pixel_values`, `input_ids`, `attention_mask`, etc. + tokenizer : PreTrainedTokenizer or PreTrainedTokenizerFast, optional + Tokenizer for the model. Used when images and prompts are provided. + processor : AutoImageProcessor, optional + Processor for the model. Used when images and prompts are provided. + images : List[str], optional + List of image paths or PIL images to process. + prompts : List[str], optional + List of text prompts corresponding to the images. streamer : TextStreamer, optional A streamer object to display generated tokens in real-time. Default is None. device_ids : List[int], optional @@ -1196,6 +1232,30 @@ def generate( if not runtime_ai100: raise NotImplementedError("PyTorch execution is not supported yet for this model!") + # Use VisionLanguageGeneration for image-prompt pairs + if (processor and images) or (tokenizer and prompts): + # Create VisionLanguageGeneration instance + batch_size_comp, ctx_len_comp, fbs = get_compilation_dims(self.lang_model.qpc_path) + vlm_gen = VisionLanguageGeneration( + qeff_model=self, + lang_qpc_path=self.lang_model.qpc_path, + vision_qpc_path=self.vision_model.qpc_path, + tokenizer=tokenizer, + processor=processor, + device_id=device_ids, # if device_ids is not None else [0], + ctx_len=ctx_len_comp, + full_batch_size=fbs, + ) + + # Call generate method + return vlm_gen.generate( + images=images, + prompts=prompts, + generation_len=generation_len, + stream=streamer is not None, + ) + + # Fallback to kv_offload_generate for direct inputs (backward compatibility) return self.kv_offload_generate( inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len ) @@ -1332,9 +1392,7 @@ def kv_offload_generate( lang_session.set_buffers(vision_outputs) - # Prepare inputs for prefill - chunk_inputs = lang_inputs.copy() - prefill_start = perf_counter() + lang_start = perf_counter() # Run prefill chunk_inputs = lang_inputs.copy() @@ -1346,7 +1404,7 @@ def kv_offload_generate( outputs = lang_session.run(chunk_inputs) chunk_inputs["image_idx"] = outputs["image_idx_output"] - prefill_time = perf_counter() - prefill_start + vision_end - vision_start + prefill_time = perf_counter() - lang_start + vision_end - vision_start # Skip inputs/outputs again lang_session.skip_buffers( [ @@ -1930,7 +1988,7 @@ class QEFFAutoModelForImageTextToText: _hf_auto_class = AutoModelForImageTextToText - def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs): + def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, continuous_batching: bool = False, **kwargs): """ Instantiate the appropriate internal class for single or dual QPC mode. @@ -1951,13 +2009,19 @@ def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs) The wrapped model instance, configured for either dual or single QPC. """ if kv_offload: - return _QEffAutoModelForImageTextToTextDualQPC(model, **kwargs) + return _QEffAutoModelForImageTextToTextDualQPC(model, continuous_batching, **kwargs) else: return _QEFFAutoModelForImageTextToTextSingleQPC(model, **kwargs) @classmethod @with_replaced_quantizers - def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, **kwargs): + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + kv_offload: Optional[bool] = None, + continuous_batching: bool = False, + **kwargs, + ): """ Load a QEfficient image-text-to-text model from a pretrained HuggingFace model or local path. @@ -1986,18 +2050,24 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona If `continuous_batching` is provided as True. """ # TODO: add a check to see if kv_offload is allowed for given model by loading the config and checking architecture or type of config here. + if continuous_batching and not kv_offload: + NotImplementedError("Continuous batching is not supported for kv_offload = False") + if kwargs.get("attn_implementation", None) not in {None, "eager"}: logger.warning('Updating attn_implementation="eager"') if kwargs.get("low_cpu_mem_usage", None): logger.warning("Updating low_cpu_mem_usage=False") - if kwargs.pop("continuous_batching", None): - NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") - kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) - return cls(model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + return cls( + model, + kv_offload=kv_offload, + continuous_batching=continuous_batching, + pretrained_model_name_or_path=pretrained_model_name_or_path, + **kwargs, + ) MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = { @@ -2705,8 +2775,8 @@ def generate( raise TypeError("Please run compile API first!") generation_len = kwargs.pop("generation_len", None) return QEfficient.cloud_ai_100_exec_kv( - tokenizer, - self.qpc_path, + tokenizer=tokenizer, + qpc_path=self.qpc_path, prompt=prompts, device_id=device_id, generation_len=generation_len, diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index e5e842e6f..0f6630210 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- import math +import os from typing import Callable, List, Optional, Tuple, Union import torch @@ -360,7 +361,7 @@ def forward(self, x, seq_len=None): ) -def eager_attention_forward( +def eager_attention_forward_q_blocked( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -368,22 +369,107 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], **kwargs, ): + """ + Q-blocked attention for Qwen2.5-VL. + Blocks only the query SL dimension. + + Args: + query: (BS, NH, Q_LEN, DH) + key: (BS, NH_KV, KV_LEN, DH) + value: (BS, NH_KV, KV_LEN, DH) + attention_mask: (BS, NH, Q_LEN, KV_LEN) or broadcastable + """ + BS, NH, Q_LEN, DH = query.shape + _, _, KV_LEN, _ = key.shape + key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) - if attention_mask is not None: - attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights - ) + target_blocks_q = int(os.environ.get("num_q_blocks", Q_LEN)) + q_block_positions = [(i * Q_LEN) // target_blocks_q for i in range(target_blocks_q)] + scaling = 1.0 / math.sqrt(module.head_dim) + + q_output_blocks = [] + q_attn_weights_blocks = [] - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_output = torch.matmul(attn_weights, value_states) + # Process each Q block + for q_block_idx in range(target_blocks_q): + qi = q_block_positions[q_block_idx] + + # Calculate Q block size + if q_block_idx == target_blocks_q - 1: + real_q_len = Q_LEN - qi + else: + real_q_len = q_block_positions[q_block_idx + 1] - qi + + # Extract Q block + q_block = query[:, :, qi : qi + real_q_len, :] + attn_mask_block = None + if attention_mask is not None: + attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :] + + # Compute attention scores for this Q block + attn_weights = torch.matmul(q_block, key_states.transpose(2, 3)) * scaling + if attn_mask_block is not None: + attn_weights = torch.where( + attn_mask_block, + torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32, device=attn_weights.device), + attn_weights, + ) + + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # Compute output for this Q block + output_block = torch.matmul(attn_weights, value_states) + + q_output_blocks.append(output_block) + q_attn_weights_blocks.append(attn_weights) + + attn_output = torch.cat(q_output_blocks, dim=2) attn_output = attn_output.transpose(1, 2).contiguous() + # Concatenate attention weights + attn_weights = torch.cat(q_attn_weights_blocks, dim=2) + return attn_output, attn_weights +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + **kwargs, +): + """ + Wrapper that routes to blocked or default attention based on environment variable. + """ + blocking_mode = os.environ.get("ATTENTION_BLOCKING_MODE", "default").lower() + + if blocking_mode == "q": + return eager_attention_forward_q_blocked(module, query, key, value, attention_mask, **kwargs) + elif blocking_mode == "default": + # Original implementation + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) + + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + else: + raise ValueError(f"Invalid ATTENTION_BLOCKING_MODE: {blocking_mode}. Must be 'q' or 'default'") + + class QEffQwen2_5_VLAttention(Qwen2_5_VLAttention): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -680,7 +766,15 @@ def __init__(self, model): self.model = model self.language_model = self.model.model.language_model - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + batch_index: Optional[torch.LongTensor] = None, + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape selected = input_ids == self.model.config.image_token_id @@ -691,7 +785,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) outputs = self.model.model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + use_cache=True, ) logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) @@ -709,7 +807,7 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffQwen_2_5_vl_DecoderWrapper(self) - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False, **kwargs): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) @@ -745,10 +843,14 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): .repeat(4, 1, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV kv_cache_shape = get_padding_shape_from_config( - config=self.model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + config=self.model.config.text_config, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -757,6 +859,9 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -775,7 +880,11 @@ def get_specializations( img_size: None, height: int = None, width: int = None, + num_frames: int = 1, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): if height is None or width is None: @@ -856,20 +965,37 @@ def smart_resize( "grid_w": grid_w, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "vision_size": vision_size, - }, - ] + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) specializations = {} @@ -878,9 +1004,11 @@ def smart_resize( specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers @@ -892,12 +1020,21 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {1: "batch_size", 2: "seq_len"}, - "vision_embeds": {0: "batch_size", 1: "vision_size"}, + "vision_embeds": {0: "vision_batch_size", 1: "vision_size"}, } for i in range(num_layers): - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_key.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + lang_dynamic_axes[f"past_value.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} dynamic_axes = {} diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index e487d4af4..49f0ad30b 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -10,6 +10,7 @@ undo_transformers_quantizers, ) from QEfficient.utils._utils import ( # noqa: F401 + LRUCache, check_and_assign_cache_dir, create_json, create_model_params, diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index abe383556..d58f54952 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -33,6 +33,36 @@ from QEfficient.utils.logging_utils import logger +class LRUCache: + """Simple LRU cache with size limit for vision outputs""" + + def __init__(self, max_size=100): + self._cache = {} + self._access_order = [] + self._max_size = max_size + + def get(self, key): + if key in self._cache: + self._access_order.remove(key) + self._access_order.append(key) + return self._cache[key] + return None + + def put(self, key, value): + if key in self._cache: + self._access_order.remove(key) + elif len(self._cache) >= self._max_size: + oldest = self._access_order.pop(0) + del self._cache[oldest] + + self._cache[key] = value + self._access_order.append(key) + + def clear(self): + self._cache.clear() + self._access_order.clear() + + class DownloadRetryLimitExceeded(Exception): """ Used for raising error when hf_download fails to download the model after given max_retries. diff --git a/examples/llama4_CB_example_vision_lang.py b/examples/llama4_CB_example_vision_lang.py new file mode 100644 index 000000000..f285ea278 --- /dev/null +++ b/examples/llama4_CB_example_vision_lang.py @@ -0,0 +1,93 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import transformers +from transformers import AutoConfig, AutoProcessor + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config.text_config.num_hidden_layers = 4 +config.vision_config.num_hidden_layers = 2 + +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +continious_batching = False +if continious_batching: + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, + ) + + qeff_model.compile( + prefill_seq_len=128, + ctx_len=3072, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + batch_size=1, + full_batch_size=4, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) +else: + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + ) + + qeff_model.compile( + prefill_seq_len=128, + ctx_len=3072, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + batch_size=1, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + +image_urls = [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +exec_info = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + device_ids=[0, 1, 2, 3], + generation_len=100, +) + +# print("Generated texts:", exec_info.generated_texts) +print("Generated IDs:", exec_info.generated_ids) +print(exec_info) diff --git a/examples/qwen2_5_vl_CB.py b/examples/qwen2_5_vl_CB.py new file mode 100644 index 000000000..96ef4898a --- /dev/null +++ b/examples/qwen2_5_vl_CB.py @@ -0,0 +1,72 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +# If we want to enable QBlocking Run below command:, default is without blocking +# ATTENTION_BLOCKING_MODE=q num_q_blocks=2 python -W ignore qwen2_5_vl_example.py + +import transformers +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +## For AWQ model update pytorch version to 2.8.* +model_id = "Qwen/Qwen2.5-VL-32B-Instruct" +config = AutoConfig.from_pretrained(model_id) +config.text_config.num_hidden_layers = 2 + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +batch_size = 1 +## Vision + Text ## +qeff_model.compile( + batch_size=batch_size, + full_batch_size=4, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, +) + +image_urls = [ + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +streamer = TextStreamer(tokenizer) +output = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + generation_len=100, +) +print(output.generated_ids) +print(tokenizer.batch_decode(output.generated_ids)) +print(output) diff --git a/examples/qwen2_5_vl_example.py b/examples/qwen2_5_vl_example.py index 374f70ad2..d5d943c9c 100644 --- a/examples/qwen2_5_vl_example.py +++ b/examples/qwen2_5_vl_example.py @@ -5,6 +5,9 @@ # # ----------------------------------------------------------------------------- +# If we want to enable QBlocking Run below command:, default is without blocking +# ATTENTION_BLOCKING_MODE=q num_q_blocks=2 python -W ignore qwen2_5_vl_example.py + import requests import transformers from PIL import Image