diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index be4b86321..33c6f5588 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -8,15 +8,16 @@ import os import warnings +import QEfficient.utils.model_registery # noqa: F401 from QEfficient.utils import custom_format_warning +from QEfficient.utils.logging_utils import logger # For faster downloads via hf_transfer # This code is put above import statements as this needs to be executed before # hf_transfer is imported (will happen on line 15 via leading imports) os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # Placeholder for all non-transformer models registered in QEfficient -import QEfficient.utils.model_registery # noqa: F401 -from QEfficient.utils.logging_utils import logger + # custom warning for the better logging experience warnings.formatwarning = custom_format_warning @@ -43,6 +44,7 @@ def check_qaic_sdk(): from QEfficient.base import ( QEFFAutoModel, QEFFAutoModelForCausalLM, + QEFFAutoModelForCTC, QEFFAutoModelForImageTextToText, QEFFAutoModelForSpeechSeq2Seq, QEFFCommonLoader, @@ -63,6 +65,7 @@ def check_qaic_sdk(): "cloud_ai_100_exec_kv", "QEFFAutoModel", "QEFFAutoModelForCausalLM", + "QEFFAutoModelForCTC", "QEffAutoPeftModelForCausalLM", "QEFFAutoModelForImageTextToText", "QEFFAutoModelForSpeechSeq2Seq", diff --git a/QEfficient/base/__init__.py b/QEfficient/base/__init__.py index d29ca7d29..d106a0759 100644 --- a/QEfficient/base/__init__.py +++ b/QEfficient/base/__init__.py @@ -9,6 +9,7 @@ from QEfficient.transformers.models.modeling_auto import ( # noqa: F401 QEFFAutoModel, QEFFAutoModelForCausalLM, + QEFFAutoModelForCTC, QEFFAutoModelForImageTextToText, QEFFAutoModelForSpeechSeq2Seq, ) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 564cd0a5f..fef550c87 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -16,6 +16,7 @@ from transformers import ( AutoModel, AutoModelForCausalLM, + AutoModelForCTC, AutoModelForImageTextToText, AutoModelForSpeechSeq2Seq, PreTrainedTokenizer, @@ -2123,3 +2124,287 @@ def generate( generated_ids=generated_ids, perf_metrics=PerfMetrics(prefill_time, decode_perf, total_perf, total_time), ) + + +class QEFFAutoModelForCTC(QEFFTransformersBase): + """ + The QEFFAutoModelForCTC class is designed for transformer models with a Connectionist Temporal Classification (CTC) speech-to-text head, + including Wav2Vec2 and other encoder-only speech models optimized for alignment-free transcription. + Although it is possible to initialize the class directly, we highly recommend using the ``from_pretrained`` method for initialization. + + ``Mandatory`` Args: + :model (nn.Module): PyTorch model + + .. code-block:: python + import torchaudio + from QEfficient import QEFFAutoModelForCTC + from transformers import AutoProcessor + + # Initialize the model using from_pretrained similar to transformers.AutoModelForCTC. + model=QEFFAutoModelForCTC.from_pretrained(model_name) + + # Now you can directly compile the model for Cloud AI 100 + model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU + + #prepare input + processor = AutoProcessor.from_pretrained(model_name) + input_audio, sample_rate = [...] # audio data loaded in via some external audio package, such as librosa or soundfile + + # Resample the input_audio if necessary + if input_audio.shape[0] > 1: + input_audio = input_audio.mean(dim=0) + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) + input_audio = resampler(input_audio) + + # You can now execute the model + out = model.generate(processor,inputs=input_audio) + """ + + _hf_auto_class = AutoModelForCTC + _pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + def __init__(self, model: nn.Module, **kwargs): + super().__init__(model, **kwargs) + self.model.base_model.config.use_cache = True + + self.hash_params["qeff_auto_class"] = self.__class__.__name__ + + @classmethod + @with_replaced_quantizers + def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **kwargs): + """ + This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCTC. + Once the model is initialized, you can use other methods such as export, compile, and generate on the same object. + + Args: + pretrained_model_name_or_path (str): The name or path of the pre-trained model. + + .. code-block:: python + + import torchaudio + from QEfficient import QEFFAutoModelForCTC + from transformers import AutoProcessor + + # Initialize the model using from_pretrained similar to transformers.AutoModelForCTC. + model=QEFFAutoModelForCTC.from_pretrained(model_name) + + # Now you can directly compile the model for Cloud AI 100 + model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU + + #prepare input + processor = AutoProcessor.from_pretrained(model_name) + input_audio, sample_rate = [...] # audio data loaded in via some external audio package, such as librosa or soundfile + + # Resample the input_audio if necessary + if input_audio.shape[0] > 1: + input_audio = input_audio.mean(dim=0) + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) + input_audio = resampler(input_audio) + + # You can now execute the model + out = model.generate(processor,inputs=input_audio) + """ + if kwargs.get("attn_implementation", None) not in {None, "eager"}: + logger.warning('Updating attn_implementation="eager"') + + if kwargs.get("low_cpu_mem_usage", None): + logger.warning("Updating low_cpu_mem_usage=False") + + kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + + # This is support models that should be classified to in a different auto class but transformers load them via this class + kv_offload = kwargs.pop("kv_offload", None) + if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: + return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( + model, kv_offload=kv_offload, **kwargs + ) + + return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, pooling=pooling, **kwargs) + + @property + def get_model_config(self) -> dict: + return self.model.config.__dict__ + + def export(self, export_dir: Optional[str] = None) -> str: + """ + Exports the model to ``ONNX`` format using ``torch.onnx.export``. + + ``Optional`` Args: + :export_dir (str, optional): The directory path to store ONNX-graph. + + Returns: + :str: Path of the generated ``ONNX`` graph. + """ + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + seq_len = constants.WAV2VEC2_MAX_SEQ_LEN + + example_inputs = { + "input_values": torch.zeros((bs, seq_len), dtype=torch.float32), + } + + dynamic_axes = {"input_values": {0: "batch_size", 1: "seq_len"}} + + output_names = ["logits"] + + return self._export( + example_inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + ) + + def compile( + self, + onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + *, + seq_len: Union[int, List[int]] = 480000, + batch_size: int = 1, + num_devices: int = 1, + num_cores: int = 16, # FIXME: Make this mandatory arg + mxfp6_matmul: bool = False, + **compiler_options, + ) -> str: + """ + This method compiles the exported ``ONNX`` model using the Cloud AI 100 Platform SDK compiler binary found at ``/opt/qti-aic/exec/qaic-exec`` and generates a ``qpc`` package. + If the model has not been exported yet, this method will handle the export process. + You can pass any other arguments that the `qaic-exec` takes as extra kwargs. + + ``Optional`` Args: + :onnx_path (str, optional): Path to pre-exported onnx model. + :compile_dir (str, optional): Path for saving the qpc generated. + :seq_len (Union[int, List[int]]): The length of the prompt should be less that ``seq_len``. ``Defaults to 32``. + :batch_size (int, optional): Batch size. ``Defaults to 1``. + :num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1. + :num_cores (int): Number of cores used to compile the model. + :mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to False``. + :compiler_options (dict, optional): Additional compiler options. + + For QAIC Compiler: Extra arguments for qaic-exec can be passed. + :aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``. + :allow_mxint8_mdp_io (bool, optional): Allows MXINT8 compression of MDP IO traffic. ``Defaults to False.`` + + Params are converted to flags as below: + + - aic_hw_version=ai100 -> -aic-hw-version=ai100 + - aic_hw_version=ai200 -> -aic-hw-version=ai200 + + For QNN Compiler: Following arguments can be passed. + :enable_qnn (bool): Enables QNN Compilation. + :qnn_config (str): Path of QNN Config parameters file. Any extra parameters for QNN compilation can be passed via this file. + + Returns: + :str: Path of the compiled ``qpc`` package. + """ + + specializations = [ + {"batch_size": batch_size, "seq_len": sl} for sl in (seq_len if isinstance(seq_len, list) else [seq_len]) + ] + + return self._compile( + onnx_path=onnx_path, + compile_dir=compile_dir, + compile_only=True, + specializations=specializations, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + **compiler_options, + ) + + def generate( + self, + processor, + inputs: torch.Tensor, + device_ids: List[int] = None, + runtime_ai100: bool = True, + ) -> Union[torch.Tensor, np.ndarray]: + """ + This method generates output by executing PyTorch runtime or the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. + ``Mandatory`` Args: + :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. + :processor (AutoProcessor): The Processor to use for encoding the waveform. + ``optional`` Args: + :device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model + :runtime_ai100 (bool, optional): ``AI_100`` and ``PyTorch`` runtime is supported as of now. Defaults to ``True`` for ``AI_100`` runtime. + Returns: + :dict: Output from the ``AI_100`` or ``PyTorch`` runtime. + """ + # AI_100 runtime + if runtime_ai100: + if not isinstance(self.qpc_path, Path): + raise TypeError("Please run compile API first!") + + return self.cloud_ai_100_feature_generate(processor, inputs=inputs, device_ids=device_ids) + # PyTorch runtime + else: + return self.pytorch_feature_generate(processor, model=self.model, inputs=inputs) + + def cloud_ai_100_feature_generate( + self, + processor, + inputs: torch.Tensor, + device_ids: List[int] = [0], + ) -> np.ndarray: + """ + Generates features with list of prompts using AI 100 runtime. + + ``Mandatory`` Args: + :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. + :processor (AutoProcessor): The Processor to use for encoding the waveform. + ``Optional`` Args: + device_ids (List[int], optional): A list of device IDs to use for the session. Defaults to [0]. + + """ + + if self.qpc_session is None: + self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids) + self.batch_size = self.qpc_session.bindings[0].dims[0] + + # Dynamic switching to closest seq_Len based on input_ids_len + inputs = processor(inputs, return_tensors="pt") + input_ids_len = inputs["input_values"].shape[-1] + + for allowed_shape in self.qpc_session.allowed_shapes: + seq_len_allowed = allowed_shape[1][1][1] + + if seq_len_allowed >= input_ids_len: + self.seq_len = seq_len_allowed + break + + # To handle single seq_len as we can't fetch allowed shapes for single seq_len + self.seq_len = self.qpc_session.bindings[0].dims[1] if not hasattr(self, "seq_len") else self.seq_len + input_values = np.array( + torch.nn.functional.pad(inputs["input_values"], (0, self.seq_len - input_ids_len), "constant", 0) + ) + inputs = dict(input_values=input_values) + outputs = self.qpc_session.run(inputs) + logits = outputs["logits"] + predicted_ids = np.argmax(logits, axis=-1) + transcriptions = processor.batch_decode(torch.tensor(predicted_ids)) + return transcriptions + + def pytorch_feature_generate(self, processor, model, inputs: Union[torch.Tensor, np.ndarray]) -> List[torch.Tensor]: + """ + Generates features from a list of text prompts using a PyTorch model. + + ``Mandatory`` Args: + :model: The transformed PyTorch model used for generating features. + :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. + :processor (AutoProcessor): The Processor to use for encoding the waveform. + + """ + input_values = processor( + inputs[0], return_tensors="pt", max_length=self.seq_len, truncation=True, padding="max_length" + ).input_values + logits = model(input_values[0]).logits + logits = logits.detach().numpy() + predicted_ids = np.argmax(logits, axis=-1) + transcriptions = processor.batch_decode(predicted_ids) + return transcriptions diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 8228b7c0e..a5ccfeab7 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -119,6 +119,9 @@ def get_models_dir(): # Gemma3 Constant GEMMA3_MAX_POSITION_EMBEDDINGS = 32768 +# Wav2Vec2 Constant +WAV2VEC2_MAX_SEQ_LEN = 480000 # 30 seconds of audio at 16 kHz sampling rate (16,000 samples/sec × 30 sec) + class Constants: # Export Constants. diff --git a/examples/wav2vec2_example/README.md b/examples/wav2vec2_example/README.md new file mode 100644 index 000000000..fba8d9ad2 --- /dev/null +++ b/examples/wav2vec2_example/README.md @@ -0,0 +1,21 @@ +# Speech Recognition with Wav2Vec2 +This directory contains an example script of how to use the AutoModelForCTC class. (for now, Wav2Vec2 models on audio <30 seconds only has been validated) + +## Required packages: +- `librosa==0.10.2` +- `soundfile==0.13.1` + +You can install them using pip: +```sh +pip install librosa==0.10.2 soundfile==0.13.1 +``` + +To run example script after package installations: +```sh +python run_wav2vec2_inference.py +``` + +Expected output for given data sample: +```sh +MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL +``` \ No newline at end of file diff --git a/examples/wav2vec2_example/run_wav2vec2_inference.py b/examples/wav2vec2_example/run_wav2vec2_inference.py new file mode 100644 index 000000000..961aabeb8 --- /dev/null +++ b/examples/wav2vec2_example/run_wav2vec2_inference.py @@ -0,0 +1,24 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from datasets import load_dataset +from transformers import AutoProcessor + +from QEfficient import QEFFAutoModelForCTC + +base_model_name = "facebook/wav2vec2-base-960h" + +ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") +data = ds[0]["audio"]["array"] +# reshape to so shape corresponds to data with batch size 1 +data = data.reshape(-1) +sample_rate = ds[0]["audio"]["sampling_rate"] +processor = AutoProcessor.from_pretrained(base_model_name) + +model = QEFFAutoModelForCTC.from_pretrained(base_model_name) +model.compile(num_cores=16) +print(model.generate(processor, inputs=data)) diff --git a/tests/transformers/models/test_audio_embedding_models.py b/tests/transformers/models/test_audio_embedding_models.py new file mode 100644 index 000000000..da30c76b0 --- /dev/null +++ b/tests/transformers/models/test_audio_embedding_models.py @@ -0,0 +1,202 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import os +from typing import List, Optional + +import numpy as np +import onnx +import onnxruntime +import pytest +import torch +from datasets import load_dataset +from transformers import AutoModelForCTC, AutoProcessor + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCTC +from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers +from QEfficient.utils import hf_download +from QEfficient.utils._utils import create_json, load_hf_processor +from QEfficient.utils.constants import WAV2VEC2_MAX_SEQ_LEN, QnnConstants +from QEfficient.utils.device_utils import get_available_device_id + +test_models = [ + "facebook/wav2vec2-base-960h", +] + + +def load_ctc_model(model_config): + """ + Function to load model from huggingface + -------- + + :model_config: Dict + + :return model_hf, params + """ + model_path = hf_download( + repo_id=model_config["model_name"], + ignore_patterns=["*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf", "*.h5", "*.msgpack"], + ) + model_hf = AutoModelForCTC.from_pretrained( + model_path, + attn_implementation="eager", + low_cpu_mem_usage=False, + ) # Run models for single layers only + params = sum(p.numel() for p in model_hf.parameters()) + model_hf.eval() + return model_hf, params + + +def run_ctc_pytorch_hf(model, processor: AutoProcessor, inputs: np.ndarray, sample_rate: int) -> List[str]: + """ + Run pytorch inference on model + + ``Mandatory`` Args: + :model: The transformed PyTorch model used for generating transcripts + :processor: autoprocessor to process inputs and decode logits + :inputs (np.ndarray): inputs to run the execution. + :sample_rate (int): the sample rate for the audio file + + + + Returns: + torch.Tensor: A list of output features generated by the model for each prompt. + """ + seq_len = WAV2VEC2_MAX_SEQ_LEN + + # prepare inputs + input_features = processor( + inputs[0], return_tensors="pt", max_length=seq_len, truncating=True, padding="max_length" + ).input_values + + model_inputs = dict( + input_values=input_features, + ) + outputs = torch.tensor(model(**model_inputs).logits) + return outputs + + +def run_ctc_ort(onnx_path, config, processor: AutoProcessor, inputs: np.ndarray, sample_rate: int) -> List[str]: + """ + Run onnxruntime inference on model + + ``Mandatory`` Args: + :model: The transformed PyTorch model used for generating transcripts + :processor: autoprocessor to process inputs and decode logits + :inputs (np.ndarray): inputs to run the execution. + :sample_rate (int): sampling rate at which input audio is stored in inputs (needed for processor) + + Returns: + torch.Tensor: A list of output features generated by the model for each prompt. + """ + seq_len = 480000 + + # Replace invalid index value for INT32 max to 0 using add_initializer + m = onnx.load(onnx_path, load_external_data=False) + # NOTE: OrtValue objects should be kept around until the session is run, hence this dict is required + added_initializers = {} + for node in m.graph.node: + if node.op_type == "Constant": + np_tensor = onnx.numpy_helper.to_array(node.attribute[0].t, os.path.dirname(onnx_path)) + if len(np_tensor.shape) == 0 and np_tensor.item() == 2147483647: + added_initializers[node.output[0]] = onnxruntime.OrtValue.ortvalue_from_numpy( + np.array(0, np_tensor.dtype) + ) + + session_options = onnxruntime.SessionOptions() + for name, value in added_initializers.items(): + session_options.add_initializer(name, value) + + session = onnxruntime.InferenceSession(onnx_path, session_options) + + # prepare inputs + input_features = processor( + inputs[0], return_tensors="pt", max_length=seq_len, truncation=True, padding="max_length" + ).input_values + + model_inputs = dict(input_values=(input_features).numpy()) + outputs = session.run(None, model_inputs) + logits = torch.tensor(outputs[0]) + return logits + + +def check_ctc_pytorch_vs_kv_vs_ort_vs_ai100( + model_name: str, + n_layer: int = 1, + enable_qnn: Optional[bool] = False, + qnn_config: Optional[str] = None, +): + """ + Validate the PyTorch model, the PyTorch model after ONNX model and the Cloud AI 100 model + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``whisper`` + :n_layers (int): Number of layers for the Model. + """ + replace_transformers_quantizers() + model_config = {"model_name": model_name} + model_config["n_layer"] = n_layer + + model_hf, _ = load_ctc_model(model_config) + + processor = load_hf_processor(pretrained_model_name_or_path=model_name) + batch_size = 1 + + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + data = ds[0]["audio"]["array"] + data = torch.tensor(data).unsqueeze(0).numpy() + sample_rate = ds[0]["audio"]["sampling_rate"] + pytorch_tokens = run_ctc_pytorch_hf(model_hf, processor, data, sample_rate) + predicted_ids = torch.argmax(pytorch_tokens, dim=-1) + pytorch_output = processor.batch_decode(predicted_ids) + + qeff_model = QEFFAutoModelForCTC(model_hf, pretrained_model_name_or_path=model_name) + qeff_model.export() + ort_tokens = run_ctc_ort(qeff_model.onnx_path, qeff_model.model.config, processor, data, sample_rate) + predicted_ids = torch.argmax(ort_tokens, dim=-1) + ort_output = processor.batch_decode(predicted_ids) + assert pytorch_output == ort_output, "Tokens don't match for pytorch output and ORT output." + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + qeff_model.compile( + num_cores=16, + batch_size=batch_size, + enable_qnn=enable_qnn, + qnn_config=qnn_config, + ) + cloud_ai_100_output = qeff_model.generate(processor, data) + assert pytorch_output == cloud_ai_100_output, "Tokens don't match for pytorch output and Cloud AI 100 output." + assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json")) + + +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_name", test_models) +def test_ctc_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model, the PyTorch model the ONNX model, and the Cloud AI 100 model. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + check_ctc_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=4) + + +@pytest.mark.on_qaic +@pytest.mark.qnn +@pytest.mark.skip(reason="Wav2Vec2 is currently not supported on QNN") +@pytest.mark.parametrize("model_name", test_models) +def test_ctc_pytorch_vs_kv_vs_ort_vs_ai100_qnn(model_name): + """ + QNN Compilation path test. + Test function to validate the PyTorch model, the PyTorch model after the ONNX model, and the Cloud AI 100 model. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json") + create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG) + + check_ctc_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, n_layer=4, enable_qnn=True, qnn_config=qnn_config_json_path + )