Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 73 additions & 23 deletions src/model_api/models/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,53 @@ def __init__(
preload: bool = False,
) -> None:
super().__init__(inference_adapter, configuration, preload)
self._check_io_number(1, 1)
self._check_io_number(1, (1, 4))
self.normalization_scale: float
self.image_threshold: float
self.pixel_threshold: float
self.task: str
self.labels: list[str]

def preprocess(self, inputs: np.ndarray) -> list[dict]:
"""Data preprocess method for Anomalib models.

Anomalib models typically expect inputs in [0,1] range as float32.
"""
original_shape = inputs.shape

if self._is_dynamic:
h, w, c = inputs.shape
resized_shape = (w, h, c)

# For anomalib models, convert to float32 and normalize to [0,1] if needed
if inputs.dtype == np.uint8:
processed_image = inputs.astype(np.float32) / 255.0
else:
processed_image = inputs.astype(np.float32)

# Apply layout change but skip InputTransform (which might apply wrong normalization)
processed_image = self._change_layout(processed_image)
else:
resized_shape = (self.w, self.h, self.c)
# For fixed models, use standard preprocessing
if self.embedded_processing:
processed_image = inputs[None]
else:
# Convert to float32 and normalize for anomalib
if inputs.dtype == np.uint8:
processed_image = inputs.astype(np.float32) / 255.0
else:
processed_image = inputs.astype(np.float32)
processed_image = self._change_layout(processed_image)

return [
{self.image_blob_name: processed_image},
{
"original_shape": original_shape,
"resized_shape": resized_shape,
},
]

def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]) -> AnomalyResult:
"""Post-processes the outputs and returns the results.

Expand All @@ -87,39 +127,49 @@ def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]) -> A
pred_label: str | None = None
pred_mask: np.ndarray | None = None
pred_boxes: np.ndarray | None = None
predictions = outputs[next(iter(self.outputs))]

if len(predictions.shape) == 1:
pred_score = predictions
else:
anomaly_map = predictions.squeeze()
pred_score = anomaly_map.reshape(-1).max()
anomalib_keys = ["pred_score", "pred_label", "pred_mask", "anomaly_map"]
if not all(key in outputs for key in anomalib_keys):
predictions = outputs[next(iter(self.outputs))]

pred_label = self.labels[1] if pred_score > self.image_threshold else self.labels[0]
if len(predictions.shape) == 1:
npred_score = predictions
else:
anomaly_map = predictions.squeeze()
npred_score = anomaly_map.reshape(-1).max()

assert anomaly_map is not None
pred_mask = (anomaly_map >= self.pixel_threshold).astype(np.uint8)
anomaly_map = self._normalize(anomaly_map, self.pixel_threshold)
anomaly_map *= 255
anomaly_map = np.round(anomaly_map).astype(np.uint8)
pred_mask = cv2.resize(
pred_mask,
(meta["original_shape"][1], meta["original_shape"][0]),
)
pred_label = self.labels[1] if npred_score > self.image_threshold else self.labels[0]

# normalize
pred_score = self._normalize(pred_score, self.image_threshold)
assert anomaly_map is not None
pred_mask = (anomaly_map >= self.pixel_threshold).astype(np.uint8)
anomaly_map = self._normalize(anomaly_map, self.pixel_threshold)

if pred_label == self.labels[0]: # normal
pred_score = 1 - pred_score # Score of normal is 1 - score of anomaly
# normalize
npred_score = self._normalize(npred_score, self.image_threshold)

if pred_label == self.labels[0]: # normal
npred_score = 1 - npred_score # Score of normal is 1 - score of anomaly
pred_score = npred_score.item()
else:
pred_score = outputs["pred_score"].item()
pred_label = str(outputs["pred_label"].item())
anomaly_map = outputs["anomaly_map"].squeeze()
pred_mask = outputs["pred_mask"].squeeze().astype(np.uint8)

anomaly_map *= 255
anomaly_map = np.round(anomaly_map).astype(np.uint8)

# resize outputs
if anomaly_map is not None:
anomaly_map = cv2.resize(
anomaly_map,
(meta["original_shape"][1], meta["original_shape"][0]),
)

pred_mask = cv2.resize(
pred_mask,
(meta["original_shape"][1], meta["original_shape"][0]),
)

if self.task == "detection":
pred_boxes = self._get_boxes(pred_mask)

Expand All @@ -128,7 +178,7 @@ def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]) -> A
pred_boxes=pred_boxes,
pred_label=pred_label,
pred_mask=pred_mask,
pred_score=pred_score.item(),
pred_score=pred_score,
)

@classmethod
Expand Down
32 changes: 27 additions & 5 deletions src/model_api/models/image_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def __init__(self, inference_adapter: InferenceAdapter, configuration: dict = {}
self.n, self.c, self.h, self.w = self.inputs[self.image_blob_name].shape
else:
self.n, self.h, self.w, self.c = self.inputs[self.image_blob_name].shape

self._is_dynamic = False
if self.h == -1 or self.w == -1:
self._is_dynamic = True

self.resize = RESIZE_TYPES[self.resize_type]
self.input_transform = InputTransform(
self.reverse_input_channels,
Expand All @@ -83,7 +88,7 @@ def __init__(self, inference_adapter: InferenceAdapter, configuration: dict = {}
layout = self.inputs[self.image_blob_name].layout
if self.embedded_processing:
self.h, self.w = self.orig_height, self.orig_width
else:
elif not self._is_dynamic:
inference_adapter.embed_preprocessing(
layout=layout,
resize_mode=self.resize_type,
Expand Down Expand Up @@ -213,11 +218,24 @@ def preprocess(self, inputs: np.ndarray) -> list[dict]:
}
- the input metadata, which might be used in `postprocess` method
"""
if self._is_dynamic:
h, w, c = inputs.shape
resized_shape = (w, h, c)
processed_image = self.input_transform(inputs)
processed_image = self._change_layout(processed_image)
else:
resized_shape = (self.w, self.h, self.c)
if self.embedded_processing:
processed_image = inputs[None]
else:
processed_image = self.input_transform(inputs)
processed_image = self._change_layout(processed_image)

return [
{self.image_blob_name: inputs[None]},
{self.image_blob_name: processed_image},
{
"original_shape": inputs.shape,
"resized_shape": (self.w, self.h, self.c),
"resized_shape": resized_shape,
},
]

Expand All @@ -230,9 +248,13 @@ def _change_layout(self, image: np.ndarray) -> np.ndarray:
Returns:
- the image with layout aligned with the model layout
"""
h, w, c = image.shape if self._is_dynamic else (self.h, self.w, self.c)

# For fixed models, use the predefined dimensions
if self.nchw_layout:
image = image.transpose((2, 0, 1)) # HWC->CHW
image = image.reshape((1, self.c, self.h, self.w))
image = image.reshape((1, c, h, w))
else:
image = image.reshape((1, self.h, self.w, self.c))
image = image.reshape((1, h, w, c))

return image
27 changes: 22 additions & 5 deletions src/model_api/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ def get_model_class(cls, name: str) -> Type:
if name.lower() == subclass.__model__.lower():
return subclass
return cls.raise_error(
f"There is no model with name {name} in list: "
f"{', '.join([subclass.__model__ for subclass in subclasses])}",
f"There is no model with name {name} in list: {', '.join([subclass.__model__ for subclass in subclasses])}",
)

@classmethod
Expand Down Expand Up @@ -195,12 +194,30 @@ def create_model(
cache_dir=cache_dir,
)
if model_type is None:
model_type = inference_adapter.get_rt_info(
["model_info", "model_type"],
).astype(str)
try:
model_type = inference_adapter.get_rt_info(
["model_info", "model_type"],
).astype(str)
except RuntimeError:
model_type = cls.detect_model_type(inference_adapter)
Model = cls.get_model_class(model_type)
return Model(inference_adapter, configuration, preload)

@classmethod
def detect_model_type(cls, inference_adapter) -> str:
"""Detects model type on available information"""
input_layers = inference_adapter.get_input_layers()
output_layers = inference_adapter.get_output_layers()

# Check for Anomalib model pattern: 1 input and specific output layer names
if len(input_layers) == 1 and len(output_layers) == 4:
expected_outputs = {"pred_score", "pred_label", "anomaly_map", "pred_mask"}
actual_outputs = set(output_layers.keys())
if expected_outputs == actual_outputs:
return "AnomalyDetection"

return "uknown"

@classmethod
def get_subclasses(cls) -> list[Any]:
"""Retrieves all the subclasses of the model class given."""
Expand Down