diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 8484eb67ed..b64c6c3628 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -22,9 +22,9 @@ class AsymmetricFocalTverskyLoss(_Loss): """ - AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. + AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss that focuses on foreground classes. - Actually, it's only supported for binary image segmentation now. + Supports multi-class segmentation with optional background inclusion. Reimplementation of the Asymmetric Focal Tversky Loss described in: @@ -39,6 +39,7 @@ def __init__( gamma: float = 0.75, epsilon: float = 1e-7, reduction: LossReduction | str = LossReduction.MEAN, + include_background: bool = True, ) -> None: """ Args: @@ -46,19 +47,21 @@ def __init__( delta : weight of the background. Defaults to 0.7. gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. + include_background: whether to include background class in loss calculation. Defaults to True. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y self.delta = delta self.gamma = gamma self.epsilon = epsilon + self.include_background: bool = include_background def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] if self.to_onehot_y: if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) else: y_true = one_hot(y_true, num_classes=n_pred_ch) @@ -74,21 +77,27 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: fn = torch.sum(y_true * (1 - y_pred), dim=axis) fp = torch.sum((1 - y_true) * y_pred, dim=axis) dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) + dice_class = torch.clamp(dice_class, self.epsilon, 1.0 - self.epsilon) # Calculate losses separately for each class, enhancing both classes - back_dice = 1 - dice_class[:, 0] - fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) + back_dice = 1 - dice_class[:, 0:1] + fore_dice = (1 - dice_class[:, 1:]) * torch.pow(1 - dice_class[:, 1:], -self.gamma) + + if not self.include_background: + back_dice = back_dice * 0.0 + + all_dice = torch.cat([back_dice, fore_dice], dim=1) # Average class scores - loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) + loss = torch.mean(all_dice) return loss class AsymmetricFocalLoss(_Loss): """ - AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. + AsymmetricFocalLoss is a variant of Focal Loss that focuses on foreground classes. - Actually, it's only supported for binary image segmentation now. + Supports multi-class segmentation with optional background inclusion. Reimplementation of the Asymmetric Focal Loss described in: @@ -103,26 +112,29 @@ def __init__( gamma: float = 2, epsilon: float = 1e-7, reduction: LossReduction | str = LossReduction.MEAN, + include_background: bool = True, ): """ Args: to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. + gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 2. epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. + include_background: whether to include background class in loss calculation. Defaults to True. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y self.delta = delta self.gamma = gamma self.epsilon = epsilon + self.include_background: bool = include_background def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] if self.to_onehot_y: if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) else: y_true = one_hot(y_true, num_classes=n_pred_ch) @@ -132,21 +144,26 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) cross_entropy = -y_true * torch.log(y_pred) - back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] + back_ce = torch.pow(1 - y_pred[:, 0:1], self.gamma) * cross_entropy[:, 0:1] back_ce = (1 - self.delta) * back_ce - fore_ce = cross_entropy[:, 1] + fore_ce = cross_entropy[:, 1:] fore_ce = self.delta * fore_ce - loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) + if not self.include_background: + back_ce = back_ce * 0.0 + + all_ce = torch.cat([back_ce, fore_ce], dim=1) + + loss = torch.mean(all_ce) return loss class AsymmetricUnifiedFocalLoss(_Loss): """ - AsymmetricUnifiedFocalLoss is a variant of Focal Loss. + AsymmetricUnifiedFocalLoss combines Asymmetric Focal Loss and Asymmetric Focal Tversky Loss. - Actually, it's only supported for binary image segmentation now + Supports multi-class segmentation with configurable activation (sigmoid/softmax) and optional background inclusion. Reimplementation of the Asymmetric Unified Focal Tversky Loss described in: @@ -162,15 +179,20 @@ def __init__( gamma: float = 0.5, delta: float = 0.7, reduction: LossReduction | str = LossReduction.MEAN, + include_background: bool = True, + use_softmax: bool = False, ): """ Args: to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. - num_classes : number of classes, it only supports 2 now. Defaults to 2. + num_classes : number of classes. Defaults to 2. + weight : weight for combining focal loss and focal tversky loss. Defaults to 0.5. + gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5. delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. - weight : weight for each loss function, if it's none it's 0.5. Defaults to None. + reduction : reduction mode for the loss. Defaults to LossReduction.MEAN. + include_background : whether to include the background class in loss calculation. Defaults to True. + use_softmax: whether to use softmax to transform the original logits into probabilities. + If True, softmax is used. If False, sigmoid is used. Defaults to False. Example: >>> import torch @@ -179,6 +201,11 @@ def __init__( >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64) >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True) >>> fl(pred, grnd) + >>> # Multiclass example with 3 classes + >>> pred_mc = torch.randn((1,3,32,32), dtype=torch.float32) + >>> grnd_mc = torch.randint(0, 3, (1,1,32,32), dtype=torch.int64) + >>> fl_mc = AsymmetricUnifiedFocalLoss(to_onehot_y=True, num_classes=3, use_softmax=True) + >>> fl_mc(pred_mc, grnd_mc) """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y @@ -186,19 +213,30 @@ def __init__( self.gamma = gamma self.delta = delta self.weight: float = weight - self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) - self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) + self.include_background: bool = include_background + self.use_softmax = use_softmax + self.asy_focal_loss = AsymmetricFocalLoss( + to_onehot_y=self.to_onehot_y, + gamma=self.gamma, + delta=self.delta, + include_background=self.include_background, + reduction=LossReduction.NONE, + ) + self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( + to_onehot_y=self.to_onehot_y, + gamma=self.gamma, + delta=self.delta, + include_background=self.include_background, + reduction=LossReduction.NONE, + ) - # TODO: Implement this function to support multiple classes segmentation def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Args: y_pred : the shape should be BNH[WD], where N is the number of classes. - It only supports binary segmentation. The input should be the original logits since it will be transformed by - a sigmoid in the forward function. - y_true : the shape should be BNH[WD], where N is the number of classes. - It only supports binary segmentation. + a sigmoid or softmax in the forward function. + y_true : the shape should be BNH[WD] or B1H[WD], where N is the number of classes. Raises: ValueError: When input and target are different shape @@ -212,20 +250,39 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") - if y_pred.shape[1] == 1: - y_pred = one_hot(y_pred, num_classes=self.num_classes) - y_true = one_hot(y_true, num_classes=self.num_classes) - - if torch.max(y_true) != self.num_classes - 1: - raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}") + if y_true.shape[1] == self.num_classes: + if not torch.all((y_true == 0) | (y_true == 1)): + raise ValueError("y_true appears to be one-hot but contains values other than 0 and 1") + elif y_true.shape[1] == 1: + if torch.max(y_true) >= self.num_classes: + raise ValueError( + f"y_true labels must be in [0, {self.num_classes - 1}], but got max {torch.max(y_true)}" + ) + else: + raise ValueError( + f"y_true must have {self.num_classes} channels (one-hot) or 1 channel (labels), got {y_true.shape[1]}" + ) n_pred_ch = y_pred.shape[1] if self.to_onehot_y: if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) else: y_true = one_hot(y_true, num_classes=n_pred_ch) + if y_pred.shape[1] == 1: + warnings.warn("single channel prediction, augmenting with background channel.", stacklevel=2) + y_pred_sigmoid = torch.sigmoid(y_pred.float()) + y_pred = torch.cat([1 - y_pred_sigmoid, y_pred_sigmoid], dim=1) + + if y_true.shape[1] == 1: + y_true = one_hot(y_true, num_classes=self.num_classes) + else: + if self.use_softmax: + y_pred = torch.softmax(y_pred.float(), dim=1) + else: + y_pred = y_pred.float() + asy_focal_loss = self.asy_focal_loss(y_pred, y_true) asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)