[96354c]: / src / losses / new_losses.py

Download this file

132 lines (99 with data), 5.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
from src.losses import utils
from torch import nn
def flatten(tensor):
"""Flattens a given tensor such that the channel axis is first.
The shapes are transformed as follows:
(N, C, D, H, W) -> (C, N * D * H * W)
"""
# number of channels
C = tensor.size(1)
# new axis order
axis_order = (1, 0) + tuple(range(2, tensor.dim()))
# Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
transposed = tensor.permute(axis_order)
# Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
return transposed.contiguous().view(C, -1)
def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None):
"""
Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given a multi channel input and target.
Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function.
Args:
input (torch.Tensor): NxCxSpatial input tensor
target (torch.Tensor): NxCxSpatial target tensor
epsilon (float): prevents division by zero
weight (torch.Tensor): Cx1 tensor of weight per channel/class
"""
# input and target shapes must match
assert input.size() == target.size(), "'input' and 'target' must have the same shape"
input = flatten(input)
target = flatten(target)
target = target.float()
# compute per channel Dice Coefficient
intersect = (input * target).sum(-1)
if weight is not None:
intersect = weight * intersect
# here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1)
denominator = (input * input).sum(-1) + (target * target).sum(-1)
return 2 * (intersect / denominator.clamp(min=epsilon))
class _AbstractDiceLoss(nn.Module):
"""
Base class for different implementations of Dice loss.
"""
def __init__(self, weight=None, sigmoid_normalization=True):
super(_AbstractDiceLoss, self).__init__()
self.register_buffer('weight', weight)
# The output from the network during training is assumed to be un-normalized probabilities and we would
# like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data,
# normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems.
# However if one would like to apply Softmax in order to get the proper probability distribution from the
# output, just specify sigmoid_normalization=False.
if sigmoid_normalization:
self.normalization = nn.Sigmoid()
else:
self.normalization = nn.Softmax(dim=1)
def dice(self, input, target, weight):
# actual Dice score computation; to be implemented by the subclass
raise NotImplementedError
def forward(self, input, target):
# get probabilities from logits
input = self.normalization(input.float())
# compute per channel Dice coefficient
per_channel_dice = self.dice(input, target, weight=self.weight)
score = torch.mean(per_channel_dice)
# average Dice score across all channels/classes
return 1. - score, score
class DiceLoss(_AbstractDiceLoss):
"""Computes Dice Loss according to https://arxiv.org/abs/1606.04797.
For multi-class segmentation `weight` parameter can be used to assign different weights per class.
The input to the loss function is assumed to be a logit and will be normalized by the Sigmoid function.
"""
def __init__(self, weight=None, sigmoid_normalization=True):
super().__init__(weight, sigmoid_normalization)
def dice(self, input, target, weight):
return compute_per_channel_dice(input, target, weight=self.weight)
class GeneralizedDiceLoss(_AbstractDiceLoss):
"""Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf."""
def __init__(self, sigmoid_normalization=True, epsilon=1e-6):
super().__init__(weight=None, sigmoid_normalization=sigmoid_normalization)
self.epsilon = epsilon
def dice(self, input, target, weight):
target = utils.expand_as_one_hot(target.long(), num_classes=4)
assert input.size() == target.size(), "'input' and 'target' must have the same shape"
input = flatten(input)
target = flatten(target)
target = target.float()
if input.size(0) == 1:
# for GDL to make sense we need at least 2 channels (see https://arxiv.org/pdf/1707.03237.pdf)
# put foreground and background voxels in separate channels
input = torch.cat((input, 1 - input), dim=0)
target = torch.cat((target, 1 - target), dim=0)
# GDL weighting: the contribution of each label is corrected by the inverse of its volume
w_l = target.sum(-1)
w_l = 1 / (w_l * w_l).clamp(min=self.epsilon)
w_l.requires_grad = False
intersect = (input * target).sum(-1)
intersect = intersect * w_l
denominator = (input + target).sum(-1)
denominator = (denominator * w_l).clamp(min=self.epsilon)
return 2 * (intersect.sum() / denominator.sum())