Skip to content

Commit 629c2e5

Browse files
committed
Added test for AutoModelForCTC class.
Signed-off-by: Tanisha Chawada <[email protected]>
1 parent 5e5f40b commit 629c2e5

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

tests/transformers/models/test_audio_embedding_models.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77

88
import os
99
from typing import List, Optional
10+
1011
import numpy as np
1112
import onnx
1213
import onnxruntime
1314
import pytest
1415
import torch
1516
from datasets import load_dataset
1617
from transformers import AutoModelForCTC, AutoProcessor
18+
1719
from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCTC
1820
from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers
1921
from QEfficient.utils import hf_download
@@ -147,18 +149,16 @@ def check_CTC_pytorch_vs_kv_vs_ort_vs_ai100(
147149
data = ds[0]["audio"]["array"]
148150
data = torch.tensor(data).unsqueeze(0).numpy()
149151
sample_rate = ds[0]["audio"]["sampling_rate"]
150-
pytorch_tokens=run_CTC_pytorch_hf(model_hf, processor, data, sample_rate)
152+
pytorch_tokens = run_CTC_pytorch_hf(model_hf, processor, data, sample_rate)
151153
predicted_ids = torch.argmax(pytorch_tokens, dim=-1)
152154
pytorch_output = processor.batch_decode(predicted_ids)
153-
155+
154156
qeff_model = QEFFAutoModelForCTC(model_hf, pretrained_model_name_or_path=model_name)
155157
qeff_model.export()
156-
ort_tokens=run_CTC_ort(qeff_model.onnx_path, qeff_model.model.config, processor, data, sample_rate)
158+
ort_tokens = run_CTC_ort(qeff_model.onnx_path, qeff_model.model.config, processor, data, sample_rate)
157159
predicted_ids = torch.argmax(ort_tokens, dim=-1)
158160
ort_output = processor.batch_decode(predicted_ids)
159-
assert (pytorch_output == ort_output), (
160-
"Tokens don't match for pytorch output and ORT output."
161-
)
161+
assert pytorch_output == ort_output, "Tokens don't match for pytorch output and ORT output."
162162
if not get_available_device_id():
163163
pytest.skip("No available devices to run model on Cloud AI 100")
164164
qeff_model.compile(
@@ -168,11 +168,8 @@ def check_CTC_pytorch_vs_kv_vs_ort_vs_ai100(
168168
qnn_config=qnn_config,
169169
)
170170
cloud_ai_100_output = qeff_model.generate(processor, data)
171-
assert (pytorch_output == cloud_ai_100_output), (
172-
"Tokens don't match for pytorch output and Cloud AI 100 output."
173-
)
171+
assert pytorch_output == cloud_ai_100_output, "Tokens don't match for pytorch output and Cloud AI 100 output."
174172
assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json"))
175-
176173

177174

178175
@pytest.mark.on_qaic

0 commit comments

Comments
 (0)