Skip to content

Commit

Permalink
[Fix] Added ignore_index and one hot encoding for dice loss (#3237)
Browse files Browse the repository at this point in the history
Added ignore_index param to forward(),
also implemented one hot encoding to ensure the dims of target matches
pred.

Thanks for your contribution and we appreciate it a lot. The following
instructions would make your pull request more healthy and more easily
get feedback. If you do not understand some items, don't worry, just
make the pull request and seek help from maintainers.

## Motivation

Please describe the motivation of this PR and the goal you want to
achieve through this PR.

Attempted to solve the problems mentioned by #3172 

## Modification

Please briefly describe what modification is made in this PR.

Added ignore_index into forward function (although the dice loss itself
does not actually take account for it for some reason).
Added _expand_onehot_labels_dice, which takes the target with shape [N,
H, W] into [N, num_classes, H, W].

## BC-breaking (Optional)

Does the modification introduce changes that break the
backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the
downstream projects should modify their code to keep compatibility with
this PR.

## Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases
here, and update the documentation.

## Checklist

1. Pre-commit or other linting tools are used to fix the potential lint
issues.
2. The modification is covered by complete unit tests. If not, please
add more unit test to ensure the correctness.
3. If the modification has potential influence on downstream projects,
this PR should be tested with downstream projects, like MMDet or
MMDet3D.
4. The documentation has been modified accordingly, like docstring or
example tutorials.

This is my first time contributing to open-source code, so I might have
made some stupid mistakes. Please don't hesitate to point it out.
  • Loading branch information
yeedrag committed Aug 9, 2023
1 parent 4927b0e commit 817c18b
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 27 deletions.
69 changes: 49 additions & 20 deletions mmseg/models/losses/dice_loss.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,42 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union

import torch
import torch.nn as nn

from mmseg.registry import MODELS
from .utils import weight_reduce_loss


def dice_loss(
pred,
target,
weight,
eps=1e-3,
reduction='mean',
naive_dice=False,
avg_factor=None,
):
def _expand_onehot_labels_dice(pred: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
"""Expand onehot labels to match the size of prediction.
Args:
pred (torch.Tensor): The prediction, has a shape (N, num_class, H, W).
target (torch.Tensor): The learning label of the prediction,
has a shape (N, H, W).
Returns:
torch.Tensor: The target after one-hot encoding,
has a shape (N, num_class, H, W).
"""
num_classes = pred.shape[1]
one_hot_target = torch.clamp(target, min=0, max=num_classes)
one_hot_target = torch.nn.functional.one_hot(one_hot_target,
num_classes + 1)
one_hot_target = one_hot_target[..., :num_classes].permute(0, 3, 1, 2)
return one_hot_target


def dice_loss(pred: torch.Tensor,
target: torch.Tensor,
weight: Union[torch.Tensor, None],
eps: float = 1e-3,
reduction: Union[str, None] = 'mean',
naive_dice: Union[bool, None] = False,
avg_factor: Union[int, None] = None,
ignore_index: Union[int, None] = 255) -> float:
"""Calculate dice loss, there are two forms of dice loss is supported:
- the one proposed in `V-Net: Fully Convolutional Neural
Expand All @@ -41,11 +63,15 @@ def dice_loss(
power.Defaults to False.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
ignore_index (int, optional): The label index to be ignored.
Defaults to 255.
"""

num_classes = pred.shape[1]
pred = pred[:, torch.arange(num_classes) != ignore_index, :, :]
target = target[:, torch.arange(num_classes) != ignore_index, :, :]
assert pred.shape[1] != 0 # if the ignored index is the only class
input = pred.flatten(1)
target = target.flatten(1).float()

a = torch.sum(input * target, 1)
if naive_dice:
b = torch.sum(input, 1)
Expand Down Expand Up @@ -93,7 +119,7 @@ def __init__(self,
denominator is the first power instead of the second
power. Defaults to False.
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
ignore_index (int | None): The label index to be ignored.
ignore_index (int, optional): The label index to be ignored.
Default: 255.
eps (float): Avoid dividing by zero. Defaults to 1e-3.
loss_name (str, optional): Name of the loss item. If you want this
Expand All @@ -116,7 +142,9 @@ def forward(self,
target,
weight=None,
avg_factor=None,
reduction_override=None):
reduction_override=None,
ignore_index=255,
**kwargs):
"""Forward function.
Args:
Expand All @@ -134,26 +162,27 @@ def forward(self,
Returns:
torch.Tensor: The calculated loss
"""

one_hot_target = target
if (pred.shape != target.shape):
one_hot_target = _expand_onehot_labels_dice(pred, target)
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)

if self.activate:
if self.use_sigmoid:
pred = pred.sigmoid()
else:
raise NotImplementedError

elif pred.shape[1] != 1:
# softmax does not work when there is only 1 class
pred = pred.softmax(dim=1)
loss = self.loss_weight * dice_loss(
pred,
target,
one_hot_target,
weight,
eps=self.eps,
reduction=reduction,
naive_dice=self.naive_dice,
avg_factor=avg_factor,
)
ignore_index=self.ignore_index)

return loss

Expand Down
51 changes: 44 additions & 7 deletions tests/test_models/test_losses/test_dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
@pytest.mark.parametrize('naive_dice', [True, False])
def test_dice_loss(naive_dice):
loss_class = DiceLoss
pred = torch.rand((10, 4, 4))
target = torch.rand((10, 4, 4))
weight = torch.rand(10)

pred = torch.rand((1, 10, 4, 4))
target = torch.randint(0, 10, (1, 4, 4))
weight = torch.rand(1)
# Test loss forward
loss = loss_class(naive_dice=naive_dice)(pred, target)
assert isinstance(loss, torch.Tensor)
Expand Down Expand Up @@ -43,10 +42,11 @@ def test_dice_loss(naive_dice):
assert isinstance(loss, torch.Tensor)

# Test loss forward with has_acted=False and use_sigmoid=False
with pytest.raises(NotImplementedError):
for use_sigmoid in [True, False]:
loss_class(
use_sigmoid=False, activate=True, naive_dice=naive_dice)(pred,
target)
use_sigmoid=use_sigmoid, activate=True,
naive_dice=naive_dice)(pred, target)
assert isinstance(loss, torch.Tensor)

# Test loss forward with weight.ndim != loss.ndim
with pytest.raises(AssertionError):
Expand All @@ -57,3 +57,40 @@ def test_dice_loss(naive_dice):
with pytest.raises(AssertionError):
weight = torch.rand(8)
loss_class(naive_dice=naive_dice)(pred, target, weight)

# Test _expand_onehot_labels_dice
pred = torch.tensor([[[[1, 1], [1, 0]], [[0, 1], [1, 1]]]]).float()
target = torch.tensor([[[0, 0], [0, 1]]])
target_onehot = torch.tensor([[[[1, 1], [1, 0]], [[0, 0], [0, 1]]]])
weight = torch.rand(1)
loss = loss_class(naive_dice=naive_dice)(pred, target, weight)
loss_onehot = loss_class(naive_dice=naive_dice)(pred, target_onehot,
weight)
assert torch.equal(loss, loss_onehot)

# Test Whether Loss is 0 when pred == target, eps == 0 and naive_dice=False
target = torch.randint(0, 2, (1, 10, 4, 4))
pred = target.float()
target = target.sigmoid()
weight = torch.rand(1)
loss = loss_class(
naive_dice=False, use_sigmoid=True, eps=0)(pred, target, weight)
assert loss.item() == 0

# Test ignore_index when ignore_index is the only class
with pytest.raises(AssertionError):
pred = torch.ones((1, 1, 4, 4))
target = torch.randint(0, 1, (1, 4, 4))
weight = torch.rand(1)
loss = loss_class(
naive_dice=naive_dice, use_sigmoid=False, ignore_index=0,
eps=0)(pred, target, weight)

# Test ignore_index with naive_dice = False
pred = torch.tensor([[[[1, 1], [1, 0]], [[0, 1], [1, 1]]]]).float()
target = torch.tensor([[[[1, 1], [1, 0]], [[1, 0], [0, 1]]]]).sigmoid()
weight = torch.rand(1)
loss = loss_class(
naive_dice=False, use_sigmoid=True, ignore_index=1,
eps=0)(pred, target, weight)
assert loss.item() == 0

0 comments on commit 817c18b

Please sign in to comment.