diff --git a/src/model_api/models/anomaly.py b/src/model_api/models/anomaly.py index 0421c96b..4e108fb1 100644 --- a/src/model_api/models/anomaly.py +++ b/src/model_api/models/anomaly.py @@ -66,13 +66,53 @@ def __init__( preload: bool = False, ) -> None: super().__init__(inference_adapter, configuration, preload) - self._check_io_number(1, 1) + self._check_io_number(1, (1, 4)) self.normalization_scale: float self.image_threshold: float self.pixel_threshold: float self.task: str self.labels: list[str] + def preprocess(self, inputs: np.ndarray) -> list[dict]: + """Data preprocess method for Anomalib models. + + Anomalib models typically expect inputs in [0,1] range as float32. + """ + original_shape = inputs.shape + + if self._is_dynamic: + h, w, c = inputs.shape + resized_shape = (w, h, c) + + # For anomalib models, convert to float32 and normalize to [0,1] if needed + if inputs.dtype == np.uint8: + processed_image = inputs.astype(np.float32) / 255.0 + else: + processed_image = inputs.astype(np.float32) + + # Apply layout change but skip InputTransform (which might apply wrong normalization) + processed_image = self._change_layout(processed_image) + else: + resized_shape = (self.w, self.h, self.c) + # For fixed models, use standard preprocessing + if self.embedded_processing: + processed_image = inputs[None] + else: + # Convert to float32 and normalize for anomalib + if inputs.dtype == np.uint8: + processed_image = inputs.astype(np.float32) / 255.0 + else: + processed_image = inputs.astype(np.float32) + processed_image = self._change_layout(processed_image) + + return [ + {self.image_blob_name: processed_image}, + { + "original_shape": original_shape, + "resized_shape": resized_shape, + }, + ] + def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]) -> AnomalyResult: """Post-processes the outputs and returns the results. @@ -87,39 +127,49 @@ def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]) -> A pred_label: str | None = None pred_mask: np.ndarray | None = None pred_boxes: np.ndarray | None = None - predictions = outputs[next(iter(self.outputs))] - if len(predictions.shape) == 1: - pred_score = predictions - else: - anomaly_map = predictions.squeeze() - pred_score = anomaly_map.reshape(-1).max() + anomalib_keys = ["pred_score", "pred_label", "pred_mask", "anomaly_map"] + if not all(key in outputs for key in anomalib_keys): + predictions = outputs[next(iter(self.outputs))] - pred_label = self.labels[1] if pred_score > self.image_threshold else self.labels[0] + if len(predictions.shape) == 1: + npred_score = predictions + else: + anomaly_map = predictions.squeeze() + npred_score = anomaly_map.reshape(-1).max() - assert anomaly_map is not None - pred_mask = (anomaly_map >= self.pixel_threshold).astype(np.uint8) - anomaly_map = self._normalize(anomaly_map, self.pixel_threshold) - anomaly_map *= 255 - anomaly_map = np.round(anomaly_map).astype(np.uint8) - pred_mask = cv2.resize( - pred_mask, - (meta["original_shape"][1], meta["original_shape"][0]), - ) + pred_label = self.labels[1] if npred_score > self.image_threshold else self.labels[0] - # normalize - pred_score = self._normalize(pred_score, self.image_threshold) + assert anomaly_map is not None + pred_mask = (anomaly_map >= self.pixel_threshold).astype(np.uint8) + anomaly_map = self._normalize(anomaly_map, self.pixel_threshold) - if pred_label == self.labels[0]: # normal - pred_score = 1 - pred_score # Score of normal is 1 - score of anomaly + # normalize + npred_score = self._normalize(npred_score, self.image_threshold) + + if pred_label == self.labels[0]: # normal + npred_score = 1 - npred_score # Score of normal is 1 - score of anomaly + pred_score = npred_score.item() + else: + pred_score = outputs["pred_score"].item() + pred_label = str(outputs["pred_label"].item()) + anomaly_map = outputs["anomaly_map"].squeeze() + pred_mask = outputs["pred_mask"].squeeze().astype(np.uint8) + + anomaly_map *= 255 + anomaly_map = np.round(anomaly_map).astype(np.uint8) - # resize outputs if anomaly_map is not None: anomaly_map = cv2.resize( anomaly_map, (meta["original_shape"][1], meta["original_shape"][0]), ) + pred_mask = cv2.resize( + pred_mask, + (meta["original_shape"][1], meta["original_shape"][0]), + ) + if self.task == "detection": pred_boxes = self._get_boxes(pred_mask) @@ -128,7 +178,7 @@ def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]) -> A pred_boxes=pred_boxes, pred_label=pred_label, pred_mask=pred_mask, - pred_score=pred_score.item(), + pred_score=pred_score, ) @classmethod diff --git a/src/model_api/models/image_model.py b/src/model_api/models/image_model.py index b68ebc2f..fd698e3f 100644 --- a/src/model_api/models/image_model.py +++ b/src/model_api/models/image_model.py @@ -73,6 +73,11 @@ def __init__(self, inference_adapter: InferenceAdapter, configuration: dict = {} self.n, self.c, self.h, self.w = self.inputs[self.image_blob_name].shape else: self.n, self.h, self.w, self.c = self.inputs[self.image_blob_name].shape + + self._is_dynamic = False + if self.h == -1 or self.w == -1: + self._is_dynamic = True + self.resize = RESIZE_TYPES[self.resize_type] self.input_transform = InputTransform( self.reverse_input_channels, @@ -83,7 +88,7 @@ def __init__(self, inference_adapter: InferenceAdapter, configuration: dict = {} layout = self.inputs[self.image_blob_name].layout if self.embedded_processing: self.h, self.w = self.orig_height, self.orig_width - else: + elif not self._is_dynamic: inference_adapter.embed_preprocessing( layout=layout, resize_mode=self.resize_type, @@ -213,11 +218,24 @@ def preprocess(self, inputs: np.ndarray) -> list[dict]: } - the input metadata, which might be used in `postprocess` method """ + if self._is_dynamic: + h, w, c = inputs.shape + resized_shape = (w, h, c) + processed_image = self.input_transform(inputs) + processed_image = self._change_layout(processed_image) + else: + resized_shape = (self.w, self.h, self.c) + if self.embedded_processing: + processed_image = inputs[None] + else: + processed_image = self.input_transform(inputs) + processed_image = self._change_layout(processed_image) + return [ - {self.image_blob_name: inputs[None]}, + {self.image_blob_name: processed_image}, { "original_shape": inputs.shape, - "resized_shape": (self.w, self.h, self.c), + "resized_shape": resized_shape, }, ] @@ -230,9 +248,13 @@ def _change_layout(self, image: np.ndarray) -> np.ndarray: Returns: - the image with layout aligned with the model layout """ + h, w, c = image.shape if self._is_dynamic else (self.h, self.w, self.c) + + # For fixed models, use the predefined dimensions if self.nchw_layout: image = image.transpose((2, 0, 1)) # HWC->CHW - image = image.reshape((1, self.c, self.h, self.w)) + image = image.reshape((1, c, h, w)) else: - image = image.reshape((1, self.h, self.w, self.c)) + image = image.reshape((1, h, w, c)) + return image diff --git a/src/model_api/models/model.py b/src/model_api/models/model.py index 88bb13a9..95abeb0f 100644 --- a/src/model_api/models/model.py +++ b/src/model_api/models/model.py @@ -125,8 +125,7 @@ def get_model_class(cls, name: str) -> Type: if name.lower() == subclass.__model__.lower(): return subclass return cls.raise_error( - f"There is no model with name {name} in list: " - f"{', '.join([subclass.__model__ for subclass in subclasses])}", + f"There is no model with name {name} in list: {', '.join([subclass.__model__ for subclass in subclasses])}", ) @classmethod @@ -195,12 +194,30 @@ def create_model( cache_dir=cache_dir, ) if model_type is None: - model_type = inference_adapter.get_rt_info( - ["model_info", "model_type"], - ).astype(str) + try: + model_type = inference_adapter.get_rt_info( + ["model_info", "model_type"], + ).astype(str) + except RuntimeError: + model_type = cls.detect_model_type(inference_adapter) Model = cls.get_model_class(model_type) return Model(inference_adapter, configuration, preload) + @classmethod + def detect_model_type(cls, inference_adapter) -> str: + """Detects model type on available information""" + input_layers = inference_adapter.get_input_layers() + output_layers = inference_adapter.get_output_layers() + + # Check for Anomalib model pattern: 1 input and specific output layer names + if len(input_layers) == 1 and len(output_layers) == 4: + expected_outputs = {"pred_score", "pred_label", "anomaly_map", "pred_mask"} + actual_outputs = set(output_layers.keys()) + if expected_outputs == actual_outputs: + return "AnomalyDetection" + + return "uknown" + @classmethod def get_subclasses(cls) -> list[Any]: """Retrieves all the subclasses of the model class given."""