Skip to content

Commit 2a5f66f

Browse files
authored
Refactor result handling of DetectionResult and InstanceSegResult (#239)
* Refactor detection handling to use DetectionResult and update related documentation * Refactor DetectionResult class and YOLO * Fix detection bounding box types * Refactor segmentation handling to use RotatedSegmentationResult and update related tests * Remove unused 'strict' parameter from detection and segmentation result classes and update label handling in tilers * Remove unused 'strict' parameter from zip function in instance segmentation models and tilers * update docs * move parsers back to ssd
1 parent 32e8a5a commit 2a5f66f

File tree

17 files changed

+671
-534
lines changed

17 files changed

+671
-534
lines changed

docs/source/python/models/detection_model.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,12 @@ A single input image of shape (H, W, 3) where H and W are the height and width o
1212

1313
### Outputs
1414

15-
Detection model outputs a list of detection objects (i.e `list[Detection]`) wrapped in `DetectionResult`, each object containing the following attributes:
15+
Detection model outputs a `DetectionResult` objects containing the following attributes:
1616

17-
- `score` (float) - Confidence score of the object.
18-
- `id` (int) - Class label of the object.
19-
- `str_label` (str) - String label of the object.
20-
- `xmin` (int) - X-coordinate of the top-left corner of the bounding box.
21-
- `ymin` (int) - Y-coordinate of the top-left corner of the bounding box.
22-
- `xmax` (int) - X-coordinate of the bottom-right corner of the bounding box.
23-
- `ymax` (int) - Y-coordinate of the bottom-right corner of the bounding box.
17+
- `boxes` (np.ndarray) - Bounding boxes of the detected objects. Each in format of x1, y1, x2 y2.
18+
- `scores` (np.ndarray) - Confidence scores of the detected objects.
19+
- `labels` (np.ndarray) - Class labels of the detected objects.
20+
- `label_names` (list[str]) - List of class names of the detected objects.
2421

2522
## Example
2623

@@ -34,11 +31,14 @@ model = SSD.create_model("model.xml")
3431
# Forward pass
3532
predictions = model(image)
3633

37-
# Iterate over the segmented objects
38-
for pred_obj in predictions.objects:
39-
pred_score = pred_obj.score
40-
label_id = pred_obj.id
41-
bbox = [pred_obj.xmin, pred_obj.ymin, pred_obj.xmax, pred_obj.ymax]
34+
# Iterate over detection result
35+
for box, score, label, label_name in zip(
36+
predictions.boxes,
37+
predictions.scores,
38+
predictions.labels,
39+
predictions.label_names,
40+
):
41+
print(f"Box: {box}, Score: {score}, Label: {label}, Label Name: {label_name}")
4242
```
4343

4444
```{eval-rst}

docs/source/python/models/instance_segmentation.md

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,13 @@ A single input image of shape (H, W, 3) where H and W are the height and width o
1212

1313
### Outputs
1414

15-
Instance segmentation model outputs a list of segmented objects (i.e `list[SegmentedObject]`)wrapped in `InstanceSegmentationResult.segmentedObjects`, each containing the following attributes:
15+
Instance segmentation model outputs a `InstanceSegmentationResult` object containing the following attributes:
1616

17-
- `mask` (numpy.ndarray) - A binary mask of the object.
18-
- `score` (float) - Confidence score of the object.
19-
- `id` (int) - Class label of the object.
20-
- `str_label` (str) - String label of the object.
21-
- `xmin` (int) - X-coordinate of the top-left corner of the bounding box.
22-
- `ymin` (int) - Y-coordinate of the top-left corner of the bounding box.
23-
- `xmax` (int) - X-coordinate of the bottom-right corner of the bounding box.
24-
- `ymax` (int) - Y-coordinate of the bottom-right corner of the bounding box.
17+
- `boxes` (np.ndarray) - Bounding boxes of the detected objects. Each in format of x1, y1, x2 y2.
18+
- `scores` (np.ndarray) - Confidence scores of the detected objects.
19+
- `masks` (np.ndarray) - Segmentation masks of the detected objects.
20+
- `labels` (np.ndarray) - Class labels of the detected objects.
21+
- `label_names` (list[str]) - List of class names of the detected objects.
2522

2623
## Example
2724

@@ -36,11 +33,17 @@ model = MaskRCNNModel.create_model("model.xml")
3633
predictions = model(image)
3734

3835
# Iterate over the segmented objects
39-
for pred_obj in predictions.segmentedObjects:
40-
pred_mask = pred_obj.mask
41-
pred_score = pred_obj.score
42-
label_id = pred_obj.id
43-
bbox = [pred_obj.xmin, pred_obj.ymin, pred_obj.xmax, pred_obj.ymax]
36+
for box, score, mask, label, label_name in zip(
37+
predictions.boxes,
38+
predictions.scores,
39+
predictions.masks,
40+
predictions.labels,
41+
predictions.label_names,
42+
):
43+
print(f"Box: {box}, Score: {score}, Label: {label}, Label Name: {label_name}")
44+
cv2.imshow("Mask", mask)
45+
cv2.waitKey(0)
46+
cv2.destroyAllWindows()
4447
```
4548

4649
```{eval-rst}

model_api/python/model_api/models/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,11 @@
1616
ClassificationResult,
1717
Contour,
1818
DetectedKeypoints,
19-
Detection,
2019
DetectionResult,
2120
ImageResultWithSoftPrediction,
2221
InstanceSegmentationResult,
2322
PredictedMask,
24-
SegmentedObject,
25-
SegmentedObjectWithRects,
23+
RotatedSegmentationResult,
2624
VisualPromptingResult,
2725
ZSLVisualPromptingResult,
2826
)
@@ -90,14 +88,12 @@
9088
"SAMImageEncoder",
9189
"ClassificationResult",
9290
"Prompt",
93-
"Detection",
9491
"DetectionResult",
9592
"DetectedKeypoints",
9693
"classification_models",
9794
"detection_models",
9895
"segmentation_models",
99-
"SegmentedObject",
100-
"SegmentedObjectWithRects",
96+
"RotatedSegmentationResult",
10197
"add_rotated_rects",
10298
"get_contours",
10399
]

model_api/python/model_api/models/detection_model.py

Lines changed: 24 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
# SPDX-License-Identifier: Apache-2.0
44
#
55

6+
import numpy as np
7+
68
from .image_model import ImageModel
7-
from .result_types import Detection
9+
from .result_types import DetectionResult
810
from .types import ListValue, NumericalValue, StringValue
911
from .utils import load_labels
1012

@@ -65,18 +67,15 @@ def parameters(cls):
6567

6668
return parameters
6769

68-
def _resize_detections(self, detections: list[Detection], meta):
70+
def _resize_detections(self, detection_result: DetectionResult, meta: dict):
6971
"""Resizes detection bounding boxes according to initial image shape.
7072
7173
It implements image resizing depending on the set `resize_type`(see `ImageModel` for details).
7274
Next, it applies bounding boxes clipping.
7375
7476
Args:
75-
detections (List[Detection]): list of detections with coordinates in normalized form
77+
detection_result (DetectionList): detection result with coordinates in normalized form
7678
meta (dict): the input metadata obtained from `preprocess` method
77-
78-
Returns:
79-
- list of detections with resized and clipped coordinates to fit the initial image
8079
"""
8180
input_img_height, input_img_widht = meta["original_shape"][:2]
8281
inverted_scale_x = input_img_widht / self.w
@@ -92,63 +91,35 @@ def _resize_detections(self, detections: list[Detection], meta):
9291
pad_left = (self.w - round(input_img_widht / inverted_scale_x)) // 2
9392
pad_top = (self.h - round(input_img_height / inverted_scale_y)) // 2
9493

95-
def _clamp_and_round(val, min_value, max_value):
96-
return round(max(min_value, min(max_value, val)))
94+
boxes = detection_result.bboxes
95+
boxes[:, 0::2] = (boxes[:, 0::2] * self.w - pad_left) * inverted_scale_x
96+
boxes[:, 1::2] = (boxes[:, 1::2] * self.h - pad_top) * inverted_scale_y
97+
np.round(boxes, out=boxes)
98+
boxes[:, 0::2] = np.clip(boxes[:, 0::2], 0, input_img_widht)
99+
boxes[:, 1::2] = np.clip(boxes[:, 1::2], 0, input_img_height)
100+
detection_result.bboxes = boxes.astype(np.int32)
97101

98-
for detection in detections:
99-
detection.xmin = _clamp_and_round(
100-
(detection.xmin * self.w - pad_left) * inverted_scale_x,
101-
0,
102-
input_img_widht,
103-
)
104-
detection.ymin = _clamp_and_round(
105-
(detection.ymin * self.h - pad_top) * inverted_scale_y,
106-
0,
107-
input_img_height,
108-
)
109-
detection.xmax = _clamp_and_round(
110-
(detection.xmax * self.w - pad_left) * inverted_scale_x,
111-
0,
112-
input_img_widht,
113-
)
114-
detection.ymax = _clamp_and_round(
115-
(detection.ymax * self.h - pad_top) * inverted_scale_y,
116-
0,
117-
input_img_height,
118-
)
119-
120-
return detections
121-
122-
def _filter_detections(self, detections: list[Detection], box_area_threshold=0.0):
102+
def _filter_detections(self, detection_result: DetectionResult, box_area_threshold=0.0):
123103
"""Filters detections by confidence threshold and box size threshold
124104
125105
Args:
126-
detections (List[Detection]): list of detections with coordinates in normalized form
106+
detection_result (DetectionResult): DetectionResult object with coordinates in normalized form
127107
box_area_threshold (float): minimal area of the bounding to be considered
128108
129109
Returns:
130110
- list of detections with confidence above the threshold
131111
"""
132-
filtered_detections = []
133-
for detection in detections:
134-
if (
135-
detection.score < self.confidence_threshold
136-
or (detection.xmax - detection.xmin) * (detection.ymax - detection.ymin) < box_area_threshold
137-
):
138-
continue
139-
filtered_detections.append(detection)
140-
141-
return filtered_detections
142-
143-
def _add_label_names(self, detections: list[Detection]):
112+
keep = (detection_result.get_obj_sizes() > box_area_threshold) & (
113+
detection_result.scores > self.confidence_threshold
114+
)
115+
detection_result.bboxes = detection_result.bboxes[keep]
116+
detection_result.labels = detection_result.labels[keep]
117+
detection_result.scores = detection_result.scores[keep]
118+
119+
def _add_label_names(self, detection_result: DetectionResult) -> None:
144120
"""Adds labels names to detections if they are available
145121
146122
Args:
147-
detections (List[Detection]): list of detections with coordinates in normalized form
148-
149-
Returns:
150-
- list of detections with label strings
123+
detection_result (List[Detection]): list of detections with coordinates in normalized form
151124
"""
152-
for detection in detections:
153-
detection.str_label = self.get_label_name(detection.id)
154-
return detections
125+
detection_result.label_names = [self.get_label_name(label_idx) for label_idx in detection_result.labels]

model_api/python/model_api/models/instance_segmentation.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from model_api.adapters.inference_adapter import InferenceAdapter
1010

1111
from .image_model import ImageModel
12-
from .result_types import InstanceSegmentationResult, SegmentedObject
12+
from .result_types import InstanceSegmentationResult
1313
from .types import BooleanValue, ListValue, NumericalValue, StringValue
1414
from .utils import load_labels
1515

@@ -176,27 +176,31 @@ def postprocess(self, outputs: dict, meta: dict) -> InstanceSegmentationResult:
176176
out=boxes,
177177
)
178178

179-
objects = []
180179
has_feature_vector_name = _feature_vector_name in self.outputs
181180
if has_feature_vector_name:
182181
if not self.labels:
183182
self.raise_error("Can't get number of classes because labels are empty")
184183
saliency_maps: list = [[] for _ in range(len(self.labels))]
185184
else:
186185
saliency_maps = []
187-
for box, confidence, cls, raw_mask in zip(boxes, scores, labels, masks):
188-
x1, y1, x2, y2 = box
189-
if (x2 - x1) * (y2 - y1) < 1 or (confidence <= self.confidence_threshold and not has_feature_vector_name):
190-
continue
191186

192-
# Skip if label index is out of bounds
193-
if self.labels and cls >= len(self.labels):
194-
continue
187+
# Apply confidence threshold, bounding box area filter and label index filter.
188+
keep = (scores > self.confidence_threshold) & ((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) > 1)
189+
190+
if self.labels:
191+
keep &= labels < len(self.labels)
192+
193+
boxes = boxes[keep].astype(np.int32)
194+
scores = scores[keep]
195+
labels = labels[keep]
196+
masks = masks[keep]
195197

196-
# Get label string
197-
str_label = self.labels[cls] if self.labels else f"#{cls}"
198+
resized_masks, label_names = [], []
199+
for box, label_idx, raw_mask in zip(boxes, labels, masks):
200+
if self.labels:
201+
label_names.append(self.labels[label_idx])
198202

199-
raw_cls_mask = raw_mask[cls, ...] if self.is_segmentoly else raw_mask
203+
raw_cls_mask = raw_mask[label_idx, ...] if self.is_segmentoly else raw_mask
200204
if self.postprocess_semantic_masks or has_feature_vector_name:
201205
resized_mask = _segm_postprocess(
202206
box,
@@ -205,27 +209,21 @@ def postprocess(self, outputs: dict, meta: dict) -> InstanceSegmentationResult:
205209
)
206210
else:
207211
resized_mask = raw_cls_mask
208-
if confidence > self.confidence_threshold:
209-
output_mask = resized_mask if self.postprocess_semantic_masks else raw_cls_mask
210-
xmin, ymin, xmax, ymax = box.astype(int)
211-
objects.append(
212-
SegmentedObject(
213-
xmin,
214-
ymin,
215-
xmax,
216-
ymax,
217-
score=confidence,
218-
id=cls,
219-
str_label=str_label,
220-
mask=output_mask,
221-
),
222-
)
223-
if has_feature_vector_name and confidence > self.confidence_threshold:
224-
saliency_maps[cls - 1].append(resized_mask)
212+
213+
output_mask = resized_mask if self.postprocess_semantic_masks else raw_cls_mask
214+
resized_masks.append(output_mask)
215+
if has_feature_vector_name:
216+
saliency_maps[label_idx - 1].append(resized_mask)
217+
218+
_masks = np.stack(resized_masks) if len(resized_masks) > 0 else np.empty((0, 16, 16), dtype=np.uint8)
225219
return InstanceSegmentationResult(
226-
objects,
227-
_average_and_normalize(saliency_maps),
228-
outputs.get(_feature_vector_name, np.ndarray(0)),
220+
bboxes=boxes,
221+
labels=labels,
222+
scores=scores,
223+
masks=_masks,
224+
label_names=label_names if label_names else None,
225+
saliency_map=_average_and_normalize(saliency_maps),
226+
feature_vector=outputs.get(_feature_vector_name, np.ndarray(0)),
229227
)
230228

231229

model_api/python/model_api/models/keypoint_detection.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import numpy as np
1111

1212
from .image_model import ImageModel
13-
from .result_types import DetectedKeypoints, Detection
13+
from .result_types import DetectedKeypoints, DetectionResult
1414
from .types import ListValue
1515

1616

@@ -77,25 +77,27 @@ def __init__(self, base_model: KeypointDetectionModel) -> None:
7777
def predict(
7878
self,
7979
image: np.ndarray,
80-
detections: list[Detection],
80+
detection_result: DetectionResult,
8181
) -> list[DetectedKeypoints]:
8282
"""Predicts keypoints for the given image and detections.
8383
8484
Args:
8585
image (np.ndarray): input full-size image
86-
detections (list[Detection]): detections located within the given image
86+
detection_result (detection_result): detections located within the given image
8787
8888
Returns:
8989
list[DetectedKeypoints]: per detection keypoints in detection coordinates
9090
"""
9191
crops = []
92-
for det in detections:
93-
crops.append(image[det.ymin : det.ymax, det.xmin : det.xmax])
92+
for box in detection_result.bboxes:
93+
x1, y1, x2, y2 = box
94+
crops.append(image[y1:y2, x1:x2])
9495

9596
crops_results = self.predict_crops(crops)
96-
for i, det in enumerate(detections):
97+
for i, box in enumerate(detection_result.bboxes):
98+
x1, y1, x2, y2 = box
9799
crops_results[i] = DetectedKeypoints(
98-
crops_results[i].keypoints + np.array([det.xmin, det.ymin]),
100+
crops_results[i].keypoints + np.array([x1, y1]),
99101
crops_results[i].scores,
100102
)
101103

0 commit comments

Comments
 (0)