-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] huasdorff distance loss (#2820)
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation Add Huasdorff distance loss --------- Co-authored-by: xiexinch <[email protected]>
- Loading branch information
1 parent
b2f4b4f
commit bb93b48
Showing
3 changed files
with
192 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
"""Modified from https://github.com/JunMa11/SegWithDistMap/blob/ | ||
master/code/train_LA_HD.py (Apache-2.0 License)""" | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from scipy.ndimage import distance_transform_edt as distance | ||
from torch import Tensor | ||
|
||
from mmseg.registry import MODELS | ||
from .utils import get_class_weight, weighted_loss | ||
|
||
|
||
def compute_dtm(img_gt: Tensor, pred: Tensor) -> Tensor: | ||
""" | ||
compute the distance transform map of foreground in mask | ||
Args: | ||
img_gt: Ground truth of the image, (b, h, w) | ||
pred: Predictions of the segmentation head after softmax, (b, c, h, w) | ||
Returns: | ||
output: the foreground Distance Map (SDM) | ||
dtm(x) = 0; x in segmentation boundary | ||
inf|x-y|; x in segmentation | ||
""" | ||
|
||
fg_dtm = torch.zeros_like(pred) | ||
out_shape = pred.shape | ||
for b in range(out_shape[0]): # batch size | ||
for c in range(1, out_shape[1]): # default 0 channel is background | ||
posmask = img_gt[b].byte() | ||
if posmask.any(): | ||
posdis = distance(posmask) | ||
fg_dtm[b][c] = torch.from_numpy(posdis) | ||
|
||
return fg_dtm | ||
|
||
|
||
@weighted_loss | ||
def hd_loss(seg_soft: Tensor, | ||
gt: Tensor, | ||
seg_dtm: Tensor, | ||
gt_dtm: Tensor, | ||
class_weight=None, | ||
ignore_index=255) -> Tensor: | ||
""" | ||
compute huasdorff distance loss for segmentation | ||
Args: | ||
seg_soft: softmax results, shape=(b,c,x,y) | ||
gt: ground truth, shape=(b,x,y) | ||
seg_dtm: segmentation distance transform map, shape=(b,c,x,y) | ||
gt_dtm: ground truth distance transform map, shape=(b,c,x,y) | ||
Returns: | ||
output: hd_loss | ||
""" | ||
assert seg_soft.shape[0] == gt.shape[0] | ||
total_loss = 0 | ||
num_class = seg_soft.shape[1] | ||
if class_weight is not None: | ||
assert class_weight.ndim == num_class | ||
for i in range(1, num_class): | ||
if i != ignore_index: | ||
delta_s = (seg_soft[:, i, ...] - gt.float())**2 | ||
s_dtm = seg_dtm[:, i, ...]**2 | ||
g_dtm = gt_dtm[:, i, ...]**2 | ||
dtm = s_dtm + g_dtm | ||
multiplied = torch.einsum('bxy, bxy->bxy', delta_s, dtm) | ||
hd_loss = multiplied.mean() | ||
if class_weight is not None: | ||
hd_loss *= class_weight[i] | ||
total_loss += hd_loss | ||
|
||
return total_loss / num_class | ||
|
||
|
||
@MODELS.register_module() | ||
class HuasdorffDisstanceLoss(nn.Module): | ||
"""HuasdorffDisstanceLoss. This loss is proposed in `How Distance Transform | ||
Maps Boost Segmentation CNNs: An Empirical Study. | ||
<http://proceedings.mlr.press/v121/ma20b.html>`_. | ||
Args: | ||
reduction (str, optional): The method used to reduce the loss into | ||
a scalar. Defaults to 'mean'. | ||
class_weight (list[float] | str, optional): Weight of each class. If in | ||
str format, read them from a file. Defaults to None. | ||
loss_weight (float): Weight of the loss. Defaults to 1.0. | ||
ignore_index (int | None): The label index to be ignored. Default: 255. | ||
loss_name (str): Name of the loss item. If you want this loss | ||
item to be included into the backward graph, `loss_` must be the | ||
prefix of the name. Defaults to 'loss_boundary'. | ||
""" | ||
|
||
def __init__(self, | ||
reduction='mean', | ||
class_weight=None, | ||
loss_weight=1.0, | ||
ignore_index=255, | ||
loss_name='loss_huasdorff_disstance', | ||
**kwargs): | ||
super().__init__() | ||
self.reduction = reduction | ||
self.loss_weight = loss_weight | ||
self.class_weight = get_class_weight(class_weight) | ||
self._loss_name = loss_name | ||
self.ignore_index = ignore_index | ||
|
||
def forward(self, | ||
pred: Tensor, | ||
target: Tensor, | ||
avg_factor=None, | ||
reduction_override=None, | ||
**kwargs) -> Tensor: | ||
"""Forward function. | ||
Args: | ||
pred (Tensor): Predictions of the segmentation head. (B, C, H, W) | ||
target (Tensor): Ground truth of the image. (B, H, W) | ||
avg_factor (int, optional): Average factor that is used to | ||
average the loss. Defaults to None. | ||
reduction_override (str, optional): The reduction method used | ||
to override the original reduction method of the loss. | ||
Options are "none", "mean" and "sum". | ||
Returns: | ||
Tensor: Loss tensor. | ||
""" | ||
assert reduction_override in (None, 'none', 'mean', 'sum') | ||
reduction = ( | ||
reduction_override if reduction_override else self.reduction) | ||
if self.class_weight is not None: | ||
class_weight = pred.new_tensor(self.class_weight) | ||
else: | ||
class_weight = None | ||
|
||
pred_soft = F.softmax(pred, dim=1) | ||
valid_mask = (target != self.ignore_index).long() | ||
target = target * valid_mask | ||
|
||
with torch.no_grad(): | ||
gt_dtm = compute_dtm(target.cpu(), pred_soft) | ||
gt_dtm = gt_dtm.float() | ||
seg_dtm2 = compute_dtm( | ||
pred_soft.argmax(dim=1, keepdim=False).cpu(), pred_soft) | ||
seg_dtm2 = seg_dtm2.float() | ||
|
||
loss_hd = self.loss_weight * hd_loss( | ||
pred_soft, | ||
target, | ||
seg_dtm=seg_dtm2, | ||
gt_dtm=gt_dtm, | ||
reduction=reduction, | ||
avg_factor=avg_factor, | ||
class_weight=class_weight, | ||
ignore_index=self.ignore_index) | ||
return loss_hd | ||
|
||
@property | ||
def loss_name(self): | ||
return self._loss_name |
29 changes: 29 additions & 0 deletions
29
tests/test_models/test_losses/test_huasdorff_distance_loss.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import pytest | ||
import torch | ||
|
||
from mmseg.models.losses import HuasdorffDisstanceLoss | ||
|
||
|
||
def test_huasdorff_distance_loss(): | ||
loss_class = HuasdorffDisstanceLoss | ||
pred = torch.rand((10, 8, 6, 6)) | ||
target = torch.rand((10, 6, 6)) | ||
class_weight = torch.rand(8) | ||
|
||
# Test loss forward | ||
loss = loss_class()(pred, target) | ||
assert isinstance(loss, torch.Tensor) | ||
|
||
# Test loss forward with avg_factor | ||
loss = loss_class()(pred, target, avg_factor=10) | ||
assert isinstance(loss, torch.Tensor) | ||
|
||
# Test loss forward with avg_factor and reduction is None, 'sum' and 'mean' | ||
for reduction in [None, 'sum', 'mean']: | ||
loss = loss_class()(pred, target, avg_factor=10, reduction=reduction) | ||
assert isinstance(loss, torch.Tensor) | ||
|
||
# Test loss forward with class_weight | ||
with pytest.raises(AssertionError): | ||
loss_class(class_weight=class_weight)(pred, target) |