Skip to content

Commit 17f3f18

Browse files
Update core API docs: Adapter and Model (#242)
* Update adapter doc and api * Update model docs * Fix python 3.9 support * Fix imports in ovms adapter * Add default args to Model.save * Apply suggestions Co-authored-by: Ashwin Vaidya <[email protected]> --------- Co-authored-by: Ashwin Vaidya <[email protected]>
1 parent 2a5f66f commit 17f3f18

File tree

6 files changed

+230
-51
lines changed

6 files changed

+230
-51
lines changed

model_api/python/model_api/adapters/inference_adapter.py

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
# SPDX-License-Identifier: Apache-2.0
44
#
55

6+
from __future__ import annotations # TODO: remove when Python3.9 support is dropped
7+
68
from abc import ABC, abstractmethod
79
from dataclasses import dataclass, field
8-
from typing import Any
10+
from typing import Any, Callable
911

1012

1113
@dataclass
@@ -69,7 +71,7 @@ def get_output_layers(self):
6971
"""
7072

7173
@abstractmethod
72-
def reshape_model(self, new_shape):
74+
def reshape_model(self, new_shape: dict):
7375
"""Reshapes the model inputs to fit the new input shape.
7476
7577
Args:
@@ -83,7 +85,7 @@ def reshape_model(self, new_shape):
8385
"""
8486

8587
@abstractmethod
86-
def infer_sync(self, dict_data) -> dict:
88+
def infer_sync(self, dict_data: dict) -> dict:
8789
"""Performs the synchronous model inference. The infer is a blocking method.
8890
8991
Args:
@@ -104,7 +106,7 @@ def infer_sync(self, dict_data) -> dict:
104106
"""
105107

106108
@abstractmethod
107-
def infer_async(self, dict_data, callback_data):
109+
def infer_async(self, dict_data: dict, callback_data: Any):
108110
"""
109111
Performs the asynchronous model inference and sets
110112
the callback for inference completion. Also, it should
@@ -122,11 +124,11 @@ def infer_async(self, dict_data, callback_data):
122124
"""
123125

124126
@abstractmethod
125-
def get_raw_result(self, infer_result) -> dict:
127+
def get_raw_result(self, infer_result: dict) -> dict:
126128
"""Gets raw results from the internal inference framework representation as a dict.
127129
128130
Args:
129-
- infer_result: framework-specific result of inference from the model
131+
- infer_result (dict): framework-specific result of inference from the model
130132
131133
Returns:
132134
- raw result (dict) - model raw output in the following format:
@@ -138,7 +140,16 @@ def get_raw_result(self, infer_result) -> dict:
138140
"""
139141

140142
@abstractmethod
141-
def is_ready(self):
143+
def set_callback(self, callback_fn: Callable):
144+
"""
145+
Sets callback that grabs results of async inference.
146+
147+
Args:
148+
callback_fn (Callable): Callback function.
149+
"""
150+
151+
@abstractmethod
152+
def is_ready(self) -> bool:
142153
"""In case of asynchronous execution checks if one can submit input data
143154
to the model for inference, or all infer requests are busy.
144155
@@ -160,29 +171,67 @@ def await_any(self):
160171
"""
161172

162173
@abstractmethod
163-
def get_rt_info(self, path):
164-
"""Forwards to openvino.Model.get_rt_info(path)"""
174+
def get_rt_info(self, path: list[str]) -> Any:
175+
"""
176+
Returns an attribute stored in model info.
177+
178+
Args:
179+
path (list[str]): a sequence of tag names leading to the attribute.
180+
181+
Returns:
182+
Any: a value stored under corresponding tag sequence.
183+
"""
165184

166185
@abstractmethod
167186
def update_model_info(self, model_info: dict[str, Any]):
168-
"""Updates model with the provided model info."""
187+
"""
188+
Updates model with the provided model info. Model info dict can
189+
also contain nested dicts.
190+
191+
Args:
192+
model_info (dict[str, Any]): model info dict to write to the model.
193+
"""
169194

170195
@abstractmethod
171-
def save_model(self, path: str, weights_path: str, version: str):
172-
"""Serializes model to the filesystem."""
196+
def save_model(self, path: str, weights_path: str | None, version: str | None):
197+
"""
198+
Serializes model to the filesystem.
199+
200+
Args:
201+
path (str): Path to write the resulting model.
202+
weights_path (str | None): Optional path to save weights if they are stored separately.
203+
version (str | None): Optional model version.
204+
"""
173205

174206
@abstractmethod
175207
def embed_preprocessing(
176208
self,
177-
layout,
209+
layout: str,
178210
resize_mode: str,
179-
interpolation_mode,
211+
interpolation_mode: str,
180212
target_shape: tuple[int, ...],
181-
pad_value,
213+
pad_value: int,
182214
dtype: type = int,
183-
brg2rgb=False,
184-
mean=None,
185-
scale=None,
186-
input_idx=0,
215+
brg2rgb: bool = False,
216+
mean: list[Any] | None = None,
217+
scale: list[Any] | None = None,
218+
input_idx: int = 0,
187219
):
188-
"""Embeds preprocessing into the model using OpenVINO preprocessing API"""
220+
"""
221+
Embeds preprocessing into the model if possible with the adapter being used.
222+
In some cases, this method would just add extra python preprocessing steps
223+
instaed actuall of embedding it into the model representation.
224+
225+
Args:
226+
layout (str): Layout, for instance NCHW.
227+
resize_mode (str): Resize type to use for preprocessing.
228+
interpolation_mode (str): Resize interpolation mode.
229+
target_shape (tuple[int, ...]): Target resize shape.
230+
pad_value (int): Value to pad with if resize implies padding.
231+
dtype (type, optional): Input data type for the preprocessing module. Defaults to int.
232+
bgr2rgb (bool, optional): Defines if we need to swap R and B channels in case of image input.
233+
Defaults to False.
234+
mean (list[Any] | None, optional): Mean values to perform input normalization. Defaults to None.
235+
scale (list[Any] | None, optional): Scale values to perform input normalization. Defaults to None.
236+
input_idx (int, optional): Index of the model input to apply preprocessing to. Defaults to 0.
237+
"""

model_api/python/model_api/adapters/onnx_adapter.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import sys
99
from functools import partial, reduce
10-
from typing import Any
10+
from typing import Any, Callable
1111

1212
import numpy as np
1313

@@ -111,7 +111,7 @@ def infer_sync(self, dict_data):
111111
def infer_async(self, dict_data, callback_data):
112112
raise NotImplementedError
113113

114-
def set_callback(self, callback_fn):
114+
def set_callback(self, callback_fn: Callable):
115115
self.callback_fn = callback_fn
116116

117117
def is_ready(self):
@@ -126,22 +126,25 @@ def await_all(self):
126126
def await_any(self):
127127
pass
128128

129-
def get_raw_result(self, infer_result):
129+
def get_raw_result(self, infer_result: dict):
130130
pass
131131

132132
def embed_preprocessing(
133133
self,
134-
layout,
134+
layout: str,
135135
resize_mode: str,
136-
interpolation_mode,
137-
target_shape,
138-
pad_value,
136+
interpolation_mode: str,
137+
target_shape: tuple[int, ...],
138+
pad_value: int,
139139
dtype: type = int,
140-
brg2rgb=False,
141-
mean=None,
142-
scale=None,
143-
input_idx=0,
140+
brg2rgb: bool = False,
141+
mean: list[Any] | None = None,
142+
scale: list[Any] | None = None,
143+
input_idx: int = 0,
144144
):
145+
"""
146+
Adds external preprocessing steps done before ONNX model execution.
147+
"""
145148
preproc_funcs = [np.squeeze]
146149
if resize_mode != "crop":
147150
if resize_mode == "fit_to_window_letterbox":
@@ -170,13 +173,23 @@ def embed_preprocessing(
170173
)
171174

172175
def get_model(self):
173-
"""Return the reference to the ONNXRuntime session."""
176+
"""Return a reference to the ONNXRuntime session."""
174177
return self.model
175178

176179
def reshape_model(self, new_shape):
180+
""" "Not supported by ONNX adapter."""
177181
raise NotImplementedError
178182

179183
def get_rt_info(self, path):
184+
"""
185+
Returns an attribute stored in model info.
186+
187+
Args:
188+
path (list[str]): a sequence of tag names leading to the attribute.
189+
190+
Returns:
191+
Any: a value stored under corresponding tag sequence.
192+
"""
180193
return get_rt_info_from_dict(self.onnx_metadata, path)
181194

182195
def update_model_info(self, model_info: dict[str, Any]):
@@ -189,7 +202,15 @@ def update_model_info(self, model_info: dict[str, Any]):
189202
else:
190203
meta.value = str(model_info[item])
191204

192-
def save_model(self, path: str, weights_path: str = "", version: str = "UNSPECIFIED"):
205+
def save_model(self, path: str, weights_path: str | None = None, version: str | None = None):
206+
"""
207+
Serializes model to the filesystem.
208+
209+
Args:
210+
path (str): paths to save .onnx file.
211+
weights_path (str | None): not used by ONNX adapter.
212+
version (str | None): not used by ONNX adapter.
213+
"""
193214
onnx.save(self.model, path)
194215

195216

model_api/python/model_api/adapters/openvino_adapter.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import logging as log
99
from pathlib import Path
10-
from typing import TYPE_CHECKING, Any
10+
from typing import TYPE_CHECKING, Any, Callable
1111

1212
if TYPE_CHECKING:
1313
from os import PathLike
@@ -300,7 +300,7 @@ def infer_sync(self, dict_data: dict[str, ndarray]) -> dict[str, ndarray]:
300300
def infer_async(self, dict_data, callback_data) -> None:
301301
self.async_queue.start_async(dict_data, callback_data)
302302

303-
def set_callback(self, callback_fn):
303+
def set_callback(self, callback_fn: Callable):
304304
self.async_queue.set_callback(callback_fn)
305305

306306
def is_ready(self) -> bool:
@@ -333,6 +333,15 @@ def operations_by_type(self, operation_type):
333333
return layers_info
334334

335335
def get_rt_info(self, path: list[str]) -> OVAny:
336+
"""
337+
Gets an attribute value from OV.model_info structure.
338+
339+
Args:
340+
path (list[str]): a suquence of tag names leading to the attribute.
341+
342+
Returns:
343+
OVAny: attribute value wrapped into OVAny object.
344+
"""
336345
if self.is_onnx_file:
337346
return get_rt_info_from_dict(self.onnx_metadata, path)
338347
return self.model.get_rt_info(path)
@@ -350,6 +359,9 @@ def embed_preprocessing(
350359
scale: list[Any] | None = None,
351360
input_idx: int = 0,
352361
) -> None:
362+
"""
363+
Embeds OpenVINO PrePostProcessor module into the model.
364+
"""
353365
ppp = PrePostProcessor(self.model)
354366

355367
# Change the input type to the 8-bit image
@@ -429,7 +441,20 @@ def update_model_info(self, model_info: dict[str, Any]):
429441
for name in model_info:
430442
self.model.set_rt_info(model_info[name], ["model_info", name])
431443

432-
def save_model(self, path: str, weights_path: str = "", version: str = "UNSPECIFIED"):
444+
def save_model(self, path: str, weights_path: str | None = None, version: str | None = None):
445+
"""
446+
Saves OV model as two files: .xml (architecture) and .bin (weights).
447+
448+
Args:
449+
path (str): path to save the model files (.xml and .bin).
450+
weights_path (str, optional): Optional path to save .bin if it differs from .xml path. Defaults to None.
451+
version (str, optional): Output IR model version (for instance, IR_V10). Defaults to None.
452+
"""
453+
if weights_path is None:
454+
weights_path = ""
455+
if version is None:
456+
version = "UNSPECIFIED"
457+
433458
ov.serialize(self.get_model(), path, weights_path, version)
434459

435460

model_api/python/model_api/adapters/ovms_adapter.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
# SPDX-License-Identifier: Apache-2.0
44
#
55

6+
from __future__ import annotations # TODO: remove when Python3.9 support is dropped
7+
68
import re
7-
from typing import Any
9+
from typing import Any, Callable
810

911
import numpy as np
1012

@@ -79,7 +81,7 @@ def infer_async(self, dict_data, callback_data):
7981
raw_result = {output_name: raw_result}
8082
self.callback_fn(raw_result, (lambda x: x, callback_data))
8183

82-
def set_callback(self, callback_fn):
84+
def set_callback(self, callback_fn: Callable):
8385
self.callback_fn = callback_fn
8486

8587
def is_ready(self):
@@ -98,7 +100,7 @@ def await_all(self):
98100
def await_any(self):
99101
pass
100102

101-
def get_raw_result(self, infer_result):
103+
def get_raw_result(self, infer_result: dict):
102104
pass
103105

104106
def embed_preprocessing(
@@ -127,7 +129,7 @@ def update_model_info(self, model_info: dict[str, Any]):
127129
msg = "OVMSAdapter does not support updating model info"
128130
raise NotImplementedError(msg)
129131

130-
def save_model(self, path: str, weights_path: str = "", version: str = "UNSPECIFIED"):
132+
def save_model(self, path: str, weights_path: str | None = None, version: str | None = None):
131133
msg = "OVMSAdapter does not support saving a model"
132134
raise NotImplementedError(msg)
133135

model_api/python/model_api/models/image_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,16 @@ def parameters(cls) -> dict[str, Any]:
146146
return parameters
147147

148148
def get_label_name(self, label_id: int) -> str:
149+
"""
150+
Returns a label name by it's index.
151+
If index is out of range, and auto-generated name is returned.
152+
153+
Args:
154+
label_id (int): label index.
155+
156+
Returns:
157+
str: label name.
158+
"""
149159
if self.labels is None:
150160
return f"#{label_id}"
151161
if label_id >= len(self.labels):

0 commit comments

Comments
 (0)