Skip to content

Commit

Permalink
[Feature] huasdorff distance loss (#2820)
Browse files Browse the repository at this point in the history
Thanks for your contribution and we appreciate it a lot. The following
instructions would make your pull request more healthy and more easily
get feedback. If you do not understand some items, don't worry, just
make the pull request and seek help from maintainers.

## Motivation

Add Huasdorff distance loss

---------

Co-authored-by: xiexinch <[email protected]>
  • Loading branch information
jinxianwei and xiexinch committed Jun 19, 2023
1 parent b2f4b4f commit bb93b48
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 1 deletion.
4 changes: 3 additions & 1 deletion mmseg/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
cross_entropy, mask_cross_entropy)
from .dice_loss import DiceLoss
from .focal_loss import FocalLoss
from .huasdorff_distance_loss import HuasdorffDisstanceLoss
from .lovasz_loss import LovaszLoss
from .ohem_cross_entropy_loss import OhemCrossEntropy
from .tversky_loss import TverskyLoss
Expand All @@ -14,5 +15,6 @@
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss',
'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss'
'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss',
'HuasdorffDisstanceLoss'
]
160 changes: 160 additions & 0 deletions mmseg/models/losses/huasdorff_distance_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/JunMa11/SegWithDistMap/blob/
master/code/train_LA_HD.py (Apache-2.0 License)"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.ndimage import distance_transform_edt as distance
from torch import Tensor

from mmseg.registry import MODELS
from .utils import get_class_weight, weighted_loss


def compute_dtm(img_gt: Tensor, pred: Tensor) -> Tensor:
"""
compute the distance transform map of foreground in mask
Args:
img_gt: Ground truth of the image, (b, h, w)
pred: Predictions of the segmentation head after softmax, (b, c, h, w)
Returns:
output: the foreground Distance Map (SDM)
dtm(x) = 0; x in segmentation boundary
inf|x-y|; x in segmentation
"""

fg_dtm = torch.zeros_like(pred)
out_shape = pred.shape
for b in range(out_shape[0]): # batch size
for c in range(1, out_shape[1]): # default 0 channel is background
posmask = img_gt[b].byte()
if posmask.any():
posdis = distance(posmask)
fg_dtm[b][c] = torch.from_numpy(posdis)

return fg_dtm


@weighted_loss
def hd_loss(seg_soft: Tensor,
gt: Tensor,
seg_dtm: Tensor,
gt_dtm: Tensor,
class_weight=None,
ignore_index=255) -> Tensor:
"""
compute huasdorff distance loss for segmentation
Args:
seg_soft: softmax results, shape=(b,c,x,y)
gt: ground truth, shape=(b,x,y)
seg_dtm: segmentation distance transform map, shape=(b,c,x,y)
gt_dtm: ground truth distance transform map, shape=(b,c,x,y)
Returns:
output: hd_loss
"""
assert seg_soft.shape[0] == gt.shape[0]
total_loss = 0
num_class = seg_soft.shape[1]
if class_weight is not None:
assert class_weight.ndim == num_class
for i in range(1, num_class):
if i != ignore_index:
delta_s = (seg_soft[:, i, ...] - gt.float())**2
s_dtm = seg_dtm[:, i, ...]**2
g_dtm = gt_dtm[:, i, ...]**2
dtm = s_dtm + g_dtm
multiplied = torch.einsum('bxy, bxy->bxy', delta_s, dtm)
hd_loss = multiplied.mean()
if class_weight is not None:
hd_loss *= class_weight[i]
total_loss += hd_loss

return total_loss / num_class


@MODELS.register_module()
class HuasdorffDisstanceLoss(nn.Module):
"""HuasdorffDisstanceLoss. This loss is proposed in `How Distance Transform
Maps Boost Segmentation CNNs: An Empirical Study.
<http://proceedings.mlr.press/v121/ma20b.html>`_.
Args:
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float): Weight of the loss. Defaults to 1.0.
ignore_index (int | None): The label index to be ignored. Default: 255.
loss_name (str): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_boundary'.
"""

def __init__(self,
reduction='mean',
class_weight=None,
loss_weight=1.0,
ignore_index=255,
loss_name='loss_huasdorff_disstance',
**kwargs):
super().__init__()
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = get_class_weight(class_weight)
self._loss_name = loss_name
self.ignore_index = ignore_index

def forward(self,
pred: Tensor,
target: Tensor,
avg_factor=None,
reduction_override=None,
**kwargs) -> Tensor:
"""Forward function.
Args:
pred (Tensor): Predictions of the segmentation head. (B, C, H, W)
target (Tensor): Ground truth of the image. (B, H, W)
avg_factor (int, optional): Average factor that is used to
average the loss. Defaults to None.
reduction_override (str, optional): The reduction method used
to override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
Tensor: Loss tensor.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = pred.new_tensor(self.class_weight)
else:
class_weight = None

pred_soft = F.softmax(pred, dim=1)
valid_mask = (target != self.ignore_index).long()
target = target * valid_mask

with torch.no_grad():
gt_dtm = compute_dtm(target.cpu(), pred_soft)
gt_dtm = gt_dtm.float()
seg_dtm2 = compute_dtm(
pred_soft.argmax(dim=1, keepdim=False).cpu(), pred_soft)
seg_dtm2 = seg_dtm2.float()

loss_hd = self.loss_weight * hd_loss(
pred_soft,
target,
seg_dtm=seg_dtm2,
gt_dtm=gt_dtm,
reduction=reduction,
avg_factor=avg_factor,
class_weight=class_weight,
ignore_index=self.ignore_index)
return loss_hd

@property
def loss_name(self):
return self._loss_name
29 changes: 29 additions & 0 deletions tests/test_models/test_losses/test_huasdorff_distance_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmseg.models.losses import HuasdorffDisstanceLoss


def test_huasdorff_distance_loss():
loss_class = HuasdorffDisstanceLoss
pred = torch.rand((10, 8, 6, 6))
target = torch.rand((10, 6, 6))
class_weight = torch.rand(8)

# Test loss forward
loss = loss_class()(pred, target)
assert isinstance(loss, torch.Tensor)

# Test loss forward with avg_factor
loss = loss_class()(pred, target, avg_factor=10)
assert isinstance(loss, torch.Tensor)

# Test loss forward with avg_factor and reduction is None, 'sum' and 'mean'
for reduction in [None, 'sum', 'mean']:
loss = loss_class()(pred, target, avg_factor=10, reduction=reduction)
assert isinstance(loss, torch.Tensor)

# Test loss forward with class_weight
with pytest.raises(AssertionError):
loss_class(class_weight=class_weight)(pred, target)

0 comments on commit bb93b48

Please sign in to comment.