33# SPDX-License-Identifier: Apache-2.0
44#
55
6+ from __future__ import annotations # TODO: remove when Python3.9 support is dropped
7+
68import logging as log
79from pathlib import Path
10+ from typing import TYPE_CHECKING , Any
11+
12+ if TYPE_CHECKING :
13+ from os import PathLike
14+
15+ from numpy import ndarray
816
917try :
1018 import openvino .runtime as ov
1119 from openvino import (
1220 AsyncInferQueue ,
1321 Core ,
1422 Dimension ,
23+ OVAny ,
1524 PartialShape ,
1625 Type ,
1726 get_version ,
3544)
3645
3746
38- def create_core ():
47+ def create_core () -> Core :
3948 if openvino_absent :
4049 msg = "The OpenVINO package is not installed"
4150 raise ImportError (msg )
@@ -45,7 +54,7 @@ def create_core():
4554 return Core ()
4655
4756
48- def parse_devices (device_string ) :
57+ def parse_devices (device_string : str ) -> list [ str ] :
4958 colon_position = device_string .find (":" )
5059 if colon_position != - 1 :
5160 device_type = device_string [:colon_position ]
@@ -57,7 +66,7 @@ def parse_devices(device_string):
5766 if parenthesis_position != - 1 :
5867 device = device [:parenthesis_position ]
5968 return devices
60- return ( device_string ,)
69+ return [ device_string ]
6170
6271
6372def parse_value_per_device (devices : set [str ], values_string : str ) -> dict [str , int ]:
@@ -82,7 +91,7 @@ def parse_value_per_device(devices: set[str], values_string: str) -> dict[str, i
8291def get_user_config (
8392 flags_d : str ,
8493 flags_nstreams : str ,
85- flags_nthreads : int ,
94+ flags_nthreads : int | None = None ,
8695) -> dict [str , str ]:
8796 config = {}
8897
@@ -111,17 +120,17 @@ class OpenvinoAdapter(InferenceAdapter):
111120
112121 def __init__ (
113122 self ,
114- core ,
115- model ,
116- weights_path = "" ,
117- model_parameters = {},
118- device = "CPU" ,
119- plugin_config = None ,
120- max_num_requests = 0 ,
121- precision = "FP16" ,
122- download_dir = None ,
123- cache_dir = None ,
124- ):
123+ core : Core ,
124+ model : str ,
125+ weights_path : PathLike | None = None ,
126+ model_parameters : dict [ str , Any ] = {},
127+ device : str = "CPU" ,
128+ plugin_config : dict [ str , Any ] | None = None ,
129+ max_num_requests : int = 0 ,
130+ precision : str = "FP16" ,
131+ download_dir : PathLike | None = None ,
132+ cache_dir : PathLike | None = None ,
133+ ) -> None :
125134 """precision, download_dir and cache_dir are ignored if model is a path to a file"""
126135 self .core = core
127136 self .model_path = model
@@ -179,7 +188,7 @@ def __init__(
179188 msg = "Model must be bytes, a file or existing OMZ model name"
180189 raise RuntimeError (msg )
181190
182- def load_model (self ):
191+ def load_model (self ) -> None :
183192 self .compiled_model = self .core .compile_model (
184193 self .model ,
185194 self .device ,
@@ -201,7 +210,7 @@ def load_model(self):
201210 )
202211 self .log_runtime_settings ()
203212
204- def log_runtime_settings (self ):
213+ def log_runtime_settings (self ) -> None :
205214 devices = set (parse_devices (self .device ))
206215 if "AUTO" not in devices :
207216 for device in devices :
@@ -222,7 +231,7 @@ def log_runtime_settings(self):
222231 pass
223232 log .info (f"\t Number of model infer requests: { len (self .async_queue )} " )
224233
225- def get_input_layers (self ):
234+ def get_input_layers (self ) -> dict [ str , Metadata ] :
226235 inputs = {}
227236 for input in self .model .inputs :
228237 input_shape = get_input_shape (input )
@@ -235,7 +244,11 @@ def get_input_layers(self):
235244 )
236245 return self ._get_meta_from_ngraph (inputs )
237246
238- def get_layout_for_input (self , input , shape = None ) -> str :
247+ def get_layout_for_input (
248+ self ,
249+ input : ov .Output ,
250+ shape : list [int ] | tuple [int , int , int , int ] | None = None ,
251+ ) -> str :
239252 input_layout = ""
240253 if self .model_parameters ["input_layouts" ]:
241254 input_layout = Layout .from_user_layouts (
@@ -251,7 +264,7 @@ def get_layout_for_input(self, input, shape=None) -> str:
251264 )
252265 return input_layout
253266
254- def get_output_layers (self ):
267+ def get_output_layers (self ) -> dict [ str , Metadata ] :
255268 outputs = {}
256269 for i , output in enumerate (self .model .outputs ):
257270 output_shape = output .partial_shape .get_min_shape () if self .model .is_dynamic () else output .shape
@@ -273,13 +286,13 @@ def reshape_model(self, new_shape):
273286 }
274287 self .model .reshape (new_shape )
275288
276- def get_raw_result (self , request ) :
289+ def get_raw_result (self , request : ov . InferRequest ) -> dict [ str , ndarray ] :
277290 return {key : request .get_tensor (key ).data for key in self .get_output_layers ()}
278291
279292 def copy_raw_result (self , request ):
280293 return {key : request .get_tensor (key ).data .copy () for key in self .get_output_layers ()}
281294
282- def infer_sync (self , dict_data ) :
295+ def infer_sync (self , dict_data : dict [ str , ndarray ]) -> dict [ str , ndarray ] :
283296 self .infer_request = self .async_queue [self .async_queue .get_idle_request_id ()]
284297 self .infer_request .infer (dict_data )
285298 return self .get_raw_result (self .infer_request )
@@ -299,7 +312,7 @@ def await_all(self) -> None:
299312 def await_any (self ) -> None :
300313 self .async_queue .get_idle_request_id ()
301314
302- def _get_meta_from_ngraph (self , layers_info ) :
315+ def _get_meta_from_ngraph (self , layers_info : dict [ str , Metadata ]) -> dict [ str , Metadata ] :
303316 for node in self .model .get_ordered_ops ():
304317 layer_name = node .get_friendly_name ()
305318 if layer_name not in layers_info :
@@ -319,24 +332,24 @@ def operations_by_type(self, operation_type):
319332 )
320333 return layers_info
321334
322- def get_rt_info (self , path ) :
335+ def get_rt_info (self , path : list [ str ]) -> OVAny :
323336 if self .is_onnx_file :
324337 return get_rt_info_from_dict (self .onnx_metadata , path )
325338 return self .model .get_rt_info (path )
326339
327340 def embed_preprocessing (
328341 self ,
329- layout ,
342+ layout : str ,
330343 resize_mode : str ,
331- interpolation_mode ,
332- target_shape : tuple [int ],
333- pad_value ,
344+ interpolation_mode : str ,
345+ target_shape : tuple [int , ... ],
346+ pad_value : int ,
334347 dtype : type = int ,
335- brg2rgb = False ,
336- mean = None ,
337- scale = None ,
338- input_idx = 0 ,
339- ):
348+ brg2rgb : bool = False ,
349+ mean : list [ Any ] | None = None ,
350+ scale : list [ Any ] | None = None ,
351+ input_idx : int = 0 ,
352+ ) -> None :
340353 ppp = PrePostProcessor (self .model )
341354
342355 # Change the input type to the 8-bit image
@@ -371,7 +384,7 @@ def embed_preprocessing(
371384 ppp .input (input_idx ).tensor ().set_shape (input_shape )
372385 ppp .input (input_idx ).preprocess ().custom (
373386 RESIZE_MODE_MAP [resize_mode ](
374- target_shape ,
387+ ( target_shape [ 0 ], target_shape [ 1 ]) ,
375388 INTERPOLATION_MODE_MAP [interpolation_mode ],
376389 pad_value ,
377390 ),
@@ -407,7 +420,7 @@ def get_model(self):
407420 return self .model
408421
409422
410- def get_input_shape (input_tensor ) :
423+ def get_input_shape (input_tensor : ov . Output ) -> list [ int ] :
411424 def string_to_tuple (string , casting_type = int ):
412425 processed = string .replace (" " , "" ).replace ("(" , "" ).replace (")" , "" ).split ("," )
413426 processed = filter (lambda x : x , processed )
@@ -428,4 +441,4 @@ def string_to_tuple(string, casting_type=int):
428441 else :
429442 shape_list .append (int (dim ))
430443 return shape_list
431- return string_to_tuple (preprocessed )
444+ return list ( string_to_tuple (preprocessed ) )
0 commit comments