Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix generalized dice computation #7970

Merged
merged 18 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 58 additions & 54 deletions monai/metrics/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,109 +20,108 @@


class GeneralizedDiceScore(CumulativeIterationMetric):
"""Compute the Generalized Dice Score metric between tensors, as the complement of the Generalized Dice Loss defined in:
"""
Compute the Generalized Dice Score metric between tensors.

This metric is the complement of the Generalized Dice Loss defined in:
Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning
loss function for highly unbalanced segmentations. DLMIA 2017.
loss function for highly unbalanced segmentations. DLMIA 2017.

The inputs `y_pred` and `y` are expected to be one-hot, binarized channel-first
or batch-first tensors, i.e., CHW[D] or BCHW[D].
The inputs `y_pred` and `y` are expected to be one-hot, binarized batch-first tensors, i.e., NCHW[D].

Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

Args:
include_background (bool, optional): whether to include the background class (assumed to be in channel 0), in the
include_background: Whether to include the background class (assumed to be in channel 0) in the
score computation. Defaults to True.
reduction (str, optional): define mode of reduction to the metrics. Available reduction modes:
{``"none"``, ``"mean_batch"``, ``"sum_batch"``}. Default to ``"mean_batch"``. If "none", will not do reduction.
weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
reduction: Define mode of reduction to the metrics. Available reduction modes:
{``"none"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean"`, ``"sum"`}. Defaults to ``"mean"``.
If "none", will not do reduction.
weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
ground truth volume into a weight factor. Defaults to ``"square"``.

Raises:
ValueError: when the `weight_type` is not one of {``"none"``, ``"mean"``, ``"sum"``}.
ValueError: When the `reduction` is not one of MetricReduction enum.
"""

def __init__(
self,
include_background: bool = True,
reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,
weight_type: Weight | str = Weight.SQUARE,
self, include_background: bool = True, reduction: str = "mean", weight_type: Weight | str = Weight.SQUARE
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
super().__init__()
self.include_background = include_background
reduction_options = [
"none",
"mean_batch",
"sum_batch",
MetricReduction.NONE,
MetricReduction.MEAN_BATCH,
MetricReduction.SUM_BATCH,
]
self.reduction = reduction
if self.reduction not in reduction_options:
raise ValueError(f"reduction must be one of {reduction_options}")
self.reduction = look_up_option(reduction, MetricReduction)
self.weight_type = look_up_option(weight_type, Weight)
self.sum_over_labels = self.reduction == MetricReduction.SUM or self.reduction == MetricReduction.MEAN

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
"""Computes the Generalized Dice Score and returns a tensor with its per image values.
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Computes the Generalized Dice Score and returns a tensor with its per image values.

Args:
y_pred (torch.Tensor): binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
y_pred: Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions.
y (torch.Tensor): binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
y: Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.

Returns:
torch.Tensor: Per batch and per class Generalized Dice Score.

Raises:
ValueError: if `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
"""
return compute_generalized_dice(
y_pred=y_pred, y=y, include_background=self.include_background, weight_type=self.weight_type
y_pred=y_pred,
y=y,
include_background=self.include_background,
weight_type=self.weight_type,
sum_over_labels=self.sum_over_labels,
)

def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor:
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
def aggregate(self) -> torch.Tensor:
"""
Execute reduction logic for the output of `compute_generalized_dice`.

Args:
reduction (Union[MetricReduction, str, None], optional): define mode of reduction to the metrics.
Available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``}.
Defaults to ``"mean"``. If "none", will not do reduction.
Returns:
torch.Tensor: Aggregated metric value.

Raises:
ValueError: If the data to aggregate is not a PyTorch Tensor.
"""
data = self.get_buffer()
if not isinstance(data, torch.Tensor):
raise ValueError("The data to aggregate must be a PyTorch Tensor.")

# Validate reduction argument if specified
if reduction is not None:
reduction_options = ["none", "mean", "sum", "mean_batch", "sum_batch"]
if reduction not in reduction_options:
raise ValueError(f"reduction must be one of {reduction_options}")

# Do metric reduction and return
f, _ = do_metric_reduction(data, reduction or self.reduction)
f, _ = do_metric_reduction(data, self.reduction)

return f


def compute_generalized_dice(
y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, weight_type: Weight | str = Weight.SQUARE
y_pred: torch.Tensor,
y: torch.Tensor,
include_background: bool = True,
weight_type: Weight | str = Weight.SQUARE,
sum_over_labels: bool = False,
surajpaib marked this conversation as resolved.
Show resolved Hide resolved
) -> torch.Tensor:
"""Computes the Generalized Dice Score and returns a tensor with its per image values.
"""
Computes the Generalized Dice Score and returns a tensor with its per image values.

Args:
y_pred (torch.Tensor): binarized segmentation model output. It should be binarized, in one-hot format
y_pred: Binarized segmentation model output. It should be binarized, in one-hot format
and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the
remaining are the spatial dimensions.
y (torch.Tensor): binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
include_background (bool, optional): whether to include score computation on the first channel of the
y: Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
include_background: Whether to include score computation on the first channel of the
predicted output. Defaults to True.
weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to
weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to
transform ground truth volume into a weight factor. Defaults to ``"square"``.
sum_over_labels: Whether to sum the numerator and denominator across all labels before the final computation.
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved

Returns:
torch.Tensor: per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].
torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].

Raises:
ValueError: if `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
ValueError: If `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
or `y_pred` and `y` don't have the same shape.
"""
# Ensure tensors have at least 3 dimensions and have the same shape
Expand Down Expand Up @@ -158,16 +157,21 @@ def compute_generalized_dice(
b[infs] = 0
b[infs] = torch.max(b)

# Compute the weighted numerator and denominator, summing along the class axis
numer = 2.0 * (intersection * w).sum(dim=1)
denom = (denominator * w).sum(dim=1)
# Compute the weighted numerator and denominator, summing along the class axis when sum_over_labels is True
surajpaib marked this conversation as resolved.
Show resolved Hide resolved
if sum_over_labels:
numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True)
denom = (denominator * w).sum(dim=1, keepdim=True)
y_pred_o = y_pred_o.sum(dim=-1, keepdim=True)
else:
numer = 2.0 * (intersection * w)
denom = denominator * w
y_pred_o = y_pred_o

# Compute the score
generalized_dice_score = numer / denom

# Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1.
# Where denom == 0 but the prediction volume is not 0, score is 0
y_pred_o = y_pred_o.sum(dim=-1)
denom_zeros = denom == 0
generalized_dice_score[denom_zeros] = torch.where(
(y_pred_o == 0)[denom_zeros],
Expand Down
40 changes: 26 additions & 14 deletions tests/test_compute_generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
_device = "cuda:0" if torch.cuda.is_available() else "cpu"

# keep background
TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1)
surajpaib marked this conversation as resolved.
Show resolved Hide resolved
TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1, 1)
{
"y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=_device),
"y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=_device),
Expand All @@ -32,7 +32,7 @@
]

# remove background
TEST_CASE_2 = [ # y (2, 1, 2, 2), y_pred (2, 3, 2, 2), expected out (2) (no background)
TEST_CASE_2 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2, 2) (no background)
{
"y_pred": torch.tensor(
[
Expand All @@ -48,11 +48,11 @@
),
"include_background": False,
},
[0.1667, 0.6667],
[0.416667],
]

# should return 0 for both cases
TEST_CASE_3 = [
TEST_CASE_3 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2, 3)
{
"y_pred": torch.tensor(
[
Expand All @@ -68,7 +68,7 @@
),
"include_background": True,
},
[0.0, 0.0],
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
]

TEST_CASE_4 = [
Expand All @@ -87,11 +87,11 @@
]
),
},
[0.5455],
[0.678571, 0.2, 0.333333],
]

TEST_CASE_5 = [
{"include_background": True, "reduction": "sum_batch"},
{"include_background": True, "reduction": "sum"},
{
"y_pred": torch.tensor(
[
Expand All @@ -106,16 +106,28 @@
]
),
},
1.0455,
[1.045455],
]

TEST_CASE_6 = [{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [1.0000, 1.0000]]
TEST_CASE_6 = [
{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))},
[[1.0000, 1.0000], [1.0000, 1.0000]],
]

TEST_CASE_7 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [0.0000, 0.0000]]
TEST_CASE_7 = [
{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))},
[[0.0000, 0.0000], [0.0000, 0.0000]],
]

TEST_CASE_8 = [{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [0.0000, 0.0000]]
TEST_CASE_8 = [
{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))},
[[0.0000, 0.0000], [0.0000, 0.0000]],
]

TEST_CASE_9 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [1.0000, 1.0000]]
TEST_CASE_9 = [
{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))},
[[1.0000, 1.0000], [1.0000, 1.0000]],
]


class TestComputeGeneralizedDiceScore(unittest.TestCase):
Expand All @@ -126,7 +138,7 @@ def test_device(self, input_data, _expected_value):
np.testing.assert_equal(result.device, input_data["y_pred"].device)

# Functional part tests
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9])
@parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9])
def test_value(self, input_data, expected_value):
result = compute_generalized_dice(**input_data)
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
Expand All @@ -146,7 +158,7 @@ def test_value_class(self, input_data, expected_value):
vals["y"] = input_data.pop("y")
generalized_dice_score = GeneralizedDiceScore(**input_data)
generalized_dice_score(**vals)
result = generalized_dice_score.aggregate(reduction="none")
result = generalized_dice_score.aggregate()
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

# Aggregation tests
Expand Down
Loading