diff --git a/mmseg/models/losses/dice_loss.py b/mmseg/models/losses/dice_loss.py index 2b7c9d4cf1..65eae8aebc 100644 --- a/mmseg/models/losses/dice_loss.py +++ b/mmseg/models/losses/dice_loss.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + import torch import torch.nn as nn @@ -6,15 +8,35 @@ from .utils import weight_reduce_loss -def dice_loss( - pred, - target, - weight, - eps=1e-3, - reduction='mean', - naive_dice=False, - avg_factor=None, -): +def _expand_onehot_labels_dice(pred: torch.Tensor, + target: torch.Tensor) -> torch.Tensor: + """Expand onehot labels to match the size of prediction. + + Args: + pred (torch.Tensor): The prediction, has a shape (N, num_class, H, W). + target (torch.Tensor): The learning label of the prediction, + has a shape (N, H, W). + + Returns: + torch.Tensor: The target after one-hot encoding, + has a shape (N, num_class, H, W). + """ + num_classes = pred.shape[1] + one_hot_target = torch.clamp(target, min=0, max=num_classes) + one_hot_target = torch.nn.functional.one_hot(one_hot_target, + num_classes + 1) + one_hot_target = one_hot_target[..., :num_classes].permute(0, 3, 1, 2) + return one_hot_target + + +def dice_loss(pred: torch.Tensor, + target: torch.Tensor, + weight: Union[torch.Tensor, None], + eps: float = 1e-3, + reduction: Union[str, None] = 'mean', + naive_dice: Union[bool, None] = False, + avg_factor: Union[int, None] = None, + ignore_index: Union[int, None] = 255) -> float: """Calculate dice loss, there are two forms of dice loss is supported: - the one proposed in `V-Net: Fully Convolutional Neural @@ -41,11 +63,15 @@ def dice_loss( power.Defaults to False. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. + ignore_index (int, optional): The label index to be ignored. + Defaults to 255. """ - + num_classes = pred.shape[1] + pred = pred[:, torch.arange(num_classes) != ignore_index, :, :] + target = target[:, torch.arange(num_classes) != ignore_index, :, :] + assert pred.shape[1] != 0 # if the ignored index is the only class input = pred.flatten(1) target = target.flatten(1).float() - a = torch.sum(input * target, 1) if naive_dice: b = torch.sum(input, 1) @@ -93,7 +119,7 @@ def __init__(self, denominator is the first power instead of the second power. Defaults to False. loss_weight (float, optional): Weight of loss. Defaults to 1.0. - ignore_index (int | None): The label index to be ignored. + ignore_index (int, optional): The label index to be ignored. Default: 255. eps (float): Avoid dividing by zero. Defaults to 1e-3. loss_name (str, optional): Name of the loss item. If you want this @@ -116,7 +142,9 @@ def forward(self, target, weight=None, avg_factor=None, - reduction_override=None): + reduction_override=None, + ignore_index=255, + **kwargs): """Forward function. Args: @@ -134,26 +162,27 @@ def forward(self, Returns: torch.Tensor: The calculated loss """ - + one_hot_target = target + if (pred.shape != target.shape): + one_hot_target = _expand_onehot_labels_dice(pred, target) assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) - if self.activate: if self.use_sigmoid: pred = pred.sigmoid() - else: - raise NotImplementedError - + elif pred.shape[1] != 1: + # softmax does not work when there is only 1 class + pred = pred.softmax(dim=1) loss = self.loss_weight * dice_loss( pred, - target, + one_hot_target, weight, eps=self.eps, reduction=reduction, naive_dice=self.naive_dice, avg_factor=avg_factor, - ) + ignore_index=self.ignore_index) return loss diff --git a/tests/test_models/test_losses/test_dice_loss.py b/tests/test_models/test_losses/test_dice_loss.py index 4095a6ad2d..34253dae12 100644 --- a/tests/test_models/test_losses/test_dice_loss.py +++ b/tests/test_models/test_losses/test_dice_loss.py @@ -8,10 +8,9 @@ @pytest.mark.parametrize('naive_dice', [True, False]) def test_dice_loss(naive_dice): loss_class = DiceLoss - pred = torch.rand((10, 4, 4)) - target = torch.rand((10, 4, 4)) - weight = torch.rand(10) - + pred = torch.rand((1, 10, 4, 4)) + target = torch.randint(0, 10, (1, 4, 4)) + weight = torch.rand(1) # Test loss forward loss = loss_class(naive_dice=naive_dice)(pred, target) assert isinstance(loss, torch.Tensor) @@ -43,10 +42,11 @@ def test_dice_loss(naive_dice): assert isinstance(loss, torch.Tensor) # Test loss forward with has_acted=False and use_sigmoid=False - with pytest.raises(NotImplementedError): + for use_sigmoid in [True, False]: loss_class( - use_sigmoid=False, activate=True, naive_dice=naive_dice)(pred, - target) + use_sigmoid=use_sigmoid, activate=True, + naive_dice=naive_dice)(pred, target) + assert isinstance(loss, torch.Tensor) # Test loss forward with weight.ndim != loss.ndim with pytest.raises(AssertionError): @@ -57,3 +57,40 @@ def test_dice_loss(naive_dice): with pytest.raises(AssertionError): weight = torch.rand(8) loss_class(naive_dice=naive_dice)(pred, target, weight) + + # Test _expand_onehot_labels_dice + pred = torch.tensor([[[[1, 1], [1, 0]], [[0, 1], [1, 1]]]]).float() + target = torch.tensor([[[0, 0], [0, 1]]]) + target_onehot = torch.tensor([[[[1, 1], [1, 0]], [[0, 0], [0, 1]]]]) + weight = torch.rand(1) + loss = loss_class(naive_dice=naive_dice)(pred, target, weight) + loss_onehot = loss_class(naive_dice=naive_dice)(pred, target_onehot, + weight) + assert torch.equal(loss, loss_onehot) + + # Test Whether Loss is 0 when pred == target, eps == 0 and naive_dice=False + target = torch.randint(0, 2, (1, 10, 4, 4)) + pred = target.float() + target = target.sigmoid() + weight = torch.rand(1) + loss = loss_class( + naive_dice=False, use_sigmoid=True, eps=0)(pred, target, weight) + assert loss.item() == 0 + + # Test ignore_index when ignore_index is the only class + with pytest.raises(AssertionError): + pred = torch.ones((1, 1, 4, 4)) + target = torch.randint(0, 1, (1, 4, 4)) + weight = torch.rand(1) + loss = loss_class( + naive_dice=naive_dice, use_sigmoid=False, ignore_index=0, + eps=0)(pred, target, weight) + + # Test ignore_index with naive_dice = False + pred = torch.tensor([[[[1, 1], [1, 0]], [[0, 1], [1, 1]]]]).float() + target = torch.tensor([[[[1, 1], [1, 0]], [[1, 0], [0, 1]]]]).sigmoid() + weight = torch.rand(1) + loss = loss_class( + naive_dice=False, use_sigmoid=True, ignore_index=1, + eps=0)(pred, target, weight) + assert loss.item() == 0