Skip to content

Commit 0e27f2a

Browse files
fix group of keypoints behavior
1 parent 5de3208 commit 0e27f2a

File tree

1 file changed

+3
-5
lines changed
  • torchvision/transforms/v2/functional

1 file changed

+3
-5
lines changed

torchvision/transforms/v2/functional/_misc.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -460,8 +460,7 @@ def sanitize_keypoints(
460460
461461
Keypoints can be passed as a set of individual keypoints or as a set of objects
462462
(e.g., polygons or polygonal chains) consisting of a fixed number of keypoints of shape ``[..., 2]``.
463-
When multiple groups of keypoints are passed
464-
(i.e., an at least 3-dimensional tensor with first dimension greater than 1),
463+
When groups of keypoints are passed (i.e., an at least 3-dimensional tensor),
465464
this transform will only remove entire groups, not individual keypoints within a group.
466465
467466
Args:
@@ -510,10 +509,9 @@ def _get_sanitize_keypoints_mask(
510509

511510
h, w = canvas_size
512511

513-
original_shape = key_points.shape[:-1]
514-
x, y = key_points[..., 0].squeeze(dim=0), key_points[..., 1].squeeze(dim=0)
512+
x, y = key_points[..., 0], key_points[..., 1]
515513
valid = (x >= 0) & (x < w) & (y >= 0) & (y < h)
516514

517-
valid = valid.flatten(start_dim=1).all(dim=1) if valid.ndim > 1 else valid.reshape(original_shape)
515+
valid = valid.flatten(start_dim=1).all(dim=1) if valid.ndim > 1 else valid
518516

519517
return valid

0 commit comments

Comments
 (0)