Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion model_api/python/model_api/models/result_types/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@

from __future__ import annotations

import cv2
import numpy as np

from model_api.visualizer.layout import Flatten, Layout
from model_api.visualizer.primitives import BoundingBoxes, Label, Overlay, Polygon

class AnomalyResult:
from .base import Result


class AnomalyResult(Result):
"""Results for anomaly models."""

def __init__(
Expand All @@ -19,6 +25,7 @@ def __init__(
pred_mask: np.ndarray | None = None,
pred_score: float | None = None,
) -> None:
super().__init__()
self.anomaly_map = anomaly_map
self.pred_boxes = pred_boxes
self.pred_label = pred_label
Expand All @@ -40,3 +47,22 @@ def __str__(self) -> str:
f"pred_label:{self.pred_label};"
f"pred_mask min:{pred_mask_min} max:{pred_mask_max};"
)

def _register_primitives(self) -> None:
"""Converts the result to primitives."""
anomaly_map = cv2.applyColorMap(self.anomaly_map, cv2.COLORMAP_JET)
self._add_primitive(Overlay(anomaly_map))
for box in self.pred_boxes:
self._add_primitive(BoundingBoxes(*box))
if self.pred_label is not None:
self._add_primitive(Label(self.pred_label, bg_color="red" if self.pred_label == "Anomaly" else "green"))
self._add_primitive(Label(f"Score: {self.pred_score}"))
self._add_primitive(Polygon(mask=self.pred_mask))

@property
def default_layout(self) -> Layout:
return Flatten(
Overlay,
Polygon,
Label,
)
12 changes: 12 additions & 0 deletions model_api/python/model_api/models/result_types/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Base result type"""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from abc import ABC

from model_api.visualizer.visualize_mixin import VisualizeMixin


class Result(VisualizeMixin, ABC):
"""Base result type."""
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@

from typing import TYPE_CHECKING

from model_api.visualizer.primitives import Label

from .base import Result
from .utils import array_shape_to_str

if TYPE_CHECKING:
import numpy as np


class ClassificationResult:
class ClassificationResult(Result):
"""Results for classification models."""

def __init__(
Expand All @@ -35,3 +38,8 @@ def __str__(self) -> str:
f"{labels}, {array_shape_to_str(self.saliency_map)}, {array_shape_to_str(self.feature_vector)}, "
f"{array_shape_to_str(self.raw_scores)}"
)

def _register_primitives(self) -> None:
# TODO add saliency map
for idx, label, confidence in self.top_labels:
self._add_primitive(Label(f"Rank: {idx}, {label}: {confidence:.3f}"))
8 changes: 8 additions & 0 deletions model_api/python/model_api/visualizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Visualizer."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .visualizer import Visualizer

__all__ = ["Visualizer"]
85 changes: 85 additions & 0 deletions model_api/python/model_api/visualizer/layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Visualization Layout"""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from abc import ABC
from typing import TYPE_CHECKING, Type

from PIL import Image

if TYPE_CHECKING:
from model_api.visualizer.primitives import Primitive

from .visualize_mixin import VisualizeMixin


class Layout(ABC):
"""Base class for layouts."""

def _compute_on_primitive(self, primitive: Primitive, image: Image, result: VisualizeMixin) -> Image | None:
if result.has_primitive(primitive):
primitives = result.get_primitive(primitive)
for primitive in primitives:
image = primitive.compute(image)
return image
return None


class HStack(Layout):
"""Horizontal stack layout."""

def __init__(self, *args: Layout | Type[Primitive]) -> None:
self.children = args

def __call__(self, image: Image, result: VisualizeMixin) -> Image:
images: list[Image] = []
for child in self.children:
if isinstance(child, Layout):
images.append(child(image, result))
else:
_image = image.copy()
_image = self._compute_on_primitive(child, _image, result)
if _image is not None:
images.append(_image)
return self._stitch(*images)

def _stitch(self, *images: Image) -> Image:
"""Stitch images together.

Args:
images (Image): Images to stitch.

Returns:
Image: Stitched image.
"""
new_image = Image.new(
"RGB",
(
sum(image.width for image in images),
max(image.height for image in images),
),
)
x_offset = 0
for image in images:
new_image.paste(image, (x_offset, 0))
x_offset += image.width
return new_image


class VStack(Layout):
"""Vertical stack layout."""


class Flatten(Layout):
"""Put all primitives on top of each other"""

def __init__(self, *args: Type[Primitive]) -> None:
self.children = args

def __call__(self, image: Image, result: VisualizeMixin) -> Image:
_image: Image = image.copy()
for child in self.children:
_image = self._compute_on_primitive(child, _image, result)
return _image
134 changes: 134 additions & 0 deletions model_api/python/model_api/visualizer/primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Base class for primitives."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from abc import ABC, abstractmethod
from io import BytesIO

import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont


class Primitive(ABC):
"""Primitive class."""

@abstractmethod
def compute(self, **kwargs) -> Image:
pass


class Label(Primitive):
"""Label primitive."""

def __init__(
self,
label: str,
fg_color: str | tuple[int, int, int] = "black",
bg_color: str | tuple[int, int, int] = "yellow",
font_path: str | None | BytesIO = None,
size: int = 16,
) -> None:
self.label = label
self.fg_color = fg_color
self.bg_color = bg_color
self.font = ImageFont.load_default(size=size) if font_path is None else ImageFont.truetype(font_path, size)

def compute(self, image: Image, overlay_on_image: bool = True, buffer_y: int = 5) -> Image:
"""Generate label image.

If overlay_on_image is True, the label will be drawn on top of the image.
Else only the label will be drawn. This is useful for collecting labels so that they can be drawn on the same
image.
"""
dummy_image = Image.new("RGB", (1, 1))
draw = ImageDraw.Draw(dummy_image)
textbox = draw.textbbox((0, 0), self.label, font=self.font)
label_image = Image.new("RGB", (textbox[2] - textbox[0], textbox[3] + buffer_y - textbox[1]), self.bg_color)
draw = ImageDraw.Draw(label_image)
draw.text((0, 0), self.label, font=self.font, fill=self.fg_color)
if overlay_on_image:
image.paste(label_image, (0, 0))
return image
return label_image

@classmethod
def overlay_labels(cls, image: Image, label_images: list[Image], buffer: int = 5) -> Image:
"""Overlay multiple label images on top of the image.

Paste the labels in a row but wrap the labels if they exceed the image width.
"""
offset_x = 0
offset_y = 0
for label_image in label_images:
image.paste(label_image, (offset_x, offset_y))
offset_x += label_image.width + buffer
if offset_x + label_image.width > image.width:
offset_x = 0
offset_y += label_image.height
return image


class Polygon(Primitive):
"""Polygon primitive."""

def __init__(
self,
points: list[tuple[int, int]] | None = None,
mask: np.ndarray | None = None,
color: str | tuple[int, int, int] = "blue",
) -> None:
self.points = self._get_points(points, mask)
self.color = color

def _get_points(self, points: list[tuple[int, int]] | None, mask: np.ndarray | None) -> list[tuple[int, int]]:
if points is not None:
return points
return self._get_points_from_mask(mask)

def _get_points_from_mask(self, mask: np.ndarray) -> list[tuple[int, int]]:
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
_points = contours[0].squeeze().tolist()
return [tuple(point) for point in _points]

def compute(self, image: Image) -> Image:
draw = ImageDraw.Draw(image)
draw.polygon(self.points, fill=self.color)
return image


class Overlay(Primitive):
"""Overlay an image.

Useful for XAI and Anomaly Maps.
"""

def __init__(self, image: Image | np.ndarray, opacity: float = 0.4) -> None:
self.image = self._to_image(image)
self.opacity = opacity

def _to_image(self, image: Image | np.ndarray) -> Image:
if isinstance(image, Image.Image):
return image
return Image.fromarray(image)

def compute(self, image: Image) -> Image:
_image = self.image.resize(image.size)
return Image.blend(image, _image, self.opacity)


class BoundingBoxes(Primitive):
def __init__(self, x1: int, y1: int, x2: int, y2: int, color: str | tuple[int, int, int] = "blue") -> None:
self.x1 = x1
self.y1 = y1
self.x2 = x2
self.y2 = y2
self.color = color
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In multiclass detection task bbox also has a label, and it looks like bbox primitive needs to have it as well


def compute(self, image: Image) -> Image:
draw = ImageDraw.Draw(image)
draw.rectangle([self.x1, self.y1, self.x2, self.y2], fill=None, outline=self.color, width=2)
return image
73 changes: 73 additions & 0 deletions model_api/python/model_api/visualizer/visualize_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Mixin for visualization."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from typing import Type

from .layout import Layout
from .primitives import BoundingBoxes, Label, Overlay, Polygon, Primitive


class VisualizeMixin(ABC):
"""Mixin for visualization."""

def __init__(self) -> None:
self._labels = []
self._polygons = []
self._overlays = []
self._bounding_boxes = []
self._registered_primitives = False

@abstractmethod
def _register_primitives(self) -> None:
"""Convert result entities to primitives."""

@property
@abstractmethod
def default_layout(self) -> Layout:
"""Default layout."""

def _add_primitive(self, primitive: Primitive) -> None:
"""Add primitive."""
if isinstance(primitive, Label):
self._labels.append(primitive)
elif isinstance(primitive, Polygon):
self._polygons.append(primitive)
elif isinstance(primitive, Overlay):
self._overlays.append(primitive)
elif isinstance(primitive, BoundingBoxes):
self._bounding_boxes.append(primitive)

def has_primitive(self, primitive: Type[Primitive]) -> bool:
"""Check if the primitive type is registered."""
self._register_primitives_if_needed()
if primitive == Label:
return bool(self._labels)
if primitive == Polygon:
return bool(self._polygons)
if primitive == Overlay:
return bool(self._overlays)
if primitive == BoundingBoxes:
return bool(self._bounding_boxes)
return False

def get_primitive(self, primitive: Type[Primitive]) -> Primitive:
"""Get primitive."""
self._register_primitives_if_needed()
if primitive == Label:
return self._labels
if primitive == Polygon:
return self._polygons
if primitive == Overlay:
return self._overlays
if primitive == BoundingBoxes:
return self._bounding_boxes
msg = f"Primitive {primitive} not found"
raise ValueError(msg)

def _register_primitives_if_needed(self):
if not self._registered_primitives:
self._register_primitives()
self._registered_primitives = True
Loading