diff --git a/setup.py b/setup.py index d9e023cc01f..e14898ff253 100644 --- a/setup.py +++ b/setup.py @@ -111,7 +111,8 @@ def get_dist(pkgname): ] # Excluding 8.3.* because of https://github.com/pytorch/vision/issues/4934 - pillow_ver = " >= 5.3.0, !=8.3.*" + # TODO remove <11.3 bound and address corresponding deprecation warnings + pillow_ver = " >= 5.3.0, !=8.3.*, <11.3" pillow_req = "pillow-simd" if get_dist("pillow-simd") is not None else "pillow" requirements.append(pillow_req + pillow_ver) diff --git a/test/assets/fakedata/draw_rotated_boxes_fill.png b/test/assets/fakedata/draw_rotated_boxes_fill.png new file mode 100644 index 00000000000..474b771f04e Binary files /dev/null and b/test/assets/fakedata/draw_rotated_boxes_fill.png differ diff --git a/test/common_utils.py b/test/common_utils.py index 9da3cf52d1c..0da8e6bbc1d 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -21,7 +21,7 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import clamp_bounding_boxes, to_image, to_pil_image +from torchvision.transforms.v2.functional import to_image, to_pil_image IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) @@ -410,6 +410,7 @@ def make_bounding_boxes( canvas_size=DEFAULT_SIZE, *, format=tv_tensors.BoundingBoxFormat.XYXY, + clamping_mode="soft", num_boxes=1, dtype=None, device="cpu", @@ -423,13 +424,6 @@ def sample_position(values, max_value): format = tv_tensors.BoundingBoxFormat[format] dtype = dtype or torch.float32 - int_dtype = dtype in ( - torch.uint8, - torch.int8, - torch.int16, - torch.int32, - torch.int64, - ) h, w = (torch.randint(1, s, (num_boxes,)) for s in canvas_size) y = sample_position(h, canvas_size[0]) @@ -456,31 +450,19 @@ def sample_position(values, max_value): elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY: r_rad = r * torch.pi / 180.0 cos, sin = torch.cos(r_rad), torch.sin(r_rad) - x1 = torch.round(x) if int_dtype else x - y1 = torch.round(y) if int_dtype else y - x2 = torch.round(x1 + w * cos) if int_dtype else x1 + w * cos - y2 = torch.round(y1 - w * sin) if int_dtype else y1 - w * sin - x3 = torch.round(x2 + h * sin) if int_dtype else x2 + h * sin - y3 = torch.round(y2 + h * cos) if int_dtype else y2 + h * cos - x4 = torch.round(x1 + h * sin) if int_dtype else x1 + h * sin - y4 = torch.round(y1 + h * cos) if int_dtype else y1 + h * cos + x1 = x + y1 = y + x2 = x1 + w * cos + y2 = y1 - w * sin + x3 = x2 + h * sin + y3 = y2 + h * cos + x4 = x1 + h * sin + y4 = y1 + h * cos parts = (x1, y1, x2, y2, x3, y3, x4, y4) else: raise ValueError(f"Format {format} is not supported") out_boxes = torch.stack(parts, dim=-1).to(dtype=dtype, device=device) - if tv_tensors.is_rotated_bounding_format(format): - # The rotated bounding boxes are not guaranteed to be within the canvas by design, - # so we apply clamping. We also add a 2 buffer to the canvas size to avoid - # numerical issues during the testing - buffer = 4 - out_boxes = clamp_bounding_boxes( - out_boxes, format=format, canvas_size=(canvas_size[0] - buffer, canvas_size[1] - buffer) - ) - if format is tv_tensors.BoundingBoxFormat.XYWHR or format is tv_tensors.BoundingBoxFormat.CXCYWHR: - out_boxes[:, :2] += buffer // 2 - elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY: - out_boxes[:, :] += buffer // 2 - return tv_tensors.BoundingBoxes(out_boxes, format=format, canvas_size=canvas_size) + return tv_tensors.BoundingBoxes(out_boxes, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode) def make_detection_masks(size=DEFAULT_SIZE, *, num_masks=1, dtype=None, device="cpu"): diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 7e667586ac1..06d6514770d 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -492,6 +492,7 @@ def adapt_fill(value, *, dtype): def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new_canvas_size=None, clamp=True): format = bounding_boxes.format canvas_size = new_canvas_size or bounding_boxes.canvas_size + clamping_mode = bounding_boxes.clamping_mode def affine_bounding_boxes(bounding_boxes): dtype = bounding_boxes.dtype @@ -535,6 +536,7 @@ def affine_bounding_boxes(bounding_boxes): output, format=format, canvas_size=canvas_size, + clamping_mode=clamping_mode, ) else: # We leave the bounding box as float64 so the caller gets the full precision to perform any additional @@ -549,6 +551,7 @@ def affine_bounding_boxes(bounding_boxes): ), format=format, canvas_size=canvas_size, + clamping_mode=clamping_mode, ) @@ -557,16 +560,10 @@ def reference_affine_rotated_bounding_boxes_helper( ): format = bounding_boxes.format canvas_size = new_canvas_size or bounding_boxes.canvas_size + clamping_mode = bounding_boxes.clamping_mode def affine_rotated_bounding_boxes(bounding_boxes): dtype = bounding_boxes.dtype - int_dtype = dtype in ( - torch.uint8, - torch.int8, - torch.int16, - torch.int32, - torch.int64, - ) device = bounding_boxes.device # Go to float before converting to prevent precision loss in case of CXCYWHR -> XYXYXYXY and W or H is 1 @@ -601,23 +598,18 @@ def affine_rotated_bounding_boxes(bounding_boxes): ) output = output[[2, 3, 0, 1, 6, 7, 4, 5]] if flip else output - if not int_dtype: - output = _parallelogram_to_bounding_boxes(output) + output = _parallelogram_to_bounding_boxes(output) output = F.convert_bounding_box_format( output, old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, new_format=format ) - if torch.is_floating_point(output) and int_dtype: - # It is important to round before cast. - output = torch.round(output) - - # For rotated boxes, it is important to cast before clamping. return ( F.clamp_bounding_boxes( output.to(dtype=dtype, device=device), format=format, canvas_size=canvas_size, + clamping_mode=clamping_mode, ) if clamp else output.to(dtype=output.dtype, device=device) @@ -635,6 +627,7 @@ def affine_rotated_bounding_boxes(bounding_boxes): ).reshape(bounding_boxes.shape), format=format, canvas_size=canvas_size, + clamping_mode=clamping_mode, ) @@ -754,6 +747,8 @@ def test_kernel_image(self, size, interpolation, use_max_size, antialias, dtype, def test_kernel_bounding_boxes(self, format, size, use_max_size, dtype, device): if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)): return + if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format): + pytest.xfail("Rotated bounding boxes should be floating point tensors") bounding_boxes = make_bounding_boxes( format=format, @@ -831,7 +826,6 @@ def test_functional(self, size, make_input): (F.resize_image, torch.Tensor), (F._geometry._resize_image_pil, PIL.Image.Image), (F.resize_image, tv_tensors.Image), - (F.resize_bounding_boxes, tv_tensors.BoundingBoxes), (F.resize_mask, tv_tensors.Mask), (F.resize_video, tv_tensors.Video), (F.resize_keypoints, tv_tensors.KeyPoints), @@ -1207,6 +1201,8 @@ def test_kernel_image(self, dtype, device): @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_bounding_boxes(self, format, dtype, device): + if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format): + pytest.xfail("Rotated bounding boxes should be floating point tensors") bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device) check_kernel( F.horizontal_flip_bounding_boxes, @@ -1302,7 +1298,7 @@ def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.B if tv_tensors.is_rotated_bounding_format(bounding_boxes.format) else reference_affine_bounding_boxes_helper ) - return helper(bounding_boxes, affine_matrix=affine_matrix) + return helper(bounding_boxes, affine_matrix=affine_matrix, clamp=False) @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize( @@ -1436,6 +1432,8 @@ def test_kernel_image(self, param, value, dtype, device): @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_bounding_boxes(self, param, value, format, dtype, device): + if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format): + pytest.xfail("Rotated bounding boxes should be floating point tensors") bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device) self._check_kernel( F.affine_bounding_boxes, @@ -1650,7 +1648,7 @@ def test_functional_bounding_boxes_correctness(self, format, angle, translate, s center=center, ) - torch.testing.assert_close(actual, expected, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) @@ -1818,6 +1816,8 @@ def test_kernel_image(self, dtype, device): @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_bounding_boxes(self, format, dtype, device): + if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format): + pytest.xfail("Rotated bounding boxes should be floating point tensors") bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device) check_kernel( F.vertical_flip_bounding_boxes, @@ -1911,7 +1911,7 @@ def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.Bou if tv_tensors.is_rotated_bounding_format(bounding_boxes.format) else reference_affine_bounding_boxes_helper ) - return helper(bounding_boxes, affine_matrix=affine_matrix) + return helper(bounding_boxes, affine_matrix=affine_matrix, clamp=False) @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)]) @@ -2016,8 +2016,14 @@ def test_kernel_bounding_boxes(self, param, value, format, dtype, device): kwargs = {param: value} if param != "angle": kwargs["angle"] = self._MINIMAL_AFFINE_KWARGS["angle"] + if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format): + pytest.xfail("Rotated bounding boxes should be floating point tensors") bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device) + if tv_tensors.is_rotated_bounding_format(format): + # TODO there is a 1e-6 difference between GPU and CPU outputs + # due to clamping. To avoid failing this test, we do clamp before hand. + bounding_boxes = F.clamp_bounding_boxes(bounding_boxes) check_kernel( F.rotate_bounding_boxes, @@ -2076,7 +2082,6 @@ def test_functional(self, make_input): (F.rotate_image, torch.Tensor), (F._geometry._rotate_image_pil, PIL.Image.Image), (F.rotate_image, tv_tensors.Image), - (F.rotate_bounding_boxes, tv_tensors.BoundingBoxes), (F.rotate_mask, tv_tensors.Mask), (F.rotate_video, tv_tensors.Video), (F.rotate_keypoints, tv_tensors.KeyPoints), @@ -2226,29 +2231,26 @@ def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, cen clamp=False, ) - return F.clamp_bounding_boxes(self._recenter_bounding_boxes_after_expand(output, recenter_xy=recenter_xy)).to( - bounding_boxes - ) + return self._recenter_bounding_boxes_after_expand(output, recenter_xy=recenter_xy).to(bounding_boxes) @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"]) @pytest.mark.parametrize("expand", [False, True]) @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) def test_functional_bounding_boxes_correctness(self, format, angle, expand, center): - bounding_boxes = make_bounding_boxes(format=format) + bounding_boxes = make_bounding_boxes(format=format, clamping_mode=None) actual = F.rotate(bounding_boxes, angle=angle, expand=expand, center=center) expected = self._reference_rotate_bounding_boxes(bounding_boxes, angle=angle, expand=expand, center=center) - - torch.testing.assert_close(actual, expected) torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0) + torch.testing.assert_close(actual, expected) @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("expand", [False, True]) @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) @pytest.mark.parametrize("seed", list(range(5))) def test_transform_bounding_boxes_correctness(self, format, expand, center, seed): - bounding_boxes = make_bounding_boxes(format=format) + bounding_boxes = make_bounding_boxes(format=format, clamping_mode=None) transform = transforms.RandomRotation(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, expand=expand, center=center) @@ -2259,9 +2261,8 @@ def test_transform_bounding_boxes_correctness(self, format, expand, center, seed actual = transform(bounding_boxes) expected = self._reference_rotate_bounding_boxes(bounding_boxes, **params, expand=expand, center=center) - - torch.testing.assert_close(actual, expected) torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0) + torch.testing.assert_close(actual, expected) def _recenter_keypoints_after_expand(self, keypoints, *, recenter_xy): x, y = recenter_xy @@ -3236,6 +3237,8 @@ def test_kernel_image(self, param, value, dtype, device): @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_bounding_boxes(self, format, dtype, device): + if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format): + pytest.xfail("Rotated bounding boxes should be floating point tensors") bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device) check_kernel( @@ -3289,7 +3292,6 @@ def test_functional(self, make_input): (F.elastic_image, torch.Tensor), (F._geometry._elastic_image_pil, PIL.Image.Image), (F.elastic_image, tv_tensors.Image), - (F.elastic_bounding_boxes, tv_tensors.BoundingBoxes), (F.elastic_mask, tv_tensors.Mask), (F.elastic_video, tv_tensors.Video), (F.elastic_keypoints, tv_tensors.KeyPoints), @@ -3400,6 +3402,8 @@ def test_kernel_image(self, kwargs, dtype, device): @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_bounding_boxes(self, kwargs, format, dtype, device): + if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format): + pytest.xfail("Rotated bounding boxes should be floating point tensors") bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, dtype=dtype, device=device) check_kernel(F.crop_bounding_boxes, bounding_boxes, format=format, **kwargs) @@ -3577,6 +3581,8 @@ def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, w @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device): + if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format): + pytest.xfail("Rotated bounding boxes should be floating point tensors") bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, dtype=dtype, device=device) actual = F.crop(bounding_boxes, **kwargs) @@ -3591,6 +3597,8 @@ def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("seed", list(range(5))) def test_transform_bounding_boxes_correctness(self, output_size, format, dtype, device, seed): + if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format): + pytest.xfail("Rotated bounding boxes should be floating point tensors") input_size = [s * 2 for s in output_size] bounding_boxes = make_bounding_boxes(input_size, format=format, dtype=dtype, device=device) @@ -4268,6 +4276,10 @@ def _reference_convert_bounding_box_format(self, bounding_boxes, new_format): @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("fn_type", ["functional", "transform"]) def test_correctness(self, old_format, new_format, dtype, device, fn_type): + if not dtype.is_floating_point and ( + tv_tensors.is_rotated_bounding_format(old_format) or tv_tensors.is_rotated_bounding_format(new_format) + ): + pytest.xfail("Rotated bounding boxes should be floating point tensors") bounding_boxes = make_bounding_boxes(format=old_format, dtype=dtype, device=device) if fn_type == "functional": @@ -4347,7 +4359,6 @@ def test_functional(self, make_input): (F.resized_crop_image, torch.Tensor), (F._geometry._resized_crop_image_pil, PIL.Image.Image), (F.resized_crop_image, tv_tensors.Image), - (F.resized_crop_bounding_boxes, tv_tensors.BoundingBoxes), (F.resized_crop_mask, tv_tensors.Mask), (F.resized_crop_video, tv_tensors.Video), (F.resized_crop_keypoints, tv_tensors.KeyPoints), @@ -4413,6 +4424,7 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h [0, 0, 1], ], ) + affine_matrix = (resize_affine_matrix @ crop_affine_matrix)[:2, :] helper = ( @@ -4421,15 +4433,15 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h else reference_affine_bounding_boxes_helper ) - return helper( - bounding_boxes, - affine_matrix=affine_matrix, - new_canvas_size=size, - ) + return helper(bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=size, clamp=False) @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) def test_functional_bounding_boxes_correctness(self, format): - bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format) + # Note that we don't want to clamp because in + # _reference_resized_crop_bounding_boxes we are fusing the crop and the + # resize operation, where none of the croppings happen - particularly, + # the intermediate one. + bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, clamping_mode=None) actual = F.resized_crop(bounding_boxes, **self.CROP_KWARGS, size=self.OUTPUT_SIZE) expected = self._reference_resized_crop_bounding_boxes( @@ -4707,6 +4719,8 @@ def _reference_pad_bounding_boxes(self, bounding_boxes, *, padding): @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("fn", [F.pad, transform_cls_to_functional(transforms.Pad)]) def test_bounding_boxes_correctness(self, padding, format, dtype, device, fn): + if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format): + pytest.xfail("Rotated bounding boxes should be floating point tensors") bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device) actual = fn(bounding_boxes, padding=padding) @@ -4877,6 +4891,8 @@ def _reference_center_crop_bounding_boxes(self, bounding_boxes, output_size): @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)]) def test_bounding_boxes_correctness(self, output_size, format, dtype, device, fn): + if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format): + pytest.xfail("Rotated bounding boxes should be floating point tensors") bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, dtype=dtype, device=device) actual = fn(bounding_boxes, output_size) @@ -5126,6 +5142,7 @@ def test_image_functional_correctness(self, coefficients, interpolation, fill): def _reference_perspective_bounding_boxes(self, bounding_boxes, *, startpoints, endpoints): format = bounding_boxes.format canvas_size = bounding_boxes.canvas_size + clamping_mode = bounding_boxes.clamping_mode dtype = bounding_boxes.dtype device = bounding_boxes.device is_rotated = tv_tensors.is_rotated_bounding_format(format) @@ -5226,6 +5243,7 @@ def perspective_bounding_boxes(bounding_boxes): output, format=format, canvas_size=canvas_size, + clamping_mode=clamping_mode, ).to(dtype=dtype, device=device) return tv_tensors.BoundingBoxes( @@ -5241,6 +5259,8 @@ def perspective_bounding_boxes(bounding_boxes): @pytest.mark.parametrize("dtype", [torch.int64, torch.float32]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_correctness_perspective_bounding_boxes(self, startpoints, endpoints, format, dtype, device): + if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format): + pytest.xfail("Rotated bounding boxes should be floating point tensors") bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device) actual = F.perspective(bounding_boxes, startpoints=startpoints, endpoints=endpoints) @@ -5506,29 +5526,37 @@ def test_correctness_image(self, mean, std, dtype, fn): class TestClampBoundingBoxes: @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) + @pytest.mark.parametrize("clamping_mode", ("soft", "hard", None)) @pytest.mark.parametrize("dtype", [torch.int64, torch.float32]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_kernel(self, format, dtype, device): - bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device) + def test_kernel(self, format, clamping_mode, dtype, device): + if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format): + pytest.xfail("Rotated bounding boxes should be floating point tensors") + bounding_boxes = make_bounding_boxes(format=format, clamping_mode=clamping_mode, dtype=dtype, device=device) check_kernel( F.clamp_bounding_boxes, bounding_boxes, format=bounding_boxes.format, canvas_size=bounding_boxes.canvas_size, + clamping_mode=clamping_mode, ) @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) - def test_functional(self, format): - check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format)) + @pytest.mark.parametrize("clamping_mode", ("soft", "hard", None)) + def test_functional(self, format, clamping_mode): + check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format, clamping_mode=clamping_mode)) def test_errors(self): input_tv_tensor = make_bounding_boxes() input_pure_tensor = input_tv_tensor.as_subclass(torch.Tensor) format, canvas_size = input_tv_tensor.format, input_tv_tensor.canvas_size - for format_, canvas_size_ in [(None, None), (format, None), (None, canvas_size)]: + for format_, canvas_size_, clamping_mode_ in itertools.product( + (format, None), (canvas_size, None), (input_tv_tensor.clamping_mode, "auto") + ): with pytest.raises( - ValueError, match="For pure tensor inputs, `format` and `canvas_size` have to be passed." + ValueError, + match="For pure tensor inputs, `format`, `canvas_size` and `clamping_mode` have to be passed.", ): F.clamp_bounding_boxes(input_pure_tensor, format=format_, canvas_size=canvas_size_) @@ -5538,9 +5566,118 @@ def test_errors(self): ): F.clamp_bounding_boxes(input_tv_tensor, format=format_, canvas_size=canvas_size_) + with pytest.raises(ValueError, match="clamping_mode must be soft,"): + F.clamp_bounding_boxes(input_tv_tensor, clamping_mode="bad") + with pytest.raises(ValueError, match="clamping_mode must be soft,"): + transforms.ClampBoundingBoxes(clamping_mode="bad")(input_tv_tensor) + def test_transform(self): check_transform(transforms.ClampBoundingBoxes(), make_bounding_boxes()) + @pytest.mark.parametrize("rotated", (True, False)) + @pytest.mark.parametrize("constructor_clamping_mode", ("soft", "hard", None)) + @pytest.mark.parametrize("clamping_mode", ("soft", "hard", None, "auto")) + @pytest.mark.parametrize("pass_pure_tensor", (True, False)) + @pytest.mark.parametrize("fn", [F.clamp_bounding_boxes, transform_cls_to_functional(transforms.ClampBoundingBoxes)]) + def test_clamping_mode(self, rotated, constructor_clamping_mode, clamping_mode, pass_pure_tensor, fn): + # This test checks 2 things: + # - That passing clamping_mode=None to the clamp_bounding_boxes + # functional (or to the class) relies on the box's `.clamping_mode` + # attribute + # - That clamping happens when it should, and only when it should, i.e. + # when the clamping mode is not None. It doesn't validate the + # numerical results, only that clamping happened. For that, we create + # a large 100x100 box inside of a small 10x10 image. + + if pass_pure_tensor and fn is not F.clamp_bounding_boxes: + # Only the functional supports pure tensors, not the class + return + if pass_pure_tensor and clamping_mode == "auto": + # cannot leave clamping_mode="auto" when passing pure tensor + return + + if rotated: + boxes = tv_tensors.BoundingBoxes( + [0.0, 0.0, 100.0, 100.0, 0.0], + format="XYWHR", + canvas_size=(10, 10), + clamping_mode=constructor_clamping_mode, + ) + expected_clamped_output = torch.tensor([[0.0, 0.0, 10.0, 10.0, 0.0]]) + else: + boxes = tv_tensors.BoundingBoxes( + [0, 100, 0, 100], format="XYXY", canvas_size=(10, 10), clamping_mode=constructor_clamping_mode + ) + expected_clamped_output = torch.tensor([[0, 10, 0, 10]]) + + if pass_pure_tensor: + out = fn( + boxes.as_subclass(torch.Tensor), + format=boxes.format, + canvas_size=boxes.canvas_size, + clamping_mode=clamping_mode, + ) + else: + out = fn(boxes, clamping_mode=clamping_mode) + + clamping_mode_prevailing = constructor_clamping_mode if clamping_mode == "auto" else clamping_mode + if clamping_mode_prevailing is None: + assert_equal(boxes, out) # should be a pass-through + else: + assert_equal(out, expected_clamped_output) + + +class TestSetClampingMode: + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) + @pytest.mark.parametrize("constructor_clamping_mode", ("soft", "hard", None)) + @pytest.mark.parametrize("desired_clamping_mode", ("soft", "hard", None)) + def test_setter(self, format, constructor_clamping_mode, desired_clamping_mode): + + in_boxes = make_bounding_boxes(format=format, clamping_mode=constructor_clamping_mode) + out_boxes = transforms.SetClampingMode(clamping_mode=desired_clamping_mode)(in_boxes) + + assert in_boxes.clamping_mode == constructor_clamping_mode # input is unchanged: no leak + assert out_boxes.clamping_mode == desired_clamping_mode + + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) + @pytest.mark.parametrize("constructor_clamping_mode", ("soft", "hard", None)) + def test_pipeline_no_leak(self, format, constructor_clamping_mode): + class AssertClampingMode(transforms.Transform): + def __init__(self, expected_clamping_mode): + super().__init__() + self.expected_clamping_mode = expected_clamping_mode + + _transformed_types = (tv_tensors.BoundingBoxes,) + + def transform(self, inpt, _): + assert inpt.clamping_mode == self.expected_clamping_mode + return inpt + + t = transforms.Compose( + [ + transforms.SetClampingMode(None), + AssertClampingMode(None), + transforms.SetClampingMode("hard"), + AssertClampingMode("hard"), + transforms.SetClampingMode(None), + AssertClampingMode(None), + transforms.ClampBoundingBoxes("hard"), + ] + ) + + in_boxes = make_bounding_boxes(format=format, clamping_mode=constructor_clamping_mode) + out_boxes = t(in_boxes) + + assert in_boxes.clamping_mode == constructor_clamping_mode # input is unchanged: no leak + + # assert that the output boxes clamping_mode is the one set by the last SetClampingMode. + # ClampBoundingBoxes doesn't set clamping_mode. + assert out_boxes.clamping_mode is None + + def test_error(self): + with pytest.raises(ValueError, match="clamping_mode must be"): + transforms.SetClampingMode("bad") + class TestClampKeyPoints: @pytest.mark.parametrize("dtype", [torch.int64, torch.float32]) @@ -6834,14 +6971,11 @@ def test_classification_preset(image_type, label_type, dataset_return_type, to_t @pytest.mark.parametrize("input_size", [(17, 11), (11, 17), (11, 11)]) -@pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) -def test_parallelogram_to_bounding_boxes(input_size, dtype, device): +def test_parallelogram_to_bounding_boxes(input_size, device): # Assert that applying `_parallelogram_to_bounding_boxes` to rotated boxes # does not modify the input. - bounding_boxes = make_bounding_boxes( - input_size, format=tv_tensors.BoundingBoxFormat.XYXYXYXY, dtype=dtype, device=device - ) + bounding_boxes = make_bounding_boxes(input_size, format=tv_tensors.BoundingBoxFormat.XYXYXYXY, device=device) actual = _parallelogram_to_bounding_boxes(bounding_boxes) torch.testing.assert_close(actual, bounding_boxes, rtol=0, atol=1) diff --git a/test/test_tv_tensors.py b/test/test_tv_tensors.py index 43efceba5c9..f9d545eb9c9 100644 --- a/test/test_tv_tensors.py +++ b/test/test_tv_tensors.py @@ -69,15 +69,39 @@ def test_bbox_instance(data, format): ) @pytest.mark.parametrize("scripted", (False, True)) def test_bbox_format(format, is_rotated_expected, scripted): - if isinstance(format, str): - format = tv_tensors.BoundingBoxFormat[(format.upper())] - fn = tv_tensors.is_rotated_bounding_format if scripted: fn = torch.jit.script(fn) assert fn(format) == is_rotated_expected +@pytest.mark.parametrize( + "format, support_integer_dtype", + [ + ("XYXY", True), + ("XYWH", True), + ("CXCYWH", True), + ("XYXYXYXY", False), + ("XYWHR", False), + ("CXCYWHR", False), + (tv_tensors.BoundingBoxFormat.XYXY, True), + (tv_tensors.BoundingBoxFormat.XYWH, True), + (tv_tensors.BoundingBoxFormat.CXCYWH, True), + (tv_tensors.BoundingBoxFormat.XYXYXYXY, False), + (tv_tensors.BoundingBoxFormat.XYWHR, False), + (tv_tensors.BoundingBoxFormat.CXCYWHR, False), + ], +) +@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) +def test_bbox_format_dtype(format, support_integer_dtype, input_dtype): + tensor = torch.randint(0, 32, size=(5, 2), dtype=input_dtype) + if not input_dtype.is_floating_point and not support_integer_dtype: + with pytest.raises(ValueError, match="Rotated bounding boxes should be floating point tensors"): + tv_tensors.BoundingBoxes(tensor, format=format, canvas_size=(32, 32)) + else: + tv_tensors.BoundingBoxes(tensor, format=format, canvas_size=(32, 32)) + + def test_bbox_dim_error(): data_3d = [[[1, 2, 3, 4]]] with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"): @@ -406,3 +430,16 @@ def test_return_type_input(): tv_tensors.set_return_type("typo") tv_tensors.set_return_type("tensor") + + +def test_box_clamping_mode_default_and_error(): + assert ( + tv_tensors.BoundingBoxes([0.0, 0.0, 10.0, 10.0], format="XYXY", canvas_size=(100, 100)).clamping_mode == "soft" + ) + assert ( + tv_tensors.BoundingBoxes([0.0, 0.0, 10.0, 10.0, 0.0], format="XYWHR", canvas_size=(100, 100)).clamping_mode + == "soft" + ) + + with pytest.raises(ValueError, match="clamping_mode must be"): + tv_tensors.BoundingBoxes([0, 0, 10, 10], format="XYXY", canvas_size=(100, 100), clamping_mode="bad") diff --git a/test/test_utils.py b/test/test_utils.py index 000798a0609..8b6f357ce6e 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -177,6 +177,17 @@ def test_draw_rotated_boxes(): assert_equal(result, expected) +@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1") +def test_draw_rotated_boxes_fill(): + img = torch.full((3, 500, 500), 255, dtype=torch.uint8) + colors = ["blue", "yellow", (0, 255, 0), "black"] + + result = utils.draw_bounding_boxes(img, rotated_boxes, colors=colors, fill=True) + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_rotated_boxes_fill.png") + expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) + assert_equal(result, expected) + + @pytest.mark.parametrize("fill", [True, False]) def test_draw_boxes_dtypes(fill): img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8) diff --git a/torchvision/datasets/fakedata.py b/torchvision/datasets/fakedata.py index 4f2343100a6..bcb413cdd32 100644 --- a/torchvision/datasets/fakedata.py +++ b/torchvision/datasets/fakedata.py @@ -11,7 +11,7 @@ class FakeData(VisionDataset): Args: size (int, optional): Size of the dataset. Default: 1000 images - image_size(tuple, optional): Size if the returned images. Default: (3, 224, 224) + image_size(tuple, optional): Size of the returned images. Default: (3, 224, 224) num_classes(int, optional): Number of classes in the dataset. Default: 10 transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 82a131d6fbc..408065dab94 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -41,7 +41,7 @@ ScaleJitter, TenCrop, ) -from ._meta import ClampBoundingBoxes, ClampKeyPoints, ConvertBoundingBoxFormat +from ._meta import ClampBoundingBoxes, ClampKeyPoints, ConvertBoundingBoxFormat, SetClampingMode from ._misc import ( ConvertImageDtype, GaussianBlur, diff --git a/torchvision/transforms/v2/_meta.py b/torchvision/transforms/v2/_meta.py index 1e3d9be2f28..68395b468ba 100644 --- a/torchvision/transforms/v2/_meta.py +++ b/torchvision/transforms/v2/_meta.py @@ -2,6 +2,7 @@ from torchvision import tv_tensors from torchvision.transforms.v2 import functional as F, Transform +from torchvision.tv_tensors._bounding_boxes import CLAMPING_MODE_TYPE class ConvertBoundingBoxFormat(Transform): @@ -28,12 +29,19 @@ class ClampBoundingBoxes(Transform): The clamping is done according to the bounding boxes' ``canvas_size`` meta-data. + Args: + clamping_mode: TODOBB more docs. Default is None which relies on the input box' clamping_mode attribute. + """ + def __init__(self, clamping_mode: Union[CLAMPING_MODE_TYPE, str] = "auto") -> None: + super().__init__() + self.clamping_mode = clamping_mode + _transformed_types = (tv_tensors.BoundingBoxes,) def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> tv_tensors.BoundingBoxes: - return F.clamp_bounding_boxes(inpt) # type: ignore[return-value] + return F.clamp_bounding_boxes(inpt, clamping_mode=self.clamping_mode) # type: ignore[return-value] class ClampKeyPoints(Transform): @@ -46,3 +54,21 @@ class ClampKeyPoints(Transform): def transform(self, inpt: tv_tensors.KeyPoints, params: dict[str, Any]) -> tv_tensors.KeyPoints: return F.clamp_keypoints(inpt) # type: ignore[return-value] + + +class SetClampingMode(Transform): + """TODOBB""" + + def __init__(self, clamping_mode: CLAMPING_MODE_TYPE) -> None: + super().__init__() + self.clamping_mode = clamping_mode + + if self.clamping_mode not in (None, "soft", "hard"): + raise ValueError(f"clamping_mode must be soft, hard or None, got {clamping_mode}") + + _transformed_types = (tv_tensors.BoundingBoxes,) + + def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> tv_tensors.BoundingBoxes: + out: tv_tensors.BoundingBoxes = inpt.clone() # type: ignore[assignment] + out.clamping_mode = self.clamping_mode + return out diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 7e9766bdaf5..1c9ce3f6df0 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -20,6 +20,7 @@ pil_to_tensor, to_pil_image, ) +from torchvision.tv_tensors._bounding_boxes import CLAMPING_MODE_TYPE from torchvision.utils import _log_api_usage_once @@ -103,16 +104,10 @@ def horizontal_flip_bounding_boxes( bounding_boxes[:, 0::2].sub_(canvas_size[1]).neg_() bounding_boxes = bounding_boxes[:, [2, 3, 0, 1, 6, 7, 4, 5]] elif format == tv_tensors.BoundingBoxFormat.XYWHR: - - dtype = bounding_boxes.dtype - if not torch.is_floating_point(bounding_boxes): - # Casting to float to support cos and sin computations. - bounding_boxes = bounding_boxes.to(torch.float32) angle_rad = bounding_boxes[:, 4].mul(torch.pi).div(180) bounding_boxes[:, 0].add_(bounding_boxes[:, 2].mul(angle_rad.cos())).sub_(canvas_size[1]).neg_() bounding_boxes[:, 1].sub_(bounding_boxes[:, 2].mul(angle_rad.sin())) bounding_boxes[:, 4].neg_() - bounding_boxes = bounding_boxes.to(dtype) else: # format == tv_tensors.BoundingBoxFormat.CXCYWHR: bounding_boxes[:, 0].sub_(canvas_size[1]).neg_() bounding_boxes[:, 4].neg_() @@ -191,15 +186,10 @@ def vertical_flip_bounding_boxes( bounding_boxes[:, 1::2].sub_(canvas_size[0]).neg_() bounding_boxes = bounding_boxes[:, [2, 3, 0, 1, 6, 7, 4, 5]] elif format == tv_tensors.BoundingBoxFormat.XYWHR: - dtype = bounding_boxes.dtype - if not torch.is_floating_point(bounding_boxes): - # Casting to float to support cos and sin computations. - bounding_boxes = bounding_boxes.to(torch.float64) angle_rad = bounding_boxes[:, 4].mul(torch.pi).div(180) bounding_boxes[:, 1].sub_(bounding_boxes[:, 2].mul(angle_rad.sin())).sub_(canvas_size[0]).neg_() bounding_boxes[:, 0].add_(bounding_boxes[:, 2].mul(angle_rad.cos())) bounding_boxes[:, 4].neg_().add_(180) - bounding_boxes = bounding_boxes.to(dtype) else: # format == tv_tensors.BoundingBoxFormat.CXCYWHR: bounding_boxes[:, 1].sub_(canvas_size[0]).neg_() bounding_boxes[:, 4].neg_().add_(180) @@ -461,19 +451,6 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso torch.Tensor: Tensor of same shape as input containing the rectangle coordinates. The output maintains the same dtype as the input. """ - dtype = parallelogram.dtype - int_dtype = dtype in ( - torch.uint8, - torch.int8, - torch.int16, - torch.int32, - torch.int64, - ) - if int_dtype: - # Does not apply the transformation to `int` boxes as the rounding error - # will typically not ensure the resulting box has a rectangular shape. - return parallelogram.clone() - out_boxes = parallelogram.clone() # Calculate parallelogram diagonal vectors @@ -498,8 +475,8 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso diag24 * torch.abs(torch.sin(torch.atan2(dx42, dy42) - r_rad)), ) - delta_x = torch.round(w * cos).to(dtype) if int_dtype else w * cos - delta_y = torch.round(w * sin).to(dtype) if int_dtype else w * sin + delta_x = w * cos + delta_y = w * sin # Update coordinates to form a rectangle # Keeping the points (x1, y1) and (x3, y3) unchanged. out_boxes[..., 2] = torch.where(mask, parallelogram[..., 0] + delta_x, parallelogram[..., 2]) @@ -521,6 +498,7 @@ def resize_bounding_boxes( size: Optional[list[int]], max_size: Optional[int] = None, format: tv_tensors.BoundingBoxFormat = tv_tensors.BoundingBoxFormat.XYXY, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: # We set the default format as `tv_tensors.BoundingBoxFormat.XYXY` # to ensure backward compatibility. @@ -546,7 +524,10 @@ def resize_bounding_boxes( transformed_points = xyxyxyxy_boxes.mul(ratios) out_bboxes = _parallelogram_to_bounding_boxes(transformed_points) out_bboxes = clamp_bounding_boxes( - out_bboxes, format=tv_tensors.BoundingBoxFormat.XYXYXYXY, canvas_size=(new_height, new_width) + out_bboxes, + format=tv_tensors.BoundingBoxFormat.XYXYXYXY, + canvas_size=(new_height, new_width), + clamping_mode=clamping_mode, ) return ( convert_bounding_box_format( @@ -572,7 +553,12 @@ def _resize_bounding_boxes_dispatch( inpt: tv_tensors.BoundingBoxes, size: Optional[list[int]], max_size: Optional[int] = None, **kwargs: Any ) -> tv_tensors.BoundingBoxes: output, canvas_size = resize_bounding_boxes( - inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, size=size, max_size=max_size + inpt.as_subclass(torch.Tensor), + format=inpt.format, + canvas_size=inpt.canvas_size, + size=size, + max_size=max_size, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -1098,6 +1084,7 @@ def _affine_bounding_boxes_with_expand( shear: list[float], center: Optional[list[float]] = None, expand: bool = False, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: if bounding_boxes.numel() == 0: return bounding_boxes, canvas_size @@ -1176,14 +1163,14 @@ def _affine_bounding_boxes_with_expand( new_width, new_height = _compute_affine_output_size(affine_vector, width, height) canvas_size = (new_height, new_width) - out_bboxes = clamp_bounding_boxes(out_bboxes, format=intermediate_format, canvas_size=canvas_size) + out_bboxes = clamp_bounding_boxes( + out_bboxes, format=intermediate_format, canvas_size=canvas_size, clamping_mode=clamping_mode + ) out_bboxes = convert_bounding_box_format( out_bboxes, old_format=intermediate_format, new_format=format, inplace=True ).reshape(original_shape) if need_cast: - if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): - out_bboxes.round_() out_bboxes = out_bboxes.to(dtype) return out_bboxes, canvas_size @@ -1197,6 +1184,7 @@ def affine_bounding_boxes( scale: float, shear: list[float], center: Optional[list[float]] = None, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> torch.Tensor: out_box, _ = _affine_bounding_boxes_with_expand( bounding_boxes, @@ -1208,6 +1196,7 @@ def affine_bounding_boxes( shear=shear, center=center, expand=False, + clamping_mode=clamping_mode, ) return out_box @@ -1231,6 +1220,7 @@ def _affine_bounding_boxes_dispatch( scale=scale, shear=shear, center=center, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt) @@ -1432,6 +1422,7 @@ def rotate_bounding_boxes( angle: float, expand: bool = False, center: Optional[list[float]] = None, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: return _affine_bounding_boxes_with_expand( bounding_boxes, @@ -1443,6 +1434,7 @@ def rotate_bounding_boxes( shear=[0.0, 0.0], center=center, expand=expand, + clamping_mode=clamping_mode, ) @@ -1457,6 +1449,7 @@ def _rotate_bounding_boxes_dispatch( angle=angle, expand=expand, center=center, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -1723,6 +1716,7 @@ def pad_bounding_boxes( canvas_size: tuple[int, int], padding: list[int], padding_mode: str = "constant", + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: if padding_mode not in ["constant"]: # TODO: add support of other padding modes @@ -1745,7 +1739,10 @@ def pad_bounding_boxes( width += left + right canvas_size = (height, width) - return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size + return ( + clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode), + canvas_size, + ) @_register_kernel_internal(pad, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) @@ -1758,6 +1755,7 @@ def _pad_bounding_boxes_dispatch( canvas_size=inpt.canvas_size, padding=padding, padding_mode=padding_mode, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -1836,6 +1834,7 @@ def crop_bounding_boxes( left: int, height: int, width: int, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: # Crop or implicit pad if left and/or top have negative values: @@ -1854,7 +1853,10 @@ def crop_bounding_boxes( if format == tv_tensors.BoundingBoxFormat.XYXYXYXY: bounding_boxes = _parallelogram_to_bounding_boxes(bounding_boxes) - return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size + return ( + clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode), + canvas_size, + ) @_register_kernel_internal(crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) @@ -1862,7 +1864,13 @@ def _crop_bounding_boxes_dispatch( inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int ) -> tv_tensors.BoundingBoxes: output, canvas_size = crop_bounding_boxes( - inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width + inpt.as_subclass(torch.Tensor), + format=inpt.format, + top=top, + left=left, + height=height, + width=width, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -2066,6 +2074,7 @@ def perspective_bounding_boxes( startpoints: Optional[list[list[int]]], endpoints: Optional[list[list[int]]], coefficients: Optional[list[float]] = None, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> torch.Tensor: if bounding_boxes.numel() == 0: return bounding_boxes @@ -2130,7 +2139,9 @@ def perspective_bounding_boxes( out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) - out_bboxes = clamp_bounding_boxes(out_bboxes, format=intermediate_format, canvas_size=canvas_size) + out_bboxes = clamp_bounding_boxes( + out_bboxes, format=intermediate_format, canvas_size=canvas_size, clamping_mode=clamping_mode + ) out_bboxes = convert_bounding_box_format( out_bboxes, old_format=intermediate_format, new_format=format, inplace=True @@ -2185,6 +2196,7 @@ def _perspective_bounding_boxes_dispatch( startpoints=startpoints, endpoints=endpoints, coefficients=coefficients, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt) @@ -2377,6 +2389,7 @@ def elastic_bounding_boxes( format: tv_tensors.BoundingBoxFormat, canvas_size: tuple[int, int], displacement: torch.Tensor, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> torch.Tensor: expected_shape = (1, canvas_size[0], canvas_size[1], 2) if not isinstance(displacement, torch.Tensor): @@ -2397,11 +2410,11 @@ def elastic_bounding_boxes( original_shape = bounding_boxes.shape # TODO: first cast to float if bbox is int64 before convert_bounding_box_format - intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY + intermediate_format = tv_tensors.BoundingBoxFormat.CXCYWHR if is_rotated else tv_tensors.BoundingBoxFormat.XYXY bounding_boxes = ( convert_bounding_box_format(bounding_boxes.clone(), old_format=format, new_format=intermediate_format) - ).reshape(-1, 8 if is_rotated else 4) + ).reshape(-1, 5 if is_rotated else 4) id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype) # We construct an approximation of inverse grid as inv_grid = id_grid - displacement @@ -2409,7 +2422,7 @@ def elastic_bounding_boxes( inv_grid = id_grid.sub_(displacement) # Get points from bboxes - points = bounding_boxes if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]] + points = bounding_boxes[:, :2] if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]] points = points.reshape(-1, 2) if points.is_floating_point(): points = points.ceil_() @@ -2421,17 +2434,15 @@ def elastic_bounding_boxes( transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5) if is_rotated: - transformed_points = transformed_points.reshape(-1, 8) - out_bboxes = _parallelogram_to_bounding_boxes(transformed_points).to(bounding_boxes.dtype) + transformed_points = transformed_points.reshape(-1, 2) + out_bboxes = torch.cat([transformed_points, bounding_boxes[:, 2:]], dim=1).to(bounding_boxes.dtype) else: transformed_points = transformed_points.reshape(-1, 4, 2) out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype) out_bboxes = clamp_bounding_boxes( - out_bboxes, - format=intermediate_format, - canvas_size=canvas_size, + out_bboxes, format=intermediate_format, canvas_size=canvas_size, clamping_mode=clamping_mode ) return convert_bounding_box_format( @@ -2444,7 +2455,11 @@ def _elastic_bounding_boxes_dispatch( inpt: tv_tensors.BoundingBoxes, displacement: torch.Tensor, **kwargs ) -> tv_tensors.BoundingBoxes: output = elastic_bounding_boxes( - inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, displacement=displacement + inpt.as_subclass(torch.Tensor), + format=inpt.format, + canvas_size=inpt.canvas_size, + displacement=displacement, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt) @@ -2581,11 +2596,18 @@ def center_crop_bounding_boxes( format: tv_tensors.BoundingBoxFormat, canvas_size: tuple[int, int], output_size: list[int], + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size) return crop_bounding_boxes( - bounding_boxes, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width + bounding_boxes, + format, + top=crop_top, + left=crop_left, + height=crop_height, + width=crop_width, + clamping_mode=clamping_mode, ) @@ -2594,7 +2616,11 @@ def _center_crop_bounding_boxes_dispatch( inpt: tv_tensors.BoundingBoxes, output_size: list[int] ) -> tv_tensors.BoundingBoxes: output, canvas_size = center_crop_bounding_boxes( - inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size + inpt.as_subclass(torch.Tensor), + format=inpt.format, + canvas_size=inpt.canvas_size, + output_size=output_size, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -2741,9 +2767,14 @@ def resized_crop_bounding_boxes( height: int, width: int, size: list[int], + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: - bounding_boxes, canvas_size = crop_bounding_boxes(bounding_boxes, format, top, left, height, width) - return resize_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size, size=size) + bounding_boxes, canvas_size = crop_bounding_boxes( + bounding_boxes, format, top, left, height, width, clamping_mode=clamping_mode + ) + return resize_bounding_boxes( + bounding_boxes, format=format, canvas_size=canvas_size, size=size, clamping_mode=clamping_mode + ) @_register_kernel_internal(resized_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) @@ -2751,7 +2782,14 @@ def _resized_crop_bounding_boxes_dispatch( inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: list[int], **kwargs ) -> tv_tensors.BoundingBoxes: output, canvas_size = resized_crop_bounding_boxes( - inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size + inpt.as_subclass(torch.Tensor), + format=inpt.format, + top=top, + left=left, + height=height, + width=width, + size=size, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 1729aa4bbaf..6256a288203 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -1,10 +1,11 @@ -from typing import Optional +from typing import Optional, Union import PIL.Image import torch from torchvision import tv_tensors from torchvision.transforms import _functional_pil as _FP from torchvision.tv_tensors import BoundingBoxFormat +from torchvision.tv_tensors._bounding_boxes import CLAMPING_MODE_TYPE from torchvision.utils import _log_api_usage_once @@ -370,8 +371,13 @@ def convert_bounding_box_format( def _clamp_bounding_boxes( - bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: tuple[int, int] + bounding_boxes: torch.Tensor, + format: BoundingBoxFormat, + canvas_size: tuple[int, int], + clamping_mode: CLAMPING_MODE_TYPE, ) -> torch.Tensor: + if clamping_mode is None: + return bounding_boxes.clone() # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every # BoundingBoxFormat instead of converting back and forth in_dtype = bounding_boxes.dtype @@ -379,6 +385,7 @@ def _clamp_bounding_boxes( xyxy_boxes = convert_bounding_box_format( bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True ) + # hard and soft modes are equivalent for non-rotated boxes xyxy_boxes[..., 0::2].clamp_(min=0, max=canvas_size[1]) xyxy_boxes[..., 1::2].clamp_(min=0, max=canvas_size[0]) out_boxes = convert_bounding_box_format( @@ -409,23 +416,113 @@ def _order_bounding_boxes_points( if indices is None: output_xyxyxyxy = bounding_boxes.reshape(-1, 8) x, y = output_xyxyxyxy[..., 0::2], output_xyxyxyxy[..., 1::2] - y_max = torch.max(y, dim=1, keepdim=True)[0] - _, x1 = ((y_max - y) / y_max + (x + 1) * 100).min(dim=1) + y_max = torch.max(y.abs(), dim=1, keepdim=True)[0] + x_max = torch.max(x.abs(), dim=1, keepdim=True)[0] + _, x1 = (y / y_max + (x / x_max) * 100).min(dim=1) indices = torch.ones_like(output_xyxyxyxy) indices[..., 0] = x1.mul(2) indices.cumsum_(1).remainder_(8) return indices, bounding_boxes.gather(1, indices.to(torch.int64)) -def _area(box: torch.Tensor) -> torch.Tensor: - x1, y1, x2, y2, x3, y3, x4, y4 = box.reshape(-1, 8).unbind(-1) - w = torch.sqrt((y2 - y1) ** 2 + (x2 - x1) ** 2) - h = torch.sqrt((y3 - y2) ** 2 + (x3 - x2) ** 2) - return w * h +def _get_slope_and_intercept(box: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the slope and y-intercept of the lines defined by consecutive vertices in a bounding box. + This function computes the slope (a) and y-intercept (b) for each line segment in a bounding box, + where each line is defined by two consecutive vertices. + """ + x, y = box[..., ::2], box[..., 1::2] + a = y.diff(append=y[..., 0:1]) / x.diff(append=x[..., 0:1]) + b = y - a * x + return a, b + + +def _get_intersection_point(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Calculate the intersection point of two lines defined by their slopes and y-intercepts. + This function computes the intersection points between pairs of lines, where each line + is defined by the equation y = ax + b (slope and y-intercept form). + """ + batch_size = a.shape[0] + x = b.diff(prepend=b[..., 3:4]).neg() / a.diff(prepend=a[..., 3:4]) + y = a * x + b + return torch.cat((x.unsqueeze(-1), y.unsqueeze(-1)), dim=-1).view(batch_size, 8) + + +def _clamp_y_intercept( + bounding_boxes: torch.Tensor, + original_bounding_boxes: torch.Tensor, + canvas_size: tuple[int, int], + clamping_mode: CLAMPING_MODE_TYPE, +) -> torch.Tensor: + """ + Apply clamping to bounding box y-intercepts. This function handles two clamping strategies: + - Hard clamping: Ensures all box vertices stay within canvas boundaries, finding the largest + angle-preserving box enclosed within the original box and the image canvas. + - Soft clamping: Allows some vertices to extend beyond the canvas, finding the smallest + angle-preserving box that encloses the intersection of the original box and the image canvas. + + The function first calculates the slopes and y-intercepts of the lines forming the bounding box, + then applies various constraints to ensure the clamping conditions are respected. + """ + + # Calculate slopes and y-intercepts for bounding boxes + a, b = _get_slope_and_intercept(bounding_boxes) + a1, a2, a3, a4 = a.unbind(-1) + b1, b2, b3, b4 = b.unbind(-1) + + # Get y-intercepts from original bounding boxes + _, bm = _get_slope_and_intercept(original_bounding_boxes) + b1m, b2m, b3m, b4m = bm.unbind(-1) + + # Soft clamping: Clamp y-intercepts within canvas boundaries + b1 = b2.clamp(b1, b3).clamp(0, canvas_size[0]) + b4 = b3.clamp(b2, b4).clamp(0, canvas_size[0]) + + if clamping_mode is not None and clamping_mode == "hard": + # Hard clamping: Average b1 and b4, and adjust b2 and b3 for maximum area + b1 = b4 = (b1 + b4) / 2 + + # Calculate candidate values for b2 based on geometric constraints + b2_candidates = torch.stack( + [ + b1 * a2 / a1, # Constraint at y=0 + b3 * a2 / a3, # Constraint at y=0 + (a1 - a2) * canvas_size[1] + b1, # Constraint at x=canvas_width + (a3 - a2) * canvas_size[1] + b3, # Constraint at x=canvas_width + ], + dim=1, + ) + # Take maximum value that doesn't exceed original b2 + b2 = torch.max(b2_candidates, dim=1)[0].clamp(max=b2) + + # Calculate candidate values for b3 based on geometric constraints + b3_candidates = torch.stack( + [ + canvas_size[0] * (1 - a3 / a4) + b4 * a3 / a4, # Constraint at y=canvas_height + canvas_size[0] * (1 - a3 / a2) + b2 * a3 / a2, # Constraint at y=canvas_height + (a2 - a3) * canvas_size[1] + b2, # Constraint at x=canvas_width + (a4 - a3) * canvas_size[1] + b4, # Constraint at x=canvas_width + ], + dim=1, + ) + # Take minimum value that doesn't go below original b3 + b3 = torch.min(b3_candidates, dim=1)[0].clamp(min=b3) + + # Final clamping to ensure y-intercepts are within original box bounds + b1.clamp_(b1m, b3m) + b3.clamp_(b1m, b3m) + b2.clamp_(b2m, b4m) + b4.clamp_(b2m, b4m) + + return torch.stack([b1, b2, b3, b4], dim=-1) def _clamp_along_y_axis( bounding_boxes: torch.Tensor, + original_bounding_boxes: torch.Tensor, + canvas_size: tuple[int, int], + clamping_mode: CLAMPING_MODE_TYPE, ) -> torch.Tensor: """ Adjusts bounding boxes along the y-axis based on specific conditions. @@ -436,48 +533,47 @@ def _clamp_along_y_axis( Args: bounding_boxes (torch.Tensor): A tensor containing bounding box coordinates. + original_bounding_boxes (torch.Tensor): The original bounding boxes before any clamping is applied. + canvas_size (tuple[int, int]): The size of the canvas as (height, width). + clamping_mode (str, optional): The clamping strategy to use. Returns: torch.Tensor: The adjusted bounding boxes. """ - original_dtype = bounding_boxes.dtype original_shape = bounding_boxes.shape - x1, y1, x2, y2, x3, y3, x4, y4 = bounding_boxes.reshape(-1, 8).unbind(-1) - a = (y2 - y1) / (x2 - x1) - b1 = y1 - a * x1 - b2 = y2 + x2 / a - b3 = y3 - a * x3 - b4 = y4 + x4 / a - b23 = (b2 - b3) / 2 * a / (1 + a**2) - z = torch.zeros_like(b1) - case_a = torch.cat([x.unsqueeze(1) for x in [z, b1, x2, y2, x3, y3, x3 - x2, y3 + b1 - y2]], dim=1) - case_b = torch.cat([x.unsqueeze(1) for x in [z, b4, x2 - x1, y2 - y1 + b4, x3, y3, x4, y4]], dim=1) - case_c = torch.cat( - [x.unsqueeze(1) for x in [z, (b2 + b3) / 2, b23, -b23 / a + b2, x3, y3, b23, b23 * a + b3]], dim=1 - ) - case_d = torch.zeros_like(case_c) - case_e = torch.cat([x.unsqueeze(1) for x in [x1.clamp(0), y1, x2.clamp(0), y2, x3, y3, x4, y4]], dim=1) - - cond_a = (x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 >= 0) - cond_a = cond_a.logical_and(_area(case_a) > _area(case_b)) - cond_a = cond_a.logical_or((x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 <= 0)) - cond_b = (x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 >= 0) - cond_b = cond_b.logical_and(_area(case_a) <= _area(case_b)) - cond_b = cond_b.logical_or((x1 < 0).logical_and(x2 <= 0).logical_and(x3 >= 0).logical_and(x4 >= 0)) - cond_c = (x1 < 0).logical_and(x2 <= 0).logical_and(x3 >= 0).logical_and(x4 <= 0) - cond_d = (x1 < 0).logical_and(x2 <= 0).logical_and(x3 <= 0).logical_and(x4 <= 0) - cond_e = x1.isclose(x2) - - for cond, case in zip( - [cond_a, cond_b, cond_c, cond_d, cond_e], - [case_a, case_b, case_c, case_d, case_e], + bounding_boxes = bounding_boxes.reshape(-1, 8) + original_bounding_boxes = original_bounding_boxes.reshape(-1, 8) + + # Calculate slopes (a) and y-intercepts (b) for all lines in the bounding boxes + a, b = _get_slope_and_intercept(bounding_boxes) + x1, y1, x2, y2, x3, y3, x4, y4 = bounding_boxes.unbind(-1) + b = _clamp_y_intercept(bounding_boxes, original_bounding_boxes, canvas_size, clamping_mode) + + case_a = _get_intersection_point(a, b) + case_b = bounding_boxes.clone() + case_b[..., 0].clamp_(0) # Clamp x1 to 0 + case_b[..., 6].clamp_(0) # Clamp x4 to 0 + case_c = torch.zeros_like(case_b) + + cond_a = (x1 < 0) & ~case_a.isnan().any(-1) # First point is outside left boundary + cond_b = y1.isclose(y2) | y3.isclose(y4) # First line is nearly vertical + cond_c = (x1 <= 0) & (x2 <= 0) & (x3 <= 0) & (x4 <= 0) # All points outside left boundary + cond_c = cond_c | y1.isclose(y4) | y2.isclose(y3) | (cond_b & x1.isclose(x2)) # First line is nearly horizontal + + for (cond, case) in zip( + [cond_a, cond_b, cond_c], + [case_a, case_b, case_c], ): bounding_boxes = torch.where(cond.unsqueeze(1).repeat(1, 8), case.reshape(-1, 8), bounding_boxes) - return bounding_boxes.to(original_dtype).reshape(original_shape) + + return bounding_boxes.reshape(original_shape) def _clamp_rotated_bounding_boxes( - bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: tuple[int, int] + bounding_boxes: torch.Tensor, + format: BoundingBoxFormat, + canvas_size: tuple[int, int], + clamping_mode: CLAMPING_MODE_TYPE, ) -> torch.Tensor: """ Clamp rotated bounding boxes to ensure they stay within the canvas boundaries. @@ -499,36 +595,38 @@ def _clamp_rotated_bounding_boxes( Returns: torch.Tensor: Clamped bounding boxes in the original format and shape """ + if clamping_mode is None: + return bounding_boxes.clone() original_shape = bounding_boxes.shape - dtype = bounding_boxes.dtype - acceptable_dtypes = [torch.float64] # Ensure consistency between CPU and GPU. - need_cast = dtype not in acceptable_dtypes - bounding_boxes = bounding_boxes.to(torch.float64) if need_cast else bounding_boxes.clone() + bounding_boxes = bounding_boxes.clone() out_boxes = ( convert_bounding_box_format( bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, inplace=True ) ).reshape(-1, 8) + original_boxes = out_boxes.clone() for _ in range(4): # Iterate over the 4 vertices. indices, out_boxes = _order_bounding_boxes_points(out_boxes) - out_boxes = _clamp_along_y_axis(out_boxes) + _, original_boxes = _order_bounding_boxes_points(original_boxes, indices) + out_boxes = _clamp_along_y_axis(out_boxes, original_boxes, canvas_size, clamping_mode) _, out_boxes = _order_bounding_boxes_points(out_boxes, indices) + _, original_boxes = _order_bounding_boxes_points(original_boxes, indices) # rotate 90 degrees counter clock wise out_boxes[:, ::2], out_boxes[:, 1::2] = ( out_boxes[:, 1::2].clone(), canvas_size[1] - out_boxes[:, ::2].clone(), ) + original_boxes[:, ::2], original_boxes[:, 1::2] = ( + original_boxes[:, 1::2].clone(), + canvas_size[1] - original_boxes[:, ::2].clone(), + ) canvas_size = (canvas_size[1], canvas_size[0]) out_boxes = convert_bounding_box_format( out_boxes, old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, new_format=format, inplace=True ).reshape(original_shape) - if need_cast: - if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): - out_boxes.round_() - out_boxes = out_boxes.to(dtype) return out_boxes @@ -536,29 +634,43 @@ def clamp_bounding_boxes( inpt: torch.Tensor, format: Optional[BoundingBoxFormat] = None, canvas_size: Optional[tuple[int, int]] = None, + clamping_mode: Union[CLAMPING_MODE_TYPE, str] = "auto", ) -> torch.Tensor: """See :func:`~torchvision.transforms.v2.ClampBoundingBoxes` for details.""" if not torch.jit.is_scripting(): _log_api_usage_once(clamp_bounding_boxes) + if clamping_mode is not None and clamping_mode not in ("soft", "hard", "auto"): + raise ValueError(f"clamping_mode must be soft, hard, auto or None, got {clamping_mode}") + if torch.jit.is_scripting() or is_pure_tensor(inpt): - if format is None or canvas_size is None: - raise ValueError("For pure tensor inputs, `format` and `canvas_size` have to be passed.") + if format is None or canvas_size is None or (clamping_mode is not None and clamping_mode == "auto"): + raise ValueError("For pure tensor inputs, `format`, `canvas_size` and `clamping_mode` have to be passed.") if tv_tensors.is_rotated_bounding_format(format): - return _clamp_rotated_bounding_boxes(inpt, format=format, canvas_size=canvas_size) + return _clamp_rotated_bounding_boxes( + inpt, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode + ) else: - return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size) + return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode) elif isinstance(inpt, tv_tensors.BoundingBoxes): if format is not None or canvas_size is not None: raise ValueError("For bounding box tv_tensor inputs, `format` and `canvas_size` must not be passed.") + if clamping_mode is not None and clamping_mode == "auto": + clamping_mode = inpt.clamping_mode if tv_tensors.is_rotated_bounding_format(inpt.format): output = _clamp_rotated_bounding_boxes( - inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size + inpt.as_subclass(torch.Tensor), + format=inpt.format, + canvas_size=inpt.canvas_size, + clamping_mode=clamping_mode, ) else: output = _clamp_bounding_boxes( - inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size + inpt.as_subclass(torch.Tensor), + format=inpt.format, + canvas_size=inpt.canvas_size, + clamping_mode=clamping_mode, ) return tv_tensors.wrap(output, like=inpt) else: diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index 40cd70d16cb..744e5241135 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -23,7 +23,7 @@ def wrap(wrappee, *, like, **kwargs): wrappee (Tensor): The tensor to convert. like (:class:`~torchvision.tv_tensors.TVTensor`): The reference. ``wrappee`` will be converted into the same subclass as ``like``. - kwargs: Can contain "format" and "canvas_size" if ``like`` is a :class:`~torchvision.tv_tensor.BoundingBoxes`. + kwargs: Can contain "format", "canvas_size" and "clamping_mode" if ``like`` is a :class:`~torchvision.tv_tensor.BoundingBoxes`. Ignored otherwise. """ if isinstance(like, BoundingBoxes): @@ -31,6 +31,7 @@ def wrap(wrappee, *, like, **kwargs): wrappee, format=kwargs.get("format", like.format), canvas_size=kwargs.get("canvas_size", like.canvas_size), + clamping_mode=kwargs.get("clamping_mode", like.clamping_mode), ) elif isinstance(like, KeyPoints): return KeyPoints._wrap(wrappee, canvas_size=kwargs.get("canvas_size", like.canvas_size)) diff --git a/torchvision/tv_tensors/_bounding_boxes.py b/torchvision/tv_tensors/_bounding_boxes.py index e661eaf8d73..e4963192671 100644 --- a/torchvision/tv_tensors/_bounding_boxes.py +++ b/torchvision/tv_tensors/_bounding_boxes.py @@ -3,7 +3,7 @@ from collections.abc import Mapping, Sequence from enum import Enum -from typing import Any +from typing import Any, Optional import torch from torch.utils._pytree import tree_flatten @@ -40,10 +40,23 @@ class BoundingBoxFormat(Enum): # TODO: Once torchscript supports Enums with staticmethod # this can be put into BoundingBoxFormat as staticmethod -def is_rotated_bounding_format(format: BoundingBoxFormat) -> bool: - return ( - format == BoundingBoxFormat.XYWHR or format == BoundingBoxFormat.CXCYWHR or format == BoundingBoxFormat.XYXYXYXY - ) +def is_rotated_bounding_format(format: BoundingBoxFormat | str) -> bool: + if isinstance(format, BoundingBoxFormat): + return ( + format == BoundingBoxFormat.XYWHR + or format == BoundingBoxFormat.CXCYWHR + or format == BoundingBoxFormat.XYXYXYXY + ) + elif isinstance(format, str): + return format in ("XYWHR", "CXCYWHR", "XYXYXYXY") + else: + raise ValueError(f"format should be str or BoundingBoxFormat, got {type(format)}") + + +# This should ideally be a Literal, but torchscript fails. +CLAMPING_MODE_TYPE = Optional[str] + +# TODOBB All docs. Add any new API to rst files, add tutorial[s]. class BoundingBoxes(TVTensor): @@ -62,6 +75,7 @@ class BoundingBoxes(TVTensor): data: Any data that can be turned into a tensor with :func:`torch.as_tensor`. format (BoundingBoxFormat, str): Format of the bounding box. canvas_size (two-tuple of ints): Height and width of the corresponding image or video. + clamping_mode: TODOBB dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from ``data``. device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a @@ -72,19 +86,25 @@ class BoundingBoxes(TVTensor): format: BoundingBoxFormat canvas_size: tuple[int, int] + clamping_mode: CLAMPING_MODE_TYPE @classmethod - def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat | str, canvas_size: tuple[int, int], check_dims: bool = True) -> BoundingBoxes: # type: ignore[override] + def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat | str, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE = "soft", check_dims: bool = True) -> BoundingBoxes: # type: ignore[override] if check_dims: if tensor.ndim == 1: tensor = tensor.unsqueeze(0) elif tensor.ndim != 2: raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D") + if clamping_mode is not None and clamping_mode not in ("hard", "soft"): + raise ValueError(f"clamping_mode must be None, hard or soft, got {clamping_mode}.") + if isinstance(format, str): format = BoundingBoxFormat[format.upper()] + bounding_boxes = tensor.as_subclass(cls) bounding_boxes.format = format bounding_boxes.canvas_size = canvas_size + bounding_boxes.clamping_mode = clamping_mode return bounding_boxes def __new__( @@ -93,12 +113,15 @@ def __new__( *, format: BoundingBoxFormat | str, canvas_size: tuple[int, int], + clamping_mode: CLAMPING_MODE_TYPE = "soft", dtype: torch.dtype | None = None, device: torch.device | str | int | None = None, requires_grad: bool | None = None, ) -> BoundingBoxes: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) - return cls._wrap(tensor, format=format, canvas_size=canvas_size) + if not torch.is_floating_point(tensor) and is_rotated_bounding_format(format): + raise ValueError(f"Rotated bounding boxes should be floating point tensors, got {tensor.dtype}.") + return cls._wrap(tensor, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode) @classmethod def _wrap_output( @@ -114,16 +137,25 @@ def _wrap_output( # something like some_xyxy_bbox + some_xywh_bbox; we don't guard against those cases. flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) # type: ignore[operator] first_bbox_from_args = next(x for x in flat_params if isinstance(x, BoundingBoxes)) - format, canvas_size = first_bbox_from_args.format, first_bbox_from_args.canvas_size + format, canvas_size, clamping_mode = ( + first_bbox_from_args.format, + first_bbox_from_args.canvas_size, + first_bbox_from_args.clamping_mode, + ) if isinstance(output, torch.Tensor) and not isinstance(output, BoundingBoxes): - output = BoundingBoxes._wrap(output, format=format, canvas_size=canvas_size, check_dims=False) + output = BoundingBoxes._wrap( + output, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False + ) elif isinstance(output, (tuple, list)): # This branch exists for chunk() and unbind() output = type(output)( - BoundingBoxes._wrap(part, format=format, canvas_size=canvas_size, check_dims=False) for part in output + BoundingBoxes._wrap( + part, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False + ) + for part in output ) return output def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] - return self._make_repr(format=self.format, canvas_size=self.canvas_size) + return self._make_repr(format=self.format, canvas_size=self.canvas_size, clamping_mode=self.clamping_mode) diff --git a/torchvision/utils.py b/torchvision/utils.py index eec7d21293f..050cc51d893 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -136,7 +136,7 @@ def oriented_rectangle(self, xy, fill=None, outline=None, width=1): width=width, fill=outline, ) - self.rectangle(xy, fill=fill, outline=None, width=0) + self.polygon(xy, fill=fill, outline=None, width=0) def dashed_line(self, xy, fill=None, width=0, joint=None, dash_length=5, space_length=5): # Calculate the total length of the line