Diff of /losses.py [000000] .. [dff9e0]

Switch to unified view

a b/losses.py
1
import torch
2
import torch.nn.functional as F
3
from torch import nn as nn
4
from torch.autograd import Variable
5
from torch.nn import MSELoss, SmoothL1Loss, L1Loss
6
7
def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None):
8
    """
9
    Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given  a multi channel input and target.
10
    Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function.
11
12
    Args:
13
         input (torch.Tensor): NxCxSpatial input tensor
14
         target (torch.Tensor): NxCxSpatial target tensor
15
         epsilon (float): prevents division by zero
16
         weight (torch.Tensor): Cx1 tensor of weight per channel/class
17
    """
18
19
    # input and target shapes must match
20
    assert input.size() == target.size(), "'input' and 'target' must have the same shape"
21
22
    input = flatten(input)
23
    target = flatten(target)
24
    target = target.float()
25
26
    # compute per channel Dice Coefficient
27
    intersect = (input * target).sum(-1)
28
    if weight is not None:
29
        intersect = weight * intersect
30
31
    # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1)
32
    denominator = (input * input).sum(-1) + (target * target).sum(-1)
33
    return 2 * (intersect / denominator.clamp(min=epsilon))
34
35
36
class _MaskingLossWrapper(nn.Module):
37
    """
38
    Loss wrapper which prevents the gradient of the loss to be computed where target is equal to `ignore_index`.
39
    """
40
41
    def __init__(self, loss, ignore_index):
42
        super(_MaskingLossWrapper, self).__init__()
43
        assert ignore_index is not None, 'ignore_index cannot be None'
44
        self.loss = loss
45
        self.ignore_index = ignore_index
46
47
    def forward(self, input, target):
48
        mask = target.clone().ne_(self.ignore_index)
49
        mask.requires_grad = False
50
51
        # mask out input/target so that the gradient is zero where on the mask
52
        input = input * mask
53
        target = target * mask
54
55
        # forward masked input and target to the loss
56
        return self.loss(input, target)
57
58
    
59
60
class SkipLastTargetChannelWrapper(nn.Module):
61
    """
62
    Loss wrapper which removes additional target channel
63
    """
64
65
    def __init__(self, loss, squeeze_channel=False):
66
        super(SkipLastTargetChannelWrapper, self).__init__()
67
        self.loss = loss
68
        self.squeeze_channel = squeeze_channel
69
70
    def forward(self, input, target):
71
        assert target.size(1) > 1, 'Target tensor has a singleton channel dimension, cannot remove channel'
72
73
        # skips last target channel if needed
74
        target = target[:, :-1, ...]
75
76
        if self.squeeze_channel:
77
            # squeeze channel dimension if singleton
78
            target = torch.squeeze(target, dim=1)
79
        return self.loss(input, target)
80
81
82
class _AbstractDiceLoss(nn.Module):
83
    """
84
    Base class for different implementations of Dice loss.
85
    """
86
87
    def __init__(self, weight=None, normalization='sigmoid'):
88
        super(_AbstractDiceLoss, self).__init__()
89
        self.register_buffer('weight', weight)
90
        # The output from the network during training is assumed to be un-normalized probabilities and we would
91
        # like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data,
92
        # normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems.
93
        # However if one would like to apply Softmax in order to get the proper probability distribution from the
94
        # output, just specify `normalization=Softmax`
95
        assert normalization in ['sigmoid', 'softmax', 'none']
96
        if normalization == 'sigmoid':
97
            self.normalization = nn.Sigmoid()
98
        elif normalization == 'softmax':
99
            self.normalization = nn.Softmax(dim=1)
100
        else:
101
            self.normalization = lambda x: x
102
103
    def dice(self, input, target, weight):
104
        # actual Dice score computation; to be implemented by the subclass
105
        raise NotImplementedError
106
107
    def forward(self, input, target):
108
        # get probabilities from logits
109
        input = self.normalization(input)
110
111
        # compute per channel Dice coefficient
112
        per_channel_dice = self.dice(input, target, weight=self.weight)
113
114
        # average Dice score across all channels/classes
115
        return 1. - torch.mean(per_channel_dice)
116
117
118
class DiceLoss(_AbstractDiceLoss):
119
    """Computes Dice Loss according to https://arxiv.org/abs/1606.04797.
120
    For multi-class segmentation `weight` parameter can be used to assign different weights per class.
121
    The input to the loss function is assumed to be a logit and will be normalized by the Sigmoid function.
122
    """
123
124
    def __init__(self, weight=None, normalization='sigmoid'):
125
        super().__init__(weight, normalization)
126
127
    def dice(self, input, target, weight):
128
        return compute_per_channel_dice(input, target, weight=self.weight)
129
130
131
class GeneralizedDiceLoss(_AbstractDiceLoss):
132
    """Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf.
133
    """
134
135
    def __init__(self, normalization='sigmoid', epsilon=1e-6):
136
        super().__init__(weight=None, normalization=normalization)
137
        self.epsilon = epsilon
138
139
    def dice(self, input, target, weight):
140
        assert input.size() == target.size(), "'input' and 'target' must have the same shape"
141
142
        input = flatten(input)
143
        target = flatten(target)
144
        target = target.float()
145
146
        if input.size(0) == 1:
147
            # for GDL to make sense we need at least 2 channels (see https://arxiv.org/pdf/1707.03237.pdf)
148
            # put foreground and background voxels in separate channels
149
            input = torch.cat((input, 1 - input), dim=0)
150
            target = torch.cat((target, 1 - target), dim=0)
151
152
        # GDL weighting: the contribution of each label is corrected by the inverse of its volume
153
        w_l = target.sum(-1)
154
        w_l = 1 / (w_l * w_l).clamp(min=self.epsilon)
155
        w_l.requires_grad = False
156
157
        intersect = (input * target).sum(-1)
158
        intersect = intersect * w_l
159
160
        denominator = (input + target).sum(-1)
161
        denominator = (denominator * w_l).clamp(min=self.epsilon)
162
163
        return 2 * (intersect.sum() / denominator.sum())
164
165
166
class BCEDiceLoss(nn.Module):
167
    """Linear combination of BCE and Dice losses"""
168
169
    def __init__(self, alpha, beta):
170
        super(BCEDiceLoss, self).__init__()
171
        self.alpha = alpha
172
        self.bce = nn.BCEWithLogitsLoss()
173
        self.beta = beta
174
        self.dice = DiceLoss()
175
176
    def forward(self, input, target):
177
        return self.alpha * self.bce(input, target) + self.beta * self.dice(input, target)
178
179
180
class WeightedCrossEntropyLoss(nn.Module):
181
    """WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf
182
    """
183
184
    def __init__(self, ignore_index=-1):
185
        super(WeightedCrossEntropyLoss, self).__init__()
186
        self.ignore_index = ignore_index
187
188
    def forward(self, input, target):
189
        weight = self._class_weights(input)
190
        return F.cross_entropy(input, target, weight=weight, ignore_index=self.ignore_index)
191
192
    @staticmethod
193
    def _class_weights(input):
194
        # normalize the input first
195
        input = F.softmax(input, dim=1)
196
        flattened = flatten(input)
197
        nominator = (1. - flattened).sum(-1)
198
        denominator = flattened.sum(-1)
199
        class_weights = Variable(nominator / denominator, requires_grad=False)
200
        return class_weights
201
202
203
class WeightedSmoothL1Loss(nn.SmoothL1Loss):
204
    def __init__(self, threshold, initial_weight, apply_below_threshold=True):
205
        super().__init__(reduction="none")
206
        self.threshold = threshold
207
        self.apply_below_threshold = apply_below_threshold
208
        self.weight = initial_weight
209
210
    def forward(self, input, target):
211
        l1 = super().forward(input, target)
212
213
        if self.apply_below_threshold:
214
            mask = target < self.threshold
215
        else:
216
            mask = target >= self.threshold
217
218
        l1[mask] = l1[mask] * self.weight
219
220
        return l1.mean()
221
222
223
def flatten(tensor):
224
    """Flattens a given tensor such that the channel axis is first.
225
    The shapes are transformed as follows:
226
       (N, C, D, H, W) -> (C, N * D * H * W)
227
    """
228
    # number of channels
229
    C = tensor.size(1)
230
    # new axis order
231
    axis_order = (1, 0) + tuple(range(2, tensor.dim()))
232
    # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
233
    transposed = tensor.permute(axis_order)
234
    # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
235
    return transposed.contiguous().view(C, -1)
236
237
238
def get_loss_criterion(config):
239
    """
240
    Returns the loss function based on provided configuration
241
    :param config: (dict) a top level configuration object containing the 'loss' key
242
    :return: an instance of the loss function
243
    """
244
    assert 'loss' in config, 'Could not find loss function configuration'
245
    loss_config = config['loss']
246
    name = loss_config.pop('name')
247
248
    ignore_index = loss_config.pop('ignore_index', None)
249
    skip_last_target = loss_config.pop('skip_last_target', False)
250
    weight = loss_config.pop('weight', None)
251
252
    if weight is not None:
253
        # convert to cuda tensor if necessary
254
        weight = torch.tensor(weight).to(config['device'])
255
256
    pos_weight = loss_config.pop('pos_weight', None)
257
    if pos_weight is not None:
258
        # convert to cuda tensor if necessary
259
        pos_weight = torch.tensor(pos_weight).to(config['device'])
260
261
    loss = _create_loss(name, loss_config, weight, ignore_index, pos_weight)
262
263
    if not (ignore_index is None or name in ['CrossEntropyLoss', 'WeightedCrossEntropyLoss']):
264
        # use MaskingLossWrapper only for non-cross-entropy losses, since CE losses allow specifying 'ignore_index' directly
265
        loss = _MaskingLossWrapper(loss, ignore_index)
266
267
    if skip_last_target:
268
        loss = SkipLastTargetChannelWrapper(loss, loss_config.get('squeeze_channel', False))
269
270
    return loss
271
272
273
#######################################################################################################################
274
275
def _create_loss(name, loss_config, weight, ignore_index, pos_weight):
276
    if name == 'BCEWithLogitsLoss':
277
        return nn.BCEWithLogitsLoss(pos_weight=pos_weight)
278
    elif name == 'BCEDiceLoss':
279
        alpha = loss_config.get('alphs', 1.)
280
        beta = loss_config.get('beta', 1.)
281
        return BCEDiceLoss(alpha, beta)
282
    elif name == 'CrossEntropyLoss':
283
        if ignore_index is None:
284
            ignore_index = -100  # use the default 'ignore_index' as defined in the CrossEntropyLoss
285
        return nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index)
286
    elif name == 'WeightedCrossEntropyLoss':
287
        if ignore_index is None:
288
            ignore_index = -100  # use the default 'ignore_index' as defined in the CrossEntropyLoss
289
        return WeightedCrossEntropyLoss(ignore_index=ignore_index)
290
    elif name == 'PixelWiseCrossEntropyLoss':
291
        return PixelWiseCrossEntropyLoss(class_weights=weight, ignore_index=ignore_index)
292
    elif name == 'GeneralizedDiceLoss':
293
        normalization = loss_config.get('normalization', 'sigmoid')
294
        return GeneralizedDiceLoss(normalization=normalization)
295
    elif name == 'DiceLoss':
296
        normalization = loss_config.get('normalization', 'sigmoid')
297
        return DiceLoss(weight=weight, normalization=normalization)
298
    elif name == 'MSELoss':
299
        return MSELoss()
300
    elif name == 'SmoothL1Loss':
301
        return SmoothL1Loss()
302
    elif name == 'L1Loss':
303
        return L1Loss()
304
    elif name == 'WeightedSmoothL1Loss':
305
        return WeightedSmoothL1Loss(threshold=loss_config['threshold'],
306
                                    initial_weight=loss_config['initial_weight'],
307
                                    apply_below_threshold=loss_config.get('apply_below_threshold', True))
308
    else:
309
        raise RuntimeError(f"Unsupported loss function: '{name}'")