import torch
from src.losses import utils
from torch import nn
def flatten(tensor):
"""Flattens a given tensor such that the channel axis is first.
The shapes are transformed as follows:
(N, C, D, H, W) -> (C, N * D * H * W)
"""
# number of channels
C = tensor.size(1)
# new axis order
axis_order = (1, 0) + tuple(range(2, tensor.dim()))
# Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
transposed = tensor.permute(axis_order)
# Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
return transposed.contiguous().view(C, -1)
def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None):
"""
Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given a multi channel input and target.
Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function.
Args:
input (torch.Tensor): NxCxSpatial input tensor
target (torch.Tensor): NxCxSpatial target tensor
epsilon (float): prevents division by zero
weight (torch.Tensor): Cx1 tensor of weight per channel/class
"""
# input and target shapes must match
assert input.size() == target.size(), "'input' and 'target' must have the same shape"
input = flatten(input)
target = flatten(target)
target = target.float()
# compute per channel Dice Coefficient
intersect = (input * target).sum(-1)
if weight is not None:
intersect = weight * intersect
# here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1)
denominator = (input * input).sum(-1) + (target * target).sum(-1)
return 2 * (intersect / denominator.clamp(min=epsilon))
class _AbstractDiceLoss(nn.Module):
"""
Base class for different implementations of Dice loss.
"""
def __init__(self, weight=None, sigmoid_normalization=True):
super(_AbstractDiceLoss, self).__init__()
self.register_buffer('weight', weight)
# The output from the network during training is assumed to be un-normalized probabilities and we would
# like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data,
# normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems.
# However if one would like to apply Softmax in order to get the proper probability distribution from the
# output, just specify sigmoid_normalization=False.
if sigmoid_normalization:
self.normalization = nn.Sigmoid()
else:
self.normalization = nn.Softmax(dim=1)
def dice(self, input, target, weight):
# actual Dice score computation; to be implemented by the subclass
raise NotImplementedError
def forward(self, input, target):
# get probabilities from logits
input = self.normalization(input.float())
# compute per channel Dice coefficient
per_channel_dice = self.dice(input, target, weight=self.weight)
score = torch.mean(per_channel_dice)
# average Dice score across all channels/classes
return 1. - score, score
class DiceLoss(_AbstractDiceLoss):
"""Computes Dice Loss according to https://arxiv.org/abs/1606.04797.
For multi-class segmentation `weight` parameter can be used to assign different weights per class.
The input to the loss function is assumed to be a logit and will be normalized by the Sigmoid function.
"""
def __init__(self, weight=None, sigmoid_normalization=True):
super().__init__(weight, sigmoid_normalization)
def dice(self, input, target, weight):
return compute_per_channel_dice(input, target, weight=self.weight)
class GeneralizedDiceLoss(_AbstractDiceLoss):
"""Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf."""
def __init__(self, sigmoid_normalization=True, epsilon=1e-6):
super().__init__(weight=None, sigmoid_normalization=sigmoid_normalization)
self.epsilon = epsilon
def dice(self, input, target, weight):
target = utils.expand_as_one_hot(target.long(), num_classes=4)
assert input.size() == target.size(), "'input' and 'target' must have the same shape"
input = flatten(input)
target = flatten(target)
target = target.float()
if input.size(0) == 1:
# for GDL to make sense we need at least 2 channels (see https://arxiv.org/pdf/1707.03237.pdf)
# put foreground and background voxels in separate channels
input = torch.cat((input, 1 - input), dim=0)
target = torch.cat((target, 1 - target), dim=0)
# GDL weighting: the contribution of each label is corrected by the inverse of its volume
w_l = target.sum(-1)
w_l = 1 / (w_l * w_l).clamp(min=self.epsilon)
w_l.requires_grad = False
intersect = (input * target).sum(-1)
intersect = intersect * w_l
denominator = (input + target).sum(-1)
denominator = (denominator * w_l).clamp(min=self.epsilon)
return 2 * (intersect.sum() / denominator.sum())