Skip to content

Commit d4d7fb5

Browse files
authored
Reorganize bounding box utilities and namespaces (#439)
* Reorganize bounding box utilities and namespaces * regorganize bbox api entrypoint
1 parent aa6a747 commit d4d7fb5

16 files changed

+123
-105
lines changed

benchmarks/metrics/coco/mean_average_precision_bucket_performance.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ def produce_random_data(include_confidence=False, num_images=128, num_classes=20
3636
)
3737

3838
images = [
39-
keras_cv.utils.bounding_box.pad_bounding_box_batch_to_shape(
40-
x, [25, images[0].shape[1]]
41-
)
39+
keras_cv.bounding_box.pad_batch_to_shape(x, [25, images[0].shape[1]])
4240
for x in images
4341
]
4442
return tf.stack(images, axis=0)

benchmarks/metrics/coco/mean_average_precision_performance.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ def produce_random_data(include_confidence=False, num_images=128, num_classes=20
3636
)
3737

3838
images = [
39-
keras_cv.utils.bounding_box.pad_bounding_box_batch_to_shape(
40-
x, [25, images[0].shape[1]]
41-
)
39+
keras_cv.bounding_box.pad_batch_to_shape(x, [25, images[0].shape[1]])
4240
for x in images
4341
]
4442
return tf.stack(images, axis=0)

benchmarks/metrics/coco/recall_performance.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ def produce_random_data(include_confidence=False, num_images=128, num_classes=20
3636
)
3737

3838
images = [
39-
keras_cv.utils.bounding_box.pad_bounding_box_batch_to_shape(
40-
x, [25, images[0].shape[1]]
41-
)
39+
keras_cv.bounding_box.pad_batch_to_shape(x, [25, images[0].shape[1]])
4240
for x in images
4341
]
4442
return tf.stack(images, axis=0)

keras_cv/bounding_box/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2022 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from keras_cv.bounding_box.convert_to_corners import convert_to_corners
16+
from keras_cv.bounding_box.pad_batch_to_shape import pad_batch_to_shape
17+
18+
# These are the indexes used in Tensors to represent each corresponding side.
19+
LEFT, TOP, RIGHT, BOTTOM = 0, 1, 2, 3
20+
21+
# Regardless of format these constants are consistent.
22+
# Class is held in the 5th index
23+
CLASS = 4
24+
# Confidence exists only on y_pred, and is in the 6th index.
25+
CONFIDENCE = 5

keras_cv/utils/bounding_box_test.py renamed to keras_cv/bounding_box/bounding_box_test.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import tensorflow as tf
1616

17-
from keras_cv.utils import bounding_box
17+
from keras_cv import bounding_box
1818

1919

2020
class BBOXTestCase(tf.test.TestCase):
@@ -97,25 +97,21 @@ def test_yolo_to_corner(self):
9797
def test_bounding_box_padding(self):
9898
bounding_boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
9999
target_shape = [3, 4]
100-
result = bounding_box.pad_bounding_box_batch_to_shape(
101-
bounding_boxes, target_shape
102-
)
100+
result = bounding_box.pad_batch_to_shape(bounding_boxes, target_shape)
103101
self.assertAllClose(result, [[1, 2, 3, 4], [5, 6, 7, 8], [-1, -1, -1, -1]])
104102

105103
target_shape = [2, 5]
106-
result = bounding_box.pad_bounding_box_batch_to_shape(
107-
bounding_boxes, target_shape
108-
)
104+
result = bounding_box.pad_batch_to_shape(bounding_boxes, target_shape)
109105
self.assertAllClose(result, [[1, 2, 3, 4, -1], [5, 6, 7, 8, -1]])
110106

111107
# Make sure to raise error if the rank is different between bounding_box and
112108
# target shape
113109
with self.assertRaisesRegex(ValueError, "Target shape should have same rank"):
114-
bounding_box.pad_bounding_box_batch_to_shape(bounding_boxes, [1, 2, 3])
110+
bounding_box.pad_batch_to_shape(bounding_boxes, [1, 2, 3])
115111

116112
# Make sure raise error if the target shape is smaller
117113
target_shape = [3, 2]
118114
with self.assertRaisesRegex(
119115
ValueError, "Target shape should be larger than bounding box shape"
120116
):
121-
bounding_box.pad_bounding_box_batch_to_shape(bounding_boxes, target_shape)
117+
bounding_box.pad_batch_to_shape(bounding_boxes, target_shape)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2022 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Shared utility functions for working with bounding boxes.
16+
17+
Usually bounding boxes is a 2D Tensor with shape [batch, 4]. The second dimension
18+
will contain 4 numbers based on 2 different formats. In KerasCV, we will use the
19+
`corners` format, which is [LEFT, TOP, RIGHT, BOTTOM].
20+
21+
In this file, provide utility functions for manipulating bounding boxes and converting
22+
their formats.
23+
"""
24+
25+
import tensorflow as tf
26+
27+
28+
def convert_to_corners(bounding_boxes, format):
29+
"""Converts bounding_boxes to corners format.
30+
31+
Converts bounding boxes from the provided format to corners format, which is:
32+
`[left, top, right, bottom]`.
33+
34+
args:
35+
format: one of "coco" or "yolo". The formats are as follows-
36+
coco=[x_min, y_min, width, height]
37+
yolo=[x_center, y_center, width, height]
38+
"""
39+
if format == "coco":
40+
return _coco_to_corners(bounding_boxes)
41+
elif format == "yolo":
42+
return _yolo_to_corners(bounding_boxes)
43+
else:
44+
raise ValueError(
45+
"Unsupported format passed to convert_to_corners(). "
46+
f"Want one 'coco' or 'yolo', got format=={format}"
47+
)
48+
49+
50+
def _yolo_to_corners(bounding_boxes):
51+
x, y, width, height, rest = tf.split(bounding_boxes, [1, 1, 1, 1, -1], axis=-1)
52+
return tf.concat(
53+
[
54+
x - width / 2.0,
55+
y - height / 2.0,
56+
x + width / 2.0,
57+
y + height / 2.0,
58+
rest, # In case there is any more index after the HEIGHT.
59+
],
60+
axis=-1,
61+
)
62+
63+
64+
def _coco_to_corners(bounding_boxes):
65+
x, y, width, height, rest = tf.split(bounding_boxes, [1, 1, 1, 1, -1], axis=-1)
66+
return tf.concat(
67+
[
68+
x,
69+
y,
70+
x + width,
71+
y + height,
72+
rest, # In case there is any more index after the HEIGHT.
73+
],
74+
axis=-1,
75+
)

keras_cv/utils/bounding_box.py renamed to keras_cv/bounding_box/pad_batch_to_shape.py

Lines changed: 4 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -11,80 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
"""Shared utility functions for working with bounding boxes.
16-
17-
Usually bounding boxes is a 2D Tensor with shape [batch, 4]. The second dimension
18-
will contain 4 numbers based on 2 different formats. In KerasCV, we will use the
19-
`corners` format, which is [LEFT, TOP, RIGHT, BOTTOM].
20-
21-
In this file, provide utility functions for manipulating bounding boxes and converting
22-
their formats.
23-
"""
24-
2514
import tensorflow as tf
2615

27-
# These are the indexes used in Tensors to represent each corresponding side.
28-
LEFT, TOP, RIGHT, BOTTOM = 0, 1, 2, 3
29-
30-
# Regardless of format these constants are consistent.
31-
# Class is held in the 5th index
32-
CLASS = 4
33-
# Confidence exists only on y_pred, and is in the 6th index.
34-
CONFIDENCE = 5
35-
36-
37-
def convert_to_corners(bounding_boxes, format):
38-
"""Converts bounding_boxes to corners format.
39-
40-
Converts bounding boxes from the provided format to corners format, which is:
41-
`[left, top, right, bottom]`.
42-
43-
args:
44-
format: one of "coco" or "yolo". The formats are as follows-
45-
coco=[x_min, y_min, width, height]
46-
yolo=[x_center, y_center, width, height]
47-
"""
48-
if format == "coco":
49-
return _coco_to_corners(bounding_boxes)
50-
elif format == "yolo":
51-
return _yolo_to_corners(bounding_boxes)
52-
else:
53-
raise ValueError(
54-
"Unsupported format passed to convert_to_corners(). "
55-
f"Want one 'coco' or 'yolo', got format=={format}"
56-
)
57-
58-
59-
def _yolo_to_corners(bounding_boxes):
60-
x, y, width, height, rest = tf.split(bounding_boxes, [1, 1, 1, 1, -1], axis=-1)
61-
return tf.concat(
62-
[
63-
x - width / 2.0,
64-
y - height / 2.0,
65-
x + width / 2.0,
66-
y + height / 2.0,
67-
rest, # In case there is any more index after the HEIGHT.
68-
],
69-
axis=-1,
70-
)
71-
72-
73-
def _coco_to_corners(bounding_boxes):
74-
x, y, width, height, rest = tf.split(bounding_boxes, [1, 1, 1, 1, -1], axis=-1)
75-
return tf.concat(
76-
[
77-
x,
78-
y,
79-
x + width,
80-
y + height,
81-
rest, # In case there is any more index after the HEIGHT.
82-
],
83-
axis=-1,
84-
)
85-
8616

87-
def pad_bounding_box_batch_to_shape(bounding_boxes, target_shape, padding_values=-1):
17+
def pad_batch_to_shape(bounding_boxes, target_shape, padding_values=-1):
8818
"""Pads a list of bounding boxes with -1s.
8919
9020
Boxes represented by all -1s are ignored by COCO metrics.
@@ -93,17 +23,17 @@ def pad_bounding_box_batch_to_shape(bounding_boxes, target_shape, padding_values
9323
bounding_box = [[1, 2, 3, 4], [5, 6, 7, 8]] # 2 bounding_boxes with with xywh or
9424
corners format.
9525
target_shape = [3, 4] # Add 1 more dummy bounding_box
96-
result = pad_bounding_box_batch_to_shape(bounding_box, target_shape)
26+
result = pad_batch_to_shape(bounding_box, target_shape)
9727
# result == [[1, 2, 3, 4], [5, 6, 7, 8], [-1, -1, -1, -1]]
9828
9929
target_shape = [2, 5] # Add 1 more index after the current 4 coordinates.
100-
result = pad_bounding_box_batch_to_shape(bounding_box, target_shape)
30+
result = pad_batch_to_shape(bounding_box, target_shape)
10131
# result == [[1, 2, 3, 4, -1], [5, 6, 7, 8, -1]]
10232
10333
Args:
10434
bounding_boxes: tf.Tensor of bounding boxes in any format.
10535
target_shape: Target shape to pad bounding box to. This should have the same
106-
rank as the bbounding_boxs. Note that if the target_shape contains any
36+
rank as the bounding_boxes. Note that if the target_shape contains any
10737
dimension that is smaller than the bounding box shape, then no value will be
10838
padded.
10939
padding_values: value to pad, defaults to -1 to mask out in coco metrics.

keras_cv/metrics/coco/mean_average_precision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
import tensorflow as tf
1717

18+
from keras_cv import bounding_box
1819
from keras_cv.metrics.coco import utils
19-
from keras_cv.utils import bounding_box
2020
from keras_cv.utils import iou as iou_lib
2121

2222

@@ -63,7 +63,7 @@ class COCOMeanAveragePrecision(tf.keras.metrics.Metric):
6363
account for this, you may either pass a `tf.RaggedTensor`, or pad Tensors
6464
with `-1`s to indicate unused boxes. A utility function to perform this
6565
padding is available at
66-
`keras_cv.utils.bounding_box.pad_bounding_box_batch_to_shape()`.
66+
`keras_cv.bounding_box.pad_batch_to_shape()`.
6767
6868
```python
6969
coco_map = keras_cv.metrics.COCOMeanAveragePrecision(

keras_cv/metrics/coco/mean_average_precision_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import tensorflow as tf
1818
from tensorflow import keras
1919

20+
from keras_cv import bounding_box
2021
from keras_cv.metrics import COCOMeanAveragePrecision
21-
from keras_cv.utils import bounding_box as bounding_box_utils
2222

2323

2424
class COCOMeanAveragePrecisionTest(tf.test.TestCase):
@@ -208,7 +208,7 @@ def test_bounding_box_counting(self):
208208
y_true = tf.constant([[[0, 0, 100, 100, 1]]], dtype=tf.float64)
209209
y_pred = tf.constant([[[0, 50, 100, 150, 1, 1.0]]], dtype=tf.float32)
210210

211-
y_true = bounding_box_utils.pad_bounding_box_batch_to_shape(y_true, (1, 20, 5))
211+
y_true = bounding_box.pad_batch_to_shape(y_true, (1, 20, 5))
212212

213213
metric = COCOMeanAveragePrecision(
214214
iou_thresholds=[0.15],

keras_cv/metrics/coco/numerical_tests/mean_average_precision_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import numpy as np
1717
import tensorflow as tf
1818

19+
from keras_cv import bounding_box
1920
from keras_cv.metrics.coco import COCOMeanAveragePrecision
20-
from keras_cv.utils import bounding_box
2121

2222
SAMPLE_FILE = os.path.dirname(os.path.abspath(__file__)) + "/sample_boxes.npz"
2323

0 commit comments

Comments
 (0)