Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 160 additions & 1 deletion monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,22 @@

from __future__ import annotations

import warnings

import numpy as np
import torch

from monai.metrics.utils import do_metric_reduction
from monai.utils import MetricReduction, deprecated_arg
from monai.utils.module import optional_import

from .metric import CumulativeIterationMetric

scipy_ndimage, has_scipy_ndimage = optional_import("scipy.ndimage")
cupy, has_cupy = optional_import("cupy")
cupy_ndimage, has_cupy_ndimage = optional_import("cupyx.scipy.ndimage")


__all__ = ["DiceMetric", "compute_dice", "DiceHelper"]


Expand All @@ -41,6 +50,9 @@ class DiceMetric(CumulativeIterationMetric):
image size they can get overwhelmed by the signal from the background. This assumes the shape of both prediction
and ground truth is BCHW[D].

The ``per_component`` parameter can be set to `True` to compute the Dice metric per connected component in the ground truth
, and then average. This requires binary segmentations with 2 channels (background + foreground) as input.

The typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

Further information can be found in the official
Expand Down Expand Up @@ -95,6 +107,9 @@ class DiceMetric(CumulativeIterationMetric):
If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True,
the index begins at "0", otherwise at "1". It can also take a list of label names.
The outcome will then be returned as a dictionary.
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
computed for each connected component in the ground truth, and then averaged. This requires binary
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.

"""

Expand All @@ -106,6 +121,7 @@ def __init__(
ignore_empty: bool = True,
num_classes: int | None = None,
return_with_label: bool | list[str] = False,
per_component: bool = False,
) -> None:
super().__init__()
self.include_background = include_background
Expand All @@ -114,13 +130,15 @@ def __init__(
self.ignore_empty = ignore_empty
self.num_classes = num_classes
self.return_with_label = return_with_label
self.per_component = per_component
self.dice_helper = DiceHelper(
include_background=self.include_background,
reduction=MetricReduction.NONE,
get_not_nans=False,
apply_argmax=False,
ignore_empty=self.ignore_empty,
num_classes=self.num_classes,
per_component=self.per_component,
)

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
Expand Down Expand Up @@ -175,6 +193,7 @@ def compute_dice(
include_background: bool = True,
ignore_empty: bool = True,
num_classes: int | None = None,
per_component: bool = False,
) -> torch.Tensor:
"""
Computes Dice score metric for a batch of predictions. This performs the same computation as
Expand All @@ -192,6 +211,9 @@ def compute_dice(
num_classes: number of input channels (always including the background). When this is ``None``,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
computed for each connected component in the ground truth, and then averaged. This requires binary
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.

Returns:
Dice scores per batch and per class, (shape: [batch_size, num_classes]).
Expand All @@ -204,6 +226,7 @@ def compute_dice(
apply_argmax=False,
ignore_empty=ignore_empty,
num_classes=num_classes,
per_component=per_component,
)(y_pred=y_pred, y=y)


Expand Down Expand Up @@ -246,6 +269,9 @@ class DiceHelper:
num_classes: number of input channels (always including the background). When this is ``None``,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
computed for each connected component in the ground truth, and then averaged. This requires binary
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.
"""

@deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax")
Expand All @@ -262,6 +288,7 @@ def __init__(
num_classes: int | None = None,
sigmoid: bool | None = None,
softmax: bool | None = None,
per_component: bool = False,
) -> None:
# handling deprecated arguments
if sigmoid is not None:
Expand All @@ -277,6 +304,117 @@ def __init__(
self.activate = activate
self.ignore_empty = ignore_empty
self.num_classes = num_classes
self.per_component = per_component

def compute_voronoi_regions_fast(self, labels):
"""
Voronoi assignment to connected components (CPU, single EDT) without cc3d.
Returns the ID of the nearest component for each voxel.

Args:
labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds.

Raises:
RuntimeError: when `scipy.ndimage` is not available.
ValueError: when `labels` has fewer than two dimensions.

Returns:
torch.Tensor: Voronoi region IDs (int32) on CPU.
"""
if isinstance(labels, torch.Tensor) and labels.is_cuda and has_cupy and has_cupy_ndimage:
xp = cupy
nd_distance_transform_edt = cupy_ndimage.distance_transform_edt
nd_generate_binary_structure = cupy_ndimage.generate_binary_structure
nd_label = cupy_ndimage.label
x = cupy.asarray(labels.detach())
else:
xp = np
nd_distance_transform_edt = scipy_ndimage.distance_transform_edt
nd_generate_binary_structure = scipy_ndimage.generate_binary_structure
nd_label = scipy_ndimage.label

if not has_scipy_ndimage:
raise RuntimeError("scipy.ndimage is required for per_component Dice computation.")

if isinstance(labels, torch.Tensor):
warnings.warn(
"Voronoi computation is running on CPU. "
"To accelerate, move the input tensor to GPU and ensure 'cupy' with 'cupyx.scipy.ndimage' is installed."
)
x = labels.cpu().numpy()
else:
x = np.asarray(labels)
rank = x.ndim
if rank == 3:
conn_map = {6: 1, 18: 2, 26: 3}
connectivity = 26
elif rank == 2:
conn_map = {4: 1, 8: 2}
connectivity = 8
else:
raise ValueError("Only 2D or 3D inputs supported")
conn_rank = conn_map.get(connectivity, max(conn_map.values()))
structure = nd_generate_binary_structure(rank=rank, connectivity=conn_rank)
cc, num = nd_label(x > 0, structure=structure)
if num == 0:
return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32)
edt_input = xp.ones(cc.shape, dtype=xp.uint8)
edt_input[cc > 0] = 0
indices = nd_distance_transform_edt(edt_input, sampling=None, return_distances=False, return_indices=True)
voronoi = cc[tuple(indices)]
if xp is cupy:
return torch.as_tensor(cupy.asnumpy(voronoi), dtype=torch.int32)
else:
return torch.as_tensor(voronoi, dtype=torch.int32)

def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Compute per-component Dice for a single batch item.

Args:
y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W) or (1, 2, H, W).
y (torch.Tensor): Ground truth with shape (1, 2, D, H, W) or (1, 2, H, W).

Returns:
torch.Tensor: Mean Dice over connected components.
"""
data = []
if y_pred.ndim == y.ndim:
y_pred_idx = torch.argmax(y_pred, dim=1)
y_idx = torch.argmax(y, dim=1)
else:
y_pred_idx = y_pred
y_idx = y
if y_idx[0].sum() == 0:
if self.ignore_empty:
data.append(torch.tensor(float("nan"), device=y_idx.device))
elif y_pred_idx.sum() == 0:
data.append(torch.tensor(1.0, device=y_idx.device))
else:
data.append(torch.tensor(0.0, device=y_idx.device))
else:
cc_assignment = self.compute_voronoi_regions_fast(y_idx[0])
if cc_assignment.device != y_idx.device:
cc_assignment = cc_assignment.to(y_idx.device)
uniq, inv = torch.unique(cc_assignment.view(-1), return_inverse=True)
nof_components = uniq.numel()
code = (y_idx.view(-1) << 1) | y_pred_idx.view(-1)
idx = (inv << 2) | code
hist = torch.bincount(idx, minlength=nof_components * 4).reshape(-1, 4)
_, fp, fn, tp = hist[:, 0], hist[:, 1], hist[:, 2], hist[:, 3]
denom = 2 * tp + fp + fn
dice_scores = torch.where(
denom > 0, (2 * tp).float() / denom.float(), torch.tensor(1.0, device=denom.device)
)
data.append(dice_scores.unsqueeze(-1))
data = [
torch.where(torch.isinf(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data
]
data = [
torch.where(torch.isnan(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data
]
data = [x.reshape(-1, 1) for x in data]
return torch.stack([x.mean() for x in data])

def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -305,6 +443,9 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...).
the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).

Raises:
ValueError: when the shapes of `y_pred` and `y` are not compatible for the per-component computation.
"""
_apply_argmax, _threshold = self.apply_argmax, self.threshold
if self.num_classes is None:
Expand All @@ -322,15 +463,33 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
y_pred = torch.sigmoid(y_pred)
y_pred = y_pred > 0.5

first_ch = 0 if self.include_background else 1
if self.per_component:
if y_pred.ndim not in (4, 5) or y.ndim not in (4, 5) or y_pred.shape[1] != 2 or y.shape[1] != 2:
same_rank = y_pred.ndim == y.ndim and y_pred.ndim in (4, 5)
binary_channels = y_pred.shape[1] == 2 and y.shape[1] == 2
same_shape = y_pred.shape == y.shape
if not (same_rank and binary_channels and same_shape):
raise ValueError(
"per_component requires matching 4D/5D binary tensors "
"(B, 2, H, W) or (B, 2, D, H, W). "
f"Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}."
)

first_ch = 0 if self.include_background and not self.per_component else 1
data = []
for b in range(y_pred.shape[0]):
if self.per_component:
data.append(self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0)).reshape(-1))
continue
c_list = []
for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
c_list.append(self.compute_channel(x_pred, x))
# if self.per_component:
# c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))]
data.append(torch.stack(c_list))

data = torch.stack(data, dim=0).contiguous() # type: ignore

f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore
Expand Down
64 changes: 64 additions & 0 deletions tests/metrics/test_compute_meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
from parameterized import parameterized

from monai.metrics import DiceHelper, DiceMetric, compute_dice
from monai.utils.module import optional_import

_, has_ndimage = optional_import("scipy.ndimage")
_, has_cupy_ndimage = optional_import("cupyx.scipy.ndimage")

_device = "cuda:0" if torch.cuda.is_available() else "cpu"
# keep background
Expand Down Expand Up @@ -250,6 +254,42 @@
{"label_1": 0.4000, "label_2": 0.6667},
]

# Testcase for per_component DiceMetric - 3D input
y = torch.zeros((5, 2, 64, 64, 64), device=_device)
y_hat = torch.zeros((5, 2, 64, 64, 64), device=_device)

y[0, 1, 20:25, 20:25, 20:25] = 1
y[0, 1, 40:45, 40:45, 40:45] = 1
y[0, 0] = 1 - y[0, 1]

y_hat[0, 1, 21:26, 21:26, 21:26] = 1
y_hat[0, 1, 41:46, 39:44, 41:46] = 1
y_hat[0, 0] = 1 - y_hat[0, 1]

TEST_CASE_16 = [
{"per_component": True, "ignore_empty": False},
{"y": y, "y_pred": y_hat},
[[0.5120], [1.0], [1.0], [1.0], [1.0]],
]

# Testcase for per_component DiceMetric - 2D input
y = torch.zeros((5, 2, 64, 64), device=_device)
y_hat = torch.zeros((5, 2, 64, 64), device=_device)

y[0, 1, 20:25, 20:25] = 1
y[0, 1, 40:45, 40:45] = 1
y[0, 0] = 1 - y[0, 1]

y_hat[0, 1, 21:26, 21:26] = 1
y_hat[0, 1, 41:46, 39:44] = 1
y_hat[0, 0] = 1 - y_hat[0, 1]

TEST_CASE_17 = [
{"per_component": True, "ignore_empty": False},
{"y": y, "y_pred": y_hat},
[[0.6400], [1.0], [1.0], [1.0], [1.0]],
]


class TestComputeMeanDice(unittest.TestCase):

Expand Down Expand Up @@ -301,6 +341,30 @@ def test_nans_class(self, params, input_data, expected_value):
else:
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

# CC DiceMetric tests
@parameterized.expand([TEST_CASE_16, TEST_CASE_17])
@unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.")
def test_cc_dice_value_nogpu(self, params, input_data, expected_value):
dice_metric = DiceMetric(**params)
cpu_inputs = {"y": input_data["y"].cpu(), "y_pred": input_data["y_pred"].cpu()}
dice_metric(**cpu_inputs)
result = dice_metric.aggregate(reduction="none")
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

@parameterized.expand([TEST_CASE_16, TEST_CASE_17])
@unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.")
@unittest.skipUnless(torch.cuda.is_available() and has_cupy_ndimage, "Requires CUDA and cupyx.scipy.ndimage.")
def test_cc_dice_value_gpu(self, params, input_data, expected_value):
dice_metric = DiceMetric(**params)
dice_metric(**input_data)
result = dice_metric.aggregate(reduction="none")
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

@unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.")
def test_channel_dimensions(self):
with self.assertRaises(ValueError):
DiceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 144, 144]))


if __name__ == "__main__":
unittest.main()
Loading