Skip to content

Commit c9f50ed

Browse files
committed
Add human correction effort measure.
1 parent 90e0deb commit c9f50ed

File tree

9 files changed

+224
-21
lines changed

9 files changed

+224
-21
lines changed

examples/metric_recorder.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def _to_list_or_scalar(item):
3636
"em": py_sod_metrics.Emeasure,
3737
"sm": py_sod_metrics.Smeasure,
3838
"wfm": py_sod_metrics.WeightedFmeasure,
39+
"hce": py_sod_metrics.HumanCorrectionEffortMeasure,
3940
}
4041

4142

@@ -45,13 +46,14 @@ def __init__(self):
4546
用于统计各种指标的类
4647
https://github.com/lartpang/Py-SOD-VOS-EvalToolkit/blob/81ce89da6813fdd3e22e3f20e3a09fe1e4a1a87c/utils/recorders/metric_recorder.py
4748
48-
主要应用于旧版本实现中的五个指标,即mae/fm/sm/em/wfm。推荐使用V2版本。
49+
主要应用于旧版本实现中的五个指标,即mae/fm/sm/em/wfm/hce。推荐使用V2版本。
4950
"""
5051
self.mae = INDIVADUAL_METRIC_MAPPING["mae"]()
5152
self.fm = INDIVADUAL_METRIC_MAPPING["fm"]()
5253
self.sm = INDIVADUAL_METRIC_MAPPING["sm"]()
5354
self.em = INDIVADUAL_METRIC_MAPPING["em"]()
5455
self.wfm = INDIVADUAL_METRIC_MAPPING["wfm"]()
56+
self.hce = INDIVADUAL_METRIC_MAPPING["hce"]()
5557

5658
def step(self, pre: np.ndarray, gt: np.ndarray):
5759
assert pre.shape == gt.shape
@@ -63,6 +65,7 @@ def step(self, pre: np.ndarray, gt: np.ndarray):
6365
self.fm.step(pre, gt)
6466
self.em.step(pre, gt)
6567
self.wfm.step(pre, gt)
68+
self.hce.step(pre, gt)
6669

6770
def get_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
6871
"""
@@ -78,6 +81,7 @@ def get_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
7881
sm = self.sm.get_results()["sm"]
7982
em = self.em.get_results()["em"]
8083
mae = self.mae.get_results()["mae"]
84+
hce = self.hce.get_results()["hce"]
8185

8286
sequential_results = {
8387
"fm": np.flip(fm["curve"]),
@@ -95,6 +99,7 @@ def get_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
9599
"avgF": fm["curve"].mean(),
96100
"adpF": fm["adp"],
97101
"wFm": wfm,
102+
"HCE": hce,
98103
}
99104
if num_bits is not None and isinstance(num_bits, int):
100105
numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()}
@@ -160,7 +165,7 @@ def get_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
160165

161166

162167
class GrayscaleMetricRecorderV2:
163-
supported_metrics = ["mae", "em", "sm", "wfm"] + sorted(GRAYSCALE_METRIC_MAPPING.keys())
168+
supported_metrics = ["mae", "em", "sm", "wfm", "hce"] + sorted(GRAYSCALE_METRIC_MAPPING.keys())
164169

165170
def __init__(self, metric_names=("sm", "wfm", "mae", "fmeasure", "em")):
166171
"""
@@ -209,7 +214,7 @@ def get_all_results(self, num_bits: int = 3, return_ndarray: bool = False) -> di
209214
numerical_results[f"adp{_name}"] = adaptive_results
210215
else:
211216
results = info[m_name]
212-
if m_name in ("wfm", "sm", "mae"):
217+
if m_name in ("wfm", "sm", "mae", "hce"):
213218
numerical_results[m_name] = results
214219
elif m_name in ("fm", "em"):
215220
sequential_results[m_name] = np.flip(results["curve"])
@@ -235,7 +240,7 @@ def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
235240

236241

237242
class BinaryMetricRecorder:
238-
supported_metrics = ["mae", "sm", "wfm"] + sorted(BINARY_METRIC_MAPPING.keys())
243+
supported_metrics = ["mae", "sm", "wfm", "hce"] + sorted(BINARY_METRIC_MAPPING.keys())
239244

240245
def __init__(self, metric_names=("bif1", "biprecision", "birecall", "biiou")):
241246
"""
@@ -278,7 +283,7 @@ def get_all_results(self, num_bits: int = 3, return_ndarray: bool = False) -> di
278283
numerical_results[_name] = binary_results
279284
else:
280285
results = info[m_name]
281-
if m_name in ("mae", "sm", "wfm"):
286+
if m_name in ("mae", "sm", "wfm", "hce"):
282287
numerical_results[m_name] = results
283288
else:
284289
raise NotImplementedError(m_name)

examples/test_metrics.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def setUpClass(cls):
5050
SM = py_sod_metrics.Smeasure()
5151
EM = py_sod_metrics.Emeasure()
5252
MAE = py_sod_metrics.MAE()
53+
HCE = py_sod_metrics.HumanCorrectionEffortMeasure()
5354
MSIOU = py_sod_metrics.MSIoU(with_dynamic=True, with_adaptive=True, with_binary=True)
5455

5556
# fmt: off
@@ -169,6 +170,7 @@ def setUpClass(cls):
169170
SM.step(pred=pred, gt=mask)
170171
EM.step(pred=pred, gt=mask)
171172
MAE.step(pred=pred, gt=mask)
173+
HCE.step(pred=pred, gt=mask)
172174
MSIOU.step(pred=pred, gt=mask)
173175
FMv2.step(pred=pred, gt=mask)
174176
SI_MAE.step(pred=pred, gt=mask)
@@ -179,13 +181,15 @@ def setUpClass(cls):
179181
sm = SM.get_results()["sm"]
180182
em = EM.get_results()["em"]
181183
mae = MAE.get_results()["mae"]
184+
hce = HCE.get_results()["hce"]
182185
msiou = MSIOU.get_results()
183186
fmv2 = FMv2.get_results()
184187
si_mae = SI_MAE.get_results()["si_mae"]
185188
si_fmv2 = SI_FMv2.get_results()
186189

187190
cls.curr_results = {
188191
"MAE": mae,
192+
"HCE": hce,
189193
"Smeasure": sm,
190194
"wFmeasure": wfm,
191195
# "MSIOU": msiou,
@@ -258,18 +262,23 @@ def setUpClass(cls):
258262
print("Current results:")
259263
pprint(cls.curr_results)
260264
cls.default_results = default_results["v1_4_3"] # 68
261-
si_variant_results = default_results["v1_5_0"] # 78+6
262-
for res in [si_variant_results]:
263-
if any([k in cls.default_results for k in res.keys()]):
265+
for append_version in [
266+
"v1_5_0", # 78+6 Size-Invariant Variants
267+
"v1_5_1", # 1 HCE
268+
]:
269+
if any([k in cls.default_results for k in default_results[append_version].keys()]):
264270
raise ValueError("Some keys will be overwritten by the SI variant results.")
265-
cls.default_results.update(res)
271+
cls.default_results.update(default_results[append_version])
266272

267273
def test_sm(self):
268274
self.assertEqual(self.curr_results["Smeasure"], self.default_results["Smeasure"])
269275

270276
def test_wfm(self):
271277
self.assertEqual(self.curr_results["wFmeasure"], self.default_results["wFmeasure"])
272278

279+
def test_hce(self):
280+
self.assertEqual(self.curr_results["HCE"], self.default_results["HCE"])
281+
273282
def test_mae(self):
274283
self.assertEqual(self.curr_results["MAE"], self.default_results["MAE"])
275284

examples/version_performance.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,5 +304,8 @@
304304
"si_overall_auc_roc": 0.6192831970413093,
305305
"si_sample_auc_pr": 0.3036500410380263,
306306
"si_sample_auc_roc": 0.6192831970413093
307+
},
308+
"v1_5_1": {
309+
"HCE": 73.66666666666667
307310
}
308311
}

py_sod_metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
MAE,
2222
Emeasure,
2323
Fmeasure,
24+
HumanCorrectionEffortMeasure,
2425
Smeasure,
2526
WeightedFmeasure,
2627
)

py_sod_metrics/sod_metrics.py

Lines changed: 172 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
# -*- coding: utf-8 -*-
21
import warnings
32

3+
import cv2
44
import numpy as np
55
from scipy.ndimage import convolve
66
from scipy.ndimage import distance_transform_edt as bwdist
7+
from skimage import measure, morphology
78

89
from .utils import EPS, TYPE, get_adaptive_threshold, validate_and_normalize_input
910

@@ -380,9 +381,7 @@ def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: flo
380381
results_parts = []
381382
for i, (part_numel, combination) in enumerate(zip(parts_numel, combinations)):
382383
align_matrix_value = (
383-
2
384-
* (combination[0] * combination[1])
385-
/ (combination[0] ** 2 + combination[1] ** 2 + EPS)
384+
2 * (combination[0] * combination[1]) / (combination[0] ** 2 + combination[1] ** 2 + EPS)
386385
)
387386
enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
388387
results_parts.append(enhanced_matrix_value * part_numel)
@@ -424,9 +423,7 @@ def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.nd
424423
results_parts = np.empty(shape=(4, 256), dtype=np.float64)
425424
for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)):
426425
align_matrix_value = (
427-
2
428-
* (combination[0] * combination[1])
429-
/ (combination[0] ** 2 + combination[1] ** 2 + EPS)
426+
2 * (combination[0] * combination[1]) / (combination[0] ** 2 + combination[1] ** 2 + EPS)
430427
)
431428
enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
432429
results_parts[i] = enhanced_matrix_value * part_numel
@@ -435,9 +432,7 @@ def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.nd
435432
em = enhanced_matrix_sum / (self.gt_size - 1 + EPS)
436433
return em
437434

438-
def generate_parts_numel_combinations(
439-
self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel
440-
):
435+
def generate_parts_numel_combinations(self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel):
441436
bg_fg_numel = self.gt_fg_numel - fg_fg_numel
442437
bg_bg_numel = pred_bg_numel - bg_fg_numel
443438

@@ -571,3 +566,170 @@ def get_results(self) -> dict:
571566
"""
572567
weighted_fm = np.mean(np.array(self.weighted_fms, dtype=TYPE))
573568
return dict(wfm=weighted_fm)
569+
570+
571+
class HumanCorrectionEffortMeasure(object):
572+
def __init__(self, relax: int = 5, epsilon: float = 2.0):
573+
"""Human Correction Effort Measure for Dichotomous Image Segmentation.
574+
575+
```
576+
@inproceedings{HumanCorrectionEffortMeasure,
577+
title = {Highly Accurate Dichotomous Image Segmentation},
578+
author = {Xuebin Qin and Hang Dai and Xiaobin Hu and Deng-Ping Fan and Ling Shao and Luc Van Gool},
579+
booktitle = ECCV,
580+
year = {2022}
581+
}
582+
```
583+
"""
584+
585+
self.hces = []
586+
self.relax = relax
587+
self.epsilon = epsilon
588+
self.morphology_kernel = morphology.disk(1)
589+
590+
def step(self, pred: np.ndarray, gt: np.ndarray, normalize: bool = True):
591+
"""Statistics the metric for the pair of pred and gt.
592+
593+
Args:
594+
pred (np.ndarray): Prediction, gray scale image.
595+
gt (np.ndarray): Ground truth, gray scale image.
596+
normalize (bool, optional): Whether to normalize the input data. Defaults to True.
597+
"""
598+
pred, gt = validate_and_normalize_input(pred, gt, normalize)
599+
600+
hce = self.cal_hce(pred, gt)
601+
self.hces.append(hce)
602+
603+
def cal_hce(self, pred: np.ndarray, gt: np.ndarray) -> float:
604+
gt_skeleton = morphology.skeletonize(gt).astype(bool)
605+
pred = pred > 0.5
606+
607+
union = np.logical_or(gt, pred)
608+
TP = np.logical_and(gt, pred)
609+
FP = np.logical_xor(pred, TP)
610+
FN = np.logical_xor(gt, TP)
611+
612+
# relax the union of gt and pred
613+
eroded_union = cv2.erode(union.astype(np.uint8), self.morphology_kernel, iterations=self.relax)
614+
615+
# get the relaxed FP regions for computing the human efforts in correcting them ---
616+
FP_ = np.logical_and(FP, eroded_union) # get the relaxed FP
617+
for i in range(0, self.relax):
618+
FP_ = cv2.dilate(FP_.astype(np.uint8), self.morphology_kernel)
619+
FP_ = np.logical_and(FP_.astype(bool), ~gt)
620+
FP_ = np.logical_and(FP, FP_)
621+
622+
# get the relaxed FN regions for computing the human efforts in correcting them ---
623+
FN_ = np.logical_and(FN, eroded_union) # preserve the structural components of FN
624+
# recover the FN, where pixels are not close to the TP borders
625+
for i in range(0, self.relax):
626+
FN_ = cv2.dilate(FN_.astype(np.uint8), self.morphology_kernel)
627+
FN_ = np.logical_and(FN_, ~pred)
628+
FN_ = np.logical_and(FN, FN_)
629+
# preserve the structural components of FN
630+
FN_ = np.logical_or(FN_, np.logical_xor(gt_skeleton, np.logical_and(TP, gt_skeleton)))
631+
632+
# Find exact polygon control points and independent regions.
633+
# find contours from FP_ and control points and independent regions for human correction
634+
contours_FP, _ = cv2.findContours(FP_.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
635+
condition_FP = np.logical_or(TP, FN_)
636+
bdies_FP, indep_cnt_FP = self.filter_conditional_boundary(contours_FP, FP_, condition_FP)
637+
# find contours from FN_ and control points and independent regions for human correction
638+
contours_FN, _ = cv2.findContours(FN_.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
639+
condition_FN = 1 - np.logical_or(np.logical_or(TP, FP_), FN_)
640+
bdies_FN, indep_cnt_FN = self.filter_conditional_boundary(contours_FN, FN_, condition_FN)
641+
642+
poly_FP_point_cnt = self.count_polygon_control_points(bdies_FP, epsilon=self.epsilon)
643+
poly_FN_point_cnt = self.count_polygon_control_points(bdies_FN, epsilon=self.epsilon)
644+
return poly_FP_point_cnt + indep_cnt_FP + poly_FN_point_cnt + indep_cnt_FN
645+
646+
def filter_conditional_boundary(self, contours: list, mask: np.ndarray, condition: np.ndarray):
647+
"""
648+
Filter boundary segments based on a given condition mask and compute
649+
the number of independent connected regions that require human correction.
650+
651+
Args:
652+
contours (List[np.ndarray]): List of boundary contours (OpenCV format).
653+
mask (np.ndarray): Binary mask representing the region of interest.
654+
condition (np.ndarray): Condition mask used to determine which
655+
boundary points need to be considered.
656+
657+
Returns:
658+
Tuple[List[np.ndarray], int]:
659+
- boundaries (List[np.ndarray]): Filtered boundary segments that require correction.
660+
- independent_count (int): Number of independent connected regions
661+
that need correction (i.e., human editing effort).
662+
"""
663+
condition = cv2.dilate(condition.astype(np.uint8), self.morphology_kernel)
664+
665+
labels = measure.label(mask) # find the connected regions
666+
independent_flags = np.ones(labels.max() + 1, dtype=int) # the label of each connected regions
667+
independent_flags[0] = 0 # 0 indicate the background region
668+
669+
boundaries = []
670+
visited_map = np.zeros(condition.shape[:2], dtype=int)
671+
for i in range(len(contours)):
672+
temp_boundaries = []
673+
temp_boundary = []
674+
for pt in contours[i]:
675+
row, col = pt[0, 1], pt[0, 0]
676+
677+
if condition[row, col].sum() == 0 or visited_map[row, col] != 0:
678+
if temp_boundary: # if the previous point is not a boundary point, append the previous boundary
679+
temp_boundaries.append(temp_boundary)
680+
temp_boundary = []
681+
continue
682+
683+
temp_boundary.append([col, row])
684+
visited_map[row, col] = visited_map[row, col] + 1
685+
independent_flags[labels[row, col]] = 0 # mark region as requiring correction
686+
687+
if temp_boundary:
688+
temp_boundaries.append(temp_boundary)
689+
690+
# check if the first and the last boundaries are connected.
691+
# if yes, invert the first boundary and attach it after the last boundary
692+
if len(temp_boundaries) > 1:
693+
first_x, first_y = temp_boundaries[0][0]
694+
last_x, last_y = temp_boundaries[-1][-1]
695+
if (
696+
(abs(first_x - last_x) == 1 and first_y == last_y)
697+
or (first_x == last_x and abs(first_y - last_y) == 1)
698+
or (abs(first_x - last_x) == 1 and abs(first_y - last_y) == 1)
699+
):
700+
temp_boundaries[-1].extend(temp_boundaries[0][::-1])
701+
del temp_boundaries[0]
702+
703+
for k in range(len(temp_boundaries)):
704+
temp_boundaries[k] = np.array(temp_boundaries[k])[:, np.newaxis, :]
705+
706+
if temp_boundaries:
707+
boundaries.extend(temp_boundaries)
708+
return boundaries, independent_flags.sum()
709+
710+
def count_polygon_control_points(self, boundaries: list, epsilon: float = 1.0) -> int:
711+
"""
712+
Approximate each boundary using the Ramer-Douglas-Peucker (RDP) algorithm
713+
and count the total number of control points of all approximated polygons.
714+
715+
Args:
716+
boundaries (List[np.ndarray]): List of boundary contours.
717+
Each contour is an Nx1x2 numpy array (OpenCV contour format).
718+
epsilon (float): RDP approximation tolerance.
719+
Larger values result in fewer control points.
720+
721+
Returns:
722+
int: The total number of control points across all approximated polygons.
723+
724+
Reference:
725+
https://en.wikipedia.org/wiki/Ramer-Douglas-Peucker_algorithm
726+
"""
727+
num_points = 0
728+
for boundary in boundaries:
729+
approx_poly = cv2.approxPolyDP(boundary, epsilon, False) # approximate boundary
730+
num_points += len(approx_poly) # count vertices (control points)
731+
return num_points
732+
733+
def get_results(self) -> dict:
734+
hce = np.mean(np.array(self.hces, dtype=TYPE))
735+
return dict(hce=hce)

0 commit comments

Comments
 (0)