7
7
8
8
import os
9
9
from typing import List , Optional
10
+
10
11
import numpy as np
11
12
import onnx
12
13
import onnxruntime
13
14
import pytest
14
15
import torch
15
16
from datasets import load_dataset
16
17
from transformers import AutoModelForCTC , AutoProcessor
18
+
17
19
from QEfficient .transformers .models .modeling_auto import QEFFAutoModelForCTC
18
20
from QEfficient .transformers .quantizers .auto import replace_transformers_quantizers
19
21
from QEfficient .utils import hf_download
@@ -147,18 +149,16 @@ def check_CTC_pytorch_vs_kv_vs_ort_vs_ai100(
147
149
data = ds [0 ]["audio" ]["array" ]
148
150
data = torch .tensor (data ).unsqueeze (0 ).numpy ()
149
151
sample_rate = ds [0 ]["audio" ]["sampling_rate" ]
150
- pytorch_tokens = run_CTC_pytorch_hf (model_hf , processor , data , sample_rate )
152
+ pytorch_tokens = run_CTC_pytorch_hf (model_hf , processor , data , sample_rate )
151
153
predicted_ids = torch .argmax (pytorch_tokens , dim = - 1 )
152
154
pytorch_output = processor .batch_decode (predicted_ids )
153
-
155
+
154
156
qeff_model = QEFFAutoModelForCTC (model_hf , pretrained_model_name_or_path = model_name )
155
157
qeff_model .export ()
156
- ort_tokens = run_CTC_ort (qeff_model .onnx_path , qeff_model .model .config , processor , data , sample_rate )
158
+ ort_tokens = run_CTC_ort (qeff_model .onnx_path , qeff_model .model .config , processor , data , sample_rate )
157
159
predicted_ids = torch .argmax (ort_tokens , dim = - 1 )
158
160
ort_output = processor .batch_decode (predicted_ids )
159
- assert (pytorch_output == ort_output ), (
160
- "Tokens don't match for pytorch output and ORT output."
161
- )
161
+ assert pytorch_output == ort_output , "Tokens don't match for pytorch output and ORT output."
162
162
if not get_available_device_id ():
163
163
pytest .skip ("No available devices to run model on Cloud AI 100" )
164
164
qeff_model .compile (
@@ -168,11 +168,8 @@ def check_CTC_pytorch_vs_kv_vs_ort_vs_ai100(
168
168
qnn_config = qnn_config ,
169
169
)
170
170
cloud_ai_100_output = qeff_model .generate (processor , data )
171
- assert (pytorch_output == cloud_ai_100_output ), (
172
- "Tokens don't match for pytorch output and Cloud AI 100 output."
173
- )
171
+ assert pytorch_output == cloud_ai_100_output , "Tokens don't match for pytorch output and Cloud AI 100 output."
174
172
assert os .path .isfile (os .path .join (os .path .dirname (qeff_model .qpc_path ), "qconfig.json" ))
175
-
176
173
177
174
178
175
@pytest .mark .on_qaic
0 commit comments