|
a |
|
b/losses.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn.functional as F |
|
|
3 |
from torch import nn as nn |
|
|
4 |
from torch.autograd import Variable |
|
|
5 |
from torch.nn import MSELoss, SmoothL1Loss, L1Loss |
|
|
6 |
|
|
|
7 |
def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None): |
|
|
8 |
""" |
|
|
9 |
Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given a multi channel input and target. |
|
|
10 |
Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function. |
|
|
11 |
|
|
|
12 |
Args: |
|
|
13 |
input (torch.Tensor): NxCxSpatial input tensor |
|
|
14 |
target (torch.Tensor): NxCxSpatial target tensor |
|
|
15 |
epsilon (float): prevents division by zero |
|
|
16 |
weight (torch.Tensor): Cx1 tensor of weight per channel/class |
|
|
17 |
""" |
|
|
18 |
|
|
|
19 |
# input and target shapes must match |
|
|
20 |
assert input.size() == target.size(), "'input' and 'target' must have the same shape" |
|
|
21 |
|
|
|
22 |
input = flatten(input) |
|
|
23 |
target = flatten(target) |
|
|
24 |
target = target.float() |
|
|
25 |
|
|
|
26 |
# compute per channel Dice Coefficient |
|
|
27 |
intersect = (input * target).sum(-1) |
|
|
28 |
if weight is not None: |
|
|
29 |
intersect = weight * intersect |
|
|
30 |
|
|
|
31 |
# here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1) |
|
|
32 |
denominator = (input * input).sum(-1) + (target * target).sum(-1) |
|
|
33 |
return 2 * (intersect / denominator.clamp(min=epsilon)) |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
class _MaskingLossWrapper(nn.Module): |
|
|
37 |
""" |
|
|
38 |
Loss wrapper which prevents the gradient of the loss to be computed where target is equal to `ignore_index`. |
|
|
39 |
""" |
|
|
40 |
|
|
|
41 |
def __init__(self, loss, ignore_index): |
|
|
42 |
super(_MaskingLossWrapper, self).__init__() |
|
|
43 |
assert ignore_index is not None, 'ignore_index cannot be None' |
|
|
44 |
self.loss = loss |
|
|
45 |
self.ignore_index = ignore_index |
|
|
46 |
|
|
|
47 |
def forward(self, input, target): |
|
|
48 |
mask = target.clone().ne_(self.ignore_index) |
|
|
49 |
mask.requires_grad = False |
|
|
50 |
|
|
|
51 |
# mask out input/target so that the gradient is zero where on the mask |
|
|
52 |
input = input * mask |
|
|
53 |
target = target * mask |
|
|
54 |
|
|
|
55 |
# forward masked input and target to the loss |
|
|
56 |
return self.loss(input, target) |
|
|
57 |
|
|
|
58 |
|
|
|
59 |
|
|
|
60 |
class SkipLastTargetChannelWrapper(nn.Module): |
|
|
61 |
""" |
|
|
62 |
Loss wrapper which removes additional target channel |
|
|
63 |
""" |
|
|
64 |
|
|
|
65 |
def __init__(self, loss, squeeze_channel=False): |
|
|
66 |
super(SkipLastTargetChannelWrapper, self).__init__() |
|
|
67 |
self.loss = loss |
|
|
68 |
self.squeeze_channel = squeeze_channel |
|
|
69 |
|
|
|
70 |
def forward(self, input, target): |
|
|
71 |
assert target.size(1) > 1, 'Target tensor has a singleton channel dimension, cannot remove channel' |
|
|
72 |
|
|
|
73 |
# skips last target channel if needed |
|
|
74 |
target = target[:, :-1, ...] |
|
|
75 |
|
|
|
76 |
if self.squeeze_channel: |
|
|
77 |
# squeeze channel dimension if singleton |
|
|
78 |
target = torch.squeeze(target, dim=1) |
|
|
79 |
return self.loss(input, target) |
|
|
80 |
|
|
|
81 |
|
|
|
82 |
class _AbstractDiceLoss(nn.Module): |
|
|
83 |
""" |
|
|
84 |
Base class for different implementations of Dice loss. |
|
|
85 |
""" |
|
|
86 |
|
|
|
87 |
def __init__(self, weight=None, normalization='sigmoid'): |
|
|
88 |
super(_AbstractDiceLoss, self).__init__() |
|
|
89 |
self.register_buffer('weight', weight) |
|
|
90 |
# The output from the network during training is assumed to be un-normalized probabilities and we would |
|
|
91 |
# like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data, |
|
|
92 |
# normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems. |
|
|
93 |
# However if one would like to apply Softmax in order to get the proper probability distribution from the |
|
|
94 |
# output, just specify `normalization=Softmax` |
|
|
95 |
assert normalization in ['sigmoid', 'softmax', 'none'] |
|
|
96 |
if normalization == 'sigmoid': |
|
|
97 |
self.normalization = nn.Sigmoid() |
|
|
98 |
elif normalization == 'softmax': |
|
|
99 |
self.normalization = nn.Softmax(dim=1) |
|
|
100 |
else: |
|
|
101 |
self.normalization = lambda x: x |
|
|
102 |
|
|
|
103 |
def dice(self, input, target, weight): |
|
|
104 |
# actual Dice score computation; to be implemented by the subclass |
|
|
105 |
raise NotImplementedError |
|
|
106 |
|
|
|
107 |
def forward(self, input, target): |
|
|
108 |
# get probabilities from logits |
|
|
109 |
input = self.normalization(input) |
|
|
110 |
|
|
|
111 |
# compute per channel Dice coefficient |
|
|
112 |
per_channel_dice = self.dice(input, target, weight=self.weight) |
|
|
113 |
|
|
|
114 |
# average Dice score across all channels/classes |
|
|
115 |
return 1. - torch.mean(per_channel_dice) |
|
|
116 |
|
|
|
117 |
|
|
|
118 |
class DiceLoss(_AbstractDiceLoss): |
|
|
119 |
"""Computes Dice Loss according to https://arxiv.org/abs/1606.04797. |
|
|
120 |
For multi-class segmentation `weight` parameter can be used to assign different weights per class. |
|
|
121 |
The input to the loss function is assumed to be a logit and will be normalized by the Sigmoid function. |
|
|
122 |
""" |
|
|
123 |
|
|
|
124 |
def __init__(self, weight=None, normalization='sigmoid'): |
|
|
125 |
super().__init__(weight, normalization) |
|
|
126 |
|
|
|
127 |
def dice(self, input, target, weight): |
|
|
128 |
return compute_per_channel_dice(input, target, weight=self.weight) |
|
|
129 |
|
|
|
130 |
|
|
|
131 |
class GeneralizedDiceLoss(_AbstractDiceLoss): |
|
|
132 |
"""Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf. |
|
|
133 |
""" |
|
|
134 |
|
|
|
135 |
def __init__(self, normalization='sigmoid', epsilon=1e-6): |
|
|
136 |
super().__init__(weight=None, normalization=normalization) |
|
|
137 |
self.epsilon = epsilon |
|
|
138 |
|
|
|
139 |
def dice(self, input, target, weight): |
|
|
140 |
assert input.size() == target.size(), "'input' and 'target' must have the same shape" |
|
|
141 |
|
|
|
142 |
input = flatten(input) |
|
|
143 |
target = flatten(target) |
|
|
144 |
target = target.float() |
|
|
145 |
|
|
|
146 |
if input.size(0) == 1: |
|
|
147 |
# for GDL to make sense we need at least 2 channels (see https://arxiv.org/pdf/1707.03237.pdf) |
|
|
148 |
# put foreground and background voxels in separate channels |
|
|
149 |
input = torch.cat((input, 1 - input), dim=0) |
|
|
150 |
target = torch.cat((target, 1 - target), dim=0) |
|
|
151 |
|
|
|
152 |
# GDL weighting: the contribution of each label is corrected by the inverse of its volume |
|
|
153 |
w_l = target.sum(-1) |
|
|
154 |
w_l = 1 / (w_l * w_l).clamp(min=self.epsilon) |
|
|
155 |
w_l.requires_grad = False |
|
|
156 |
|
|
|
157 |
intersect = (input * target).sum(-1) |
|
|
158 |
intersect = intersect * w_l |
|
|
159 |
|
|
|
160 |
denominator = (input + target).sum(-1) |
|
|
161 |
denominator = (denominator * w_l).clamp(min=self.epsilon) |
|
|
162 |
|
|
|
163 |
return 2 * (intersect.sum() / denominator.sum()) |
|
|
164 |
|
|
|
165 |
|
|
|
166 |
class BCEDiceLoss(nn.Module): |
|
|
167 |
"""Linear combination of BCE and Dice losses""" |
|
|
168 |
|
|
|
169 |
def __init__(self, alpha, beta): |
|
|
170 |
super(BCEDiceLoss, self).__init__() |
|
|
171 |
self.alpha = alpha |
|
|
172 |
self.bce = nn.BCEWithLogitsLoss() |
|
|
173 |
self.beta = beta |
|
|
174 |
self.dice = DiceLoss() |
|
|
175 |
|
|
|
176 |
def forward(self, input, target): |
|
|
177 |
return self.alpha * self.bce(input, target) + self.beta * self.dice(input, target) |
|
|
178 |
|
|
|
179 |
|
|
|
180 |
class WeightedCrossEntropyLoss(nn.Module): |
|
|
181 |
"""WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf |
|
|
182 |
""" |
|
|
183 |
|
|
|
184 |
def __init__(self, ignore_index=-1): |
|
|
185 |
super(WeightedCrossEntropyLoss, self).__init__() |
|
|
186 |
self.ignore_index = ignore_index |
|
|
187 |
|
|
|
188 |
def forward(self, input, target): |
|
|
189 |
weight = self._class_weights(input) |
|
|
190 |
return F.cross_entropy(input, target, weight=weight, ignore_index=self.ignore_index) |
|
|
191 |
|
|
|
192 |
@staticmethod |
|
|
193 |
def _class_weights(input): |
|
|
194 |
# normalize the input first |
|
|
195 |
input = F.softmax(input, dim=1) |
|
|
196 |
flattened = flatten(input) |
|
|
197 |
nominator = (1. - flattened).sum(-1) |
|
|
198 |
denominator = flattened.sum(-1) |
|
|
199 |
class_weights = Variable(nominator / denominator, requires_grad=False) |
|
|
200 |
return class_weights |
|
|
201 |
|
|
|
202 |
|
|
|
203 |
class WeightedSmoothL1Loss(nn.SmoothL1Loss): |
|
|
204 |
def __init__(self, threshold, initial_weight, apply_below_threshold=True): |
|
|
205 |
super().__init__(reduction="none") |
|
|
206 |
self.threshold = threshold |
|
|
207 |
self.apply_below_threshold = apply_below_threshold |
|
|
208 |
self.weight = initial_weight |
|
|
209 |
|
|
|
210 |
def forward(self, input, target): |
|
|
211 |
l1 = super().forward(input, target) |
|
|
212 |
|
|
|
213 |
if self.apply_below_threshold: |
|
|
214 |
mask = target < self.threshold |
|
|
215 |
else: |
|
|
216 |
mask = target >= self.threshold |
|
|
217 |
|
|
|
218 |
l1[mask] = l1[mask] * self.weight |
|
|
219 |
|
|
|
220 |
return l1.mean() |
|
|
221 |
|
|
|
222 |
|
|
|
223 |
def flatten(tensor): |
|
|
224 |
"""Flattens a given tensor such that the channel axis is first. |
|
|
225 |
The shapes are transformed as follows: |
|
|
226 |
(N, C, D, H, W) -> (C, N * D * H * W) |
|
|
227 |
""" |
|
|
228 |
# number of channels |
|
|
229 |
C = tensor.size(1) |
|
|
230 |
# new axis order |
|
|
231 |
axis_order = (1, 0) + tuple(range(2, tensor.dim())) |
|
|
232 |
# Transpose: (N, C, D, H, W) -> (C, N, D, H, W) |
|
|
233 |
transposed = tensor.permute(axis_order) |
|
|
234 |
# Flatten: (C, N, D, H, W) -> (C, N * D * H * W) |
|
|
235 |
return transposed.contiguous().view(C, -1) |
|
|
236 |
|
|
|
237 |
|
|
|
238 |
def get_loss_criterion(config): |
|
|
239 |
""" |
|
|
240 |
Returns the loss function based on provided configuration |
|
|
241 |
:param config: (dict) a top level configuration object containing the 'loss' key |
|
|
242 |
:return: an instance of the loss function |
|
|
243 |
""" |
|
|
244 |
assert 'loss' in config, 'Could not find loss function configuration' |
|
|
245 |
loss_config = config['loss'] |
|
|
246 |
name = loss_config.pop('name') |
|
|
247 |
|
|
|
248 |
ignore_index = loss_config.pop('ignore_index', None) |
|
|
249 |
skip_last_target = loss_config.pop('skip_last_target', False) |
|
|
250 |
weight = loss_config.pop('weight', None) |
|
|
251 |
|
|
|
252 |
if weight is not None: |
|
|
253 |
# convert to cuda tensor if necessary |
|
|
254 |
weight = torch.tensor(weight).to(config['device']) |
|
|
255 |
|
|
|
256 |
pos_weight = loss_config.pop('pos_weight', None) |
|
|
257 |
if pos_weight is not None: |
|
|
258 |
# convert to cuda tensor if necessary |
|
|
259 |
pos_weight = torch.tensor(pos_weight).to(config['device']) |
|
|
260 |
|
|
|
261 |
loss = _create_loss(name, loss_config, weight, ignore_index, pos_weight) |
|
|
262 |
|
|
|
263 |
if not (ignore_index is None or name in ['CrossEntropyLoss', 'WeightedCrossEntropyLoss']): |
|
|
264 |
# use MaskingLossWrapper only for non-cross-entropy losses, since CE losses allow specifying 'ignore_index' directly |
|
|
265 |
loss = _MaskingLossWrapper(loss, ignore_index) |
|
|
266 |
|
|
|
267 |
if skip_last_target: |
|
|
268 |
loss = SkipLastTargetChannelWrapper(loss, loss_config.get('squeeze_channel', False)) |
|
|
269 |
|
|
|
270 |
return loss |
|
|
271 |
|
|
|
272 |
|
|
|
273 |
####################################################################################################################### |
|
|
274 |
|
|
|
275 |
def _create_loss(name, loss_config, weight, ignore_index, pos_weight): |
|
|
276 |
if name == 'BCEWithLogitsLoss': |
|
|
277 |
return nn.BCEWithLogitsLoss(pos_weight=pos_weight) |
|
|
278 |
elif name == 'BCEDiceLoss': |
|
|
279 |
alpha = loss_config.get('alphs', 1.) |
|
|
280 |
beta = loss_config.get('beta', 1.) |
|
|
281 |
return BCEDiceLoss(alpha, beta) |
|
|
282 |
elif name == 'CrossEntropyLoss': |
|
|
283 |
if ignore_index is None: |
|
|
284 |
ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss |
|
|
285 |
return nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index) |
|
|
286 |
elif name == 'WeightedCrossEntropyLoss': |
|
|
287 |
if ignore_index is None: |
|
|
288 |
ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss |
|
|
289 |
return WeightedCrossEntropyLoss(ignore_index=ignore_index) |
|
|
290 |
elif name == 'PixelWiseCrossEntropyLoss': |
|
|
291 |
return PixelWiseCrossEntropyLoss(class_weights=weight, ignore_index=ignore_index) |
|
|
292 |
elif name == 'GeneralizedDiceLoss': |
|
|
293 |
normalization = loss_config.get('normalization', 'sigmoid') |
|
|
294 |
return GeneralizedDiceLoss(normalization=normalization) |
|
|
295 |
elif name == 'DiceLoss': |
|
|
296 |
normalization = loss_config.get('normalization', 'sigmoid') |
|
|
297 |
return DiceLoss(weight=weight, normalization=normalization) |
|
|
298 |
elif name == 'MSELoss': |
|
|
299 |
return MSELoss() |
|
|
300 |
elif name == 'SmoothL1Loss': |
|
|
301 |
return SmoothL1Loss() |
|
|
302 |
elif name == 'L1Loss': |
|
|
303 |
return L1Loss() |
|
|
304 |
elif name == 'WeightedSmoothL1Loss': |
|
|
305 |
return WeightedSmoothL1Loss(threshold=loss_config['threshold'], |
|
|
306 |
initial_weight=loss_config['initial_weight'], |
|
|
307 |
apply_below_threshold=loss_config.get('apply_below_threshold', True)) |
|
|
308 |
else: |
|
|
309 |
raise RuntimeError(f"Unsupported loss function: '{name}'") |