Skip to content

Commit c2e8c78

Browse files
author
pytorchbot
committed
2025-07-01 nightly release (fb3926e)
1 parent 5fd7644 commit c2e8c78

File tree

8 files changed

+334
-206
lines changed

8 files changed

+334
-206
lines changed

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def get_dist(pkgname):
111111
]
112112

113113
# Excluding 8.3.* because of https://github.com/pytorch/vision/issues/4934
114-
pillow_ver = " >= 5.3.0, !=8.3.*"
114+
# TODO remove <11.3 bound and address corresponding deprecation warnings
115+
pillow_ver = " >= 5.3.0, !=8.3.*, <11.3"
115116
pillow_req = "pillow-simd" if get_dist("pillow-simd") is not None else "pillow"
116117
requirements.append(pillow_req + pillow_ver)
117118

test/common_utils.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
2222
from torchvision import io, tv_tensors
2323
from torchvision.transforms._functional_tensor import _max_value as get_max_value
24-
from torchvision.transforms.v2.functional import clamp_bounding_boxes, to_image, to_pil_image
24+
from torchvision.transforms.v2.functional import to_image, to_pil_image
2525

2626

2727
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
@@ -410,7 +410,7 @@ def make_bounding_boxes(
410410
canvas_size=DEFAULT_SIZE,
411411
*,
412412
format=tv_tensors.BoundingBoxFormat.XYXY,
413-
clamping_mode="hard", # TODOBB
413+
clamping_mode="soft",
414414
num_boxes=1,
415415
dtype=None,
416416
device="cpu",
@@ -424,13 +424,6 @@ def sample_position(values, max_value):
424424
format = tv_tensors.BoundingBoxFormat[format]
425425

426426
dtype = dtype or torch.float32
427-
int_dtype = dtype in (
428-
torch.uint8,
429-
torch.int8,
430-
torch.int16,
431-
torch.int32,
432-
torch.int64,
433-
)
434427

435428
h, w = (torch.randint(1, s, (num_boxes,)) for s in canvas_size)
436429
y = sample_position(h, canvas_size[0])
@@ -457,33 +450,18 @@ def sample_position(values, max_value):
457450
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
458451
r_rad = r * torch.pi / 180.0
459452
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
460-
x1 = torch.round(x) if int_dtype else x
461-
y1 = torch.round(y) if int_dtype else y
462-
x2 = torch.round(x1 + w * cos) if int_dtype else x1 + w * cos
463-
y2 = torch.round(y1 - w * sin) if int_dtype else y1 - w * sin
464-
x3 = torch.round(x2 + h * sin) if int_dtype else x2 + h * sin
465-
y3 = torch.round(y2 + h * cos) if int_dtype else y2 + h * cos
466-
x4 = torch.round(x1 + h * sin) if int_dtype else x1 + h * sin
467-
y4 = torch.round(y1 + h * cos) if int_dtype else y1 + h * cos
453+
x1 = x
454+
y1 = y
455+
x2 = x1 + w * cos
456+
y2 = y1 - w * sin
457+
x3 = x2 + h * sin
458+
y3 = y2 + h * cos
459+
x4 = x1 + h * sin
460+
y4 = y1 + h * cos
468461
parts = (x1, y1, x2, y2, x3, y3, x4, y4)
469462
else:
470463
raise ValueError(f"Format {format} is not supported")
471464
out_boxes = torch.stack(parts, dim=-1).to(dtype=dtype, device=device)
472-
if tv_tensors.is_rotated_bounding_format(format):
473-
# The rotated bounding boxes are not guaranteed to be within the canvas by design,
474-
# so we apply clamping. We also add a 2 buffer to the canvas size to avoid
475-
# numerical issues during the testing
476-
buffer = 4
477-
out_boxes = clamp_bounding_boxes(
478-
out_boxes,
479-
format=format,
480-
canvas_size=(canvas_size[0] - buffer, canvas_size[1] - buffer),
481-
clamping_mode=clamping_mode,
482-
)
483-
if format is tv_tensors.BoundingBoxFormat.XYWHR or format is tv_tensors.BoundingBoxFormat.CXCYWHR:
484-
out_boxes[:, :2] += buffer // 2
485-
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
486-
out_boxes[:, :] += buffer // 2
487465
return tv_tensors.BoundingBoxes(out_boxes, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode)
488466

489467

test/test_transforms_v2.py

Lines changed: 78 additions & 60 deletions
Large diffs are not rendered by default.

test/test_tv_tensors.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,39 @@ def test_bbox_instance(data, format):
6969
)
7070
@pytest.mark.parametrize("scripted", (False, True))
7171
def test_bbox_format(format, is_rotated_expected, scripted):
72-
if isinstance(format, str):
73-
format = tv_tensors.BoundingBoxFormat[(format.upper())]
74-
7572
fn = tv_tensors.is_rotated_bounding_format
7673
if scripted:
7774
fn = torch.jit.script(fn)
7875
assert fn(format) == is_rotated_expected
7976

8077

78+
@pytest.mark.parametrize(
79+
"format, support_integer_dtype",
80+
[
81+
("XYXY", True),
82+
("XYWH", True),
83+
("CXCYWH", True),
84+
("XYXYXYXY", False),
85+
("XYWHR", False),
86+
("CXCYWHR", False),
87+
(tv_tensors.BoundingBoxFormat.XYXY, True),
88+
(tv_tensors.BoundingBoxFormat.XYWH, True),
89+
(tv_tensors.BoundingBoxFormat.CXCYWH, True),
90+
(tv_tensors.BoundingBoxFormat.XYXYXYXY, False),
91+
(tv_tensors.BoundingBoxFormat.XYWHR, False),
92+
(tv_tensors.BoundingBoxFormat.CXCYWHR, False),
93+
],
94+
)
95+
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
96+
def test_bbox_format_dtype(format, support_integer_dtype, input_dtype):
97+
tensor = torch.randint(0, 32, size=(5, 2), dtype=input_dtype)
98+
if not input_dtype.is_floating_point and not support_integer_dtype:
99+
with pytest.raises(ValueError, match="Rotated bounding boxes should be floating point tensors"):
100+
tv_tensors.BoundingBoxes(tensor, format=format, canvas_size=(32, 32))
101+
else:
102+
tv_tensors.BoundingBoxes(tensor, format=format, canvas_size=(32, 32))
103+
104+
81105
def test_bbox_dim_error():
82106
data_3d = [[[1, 2, 3, 4]]]
83107
with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"):
@@ -406,3 +430,13 @@ def test_return_type_input():
406430
tv_tensors.set_return_type("typo")
407431

408432
tv_tensors.set_return_type("tensor")
433+
434+
435+
def test_box_clamping_mode_default():
436+
assert (
437+
tv_tensors.BoundingBoxes([0.0, 0.0, 10.0, 10.0], format="XYXY", canvas_size=(100, 100)).clamping_mode == "soft"
438+
)
439+
assert (
440+
tv_tensors.BoundingBoxes([0.0, 0.0, 10.0, 10.0, 0.0], format="XYWHR", canvas_size=(100, 100)).clamping_mode
441+
== "soft"
442+
)

torchvision/transforms/v2/_meta.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Union
1+
from typing import Any, Union
22

33
from torchvision import tv_tensors
44
from torchvision.transforms.v2 import functional as F, Transform
@@ -34,7 +34,9 @@ class ClampBoundingBoxes(Transform):
3434
3535
"""
3636

37-
def __init__(self, clamping_mode: Optional[CLAMPING_MODE_TYPE] = None) -> None:
37+
# TODOBB consider "auto" to be a Literal, make sur torchscript is still happy
38+
# TODOBB validate clamping_mode
39+
def __init__(self, clamping_mode: Union[CLAMPING_MODE_TYPE, str] = "auto") -> None:
3840
super().__init__()
3941
self.clamping_mode = clamping_mode
4042

0 commit comments

Comments
 (0)