Skip to content

Commit 1222eef

Browse files
Type annotation (#229)
* Partial annotation (#226) Signed-off-by: Ashwin Vaidya <[email protected]> * Typing: core modules (#227) * Typing for OV adapter * Enable type checking option * Fix typing issue in adapter * Cover model * Cover ImageModel * Cover types * Cover utils * Add missing future imports * Update type checking imports --------- Signed-off-by: Ashwin Vaidya <[email protected]> Co-authored-by: Ashwin Vaidya <[email protected]>
1 parent 4cc192d commit 1222eef

File tree

15 files changed

+348
-259
lines changed

15 files changed

+348
-259
lines changed

model_api/python/model_api/adapters/inference_adapter.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def reshape_model(self, new_shape):
8383
"""
8484

8585
@abstractmethod
86-
def infer_sync(self, dict_data):
86+
def infer_sync(self, dict_data) -> dict:
8787
"""Performs the synchronous model inference. The infer is a blocking method.
8888
8989
Args:
@@ -121,6 +121,22 @@ def infer_async(self, dict_data, callback_data):
121121
- callback_data: the data for callback, that will be taken after the model inference is ended
122122
"""
123123

124+
@abstractmethod
125+
def get_raw_result(self, infer_result) -> dict:
126+
"""Gets raw results from the internal inference framework representation as a dict.
127+
128+
Args:
129+
- infer_result: framework-specific result of inference from the model
130+
131+
Returns:
132+
- raw result (dict) - model raw output in the following format:
133+
{
134+
'output_layer_name_1': raw_result_1,
135+
'output_layer_name_2': raw_result_2,
136+
...
137+
}
138+
"""
139+
124140
@abstractmethod
125141
def is_ready(self):
126142
"""In case of asynchronous execution checks if one can submit input data
@@ -153,7 +169,7 @@ def embed_preprocessing(
153169
layout,
154170
resize_mode: str,
155171
interpolation_mode,
156-
target_shape: tuple[int],
172+
target_shape: tuple[int, ...],
157173
pad_value,
158174
dtype: type = int,
159175
brg2rgb=False,

model_api/python/model_api/adapters/onnx_adapter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# SPDX-License-Identifier: Apache-2.0
44
#
55

6+
from __future__ import annotations # TODO: remove when Python3.9 support is dropped
7+
68
import sys
79
from functools import partial, reduce
810

@@ -122,6 +124,9 @@ def await_all(self):
122124
def await_any(self):
123125
pass
124126

127+
def get_raw_result(self, infer_result):
128+
pass
129+
125130
def embed_preprocessing(
126131
self,
127132
layout,

model_api/python/model_api/adapters/openvino_adapter.py

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,24 @@
33
# SPDX-License-Identifier: Apache-2.0
44
#
55

6+
from __future__ import annotations # TODO: remove when Python3.9 support is dropped
7+
68
import logging as log
79
from pathlib import Path
10+
from typing import TYPE_CHECKING, Any
11+
12+
if TYPE_CHECKING:
13+
from os import PathLike
14+
15+
from numpy import ndarray
816

917
try:
1018
import openvino.runtime as ov
1119
from openvino import (
1220
AsyncInferQueue,
1321
Core,
1422
Dimension,
23+
OVAny,
1524
PartialShape,
1625
Type,
1726
get_version,
@@ -35,7 +44,7 @@
3544
)
3645

3746

38-
def create_core():
47+
def create_core() -> Core:
3948
if openvino_absent:
4049
msg = "The OpenVINO package is not installed"
4150
raise ImportError(msg)
@@ -45,7 +54,7 @@ def create_core():
4554
return Core()
4655

4756

48-
def parse_devices(device_string):
57+
def parse_devices(device_string: str) -> list[str]:
4958
colon_position = device_string.find(":")
5059
if colon_position != -1:
5160
device_type = device_string[:colon_position]
@@ -57,7 +66,7 @@ def parse_devices(device_string):
5766
if parenthesis_position != -1:
5867
device = device[:parenthesis_position]
5968
return devices
60-
return (device_string,)
69+
return [device_string]
6170

6271

6372
def parse_value_per_device(devices: set[str], values_string: str) -> dict[str, int]:
@@ -82,7 +91,7 @@ def parse_value_per_device(devices: set[str], values_string: str) -> dict[str, i
8291
def get_user_config(
8392
flags_d: str,
8493
flags_nstreams: str,
85-
flags_nthreads: int,
94+
flags_nthreads: int | None = None,
8695
) -> dict[str, str]:
8796
config = {}
8897

@@ -111,17 +120,17 @@ class OpenvinoAdapter(InferenceAdapter):
111120

112121
def __init__(
113122
self,
114-
core,
115-
model,
116-
weights_path="",
117-
model_parameters={},
118-
device="CPU",
119-
plugin_config=None,
120-
max_num_requests=0,
121-
precision="FP16",
122-
download_dir=None,
123-
cache_dir=None,
124-
):
123+
core: Core,
124+
model: str,
125+
weights_path: PathLike | None = None,
126+
model_parameters: dict[str, Any] = {},
127+
device: str = "CPU",
128+
plugin_config: dict[str, Any] | None = None,
129+
max_num_requests: int = 0,
130+
precision: str = "FP16",
131+
download_dir: PathLike | None = None,
132+
cache_dir: PathLike | None = None,
133+
) -> None:
125134
"""precision, download_dir and cache_dir are ignored if model is a path to a file"""
126135
self.core = core
127136
self.model_path = model
@@ -179,7 +188,7 @@ def __init__(
179188
msg = "Model must be bytes, a file or existing OMZ model name"
180189
raise RuntimeError(msg)
181190

182-
def load_model(self):
191+
def load_model(self) -> None:
183192
self.compiled_model = self.core.compile_model(
184193
self.model,
185194
self.device,
@@ -201,7 +210,7 @@ def load_model(self):
201210
)
202211
self.log_runtime_settings()
203212

204-
def log_runtime_settings(self):
213+
def log_runtime_settings(self) -> None:
205214
devices = set(parse_devices(self.device))
206215
if "AUTO" not in devices:
207216
for device in devices:
@@ -222,7 +231,7 @@ def log_runtime_settings(self):
222231
pass
223232
log.info(f"\tNumber of model infer requests: {len(self.async_queue)}")
224233

225-
def get_input_layers(self):
234+
def get_input_layers(self) -> dict[str, Metadata]:
226235
inputs = {}
227236
for input in self.model.inputs:
228237
input_shape = get_input_shape(input)
@@ -235,7 +244,11 @@ def get_input_layers(self):
235244
)
236245
return self._get_meta_from_ngraph(inputs)
237246

238-
def get_layout_for_input(self, input, shape=None) -> str:
247+
def get_layout_for_input(
248+
self,
249+
input: ov.Output,
250+
shape: list[int] | tuple[int, int, int, int] | None = None,
251+
) -> str:
239252
input_layout = ""
240253
if self.model_parameters["input_layouts"]:
241254
input_layout = Layout.from_user_layouts(
@@ -251,7 +264,7 @@ def get_layout_for_input(self, input, shape=None) -> str:
251264
)
252265
return input_layout
253266

254-
def get_output_layers(self):
267+
def get_output_layers(self) -> dict[str, Metadata]:
255268
outputs = {}
256269
for i, output in enumerate(self.model.outputs):
257270
output_shape = output.partial_shape.get_min_shape() if self.model.is_dynamic() else output.shape
@@ -273,13 +286,13 @@ def reshape_model(self, new_shape):
273286
}
274287
self.model.reshape(new_shape)
275288

276-
def get_raw_result(self, request):
289+
def get_raw_result(self, request: ov.InferRequest) -> dict[str, ndarray]:
277290
return {key: request.get_tensor(key).data for key in self.get_output_layers()}
278291

279292
def copy_raw_result(self, request):
280293
return {key: request.get_tensor(key).data.copy() for key in self.get_output_layers()}
281294

282-
def infer_sync(self, dict_data):
295+
def infer_sync(self, dict_data: dict[str, ndarray]) -> dict[str, ndarray]:
283296
self.infer_request = self.async_queue[self.async_queue.get_idle_request_id()]
284297
self.infer_request.infer(dict_data)
285298
return self.get_raw_result(self.infer_request)
@@ -299,7 +312,7 @@ def await_all(self) -> None:
299312
def await_any(self) -> None:
300313
self.async_queue.get_idle_request_id()
301314

302-
def _get_meta_from_ngraph(self, layers_info):
315+
def _get_meta_from_ngraph(self, layers_info: dict[str, Metadata]) -> dict[str, Metadata]:
303316
for node in self.model.get_ordered_ops():
304317
layer_name = node.get_friendly_name()
305318
if layer_name not in layers_info:
@@ -319,24 +332,24 @@ def operations_by_type(self, operation_type):
319332
)
320333
return layers_info
321334

322-
def get_rt_info(self, path):
335+
def get_rt_info(self, path: list[str]) -> OVAny:
323336
if self.is_onnx_file:
324337
return get_rt_info_from_dict(self.onnx_metadata, path)
325338
return self.model.get_rt_info(path)
326339

327340
def embed_preprocessing(
328341
self,
329-
layout,
342+
layout: str,
330343
resize_mode: str,
331-
interpolation_mode,
332-
target_shape: tuple[int],
333-
pad_value,
344+
interpolation_mode: str,
345+
target_shape: tuple[int, ...],
346+
pad_value: int,
334347
dtype: type = int,
335-
brg2rgb=False,
336-
mean=None,
337-
scale=None,
338-
input_idx=0,
339-
):
348+
brg2rgb: bool = False,
349+
mean: list[Any] | None = None,
350+
scale: list[Any] | None = None,
351+
input_idx: int = 0,
352+
) -> None:
340353
ppp = PrePostProcessor(self.model)
341354

342355
# Change the input type to the 8-bit image
@@ -371,7 +384,7 @@ def embed_preprocessing(
371384
ppp.input(input_idx).tensor().set_shape(input_shape)
372385
ppp.input(input_idx).preprocess().custom(
373386
RESIZE_MODE_MAP[resize_mode](
374-
target_shape,
387+
(target_shape[0], target_shape[1]),
375388
INTERPOLATION_MODE_MAP[interpolation_mode],
376389
pad_value,
377390
),
@@ -407,7 +420,7 @@ def get_model(self):
407420
return self.model
408421

409422

410-
def get_input_shape(input_tensor):
423+
def get_input_shape(input_tensor: ov.Output) -> list[int]:
411424
def string_to_tuple(string, casting_type=int):
412425
processed = string.replace(" ", "").replace("(", "").replace(")", "").split(",")
413426
processed = filter(lambda x: x, processed)
@@ -428,4 +441,4 @@ def string_to_tuple(string, casting_type=int):
428441
else:
429442
shape_list.append(int(dim))
430443
return shape_list
431-
return string_to_tuple(preprocessed)
444+
return list(string_to_tuple(preprocessed))

model_api/python/model_api/adapters/ovms_adapter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def await_all(self):
9797
def await_any(self):
9898
pass
9999

100+
def get_raw_result(self, infer_result):
101+
pass
102+
100103
def embed_preprocessing(
101104
self,
102105
layout,

0 commit comments

Comments
 (0)