From 99c497a2ecd3f76901a10e8cf2b4062e317e1a0e Mon Sep 17 00:00:00 2001 From: Vladisalv Sovrasov Date: Tue, 10 Dec 2024 21:39:52 +0900 Subject: [PATCH 1/6] Update adapter doc and api --- .../model_api/adapters/inference_adapter.py | 76 ++++++++++++++----- .../python/model_api/adapters/onnx_adapter.py | 41 +++++++--- .../model_api/adapters/openvino_adapter.py | 27 ++++++- .../python/model_api/adapters/ovms_adapter.py | 2 +- 4 files changed, 115 insertions(+), 31 deletions(-) diff --git a/model_api/python/model_api/adapters/inference_adapter.py b/model_api/python/model_api/adapters/inference_adapter.py index efebe977..4512f601 100644 --- a/model_api/python/model_api/adapters/inference_adapter.py +++ b/model_api/python/model_api/adapters/inference_adapter.py @@ -69,7 +69,7 @@ def get_output_layers(self): """ @abstractmethod - def reshape_model(self, new_shape): + def reshape_model(self, new_shape: dict): """Reshapes the model inputs to fit the new input shape. Args: @@ -83,7 +83,7 @@ def reshape_model(self, new_shape): """ @abstractmethod - def infer_sync(self, dict_data) -> dict: + def infer_sync(self, dict_data: dict) -> dict: """Performs the synchronous model inference. The infer is a blocking method. Args: @@ -104,7 +104,7 @@ def infer_sync(self, dict_data) -> dict: """ @abstractmethod - def infer_async(self, dict_data, callback_data): + def infer_async(self, dict_data: dict, callback_data: Any): """ Performs the asynchronous model inference and sets the callback for inference completion. Also, it should @@ -122,11 +122,11 @@ def infer_async(self, dict_data, callback_data): """ @abstractmethod - def get_raw_result(self, infer_result) -> dict: + def get_raw_result(self, infer_result: dict) -> dict: """Gets raw results from the internal inference framework representation as a dict. Args: - - infer_result: framework-specific result of inference from the model + - infer_resul (dict): framework-specific result of inference from the model Returns: - raw result (dict) - model raw output in the following format: @@ -138,7 +138,7 @@ def get_raw_result(self, infer_result) -> dict: """ @abstractmethod - def is_ready(self): + def is_ready(self) -> bool: """In case of asynchronous execution checks if one can submit input data to the model for inference, or all infer requests are busy. @@ -160,29 +160,67 @@ def await_any(self): """ @abstractmethod - def get_rt_info(self, path): - """Forwards to openvino.Model.get_rt_info(path)""" + def get_rt_info(self, path: list[str]) -> Any: + """ + Returns an attribute stored in model info. + + Args: + path (list[str]): a sequence of tag names leading to the attribute. + + Returns: + Any: a value stored under corresponding tag sequence. + """ @abstractmethod def update_model_info(self, model_info: dict[str, Any]): - """Updates model with the provided model info.""" + """ + Updates model with the provided model info. Model info dict can + also contain nested dicts. + + Args: + model_info (dict[str, Any]): model info dict to write to the model. + """ @abstractmethod - def save_model(self, path: str, weights_path: str, version: str): - """Serializes model to the filesystem.""" + def save_model(self, path: str, weights_path: str | None, version: str | None): + """ + Serializes model to the filesystem. + + Args: + path (str): Path to write the resulting model. + weights_path (str | None): Optional path to save weights if they are stored separately. + version (str | None): Optional model version. + """ @abstractmethod def embed_preprocessing( self, - layout, + layout: str, resize_mode: str, - interpolation_mode, + interpolation_mode: str, target_shape: tuple[int, ...], - pad_value, + pad_value: int, dtype: type = int, - brg2rgb=False, - mean=None, - scale=None, - input_idx=0, + brg2rgb: bool = False, + mean: list[Any] | None = None, + scale: list[Any] | None = None, + input_idx: int = 0, ): - """Embeds preprocessing into the model using OpenVINO preprocessing API""" + """ + Embeds preprocessing into the model if possible with the adapter being used. + In some cases, this method would just add extra python preprocessing steps + instaed actuall of embedding it into the model representation. + + Args: + layout (str): Layout, for instance NCHW. + resize_mode (str): Resize type to use for preprocessing. + interpolation_mode (str): Resize interpolation mode. + target_shape (tuple[int, ...]): Target resize shape. + pad_value (int): Value to pad with if resize implies padding. + dtype (type, optional): Input data type for the preprocessing module. Defaults to int. + bgr2rgb (bool, optional): Defines if we need to swap R and B channels in case of image input. + Defaults to False. + mean (list[Any] | None, optional): Mean values to perform input normalization. Defaults to None. + scale (list[Any] | None, optional): Scale values to perform input normalization. Defaults to None. + input_idx (int, optional): Index of the model input to apply preprocessing to. Defaults to 0. + """ diff --git a/model_api/python/model_api/adapters/onnx_adapter.py b/model_api/python/model_api/adapters/onnx_adapter.py index fd56e2b5..90900aab 100644 --- a/model_api/python/model_api/adapters/onnx_adapter.py +++ b/model_api/python/model_api/adapters/onnx_adapter.py @@ -131,17 +131,20 @@ def get_raw_result(self, infer_result): def embed_preprocessing( self, - layout, + layout: str, resize_mode: str, - interpolation_mode, - target_shape, - pad_value, + interpolation_mode: str, + target_shape: tuple[int, ...], + pad_value: int, dtype: type = int, - brg2rgb=False, - mean=None, - scale=None, - input_idx=0, + brg2rgb: bool = False, + mean: list[Any] | None = None, + scale: list[Any] | None = None, + input_idx: int = 0, ): + """ + Adds external preprocessing steps done before ONNX model execution. + """ preproc_funcs = [np.squeeze] if resize_mode != "crop": if resize_mode == "fit_to_window_letterbox": @@ -170,13 +173,23 @@ def embed_preprocessing( ) def get_model(self): - """Return the reference to the ONNXRuntime session.""" + """Return a reference to the ONNXRuntime session.""" return self.model def reshape_model(self, new_shape): + """ "Not supported by ONNX adapter.""" raise NotImplementedError def get_rt_info(self, path): + """ + Returns an attribute stored in model info. + + Args: + path (list[str]): a sequence of tag names leading to the attribute. + + Returns: + Any: a value stored under corresponding tag sequence. + """ return get_rt_info_from_dict(self.onnx_metadata, path) def update_model_info(self, model_info: dict[str, Any]): @@ -189,7 +202,15 @@ def update_model_info(self, model_info: dict[str, Any]): else: meta.value = str(model_info[item]) - def save_model(self, path: str, weights_path: str = "", version: str = "UNSPECIFIED"): + def save_model(self, path: str, weights_path: str | None, version: str | None): + """ + Serializes model to the filesystem. + + Args: + path (str): paths to save .onnx file. + weights_path (str | None): not used by ONNX adapter. + version (str | None): not used by ONNX adapter. + """ onnx.save(self.model, path) diff --git a/model_api/python/model_api/adapters/openvino_adapter.py b/model_api/python/model_api/adapters/openvino_adapter.py index 03f973e0..69e39da2 100644 --- a/model_api/python/model_api/adapters/openvino_adapter.py +++ b/model_api/python/model_api/adapters/openvino_adapter.py @@ -333,6 +333,15 @@ def operations_by_type(self, operation_type): return layers_info def get_rt_info(self, path: list[str]) -> OVAny: + """ + Gets an attribute value from OV.model_info structure. + + Args: + path (list[str]): a suquence of tag names leading to the attribute. + + Returns: + OVAny: attribute value wrapped into OVAny object. + """ if self.is_onnx_file: return get_rt_info_from_dict(self.onnx_metadata, path) return self.model.get_rt_info(path) @@ -350,6 +359,9 @@ def embed_preprocessing( scale: list[Any] | None = None, input_idx: int = 0, ) -> None: + """ + Embeds OpenVINO PrePostProcessor module into the model. + """ ppp = PrePostProcessor(self.model) # Change the input type to the 8-bit image @@ -429,7 +441,20 @@ def update_model_info(self, model_info: dict[str, Any]): for name in model_info: self.model.set_rt_info(model_info[name], ["model_info", name]) - def save_model(self, path: str, weights_path: str = "", version: str = "UNSPECIFIED"): + def save_model(self, path: str, weights_path: str | None, version: str | None): + """ + Saves OV model as two files: .xml (architecture) and .bin (weights). + + Args: + path (str): path to save the model files (.xml and .bin). + weights_path (str, optional): Optional path to save .bin if it differs from .xml path. Defaults to None. + version (str, optional): Output IR model version (for instance, IR_V10). Defaults to None. + """ + if weights_path is None: + weights_path = "" + if version is None: + version = "UNSPECIFIED" + ov.serialize(self.get_model(), path, weights_path, version) diff --git a/model_api/python/model_api/adapters/ovms_adapter.py b/model_api/python/model_api/adapters/ovms_adapter.py index d07d479e..8530387a 100644 --- a/model_api/python/model_api/adapters/ovms_adapter.py +++ b/model_api/python/model_api/adapters/ovms_adapter.py @@ -127,7 +127,7 @@ def update_model_info(self, model_info: dict[str, Any]): msg = "OVMSAdapter does not support updating model info" raise NotImplementedError(msg) - def save_model(self, path: str, weights_path: str = "", version: str = "UNSPECIFIED"): + def save_model(self, path: str, weights_path: str | None, version: str | None): msg = "OVMSAdapter does not support saving a model" raise NotImplementedError(msg) From 1f335bfec9dfa4ae66df36c4bc180f889d37ebf0 Mon Sep 17 00:00:00 2001 From: Vladisalv Sovrasov Date: Tue, 10 Dec 2024 23:22:24 +0900 Subject: [PATCH 2/6] Update model docs --- .../model_api/adapters/inference_adapter.py | 11 ++- .../python/model_api/adapters/onnx_adapter.py | 6 +- .../model_api/adapters/openvino_adapter.py | 4 +- .../python/model_api/adapters/ovms_adapter.py | 6 +- .../python/model_api/models/image_model.py | 10 ++ model_api/python/model_api/models/model.py | 94 ++++++++++++++++--- 6 files changed, 111 insertions(+), 20 deletions(-) diff --git a/model_api/python/model_api/adapters/inference_adapter.py b/model_api/python/model_api/adapters/inference_adapter.py index 4512f601..2ba7855c 100644 --- a/model_api/python/model_api/adapters/inference_adapter.py +++ b/model_api/python/model_api/adapters/inference_adapter.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any +from typing import Any, Callable @dataclass @@ -137,6 +137,15 @@ def get_raw_result(self, infer_result: dict) -> dict: } """ + @abstractmethod + def set_callback(self, callback_fn: Callable): + """ + Sets callback that grabs results of async inference. + + Args: + callback_fn (Callable): Callback function. + """ + @abstractmethod def is_ready(self) -> bool: """In case of asynchronous execution checks if one can submit input data diff --git a/model_api/python/model_api/adapters/onnx_adapter.py b/model_api/python/model_api/adapters/onnx_adapter.py index 90900aab..db8ac738 100644 --- a/model_api/python/model_api/adapters/onnx_adapter.py +++ b/model_api/python/model_api/adapters/onnx_adapter.py @@ -7,7 +7,7 @@ import sys from functools import partial, reduce -from typing import Any +from typing import Any, Callable import numpy as np @@ -111,7 +111,7 @@ def infer_sync(self, dict_data): def infer_async(self, dict_data, callback_data): raise NotImplementedError - def set_callback(self, callback_fn): + def set_callback(self, callback_fn: Callable): self.callback_fn = callback_fn def is_ready(self): @@ -126,7 +126,7 @@ def await_all(self): def await_any(self): pass - def get_raw_result(self, infer_result): + def get_raw_result(self, infer_result: dict): pass def embed_preprocessing( diff --git a/model_api/python/model_api/adapters/openvino_adapter.py b/model_api/python/model_api/adapters/openvino_adapter.py index 69e39da2..f45abdd0 100644 --- a/model_api/python/model_api/adapters/openvino_adapter.py +++ b/model_api/python/model_api/adapters/openvino_adapter.py @@ -7,7 +7,7 @@ import logging as log from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable if TYPE_CHECKING: from os import PathLike @@ -300,7 +300,7 @@ def infer_sync(self, dict_data: dict[str, ndarray]) -> dict[str, ndarray]: def infer_async(self, dict_data, callback_data) -> None: self.async_queue.start_async(dict_data, callback_data) - def set_callback(self, callback_fn): + def set_callback(self, callback_fn: Callable): self.async_queue.set_callback(callback_fn) def is_ready(self) -> bool: diff --git a/model_api/python/model_api/adapters/ovms_adapter.py b/model_api/python/model_api/adapters/ovms_adapter.py index 8530387a..7060ed19 100644 --- a/model_api/python/model_api/adapters/ovms_adapter.py +++ b/model_api/python/model_api/adapters/ovms_adapter.py @@ -4,7 +4,7 @@ # import re -from typing import Any +from typing import Any, Callable import numpy as np @@ -79,7 +79,7 @@ def infer_async(self, dict_data, callback_data): raw_result = {output_name: raw_result} self.callback_fn(raw_result, (lambda x: x, callback_data)) - def set_callback(self, callback_fn): + def set_callback(self, callback_fn: Callable): self.callback_fn = callback_fn def is_ready(self): @@ -98,7 +98,7 @@ def await_all(self): def await_any(self): pass - def get_raw_result(self, infer_result): + def get_raw_result(self, infer_result: dict): pass def embed_preprocessing( diff --git a/model_api/python/model_api/models/image_model.py b/model_api/python/model_api/models/image_model.py index 31404d62..b68ebc2f 100644 --- a/model_api/python/model_api/models/image_model.py +++ b/model_api/python/model_api/models/image_model.py @@ -146,6 +146,16 @@ def parameters(cls) -> dict[str, Any]: return parameters def get_label_name(self, label_id: int) -> str: + """ + Returns a label name by it's index. + If index is out of range, and auto-generated name is returned. + + Args: + label_id (int): label index. + + Returns: + str: label name. + """ if self.labels is None: return f"#{label_id}" if label_id >= len(self.labels): diff --git a/model_api/python/model_api/models/model.py b/model_api/python/model_api/models/model.py index 6f001454..05df1a8b 100644 --- a/model_api/python/model_api/models/model.py +++ b/model_api/python/model_api/models/model.py @@ -8,7 +8,7 @@ import logging as log import re from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, NoReturn, Type +from typing import TYPE_CHECKING, Any, Callable, NoReturn, Type from model_api.adapters.inference_adapter import InferenceAdapter from model_api.adapters.onnx_adapter import ONNXRuntimeAdapter @@ -98,11 +98,26 @@ def __init__(self, inference_adapter: InferenceAdapter, configuration: dict = {} self.load() self.callback_fn = lambda _: None - def get_model(self): + def get_model(self) -> Any: + """ + Returns underlying adapter-specific model. + + Returns: + Any: Model object. + """ return self.inference_adapter.get_model() @classmethod def get_model_class(cls, name: str) -> Type: + """ + Retrieves a wrapper class by a given wrapper name. + + Args: + name (str): Wrapper name. + + Returns: + Type: Model class. + """ subclasses = [subclass for subclass in cls.get_subclasses() if subclass.__model__] if cls.__model__: subclasses.append(cls) @@ -188,6 +203,7 @@ def create_model( @classmethod def get_subclasses(cls) -> list[Any]: + """Retrieves all the subclasses of the model class given.""" all_subclasses = [] for subclass in cls.__subclasses__(): all_subclasses.append(subclass) @@ -195,7 +211,11 @@ def get_subclasses(cls) -> list[Any]: return all_subclasses @classmethod - def available_wrappers(cls): + def available_wrappers(cls) -> list[str]: + """ + Prepares a list of all discoverable wrapper names + (including custom ones inherited from the core wrappers). + """ available_classes = [cls] if cls.__model__ else [] available_classes.extend(cls.get_subclasses()) return [subclass.__model__ for subclass in available_classes if subclass.__model__] @@ -368,7 +388,7 @@ def __call__(self, inputs: ndarray): raw_result = self.infer_sync(dict_data) return self.postprocess(raw_result, input_meta) - def infer_batch(self, inputs): + def infer_batch(self, inputs: list) -> list[Any]: """Applies preprocessing, asynchronous inference, postprocessing routines to a collection of inputs. Args: @@ -402,11 +422,24 @@ def batch_infer_callback(result, id): return [completed_results[i] for i in range(len(inputs))] def load(self, force: bool = False) -> None: + """ + Prepares the model to be executed by the inference adapter. + + Args: + force (bool, optional): Forces the process even if the model is ready. Defaults to False. + """ if not self.model_loaded or force: self.model_loaded = True self.inference_adapter.load_model() - def reshape(self, new_shape): + def reshape(self, new_shape: dict): + """ + Reshapes the model inputs to fit the new input shape. + + Args: + new_shape (_type_): a dictionary with inputs names as keys and + list of new shape as values in the following format. + """ if self.model_loaded: self.logger.warning( f"{self.__model__}: the model already loaded to device, ", @@ -418,6 +451,10 @@ def reshape(self, new_shape): self.outputs = self.inference_adapter.get_output_layers() def infer_sync(self, dict_data: dict[str, ndarray]) -> dict[str, ndarray]: + """ + Performs the synchronous model inference. The infer is a blocking method. + See InferenceAdapter documentation for details. + """ if not self.model_loaded: self.raise_error( "The model is not loaded to the device. Please, create the wrapper " @@ -425,7 +462,14 @@ def infer_sync(self, dict_data: dict[str, ndarray]) -> dict[str, ndarray]: ) return self.inference_adapter.infer_sync(dict_data) - def infer_async_raw(self, dict_data, callback_data): + def infer_async_raw(self, dict_data: dict, callback_data: Any): + """ + Runs asynchronous inference on raw data skipping preprocess() call. + + Args: + dict_data (dict): data to be passed to the model + callback_data (Any): data to be passed to the callback alongside with inference results. + """ if not self.model_loaded: self.raise_error( "The model is not loaded to the device. Please, create the wrapper " @@ -433,7 +477,15 @@ def infer_async_raw(self, dict_data, callback_data): ) self.inference_adapter.infer_async(dict_data, callback_data) - def infer_async(self, input_data, user_data): + def infer_async(self, input_data: dict, user_data: Any): + """ + Runs asynchronous model inference. + + Args: + input_data (dict): Input dict containing model input name as keys and data object as values. + user_data (Any): data to be passed to the callback alongside with inference results. + """ + if not self.model_loaded: self.raise_error( "The model is not loaded to the device. Please, create the wrapper " @@ -452,23 +504,35 @@ def infer_async(self, input_data, user_data): ) @staticmethod - def process_callback(request, callback_data): + def _process_callback(request, callback_data: Any): + """ + A wrapper for async inference callback. + """ meta, get_result_fn, postprocess_fn, callback_fn, user_data = callback_data raw_result = get_result_fn(request) result = postprocess_fn(raw_result, meta) callback_fn(result, user_data) - def set_callback(self, callback_fn): + def set_callback(self, callback_fn: Callable): + """ + Sets callback that grabs results of async inference. + + Args: + callback_fn (Callable): _description_ + """ self.callback_fn = callback_fn - self.inference_adapter.set_callback(Model.process_callback) + self.inference_adapter.set_callback(Model._process_callback) def is_ready(self): + """Checks if model is ready for async inference.""" return self.inference_adapter.is_ready() def await_all(self): + """Waits for all async inference requests to be completed.""" self.inference_adapter.await_all() def await_any(self): + """Waits for model to be available for an async infer request.""" self.inference_adapter.await_any() def log_layers_info(self): @@ -484,7 +548,15 @@ def log_layers_info(self): f"precision: {metadata.precision}, layout: {metadata.layout}", ) - def save(self, path: str, weights_path: str = "", version: str = "UNSPECIFIED"): + def save(self, path: str, weights_path: str | None, version: str | None): + """ + Serializes model to the filesystem. Model format depends in the InferenceAdapter being used. + + Args: + path (str): Path to write the resulting model. + weights_path (str | None): Optional path to save weights if they are stored separately. + version (str | None): Optional model version. + """ model_info = { "model_type": self.__model__, } From b0e00d94e5fa11bd11cd9fb90a16046a14db5879 Mon Sep 17 00:00:00 2001 From: Vladisalv Sovrasov Date: Tue, 10 Dec 2024 23:31:45 +0900 Subject: [PATCH 3/6] Fix python 3.9 support --- model_api/python/model_api/adapters/inference_adapter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/model_api/python/model_api/adapters/inference_adapter.py b/model_api/python/model_api/adapters/inference_adapter.py index 2ba7855c..61727ea5 100644 --- a/model_api/python/model_api/adapters/inference_adapter.py +++ b/model_api/python/model_api/adapters/inference_adapter.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 # +from __future__ import annotations # TODO: remove when Python3.9 support is dropped + from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, Callable From 2ce5812451dd1399b253b217b584285b81d69dfd Mon Sep 17 00:00:00 2001 From: Vladisalv Sovrasov Date: Wed, 11 Dec 2024 00:37:45 +0900 Subject: [PATCH 4/6] Fix imports in ovms adapter --- model_api/python/model_api/adapters/ovms_adapter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/model_api/python/model_api/adapters/ovms_adapter.py b/model_api/python/model_api/adapters/ovms_adapter.py index 7060ed19..64370ff6 100644 --- a/model_api/python/model_api/adapters/ovms_adapter.py +++ b/model_api/python/model_api/adapters/ovms_adapter.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 # +from __future__ import annotations # TODO: remove when Python3.9 support is dropped + import re from typing import Any, Callable From f54afc3e8506f27f97f646bffc5e1327f5210160 Mon Sep 17 00:00:00 2001 From: Vladisalv Sovrasov Date: Wed, 11 Dec 2024 00:47:36 +0900 Subject: [PATCH 5/6] Add default args to Model.save --- model_api/python/model_api/adapters/onnx_adapter.py | 2 +- model_api/python/model_api/adapters/openvino_adapter.py | 2 +- model_api/python/model_api/adapters/ovms_adapter.py | 2 +- model_api/python/model_api/models/model.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/model_api/python/model_api/adapters/onnx_adapter.py b/model_api/python/model_api/adapters/onnx_adapter.py index db8ac738..619caae1 100644 --- a/model_api/python/model_api/adapters/onnx_adapter.py +++ b/model_api/python/model_api/adapters/onnx_adapter.py @@ -202,7 +202,7 @@ def update_model_info(self, model_info: dict[str, Any]): else: meta.value = str(model_info[item]) - def save_model(self, path: str, weights_path: str | None, version: str | None): + def save_model(self, path: str, weights_path: str | None = None, version: str | None = None): """ Serializes model to the filesystem. diff --git a/model_api/python/model_api/adapters/openvino_adapter.py b/model_api/python/model_api/adapters/openvino_adapter.py index f45abdd0..81ad654b 100644 --- a/model_api/python/model_api/adapters/openvino_adapter.py +++ b/model_api/python/model_api/adapters/openvino_adapter.py @@ -441,7 +441,7 @@ def update_model_info(self, model_info: dict[str, Any]): for name in model_info: self.model.set_rt_info(model_info[name], ["model_info", name]) - def save_model(self, path: str, weights_path: str | None, version: str | None): + def save_model(self, path: str, weights_path: str | None = None, version: str | None = None): """ Saves OV model as two files: .xml (architecture) and .bin (weights). diff --git a/model_api/python/model_api/adapters/ovms_adapter.py b/model_api/python/model_api/adapters/ovms_adapter.py index 64370ff6..6dd5eef7 100644 --- a/model_api/python/model_api/adapters/ovms_adapter.py +++ b/model_api/python/model_api/adapters/ovms_adapter.py @@ -129,7 +129,7 @@ def update_model_info(self, model_info: dict[str, Any]): msg = "OVMSAdapter does not support updating model info" raise NotImplementedError(msg) - def save_model(self, path: str, weights_path: str | None, version: str | None): + def save_model(self, path: str, weights_path: str | None = None, version: str | None = None): msg = "OVMSAdapter does not support saving a model" raise NotImplementedError(msg) diff --git a/model_api/python/model_api/models/model.py b/model_api/python/model_api/models/model.py index 05df1a8b..707eee3e 100644 --- a/model_api/python/model_api/models/model.py +++ b/model_api/python/model_api/models/model.py @@ -548,7 +548,7 @@ def log_layers_info(self): f"precision: {metadata.precision}, layout: {metadata.layout}", ) - def save(self, path: str, weights_path: str | None, version: str | None): + def save(self, path: str, weights_path: str | None = None, version: str | None = None): """ Serializes model to the filesystem. Model format depends in the InferenceAdapter being used. From e8bb296fb48289c130b929f2cb3f673418deb3dc Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Thu, 12 Dec 2024 11:39:36 +0100 Subject: [PATCH 6/6] Apply suggestions Co-authored-by: Ashwin Vaidya --- model_api/python/model_api/adapters/inference_adapter.py | 2 +- model_api/python/model_api/models/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/model_api/python/model_api/adapters/inference_adapter.py b/model_api/python/model_api/adapters/inference_adapter.py index 61727ea5..c18b452b 100644 --- a/model_api/python/model_api/adapters/inference_adapter.py +++ b/model_api/python/model_api/adapters/inference_adapter.py @@ -128,7 +128,7 @@ def get_raw_result(self, infer_result: dict) -> dict: """Gets raw results from the internal inference framework representation as a dict. Args: - - infer_resul (dict): framework-specific result of inference from the model + - infer_result (dict): framework-specific result of inference from the model Returns: - raw result (dict) - model raw output in the following format: diff --git a/model_api/python/model_api/models/model.py b/model_api/python/model_api/models/model.py index 707eee3e..b301b98c 100644 --- a/model_api/python/model_api/models/model.py +++ b/model_api/python/model_api/models/model.py @@ -437,7 +437,7 @@ def reshape(self, new_shape: dict): Reshapes the model inputs to fit the new input shape. Args: - new_shape (_type_): a dictionary with inputs names as keys and + new_shape (dict): a dictionary with inputs names as keys and list of new shape as values in the following format. """ if self.model_loaded: