Skip to content
7 changes: 5 additions & 2 deletions QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
import os
import warnings

import QEfficient.utils.model_registery # noqa: F401
from QEfficient.utils import custom_format_warning
from QEfficient.utils.logging_utils import logger

# For faster downloads via hf_transfer
# This code is put above import statements as this needs to be executed before
# hf_transfer is imported (will happen on line 15 via leading imports)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Placeholder for all non-transformer models registered in QEfficient
import QEfficient.utils.model_registery # noqa: F401
from QEfficient.utils.logging_utils import logger


# custom warning for the better logging experience
warnings.formatwarning = custom_format_warning
Expand All @@ -43,6 +44,7 @@ def check_qaic_sdk():
from QEfficient.base import (
QEFFAutoModel,
QEFFAutoModelForCausalLM,
QEFFAutoModelForCTC,
QEFFAutoModelForImageTextToText,
QEFFAutoModelForSpeechSeq2Seq,
QEFFCommonLoader,
Expand All @@ -63,6 +65,7 @@ def check_qaic_sdk():
"cloud_ai_100_exec_kv",
"QEFFAutoModel",
"QEFFAutoModelForCausalLM",
"QEFFAutoModelForCTC",
"QEffAutoPeftModelForCausalLM",
"QEFFAutoModelForImageTextToText",
"QEFFAutoModelForSpeechSeq2Seq",
Expand Down
1 change: 1 addition & 0 deletions QEfficient/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from QEfficient.transformers.models.modeling_auto import ( # noqa: F401
QEFFAutoModel,
QEFFAutoModelForCausalLM,
QEFFAutoModelForCTC,
QEFFAutoModelForImageTextToText,
QEFFAutoModelForSpeechSeq2Seq,
)
285 changes: 285 additions & 0 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForCTC,
AutoModelForImageTextToText,
AutoModelForSpeechSeq2Seq,
PreTrainedTokenizer,
Expand Down Expand Up @@ -2123,3 +2124,287 @@ def generate(
generated_ids=generated_ids,
perf_metrics=PerfMetrics(prefill_time, decode_perf, total_perf, total_time),
)


class QEFFAutoModelForCTC(QEFFTransformersBase):
"""
The QEFFAutoModelForCTC class is designed for transformer models with a Connectionist Temporal Classification (CTC) speech-to-text head,
including Wav2Vec2 and other encoder-only speech models optimized for alignment-free transcription.
Although it is possible to initialize the class directly, we highly recommend using the ``from_pretrained`` method for initialization.

``Mandatory`` Args:
:model (nn.Module): PyTorch model

.. code-block:: python
import torchaudio
from QEfficient import QEFFAutoModelForCTC
from transformers import AutoProcessor

# Initialize the model using from_pretrained similar to transformers.AutoModelForCTC.
model=QEFFAutoModelForCTC.from_pretrained(model_name)

# Now you can directly compile the model for Cloud AI 100
model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU

#prepare input
processor = AutoProcessor.from_pretrained(model_name)
input_audio, sample_rate = [...] # audio data loaded in via some external audio package, such as librosa or soundfile

# Resample the input_audio if necessary
if input_audio.shape[0] > 1:
input_audio = input_audio.mean(dim=0)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
input_audio = resampler(input_audio)

# You can now execute the model
out = model.generate(processor,inputs=input_audio)
"""

_hf_auto_class = AutoModelForCTC
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model: nn.Module, **kwargs):
super().__init__(model, **kwargs)
self.model.base_model.config.use_cache = True

self.hash_params["qeff_auto_class"] = self.__class__.__name__

@classmethod
@with_replaced_quantizers
def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **kwargs):
"""
This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCTC.
Once the model is initialized, you can use other methods such as export, compile, and generate on the same object.

Args:
pretrained_model_name_or_path (str): The name or path of the pre-trained model.

.. code-block:: python

import torchaudio
from QEfficient import QEFFAutoModelForCTC
from transformers import AutoProcessor

# Initialize the model using from_pretrained similar to transformers.AutoModelForCTC.
model=QEFFAutoModelForCTC.from_pretrained(model_name)

# Now you can directly compile the model for Cloud AI 100
model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU

#prepare input
processor = AutoProcessor.from_pretrained(model_name)
input_audio, sample_rate = [...] # audio data loaded in via some external audio package, such as librosa or soundfile

# Resample the input_audio if necessary
if input_audio.shape[0] > 1:
input_audio = input_audio.mean(dim=0)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
input_audio = resampler(input_audio)

# You can now execute the model
out = model.generate(processor,inputs=input_audio)
"""
if kwargs.get("attn_implementation", None) not in {None, "eager"}:
logger.warning('Updating attn_implementation="eager"')

if kwargs.get("low_cpu_mem_usage", None):
logger.warning("Updating low_cpu_mem_usage=False")

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})

model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)

# This is support models that should be classified to in a different auto class but transformers load them via this class
kv_offload = kwargs.pop("kv_offload", None)
if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP:
return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__](
model, kv_offload=kv_offload, **kwargs
)

return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, pooling=pooling, **kwargs)

@property
def get_model_config(self) -> dict:
return self.model.config.__dict__

def export(self, export_dir: Optional[str] = None) -> str:
"""
Exports the model to ``ONNX`` format using ``torch.onnx.export``.

``Optional`` Args:
:export_dir (str, optional): The directory path to store ONNX-graph.

Returns:
:str: Path of the generated ``ONNX`` graph.
"""
bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
seq_len = constants.WAV2VEC2_MAX_SEQ_LEN

example_inputs = {
"input_values": torch.zeros((bs, seq_len), dtype=torch.float32),
}

dynamic_axes = {"input_values": {0: "batch_size", 1: "seq_len"}}

output_names = ["logits"]

return self._export(
example_inputs,
output_names,
dynamic_axes,
export_dir=export_dir,
)

def compile(
self,
onnx_path: Optional[str] = None,
compile_dir: Optional[str] = None,
*,
seq_len: Union[int, List[int]] = 480000,
batch_size: int = 1,
num_devices: int = 1,
num_cores: int = 16, # FIXME: Make this mandatory arg
mxfp6_matmul: bool = False,
**compiler_options,
) -> str:
"""
This method compiles the exported ``ONNX`` model using the Cloud AI 100 Platform SDK compiler binary found at ``/opt/qti-aic/exec/qaic-exec`` and generates a ``qpc`` package.
If the model has not been exported yet, this method will handle the export process.
You can pass any other arguments that the `qaic-exec` takes as extra kwargs.

``Optional`` Args:
:onnx_path (str, optional): Path to pre-exported onnx model.
:compile_dir (str, optional): Path for saving the qpc generated.
:seq_len (Union[int, List[int]]): The length of the prompt should be less that ``seq_len``. ``Defaults to 32``.
:batch_size (int, optional): Batch size. ``Defaults to 1``.
:num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1.
:num_cores (int): Number of cores used to compile the model.
:mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to False``.
:compiler_options (dict, optional): Additional compiler options.

For QAIC Compiler: Extra arguments for qaic-exec can be passed.
:aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
:allow_mxint8_mdp_io (bool, optional): Allows MXINT8 compression of MDP IO traffic. ``Defaults to False.``

Params are converted to flags as below:

- aic_hw_version=ai100 -> -aic-hw-version=ai100
- aic_hw_version=ai200 -> -aic-hw-version=ai200

For QNN Compiler: Following arguments can be passed.
:enable_qnn (bool): Enables QNN Compilation.
:qnn_config (str): Path of QNN Config parameters file. Any extra parameters for QNN compilation can be passed via this file.

Returns:
:str: Path of the compiled ``qpc`` package.
"""

specializations = [
{"batch_size": batch_size, "seq_len": sl} for sl in (seq_len if isinstance(seq_len, list) else [seq_len])
]

return self._compile(
onnx_path=onnx_path,
compile_dir=compile_dir,
compile_only=True,
specializations=specializations,
convert_to_fp16=True,
mxfp6_matmul=mxfp6_matmul,
mdp_ts_num_devices=num_devices,
aic_num_cores=num_cores,
**compiler_options,
)

def generate(
self,
processor,
inputs: torch.Tensor,
device_ids: List[int] = None,
runtime_ai100: bool = True,
) -> Union[torch.Tensor, np.ndarray]:
"""
This method generates output by executing PyTorch runtime or the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
``Mandatory`` Args:
:inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution.
:processor (AutoProcessor): The Processor to use for encoding the waveform.
``optional`` Args:
:device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
:runtime_ai100 (bool, optional): ``AI_100`` and ``PyTorch`` runtime is supported as of now. Defaults to ``True`` for ``AI_100`` runtime.
Returns:
:dict: Output from the ``AI_100`` or ``PyTorch`` runtime.
"""
# AI_100 runtime
if runtime_ai100:
if not isinstance(self.qpc_path, Path):
raise TypeError("Please run compile API first!")

return self.cloud_ai_100_feature_generate(processor, inputs=inputs, device_ids=device_ids)
# PyTorch runtime
else:
return self.pytorch_feature_generate(processor, model=self.model, inputs=inputs)

def cloud_ai_100_feature_generate(
self,
processor,
inputs: torch.Tensor,
device_ids: List[int] = [0],
) -> np.ndarray:
"""
Generates features with list of prompts using AI 100 runtime.

``Mandatory`` Args:
:inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution.
:processor (AutoProcessor): The Processor to use for encoding the waveform.
``Optional`` Args:
device_ids (List[int], optional): A list of device IDs to use for the session. Defaults to [0].

"""

if self.qpc_session is None:
self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids)
self.batch_size = self.qpc_session.bindings[0].dims[0]

# Dynamic switching to closest seq_Len based on input_ids_len
inputs = processor(inputs, return_tensors="pt")
input_ids_len = inputs["input_values"].shape[-1]

for allowed_shape in self.qpc_session.allowed_shapes:
seq_len_allowed = allowed_shape[1][1][1]

if seq_len_allowed >= input_ids_len:
self.seq_len = seq_len_allowed
break

# To handle single seq_len as we can't fetch allowed shapes for single seq_len
self.seq_len = self.qpc_session.bindings[0].dims[1] if not hasattr(self, "seq_len") else self.seq_len
input_values = np.array(
torch.nn.functional.pad(inputs["input_values"], (0, self.seq_len - input_ids_len), "constant", 0)
)
inputs = dict(input_values=input_values)
outputs = self.qpc_session.run(inputs)
logits = outputs["logits"]
predicted_ids = np.argmax(logits, axis=-1)
transcriptions = processor.batch_decode(torch.tensor(predicted_ids))
return transcriptions

def pytorch_feature_generate(self, processor, model, inputs: Union[torch.Tensor, np.ndarray]) -> List[torch.Tensor]:
"""
Generates features from a list of text prompts using a PyTorch model.

``Mandatory`` Args:
:model: The transformed PyTorch model used for generating features.
:inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution.
:processor (AutoProcessor): The Processor to use for encoding the waveform.

"""
input_values = processor(
inputs[0], return_tensors="pt", max_length=self.seq_len, truncation=True, padding="max_length"
).input_values
logits = model(input_values[0]).logits
logits = logits.detach().numpy()
predicted_ids = np.argmax(logits, axis=-1)
transcriptions = processor.batch_decode(predicted_ids)
return transcriptions
3 changes: 3 additions & 0 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def get_models_dir():
# Gemma3 Constant
GEMMA3_MAX_POSITION_EMBEDDINGS = 32768

# Wav2Vec2 Constant
WAV2VEC2_MAX_SEQ_LEN = 480000 # 30 seconds of audio at 16 kHz sampling rate (16,000 samples/sec × 30 sec)


class Constants:
# Export Constants.
Expand Down
21 changes: 21 additions & 0 deletions examples/wav2vec2_example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Speech Recognition with Wav2Vec2
This directory contains an example script of how to use the AutoModelForCTC class. (for now, Wav2Vec2 models on audio <30 seconds only has been validated)

## Required packages:
- `librosa==0.10.2`
- `soundfile==0.13.1`

You can install them using pip:
```sh
pip install librosa==0.10.2 soundfile==0.13.1
```

To run example script after package installations:
```sh
python run_wav2vec2_inference.py
```

Expected output for given data sample:
```sh
MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL
```
24 changes: 24 additions & 0 deletions examples/wav2vec2_example/run_wav2vec2_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

from datasets import load_dataset
from transformers import AutoProcessor

from QEfficient import QEFFAutoModelForCTC

base_model_name = "facebook/wav2vec2-base-960h"

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
data = ds[0]["audio"]["array"]
# reshape to so shape corresponds to data with batch size 1
data = data.reshape(-1)
sample_rate = ds[0]["audio"]["sampling_rate"]
processor = AutoProcessor.from_pretrained(base_model_name)

model = QEFFAutoModelForCTC.from_pretrained(base_model_name)
model.compile(num_cores=16)
print(model.generate(processor, inputs=data))
Loading
Loading