Skip to content
Open
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 75 additions & 21 deletions monai/losses/unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,22 @@ def __init__(
gamma: float = 0.75,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
include_background: bool = True,
) -> None:
"""
Args:
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
include_background: whether to include background class in loss calculation. Defaults to True.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon
self.include_background: bool = include_background

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]
Expand All @@ -74,13 +77,19 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)
dice_class = torch.clamp(dice_class, self.epsilon, 1.0 - self.epsilon)

# Calculate losses separately for each class, enhancing both classes
back_dice = 1 - dice_class[:, 0]
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)
back_dice = 1 - dice_class[:, 0:1]
fore_dice = (1 - dice_class[:, 1:]) * torch.pow(1 - dice_class[:, 1:], -self.gamma)

if not self.include_background:
back_dice = back_dice * 0.0

all_dice = torch.cat([back_dice, fore_dice], dim=1)

# Average class scores
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
loss = torch.mean(all_dice)
return loss


Expand All @@ -103,19 +112,22 @@ def __init__(
gamma: float = 2,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
include_background: bool = True,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 2.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
include_background: whether to include background class in loss calculation. Defaults to True.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon
self.include_background: bool = include_background

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]
Expand All @@ -132,13 +144,18 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
cross_entropy = -y_true * torch.log(y_pred)

back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
back_ce = torch.pow(1 - y_pred[:, 0:1], self.gamma) * cross_entropy[:, 0:1]
back_ce = (1 - self.delta) * back_ce

fore_ce = cross_entropy[:, 1]
fore_ce = cross_entropy[:, 1:]
fore_ce = self.delta * fore_ce

loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
if not self.include_background:
back_ce = back_ce * 0.0

all_ce = torch.cat([back_ce, fore_ce], dim=1)

loss = torch.mean(all_ce)
return loss


Expand All @@ -162,15 +179,20 @@ def __init__(
gamma: float = 0.5,
delta: float = 0.7,
reduction: LossReduction | str = LossReduction.MEAN,
include_background: bool = True,
use_softmax: bool = False,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
num_classes : number of classes, it only supports 2 now. Defaults to 2.
num_classes : number of classes. Defaults to 2.
weight : weight for combining focal loss and focal tversky loss. Defaults to 0.5.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
reduction : reduction mode for the loss. Defaults to LossReduction.MEAN.
include_background : whether to include the background class in loss calculation. Defaults to True.
use_softmax: whether to use softmax to transform the original logits into probabilities.
If True, softmax is used. If False, sigmoid is used. Defaults to False.

Example:
>>> import torch
Expand All @@ -186,8 +208,22 @@ def __init__(
self.gamma = gamma
self.delta = delta
self.weight: float = weight
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
self.include_background: bool = include_background
self.use_softmax = use_softmax
self.asy_focal_loss = AsymmetricFocalLoss(
to_onehot_y=self.to_onehot_y,
gamma=self.gamma,
delta=self.delta,
include_background=self.include_background,
reduction=LossReduction.NONE,
)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
to_onehot_y=self.to_onehot_y,
gamma=self.gamma,
delta=self.delta,
include_background=self.include_background,
reduction=LossReduction.NONE,
)

# TODO: Implement this function to support multiple classes segmentation
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
Expand All @@ -196,8 +232,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
y_pred : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.
The input should be the original logits since it will be transformed by
a sigmoid in the forward function.
y_true : the shape should be BNH[WD], where N is the number of classes.
a sigmoid or softmax in the forward function.
y_true : the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
It only supports binary segmentation.

Raises:
Expand All @@ -212,12 +248,18 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")

if y_pred.shape[1] == 1:
y_pred = one_hot(y_pred, num_classes=self.num_classes)
y_true = one_hot(y_true, num_classes=self.num_classes)

if torch.max(y_true) != self.num_classes - 1:
raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}")
if y_true.shape[1] == self.num_classes:
if not torch.all((y_true == 0) | (y_true == 1)):
raise ValueError("y_true appears to be one-hot but contains values other than 0 and 1")
elif y_true.shape[1] == 1:
if torch.max(y_true) >= self.num_classes:
raise ValueError(
f"y_true labels must be in [0, {self.num_classes - 1}], but got max {torch.max(y_true)}"
)
else:
raise ValueError(
f"y_true must have {self.num_classes} channels (one-hot) or 1 channel (labels), got {y_true.shape[1]}"
)

n_pred_ch = y_pred.shape[1]
if self.to_onehot_y:
Expand All @@ -226,6 +268,18 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_pred.shape[1] == 1:
y_pred_sigmoid = torch.sigmoid(y_pred.float())
y_pred = torch.cat([1 - y_pred_sigmoid, y_pred_sigmoid], dim=1)

if y_true.shape[1] == 1:
y_true = one_hot(y_true, num_classes=self.num_classes)
else:
if self.use_softmax:
y_pred = torch.softmax(y_pred.float(), dim=1)
else:
y_pred = torch.sigmoid(y_pred.float())

asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)

Expand Down
Loading