From 817c18bf2cc471f8ccc77ff0fb629dca9063436f Mon Sep 17 00:00:00 2001 From: yeedrag <46050186+yeedrag@users.noreply.github.com> Date: Wed, 9 Aug 2023 10:16:22 +0800 Subject: [PATCH] [Fix] Added ignore_index and one hot encoding for dice loss (#3237) Added ignore_index param to forward(), also implemented one hot encoding to ensure the dims of target matches pred. Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation Please describe the motivation of this PR and the goal you want to achieve through this PR. Attempted to solve the problems mentioned by #3172 ## Modification Please briefly describe what modification is made in this PR. Added ignore_index into forward function (although the dice loss itself does not actually take account for it for some reason). Added _expand_onehot_labels_dice, which takes the target with shape [N, H, W] into [N, num_classes, H, W]. ## BC-breaking (Optional) Does the modification introduce changes that break the backward-compatibility of the downstream repos? If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR. ## Use cases (Optional) If this PR introduces a new feature, it is better to list some use cases here, and update the documentation. ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. 4. The documentation has been modified accordingly, like docstring or example tutorials. This is my first time contributing to open-source code, so I might have made some stupid mistakes. Please don't hesitate to point it out. --- mmseg/models/losses/dice_loss.py | 69 +++++++++++++------ .../test_models/test_losses/test_dice_loss.py | 51 ++++++++++++-- 2 files changed, 93 insertions(+), 27 deletions(-) diff --git a/mmseg/models/losses/dice_loss.py b/mmseg/models/losses/dice_loss.py index 2b7c9d4cf1..65eae8aebc 100644 --- a/mmseg/models/losses/dice_loss.py +++ b/mmseg/models/losses/dice_loss.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + import torch import torch.nn as nn @@ -6,15 +8,35 @@ from .utils import weight_reduce_loss -def dice_loss( - pred, - target, - weight, - eps=1e-3, - reduction='mean', - naive_dice=False, - avg_factor=None, -): +def _expand_onehot_labels_dice(pred: torch.Tensor, + target: torch.Tensor) -> torch.Tensor: + """Expand onehot labels to match the size of prediction. + + Args: + pred (torch.Tensor): The prediction, has a shape (N, num_class, H, W). + target (torch.Tensor): The learning label of the prediction, + has a shape (N, H, W). + + Returns: + torch.Tensor: The target after one-hot encoding, + has a shape (N, num_class, H, W). + """ + num_classes = pred.shape[1] + one_hot_target = torch.clamp(target, min=0, max=num_classes) + one_hot_target = torch.nn.functional.one_hot(one_hot_target, + num_classes + 1) + one_hot_target = one_hot_target[..., :num_classes].permute(0, 3, 1, 2) + return one_hot_target + + +def dice_loss(pred: torch.Tensor, + target: torch.Tensor, + weight: Union[torch.Tensor, None], + eps: float = 1e-3, + reduction: Union[str, None] = 'mean', + naive_dice: Union[bool, None] = False, + avg_factor: Union[int, None] = None, + ignore_index: Union[int, None] = 255) -> float: """Calculate dice loss, there are two forms of dice loss is supported: - the one proposed in `V-Net: Fully Convolutional Neural @@ -41,11 +63,15 @@ def dice_loss( power.Defaults to False. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. + ignore_index (int, optional): The label index to be ignored. + Defaults to 255. """ - + num_classes = pred.shape[1] + pred = pred[:, torch.arange(num_classes) != ignore_index, :, :] + target = target[:, torch.arange(num_classes) != ignore_index, :, :] + assert pred.shape[1] != 0 # if the ignored index is the only class input = pred.flatten(1) target = target.flatten(1).float() - a = torch.sum(input * target, 1) if naive_dice: b = torch.sum(input, 1) @@ -93,7 +119,7 @@ def __init__(self, denominator is the first power instead of the second power. Defaults to False. loss_weight (float, optional): Weight of loss. Defaults to 1.0. - ignore_index (int | None): The label index to be ignored. + ignore_index (int, optional): The label index to be ignored. Default: 255. eps (float): Avoid dividing by zero. Defaults to 1e-3. loss_name (str, optional): Name of the loss item. If you want this @@ -116,7 +142,9 @@ def forward(self, target, weight=None, avg_factor=None, - reduction_override=None): + reduction_override=None, + ignore_index=255, + **kwargs): """Forward function. Args: @@ -134,26 +162,27 @@ def forward(self, Returns: torch.Tensor: The calculated loss """ - + one_hot_target = target + if (pred.shape != target.shape): + one_hot_target = _expand_onehot_labels_dice(pred, target) assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) - if self.activate: if self.use_sigmoid: pred = pred.sigmoid() - else: - raise NotImplementedError - + elif pred.shape[1] != 1: + # softmax does not work when there is only 1 class + pred = pred.softmax(dim=1) loss = self.loss_weight * dice_loss( pred, - target, + one_hot_target, weight, eps=self.eps, reduction=reduction, naive_dice=self.naive_dice, avg_factor=avg_factor, - ) + ignore_index=self.ignore_index) return loss diff --git a/tests/test_models/test_losses/test_dice_loss.py b/tests/test_models/test_losses/test_dice_loss.py index 4095a6ad2d..34253dae12 100644 --- a/tests/test_models/test_losses/test_dice_loss.py +++ b/tests/test_models/test_losses/test_dice_loss.py @@ -8,10 +8,9 @@ @pytest.mark.parametrize('naive_dice', [True, False]) def test_dice_loss(naive_dice): loss_class = DiceLoss - pred = torch.rand((10, 4, 4)) - target = torch.rand((10, 4, 4)) - weight = torch.rand(10) - + pred = torch.rand((1, 10, 4, 4)) + target = torch.randint(0, 10, (1, 4, 4)) + weight = torch.rand(1) # Test loss forward loss = loss_class(naive_dice=naive_dice)(pred, target) assert isinstance(loss, torch.Tensor) @@ -43,10 +42,11 @@ def test_dice_loss(naive_dice): assert isinstance(loss, torch.Tensor) # Test loss forward with has_acted=False and use_sigmoid=False - with pytest.raises(NotImplementedError): + for use_sigmoid in [True, False]: loss_class( - use_sigmoid=False, activate=True, naive_dice=naive_dice)(pred, - target) + use_sigmoid=use_sigmoid, activate=True, + naive_dice=naive_dice)(pred, target) + assert isinstance(loss, torch.Tensor) # Test loss forward with weight.ndim != loss.ndim with pytest.raises(AssertionError): @@ -57,3 +57,40 @@ def test_dice_loss(naive_dice): with pytest.raises(AssertionError): weight = torch.rand(8) loss_class(naive_dice=naive_dice)(pred, target, weight) + + # Test _expand_onehot_labels_dice + pred = torch.tensor([[[[1, 1], [1, 0]], [[0, 1], [1, 1]]]]).float() + target = torch.tensor([[[0, 0], [0, 1]]]) + target_onehot = torch.tensor([[[[1, 1], [1, 0]], [[0, 0], [0, 1]]]]) + weight = torch.rand(1) + loss = loss_class(naive_dice=naive_dice)(pred, target, weight) + loss_onehot = loss_class(naive_dice=naive_dice)(pred, target_onehot, + weight) + assert torch.equal(loss, loss_onehot) + + # Test Whether Loss is 0 when pred == target, eps == 0 and naive_dice=False + target = torch.randint(0, 2, (1, 10, 4, 4)) + pred = target.float() + target = target.sigmoid() + weight = torch.rand(1) + loss = loss_class( + naive_dice=False, use_sigmoid=True, eps=0)(pred, target, weight) + assert loss.item() == 0 + + # Test ignore_index when ignore_index is the only class + with pytest.raises(AssertionError): + pred = torch.ones((1, 1, 4, 4)) + target = torch.randint(0, 1, (1, 4, 4)) + weight = torch.rand(1) + loss = loss_class( + naive_dice=naive_dice, use_sigmoid=False, ignore_index=0, + eps=0)(pred, target, weight) + + # Test ignore_index with naive_dice = False + pred = torch.tensor([[[[1, 1], [1, 0]], [[0, 1], [1, 1]]]]).float() + target = torch.tensor([[[[1, 1], [1, 0]], [[1, 0], [0, 1]]]]).sigmoid() + weight = torch.rand(1) + loss = loss_class( + naive_dice=False, use_sigmoid=True, ignore_index=1, + eps=0)(pred, target, weight) + assert loss.item() == 0