[637b40]: / adpkd_segmentation / utils / losses.py

Download this file

381 lines (303 with data), 11.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
"""Loss utilities and definitions"""
# %%
import torch
import torch.nn as nn
import torch.nn.functional as F
from adpkd_segmentation.data.data_utils import (
KIDNEY_PIXELS,
STUDY_TKV,
VOXEL_VOLUME,
)
# %%
def binarize_thresholds(pred, thresholds=[0.5]):
"""
Args:
pred: model pred tensor with shape b x c x (X x Y)
thresholds: list of floats i.e. [0.6,0.5,0.4]
Returns:
float tensor: binary values
"""
C = len(thresholds)
thresholds = torch.tensor(thresholds)
thresholds = thresholds.reshape(1, C, 1, 1)
thresholds.expand_as(pred)
thresholds = thresholds.to(pred.device)
res = pred > thresholds
return res.float()
# %%
def binarize_argmax(pred):
"""
Args:
pred: model pred tensor with shape b x c x (X x Y)
Returns:
float tensor: binary values
"""
max_c = torch.argmax(pred, 1) # argmax across C axis
num_classes = pred.shape[1]
encoded = torch.nn.functional.one_hot(max_c, num_classes)
encoded = encoded.permute([0, 3, 1, 2])
return encoded.float()
class SigmoidBinarize:
def __init__(self, thresholds):
self.thresholds = thresholds
def __call__(self, pred):
# Expects (N, C, H, W) format
return binarize_thresholds(torch.sigmoid(pred), self.thresholds)
class SigmoidForwardBinarize:
def __init__(self, thresholds):
self.thresholds = thresholds
def __call__(self, pred):
# Expects (N, C, H, W) format
soft = torch.sigmoid(pred)
hard = binarize_thresholds(soft, self.thresholds)
return hard.detach() + soft - soft.detach()
class SoftmaxBinarize:
def __call__(self, pred):
# Expects (N, C, H, W) format
return binarize_argmax(pred)
class SoftmaxForwardBinarize:
def __call__(self, pred):
# Expects (N, C, H, W) format
soft = F.softmax(pred, dim=1)
hard = binarize_argmax(soft)
return hard.detach() + soft - soft.detach()
class StandardizeModels:
def __init__(self, ignore_channel=2):
# used for backgoround in 3 channel setups
self.ignore_channel = 2
def __call__(self, binary_mask):
# N, C, H, W mask
num_channels = binary_mask.shape[1]
if num_channels == 1:
return binary_mask
elif num_channels == 2:
return torch.sum(binary_mask, dim=1, keepdim=True)
elif num_channels == 3:
sum_all = torch.sum(binary_mask, dim=1)
sum_all = sum_all - binary_mask[:, self.ignore_channel, ...]
sum_all = sum_all.unsqueeze(1)
return sum_all
else:
raise ValueError(
"Unsupported number of channels: {}".format(num_channels)
)
class Dice(nn.Module):
"""Dice metric/loss.
Supports different Dice variants.
"""
def __init__(
self,
pred_process=None,
epsilon=1e-8,
power=2,
dim=(2, 3),
standardize_func=None,
use_as_loss=True,
):
super().__init__()
self.pred_process = pred_process
self.epsilon = epsilon
self.power = power
self.dim = dim
self.standardize_func = standardize_func
self.use_as_loss = use_as_loss
def __call__(self, pred, target):
if self.pred_process is not None:
pred = self.pred_process(pred)
if self.standardize_func is not None:
pred = self.standardize_func(pred)
target = self.standardize_func(target)
intersection = torch.sum(pred * target, dim=self.dim)
set_add = torch.sum(
pred ** self.power + target ** self.power, dim=self.dim
)
score = (2 * intersection + self.epsilon) / (set_add + self.epsilon)
score = score.mean()
if self.use_as_loss:
return 1 - score
return score
class PredictionEntropy(nn.Module):
"""
Calculates average entropy of the predicted soft mask.
Doesn't depend on ground truth mask.
"""
def __init__(self, pred_process, epsilon=1e-8, standardize_func=None):
super().__init__()
self.pred_process = pred_process
self.epsilon = epsilon
self.standardize_func = standardize_func
def __call__(self, pred, target):
pred = self.pred_process(pred)
if self.standardize_func is not None:
pred = self.standardize_func(pred)
target = self.standardize_func(target)
entropy = -pred * torch.log(pred + self.epsilon)
return entropy.mean()
class KidneyPixelMAPE(nn.Module):
"""
Calculates the absolute percentage error for predicted kidney pixel counts
(label kidney pixel count - predicted k.p. count) / (label k.p. count)
By default, kidney pixel summation is done for each image separately, and
averaged over the entire batch.
Depending on the `pred_process` function,
predicted kidney pixel count can be soft or hard.
"""
def __init__(
self, pred_process, epsilon=1.0, dim=(2, 3), standardize_func=None
):
super().__init__()
self.pred_process = pred_process
self.epsilon = epsilon
self.dim = dim
self.standardize_func = standardize_func
def __call__(self, pred, target):
pred = self.pred_process(pred)
if self.standardize_func is not None:
pred = self.standardize_func(pred)
target = self.standardize_func(target)
target_count = target.sum(dim=self.dim).detach()
pred_count = pred.sum(dim=self.dim)
kp_batch_MAPE = torch.abs(
(target_count - pred_count) / (target_count + self.epsilon)
).mean()
return kp_batch_MAPE
class KidneyPixelMSLE(nn.Module):
"""
Mean square error for the log of kidney pixel counts.
MSE of ln(label kidney pixel count) - ln(predicted k.p. count)
By default, pixels are counted separetely for each image, with final
averaging across all images
Depending on the `pred_process` function,
predicted kidney pixel count can be soft or hard.
"""
def __init__(
self, pred_process, epsilon=1.0, dim=(2, 3), standardize_func=None
):
super().__init__()
self.pred_process = pred_process
self.epsilon = epsilon
self.dim = dim
self.standardize_func = standardize_func
def __call__(self, pred, target):
pred = self.pred_process(pred)
if self.standardize_func is not None:
pred = self.standardize_func(pred)
target = self.standardize_func(target)
target_count = target.sum(dim=self.dim).detach()
pred_count = pred.sum(dim=self.dim)
sle = (
torch.log(target_count + self.epsilon)
- torch.log(pred_count + self.epsilon)
) ** 2
msle = torch.mean(sle)
return msle
class WeightedLosses(nn.Module):
def __init__(self, criterions, weights, requires_extra_dict=None):
super().__init__()
self.criterions = criterions
self.weights = weights
self.requires_extra_dict = requires_extra_dict
if requires_extra_dict is None:
self.requires_extra_dict = [False for c in self.criterions]
def __call__(self, pred, target, extra_dict=None):
losses = []
for c, w, e in zip(
self.criterions, self.weights, self.requires_extra_dict
):
loss = c(pred, target, extra_dict) if e else c(pred, target)
losses.append(loss * w)
return torch.sum(torch.stack(losses))
class DynamicBalanceLosses(nn.Module):
def __init__(
self, criterions, epsilon=1e-6, weights=None, requires_extra_dict=None
):
self.criterions = criterions
self.epsilon = epsilon
self.requires_extra_dict = requires_extra_dict
self.weights = weights
if weights is None:
self.weights = [1.0] * len(self.criterions)
self.weights = torch.tensor(self.weights)
if requires_extra_dict is None:
self.requires_extra_dict = [False for c in self.criterions]
def __call__(self, pred, target, extra_dict=None):
# first, scale losses such that
# L_1 * s_1 = L_2 * s_2 = ... L_n * s_n =
# L_1 * L_2 * ... * L_n
# e.g. s_2 = L_1 * L_3 * ... * L_n
# calculate scaling factors dynamically
partial_losses = []
for c, e in zip(self.criterions, self.requires_extra_dict):
loss = c(pred, target, extra_dict) if e else c(pred, target)
partial_losses.append(loss)
partial_losses = torch.stack(partial_losses) + self.epsilon
# no backprop through dynamic scaling factors
detached = partial_losses.detach()
prod = torch.prod(detached)
# divide the total product by the vector of loss values
# to get scaling factors such as e.g. s_2 = L_1 * L_3 * ... * L_n
scales = prod / detached
# final weighting by external weights
self.weights = self.weights.to(scales.device)
scales = scales * self.weights
normalization = torch.sum(scales)
loss = (partial_losses * scales).sum() / normalization
return loss
class ErrorLogTKVRelative(nn.Module):
def __init__(self, pred_process, epsilon=1.0, standardize_func=None):
super().__init__()
self.pred_process = pred_process
self.epsilon = epsilon
self.standardize_func = standardize_func
def __call__(self, pred, target, extra_dict):
pred = self.pred_process(pred)
if self.standardize_func is not None:
pred = self.standardize_func(pred)
target = self.standardize_func(target)
intersection = torch.sum(pred * target, dim=(1, 2, 3))
error = (
torch.sum(pred ** 2, dim=(1, 2, 3))
+ torch.sum(target ** 2, dim=(1, 2, 3))
- 2 * intersection
)
# augmentation correction for original kidney pixel count
# also, convert to VOXEL VOLUME
scale = (extra_dict[KIDNEY_PIXELS] + self.epsilon) / (
torch.sum(target, dim=(1, 2, 3)) + self.epsilon
)
scaled_vol_error = scale * error * extra_dict[VOXEL_VOLUME]
# error more important if kidneys are smaller
# but for the same kidney volume, error on any slice
# matters equally
# use log due to different orders of magnitudes
weight = 1 / (torch.log(extra_dict[STUDY_TKV]) + self.epsilon)
log_error = (scaled_vol_error * weight).mean()
return log_error
class BiasReductionLoss(nn.Module):
def __init__(
self, pred_process, standardize_func=None, w1=0.5, w2=0.5, epsilon=1e-8
):
super().__init__()
self.pred_process = pred_process
self.standardize_func = standardize_func
self.w1 = w1
self.w2 = w2
self.epsilon = epsilon
def __call__(self, pred, target):
pred = self.pred_process(pred)
if self.standardize_func is not None:
pred = self.standardize_func(pred)
target = self.standardize_func(target)
intersection = torch.sum(pred * target, dim=(1, 2, 3))
# count what's missing from the target area
missing = target.sum(dim=(1, 2, 3)) - intersection
# count all extra predictions outside the target area
false_pos = torch.sum(pred * (1 - target), dim=(1, 2, 3))
# both losses should go to zero, but they should also be the same
loss = (
self.w1 * (missing ** 2 + false_pos ** 2)
+ self.w2 * (missing - false_pos) ** 2
)
# sqrt is not differentiable at zero
loss = (loss.mean() + self.epsilon) ** 0.5
return loss