Skip to content

Commit 90e0deb

Browse files
committed
Fix the bug for calculating auc-pr/roc.
1 parent 9180346 commit 90e0deb

File tree

2 files changed

+57
-9
lines changed

2 files changed

+57
-9
lines changed

examples/metric_recorder.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ def get_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
148148
SIZEINVARIANCE_METRIC_MAPPING = {
149149
"handler":{
150150
"si_fm": {"handler": py_sod_metrics.FmeasureHandler, "kwargs": dict(**sample_gray, beta=0.3)},
151+
"si_pre": {"handler": py_sod_metrics.PrecisionHandler, "kwargs": dict(with_adaptive=False, with_dynamic=True, sample_based=True)},
152+
"si_rec": {"handler": py_sod_metrics.RecallHandler, "kwargs": dict(with_adaptive=False, with_dynamic=True, sample_based=True)},
153+
"si_tpr": {"handler": py_sod_metrics.TPRHandler, "kwargs": dict(with_adaptive=False, with_dynamic=True, sample_based=True)},
154+
"si_fpr": {"handler": py_sod_metrics.FPRHandler, "kwargs": dict(with_adaptive=False, with_dynamic=True, sample_based=True)},
151155
},
152156
"si_fmeasurev2": py_sod_metrics.SizeInvarianceFmeasureV2,
153157
"si_mae": py_sod_metrics.SizeInvarianceMAE,
@@ -319,13 +323,49 @@ def step(self, pre: np.ndarray, gt: np.ndarray):
319323
for m_obj in self.metric_objs.values():
320324
m_obj.step(pre, gt)
321325

326+
def cal_auc(self, y, x):
327+
sorted_idx = np.argsort(x, axis=-1, kind="stable")
328+
x = np.take_along_axis(x, sorted_idx, axis=-1)
329+
y = np.take_along_axis(y, sorted_idx, axis=-1)
330+
return np.trapz(y, x, axis=-1)
331+
322332
def get_all_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
323333
sequential_results = {}
324334
numerical_results = {}
325335
for m_name, m_obj in self.metric_objs.items():
326336
info = m_obj.get_results()
327337

328338
if m_name == "si_fmeasurev2":
339+
# AUC-ROC
340+
if "si_tpr" in info and "si_fpr" in info:
341+
ys = info.pop("si_tpr")["dynamic"] # >=255,>=254,...>=1,>=0
342+
xs = info.pop("si_fpr")["dynamic"]
343+
if isinstance(ys, list) and isinstance(xs, list): # Nx[T'x256]
344+
auc_results = []
345+
for y, x in zip(ys, xs):
346+
# NOTE: before calculate the auc, we need to flip the y and x to corresponding to ascending thresholds
347+
# because the dynamic results from our metrics is based on the descending order of thresholds, i.e., >=255,>=254,...>=1,>=0
348+
y = np.flip(y, axis=-1)
349+
x = np.flip(x, axis=-1)
350+
auc_results.append(self.cal_auc(y, x).mean())
351+
numerical_results["si_sample_auc_roc"] = np.asarray(auc_results).mean()
352+
else: # 256
353+
numerical_results["si_overall_auc_roc"] = self.cal_auc(y=ys, x=xs).mean()
354+
355+
# AUC-PR
356+
if "si_pre" in info and "si_rec" in info:
357+
ys = info.pop("si_pre")["dynamic"] # >=255,>=254,...>=1,>=0
358+
xs = info.pop("si_rec")["dynamic"]
359+
if isinstance(ys, list) and isinstance(xs, list): # Nx[T'x256]
360+
auc_results = []
361+
for y, x in zip(ys, xs):
362+
y = np.flip(y, axis=-1)
363+
x = np.flip(x, axis=-1)
364+
auc_results.append(self.cal_auc(y, x).mean())
365+
numerical_results["si_sample_auc_pr"] = np.asarray(auc_results).mean()
366+
else: # 256
367+
numerical_results["si_overall_auc_pr"] = self.cal_auc(y=ys, x=xs).mean()
368+
329369
for _name, results in info.items():
330370
dynamic_results = results.get("dynamic")
331371
if dynamic_results is not None:
@@ -338,17 +378,17 @@ def get_all_results(self, num_bits: int = 3, return_ndarray: bool = False) -> di
338378
avg_results.append(s.mean(axis=-1).mean()) # 1
339379
seq_results.append(s.mean(axis=0)) # 256
340380
seq_results = np.mean(np.asarray(seq_results), axis=0)
341-
numerical_results[f"max{_name}"] = np.asarray(max_results).mean()
342-
numerical_results[f"avg{_name}"] = np.asarray(avg_results).mean()
381+
numerical_results[f"si_sample_max{_name}"] = np.asarray(max_results).mean()
382+
numerical_results[f"si_sample_avg{_name}"] = np.asarray(avg_results).mean()
343383
else: # 256
344384
seq_results = dynamic_results
345-
numerical_results[f"max{_name}"] = dynamic_results.max()
346-
numerical_results[f"avg{_name}"] = dynamic_results.mean()
385+
numerical_results[f"si_overall_max{_name}"] = dynamic_results.max()
386+
numerical_results[f"si_overall_avg{_name}"] = dynamic_results.mean()
347387
sequential_results[_name] = np.flip(seq_results)
348388

349389
adaptive_results = results.get("adaptive")
350390
if adaptive_results is not None:
351-
numerical_results[f"adp{_name}"] = adaptive_results
391+
numerical_results[f"si_sample_adp{_name}"] = adaptive_results
352392
else:
353393
results = info[m_name]
354394
if m_name in ("si_mae",):

examples/test_metrics.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ def reduce_dynamic_results_for_max_avg(dynamic_results: list): # Nx[T'x256] ->
3434
def reduce_dynamic_results_for_auc(ys: list, xs: list): # Nx[T'x256] -> Nx[T'] -> N -> 1
3535
auc_results = []
3636
for y, x in zip(ys, xs):
37+
# NOTE: before calculate the auc, we need to flip the y and x to corresponding to ascending thresholds
38+
# because the dynamic results from our metrics is based on the descending order of thresholds, i.e., >=255,>=254,...>=1,>=0
39+
y = np.flip(y, -1)
40+
x = np.flip(x, -1)
3741
auc_results.append(cal_auc(y=y, x=x).mean())
3842
return np.asarray(auc_results).mean()
3943

@@ -230,14 +234,18 @@ def setUpClass(cls):
230234
pr_pre = fmv2["pre"]["dynamic"] # 256
231235
pr_rec = fmv2["rec"]["dynamic"] # 256
232236
roc_fpr = fmv2["fpr"]["dynamic"] # tpr is the same as recall
233-
cls.curr_results["auc_pr"] = cal_auc(y=pr_pre, x=pr_rec)
234-
cls.curr_results["auc_roc"] = cal_auc(y=pr_rec, x=roc_fpr)
237+
cls.curr_results["auc_pr"] = cal_auc(y=np.flip(pr_pre, -1), x=np.flip(pr_rec, -1))
238+
cls.curr_results["auc_roc"] = cal_auc(y=np.flip(pr_rec, -1), x=np.flip(roc_fpr, -1))
235239

236240
si_overall_pr_pre = si_fmv2["si_overall_pre"]["dynamic"] # 256
237241
si_overall_pr_rec = si_fmv2["si_overall_rec"]["dynamic"] # 256
238242
si_overall_roc_fpr = si_fmv2["si_overall_fpr"]["dynamic"] # 256
239-
cls.curr_results["si_overall_auc_pr"] = cal_auc(y=si_overall_pr_pre, x=si_overall_pr_rec)
240-
cls.curr_results["si_overall_auc_roc"] = cal_auc(y=si_overall_pr_rec, x=si_overall_roc_fpr)
243+
cls.curr_results["si_overall_auc_pr"] = cal_auc(
244+
y=np.flip(si_overall_pr_pre, -1), x=np.flip(si_overall_pr_rec, -1)
245+
)
246+
cls.curr_results["si_overall_auc_roc"] = cal_auc(
247+
y=np.flip(si_overall_pr_rec, -1), x=np.flip(si_overall_roc_fpr, -1)
248+
)
241249

242250
si_sample_pr_pre = si_fmv2["si_sample_pre"]["dynamic"] # Nx[T'x256]
243251
si_sample_pr_rec = si_fmv2["si_sample_rec"]["dynamic"] # Nx[T'x256]

0 commit comments

Comments
 (0)