diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 3ce603c3ed2..b2bf79a3b0b 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,6 +21,7 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, @@ -41,7 +42,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors @@ -5449,7 +5449,18 @@ def test_kernel_image(self, dtype, device): def test_kernel_video(self): check_kernel(F.equalize_image, make_video()) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image_pil, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) def test_functional(self, make_input): check_functional(F.equalize, make_input()) @@ -5460,33 +5471,68 @@ def test_functional(self, make_input): (F._color._equalize_image_pil, PIL.Image.Image), (F.equalize_image, tv_tensors.Image), (F.equalize_video, tv_tensors.Video), + pytest.param( + F._color._equalize_image_cvcuda, + None, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if kernel is F._color._equalize_image_cvcuda: + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.equalize, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image_pil, make_image, make_video], + [ + make_image_tensor, + make_image_pil, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], ) def test_transform(self, make_input): check_transform(transforms.RandomEqualize(p=1), make_input()) @pytest.mark.parametrize(("low", "high"), [(0, 64), (64, 192), (192, 256), (0, 1), (127, 128), (255, 256)]) + @pytest.mark.parametrize( + "tensor_type", + [ + torch.Tensor, + pytest.param( + "cvcuda.Tensor", marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) @pytest.mark.parametrize("fn", [F.equalize, transform_cls_to_functional(transforms.RandomEqualize, p=1)]) - def test_image_correctness(self, low, high, fn): + def test_image_correctness(self, low, high, tensor_type, fn): # We are not using the default `make_image` here since that uniformly samples the values over the whole value # range. Since the whole point of F.equalize is to transform an arbitrary distribution of values into a uniform # one over the full range, the information gain is low if we already provide something really close to the # expected value. - image = tv_tensors.Image( - torch.testing.make_tensor((3, 117, 253), dtype=torch.uint8, device="cpu", low=low, high=high) - ) + shape = (3, 117, 253) + if tensor_type == "cvcuda.Tensor": + shape = (1, *shape) + image = tv_tensors.Image(torch.testing.make_tensor(shape, dtype=torch.uint8, device="cpu", low=low, high=high)) + + if tensor_type == "cvcuda.Tensor": + image = F.to_cvcuda_tensor(image) actual = fn(image) + + if tensor_type == "cvcuda.Tensor": + image = F.cvcuda_to_tensor(image)[0].cpu() + expected = F.to_image(F.equalize(F.to_pil_image(image))) - assert_equal(actual, expected) + if tensor_type == "cvcuda.Tensor": + assert_close(actual, expected, rtol=1e-10, atol=1) + else: + assert_equal(actual, expected) class TestUniformTemporalSubsample: diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index bf4ae55d232..da34538a249 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -5,6 +5,7 @@ import torch from torchvision import transforms as _transforms from torchvision.transforms.v2 import functional as F, Transform +from torchvision.transforms.v2.functional._utils import _is_cvcuda_tensor from ._transform import _RandomApplyTransform from ._utils import query_chw @@ -265,6 +266,8 @@ class RandomEqualize(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomEqualize + _transformed_types = _RandomApplyTransform._transformed_types + (_is_cvcuda_tensor,) + def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.equalize, inpt) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index be254c0d63a..eea80a3af7e 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import PIL.Image import torch from torch.nn.functional import conv2d @@ -9,7 +11,13 @@ from ._misc import _num_value_bits, to_dtype_image from ._type_conversion import pil_to_tensor, to_pil_image -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: @@ -649,6 +657,17 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor: return equalize_image(video) +def _equalize_image_cvcuda( + image: "cvcuda.Tensor", +) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + return cvcuda.histogrameq(image, dtype=image.dtype) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(equalize, _import_cvcuda().Tensor)(_equalize_image_cvcuda) + + def invert(inpt: torch.Tensor) -> torch.Tensor: """See :func:`~torchvision.transforms.v2.RandomInvert`.""" if torch.jit.is_scripting():