Skip to content

Commit 0e6702d

Browse files
jebastin-nadarfacebook-github-bot
authored andcommitted
Added diou and ciou losses for bbox regression
Summary: Resolves #1085 references and credits: https://github.com/Zzh-tju/DIoU-pytorch-detectron/blob/master/lib/utils/net.py https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/giou_loss.py Pull Request resolved: #3481 Test Plan: sandcastle Differential Revision: D31463505 Pulled By: ppwwyyxx fbshipit-source-id: 04d815f979b589b7e3b3e5d9c55eab318762efe8
1 parent 7f8f29d commit 0e6702d

File tree

7 files changed

+258
-8
lines changed

7 files changed

+258
-8
lines changed

detectron2/config/defaults.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@
219219
_C.MODEL.RPN.BATCH_SIZE_PER_IMAGE = 256
220220
# Target fraction of foreground (positive) examples per RPN minibatch
221221
_C.MODEL.RPN.POSITIVE_FRACTION = 0.5
222-
# Options are: "smooth_l1", "giou"
222+
# Options are: "smooth_l1", "giou", "diou", "ciou"
223223
_C.MODEL.RPN.BBOX_REG_LOSS_TYPE = "smooth_l1"
224224
_C.MODEL.RPN.BBOX_REG_LOSS_WEIGHT = 1.0
225225
# Weights on (dx, dy, dw, dh) for normalizing RPN anchor regression targets
@@ -290,7 +290,7 @@
290290
# C4 don't use head name option
291291
# Options for non-C4 models: FastRCNNConvFCHead,
292292
_C.MODEL.ROI_BOX_HEAD.NAME = ""
293-
# Options are: "smooth_l1", "giou"
293+
# Options are: "smooth_l1", "giou", "diou", "ciou"
294294
_C.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_TYPE = "smooth_l1"
295295
# The final scaling coefficient on the box regression loss, used to balance the magnitude of its
296296
# gradients with other losses in the model. See also `MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT`.
@@ -455,7 +455,7 @@
455455
_C.MODEL.RETINANET.FOCAL_LOSS_GAMMA = 2.0
456456
_C.MODEL.RETINANET.FOCAL_LOSS_ALPHA = 0.25
457457
_C.MODEL.RETINANET.SMOOTH_L1_LOSS_BETA = 0.1
458-
# Options are: "smooth_l1", "giou"
458+
# Options are: "smooth_l1", "giou", "diou", "ciou"
459459
_C.MODEL.RETINANET.BBOX_REG_LOSS_TYPE = "smooth_l1"
460460

461461
# One of BN, SyncBN, FrozenBN, GN

detectron2/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@
1919
)
2020
from .blocks import CNNBlockBase, DepthwiseSeparableConv2d
2121
from .aspp import ASPP
22+
from .losses import ciou_loss, diou_loss
2223

2324
__all__ = [k for k in globals().keys() if not k.startswith("_")]

detectron2/layers/losses.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import math
2+
import torch
3+
4+
5+
def diou_loss(
6+
boxes1: torch.Tensor,
7+
boxes2: torch.Tensor,
8+
reduction: str = "none",
9+
eps: float = 1e-7,
10+
) -> torch.Tensor:
11+
"""
12+
Distance Intersection over Union Loss (Zhaohui Zheng et. al)
13+
https://arxiv.org/abs/1911.08287
14+
Args:
15+
boxes1, boxes2 (Tensor): box locations in XYXY format, shape (N, 4) or (4,).
16+
reduction: 'none' | 'mean' | 'sum'
17+
'none': No reduction will be applied to the output.
18+
'mean': The output will be averaged.
19+
'sum': The output will be summed.
20+
eps (float): small number to prevent division by zero
21+
"""
22+
23+
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
24+
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
25+
26+
# TODO: use torch._assert_async() when pytorch 1.8 support is dropped
27+
assert (x2 >= x1).all(), "bad box: x1 larger than x2"
28+
assert (y2 >= y1).all(), "bad box: y1 larger than y2"
29+
30+
# Intersection keypoints
31+
xkis1 = torch.max(x1, x1g)
32+
ykis1 = torch.max(y1, y1g)
33+
xkis2 = torch.min(x2, x2g)
34+
ykis2 = torch.min(y2, y2g)
35+
36+
intsct = torch.zeros_like(x1)
37+
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
38+
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
39+
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps
40+
iou = intsct / union
41+
42+
# smallest enclosing box
43+
xc1 = torch.min(x1, x1g)
44+
yc1 = torch.min(y1, y1g)
45+
xc2 = torch.max(x2, x2g)
46+
yc2 = torch.max(y2, y2g)
47+
diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps
48+
49+
# centers of boxes
50+
x_p = (x2 + x1) / 2
51+
y_p = (y2 + y1) / 2
52+
x_g = (x1g + x2g) / 2
53+
y_g = (y1g + y2g) / 2
54+
distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)
55+
56+
# Eqn. (7)
57+
loss = 1 - iou + (distance / diag_len)
58+
if reduction == "mean":
59+
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
60+
elif reduction == "sum":
61+
loss = loss.sum()
62+
63+
return loss
64+
65+
66+
def ciou_loss(
67+
boxes1: torch.Tensor,
68+
boxes2: torch.Tensor,
69+
reduction: str = "none",
70+
eps: float = 1e-7,
71+
) -> torch.Tensor:
72+
"""
73+
Complete Intersection over Union Loss (Zhaohui Zheng et. al)
74+
https://arxiv.org/abs/1911.08287
75+
Args:
76+
boxes1, boxes2 (Tensor): box locations in XYXY format, shape (N, 4) or (4,).
77+
reduction: 'none' | 'mean' | 'sum'
78+
'none': No reduction will be applied to the output.
79+
'mean': The output will be averaged.
80+
'sum': The output will be summed.
81+
eps (float): small number to prevent division by zero
82+
"""
83+
84+
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
85+
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
86+
87+
# TODO: use torch._assert_async() when pytorch 1.8 support is dropped
88+
assert (x2 >= x1).all(), "bad box: x1 larger than x2"
89+
assert (y2 >= y1).all(), "bad box: y1 larger than y2"
90+
91+
# Intersection keypoints
92+
xkis1 = torch.max(x1, x1g)
93+
ykis1 = torch.max(y1, y1g)
94+
xkis2 = torch.min(x2, x2g)
95+
ykis2 = torch.min(y2, y2g)
96+
97+
intsct = torch.zeros_like(x1)
98+
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
99+
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
100+
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps
101+
iou = intsct / union
102+
103+
# smallest enclosing box
104+
xc1 = torch.min(x1, x1g)
105+
yc1 = torch.min(y1, y1g)
106+
xc2 = torch.max(x2, x2g)
107+
yc2 = torch.max(y2, y2g)
108+
diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps
109+
110+
# centers of boxes
111+
x_p = (x2 + x1) / 2
112+
y_p = (y2 + y1) / 2
113+
x_g = (x1g + x2g) / 2
114+
y_g = (y1g + y2g) / 2
115+
distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)
116+
117+
# width and height of boxes
118+
w_pred = x2 - x1
119+
h_pred = y2 - y1
120+
w_gt = x2g - x1g
121+
h_gt = y2g - y1g
122+
v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2)
123+
with torch.no_grad():
124+
alpha = v / (1 - iou + v + eps)
125+
126+
# Eqn. (10)
127+
loss = 1 - iou + (distance / diag_len) + alpha * v
128+
if reduction == "mean":
129+
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
130+
elif reduction == "sum":
131+
loss = loss.sum()
132+
133+
return loss

detectron2/modeling/box_regression.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from fvcore.nn import giou_loss, smooth_l1_loss
66

7-
from detectron2.layers import cat
7+
from detectron2.layers import cat, ciou_loss, diou_loss
88
from detectron2.structures import Boxes
99

1010
# Value for clamping large dw and dh predictions. The heuristic is that we clamp
@@ -315,7 +315,8 @@ def _dense_box_regression_loss(
315315
pred_anchor_deltas: #lvl predictions, each is (N, HixWixA, 4)
316316
gt_boxes: N ground truth boxes, each has shape (R, 4) (R = sum(Hi * Wi * A))
317317
fg_mask: the foreground boolean mask of shape (N, R) to compute loss on
318-
box_reg_loss_type (str): Loss type to use. Supported losses: "smooth_l1", "giou".
318+
box_reg_loss_type (str): Loss type to use. Supported losses: "smooth_l1", "giou",
319+
"diou", "ciou".
319320
smooth_l1_beta (float): beta parameter for the smooth L1 regression loss. Default to
320321
use L1 loss. Only used when `box_reg_loss_type` is "smooth_l1"
321322
"""
@@ -336,6 +337,20 @@ def _dense_box_regression_loss(
336337
loss_box_reg = giou_loss(
337338
torch.stack(pred_boxes)[fg_mask], torch.stack(gt_boxes)[fg_mask], reduction="sum"
338339
)
340+
elif box_reg_loss_type == "diou":
341+
pred_boxes = [
342+
box2box_transform.apply_deltas(k, anchors) for k in cat(pred_anchor_deltas, dim=1)
343+
]
344+
loss_box_reg = diou_loss(
345+
torch.stack(pred_boxes)[fg_mask], torch.stack(gt_boxes)[fg_mask], reduction="sum"
346+
)
347+
elif box_reg_loss_type == "ciou":
348+
pred_boxes = [
349+
box2box_transform.apply_deltas(k, anchors) for k in cat(pred_anchor_deltas, dim=1)
350+
]
351+
loss_box_reg = ciou_loss(
352+
torch.stack(pred_boxes)[fg_mask], torch.stack(gt_boxes)[fg_mask], reduction="sum"
353+
)
339354
else:
340355
raise ValueError(f"Invalid dense box regression loss type '{box_reg_loss_type}'")
341356
return loss_box_reg

detectron2/modeling/meta_arch/retinanet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(
8888
focal_loss_alpha (float): focal_loss_alpha
8989
focal_loss_gamma (float): focal_loss_gamma
9090
smooth_l1_beta (float): smooth_l1_beta
91-
box_reg_loss_type (str): Options are "smooth_l1", "giou"
91+
box_reg_loss_type (str): Options are "smooth_l1", "giou", "diou", "ciou"
9292
9393
# Inference parameters:
9494
test_score_thresh (float): Inference cls score threshold, only anchors with

detectron2/modeling/roi_heads/fast_rcnn.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,15 @@
77
from torch.nn import functional as F
88

99
from detectron2.config import configurable
10-
from detectron2.layers import ShapeSpec, batched_nms, cat, cross_entropy, nonzero_tuple
10+
from detectron2.layers import (
11+
ShapeSpec,
12+
batched_nms,
13+
cat,
14+
ciou_loss,
15+
cross_entropy,
16+
diou_loss,
17+
nonzero_tuple,
18+
)
1119
from detectron2.modeling.box_regression import Box2BoxTransform
1220
from detectron2.structures import Boxes, Instances
1321
from detectron2.utils.events import get_event_storage
@@ -207,7 +215,8 @@ def __init__(
207215
cls_agnostic_bbox_reg (bool): whether to use class agnostic for bbox regression
208216
smooth_l1_beta (float): transition point from L1 to L2 loss. Only used if
209217
`box_reg_loss_type` is "smooth_l1"
210-
box_reg_loss_type (str): Box regression loss type. One of: "smooth_l1", "giou"
218+
box_reg_loss_type (str): Box regression loss type. One of: "smooth_l1", "giou",
219+
"diou", "ciou"
211220
loss_weight (float|dict): weights to use for losses. Can be single float for weighting
212221
all losses, or a dict of individual weightings. Valid dict keys are:
213222
* "loss_cls": applied to classification loss
@@ -347,6 +356,16 @@ def box_reg_loss(self, proposal_boxes, gt_boxes, pred_deltas, gt_classes):
347356
fg_pred_deltas, proposal_boxes[fg_inds]
348357
)
349358
loss_box_reg = giou_loss(fg_pred_boxes, gt_boxes[fg_inds], reduction="sum")
359+
elif self.box_reg_loss_type == "diou":
360+
fg_pred_boxes = self.box2box_transform.apply_deltas(
361+
fg_pred_deltas, proposal_boxes[fg_inds]
362+
)
363+
loss_box_reg = diou_loss(fg_pred_boxes, gt_boxes[fg_inds], reduction="sum")
364+
elif self.box_reg_loss_type == "ciou":
365+
fg_pred_boxes = self.box2box_transform.apply_deltas(
366+
fg_pred_deltas, proposal_boxes[fg_inds]
367+
)
368+
loss_box_reg = ciou_loss(fg_pred_boxes, gt_boxes[fg_inds], reduction="sum")
350369
else:
351370
raise ValueError(f"Invalid bbox reg loss type '{self.box_reg_loss_type}'")
352371
# The reg loss is normalized using the total number of regions (R), not the number

tests/layers/test_losses.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
import numpy as np
3+
import unittest
4+
import torch
5+
6+
from detectron2.layers import ciou_loss, diou_loss
7+
8+
9+
class TestLosses(unittest.TestCase):
10+
def test_diou_loss(self):
11+
"""
12+
loss = 1 - iou + d/c
13+
where,
14+
d = (distance between centers of the 2 boxes)^2
15+
c = (diagonal length of the smallest enclosing box covering the 2 boxes)^2
16+
"""
17+
# Identical boxes should have loss of 0
18+
box = torch.tensor([-1, -1, 1, 1], dtype=torch.float32)
19+
loss = diou_loss(box, box)
20+
self.assertTrue(np.allclose(loss, [0.0]))
21+
22+
# Half size box inside other box
23+
# iou = 0.5, d = 0.25, c = 8
24+
box2 = torch.tensor([0, -1, 1, 1], dtype=torch.float32)
25+
loss = diou_loss(box, box2)
26+
self.assertTrue(np.allclose(loss, [0.53125]))
27+
28+
# Two diagonally adjacent boxes
29+
# iou = 0, d = 2, c = 8
30+
box3 = torch.tensor([0, 0, 1, 1], dtype=torch.float32)
31+
box4 = torch.tensor([1, 1, 2, 2], dtype=torch.float32)
32+
loss = diou_loss(box3, box4)
33+
self.assertTrue(np.allclose(loss, [1.25]))
34+
35+
# Test batched loss and reductions
36+
box1s = torch.stack([box, box3], dim=0)
37+
box2s = torch.stack([box2, box4], dim=0)
38+
39+
loss = diou_loss(box1s, box2s, reduction="sum")
40+
self.assertTrue(np.allclose(loss, [1.78125]))
41+
42+
loss = diou_loss(box1s, box2s, reduction="mean")
43+
self.assertTrue(np.allclose(loss, [0.890625]))
44+
45+
def test_ciou_loss(self):
46+
"""
47+
loss = 1 - iou + d/c + alpha*v
48+
where,
49+
d = (distance between centers of the 2 boxes)^2
50+
c = (diagonal length of the smallest enclosing box covering the 2 boxes)^2
51+
v = (4/pi^2) * (arctan(box1_w/box1_h) - arctan(box2_w/box2_h))^2
52+
alpha = v/(1 - iou + v)
53+
"""
54+
# Identical boxes should have loss of 0
55+
box = torch.tensor([-1, -1, 1, 1], dtype=torch.float32)
56+
loss = ciou_loss(box, box)
57+
self.assertTrue(np.allclose(loss, [0.0]))
58+
59+
# Half size box inside other box
60+
# iou = 0.5, d = 0.25, c = 8
61+
# v = (4/pi^2) * (arctan(1) - arctan(0.5))^2 = 0.042
62+
# alpha = 0.0775
63+
box2 = torch.tensor([0, -1, 1, 1], dtype=torch.float32)
64+
loss = ciou_loss(box, box2)
65+
self.assertTrue(np.allclose(loss, [0.5345]))
66+
67+
# Two diagonally adjacent boxes
68+
# iou = 0, d = 2, c = 8, v = 0, alpha = 0
69+
box3 = torch.tensor([0, 0, 1, 1], dtype=torch.float32)
70+
box4 = torch.tensor([1, 1, 2, 2], dtype=torch.float32)
71+
loss = ciou_loss(box3, box4)
72+
self.assertTrue(np.allclose(loss, [1.25]))
73+
74+
# Test batched loss and reductions
75+
box1s = torch.stack([box, box3], dim=0)
76+
box2s = torch.stack([box2, box4], dim=0)
77+
78+
loss = ciou_loss(box1s, box2s, reduction="sum")
79+
self.assertTrue(np.allclose(loss, [1.7845]))
80+
81+
loss = ciou_loss(box1s, box2s, reduction="mean")
82+
self.assertTrue(np.allclose(loss, [0.89225]))

0 commit comments

Comments
 (0)