-
Notifications
You must be signed in to change notification settings - Fork 41
Open
Description
After applying the following patch to helper_classes.py
:
diff --git a/alt_e2eshark/onnx_tests/helper_classes.py b/alt_e2eshark/onnx_tests/helper_classes.py
index af30cfe..41651c9 100644
--- a/alt_e2eshark/onnx_tests/helper_classes.py
+++ b/alt_e2eshark/onnx_tests/helper_classes.py
@@ -21,6 +21,7 @@ from e2e_testing import azutils
from e2e_testing.framework import (
ExtraOptions,
ImporterOptions,
+ RuntimeOptions,
OnnxModelInfo,
TestTensors,
)
@@ -164,10 +165,16 @@ class OnnxModelZooDownloadableModel(OnnxModelInfo):
return yaml_path
def update_extra_options(self):
+ param_str = f"parameters=model={Path(self.model).parent}" + "/model.torch_onnx_params.irpa"
# called in __init__
# default to using opset version 21 for all ONNX Model Zoo models.
self.extra_options = ExtraOptions(
- import_model_options=ImporterOptions(opset_version=21)
+ import_model_options=ImporterOptions(opset_version=21, externalize_params=True, num_elements_threshold=2),
+ compiled_inference_options=RuntimeOptions(
+ common_extra_args=(
+ param_str,
+ ),
+ )
)
def update_input_name_to_shape_map(self):
All the MobileNet models fail compilation with the following error:
SHARK-TestSuite/alt_e2eshark/test-run/mobilenetv3_small_100_Opset16_timm/model.torch_onnx.mlir:343:12: error: failed to legalize operation 'torch.aten.sum.dim_IntList' that was explicitly marked illegal
%113 = torch.operator "onnx.ReduceMean"(%111, %112) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[1,16,56,56],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,16,1,1],f32>
^
/home/vinaydev/SHARK-TestSuite/alt_e2eshark/test-run/mobilenetv3_small_100_Opset16_timm/model.torch_onnx.mlir:343:12: note: see current operation: %636 = "torch.aten.sum.dim_IntList"(%509, %635, %21, %32) : (!torch.vtensor<[1,16,56,56],f32>, !torch.list<int>, !torch.bool, !torch.none) -> !torch.vtensor<[1,16,1,1],f32>
This issue is observable in a newly created environment when setting num_elements_threshold
to any value <= 2 in ImporterOptions
.
The failing models pass once the threshold is increased to more than 2.
Metadata
Metadata
Assignees
Labels
No labels