Switch to unified view

a b/BraTs18Challege/Vnet/loss_metric.py
1
from __future__ import print_function, division
2
import tensorflow as tf
3
import numpy as np
4
5
6
# --------------------------- BINARY Evaluation ---------------------------
7
def binary_iou(Y_pred, Y_gt, prob=0.5):
8
    """
9
    binary iou
10
    :param Y_pred:A tensor resulting from a sigmod
11
    :param Y_gt:A tensor of the same shape as `output`
12
    :return: binary iou
13
    """
14
    Y_pred_part = tf.to_float(Y_pred > prob)
15
    Y_pred_part = tf.cast(Y_pred_part, tf.float32)
16
    Y_gt_part = tf.identity(Y_gt)
17
    Y_gt_part = tf.cast(Y_gt_part, tf.float32)
18
    Z, H, W, C = Y_gt.get_shape().as_list()[1:]
19
    smooth = 1.e-5
20
    smooth_tf = tf.constant(smooth, tf.float32)
21
    pred_flat = tf.reshape(Y_pred_part, [-1, H * W * C * Z])
22
    true_flat = tf.reshape(Y_gt_part, [-1, H * W * C * Z])
23
    intersection = tf.reduce_sum(pred_flat * true_flat, axis=-1)
24
    union = tf.reduce_sum(pred_flat, axis=-1) + tf.reduce_sum(true_flat, axis=-1) - intersection
25
    metric = tf.reduce_mean((intersection + smooth_tf) / (union + smooth_tf))
26
    metric = tf.cond(tf.is_inf(metric), lambda: smooth_tf, lambda: metric)
27
    return metric
28
29
30
# --------------------------- BINARY LOSSES ---------------------------
31
def binary_crossentropy(Y_pred, Y_gt):
32
    """
33
    Binary crossentropy between an output tensor and a target tensor.
34
    :param Y_pred:A tensor with (batchsize,z,x,y,channel)from a sigmod,probability distribution.
35
    :param Y_gt:A tensor with the same shape as `output`.
36
    :return:binary_crossentropy
37
    """
38
    epsilon = 1.e-5
39
    Y_pred = tf.clip_by_value(Y_pred, epsilon, 1. - epsilon)
40
    logits = tf.log(Y_pred / (1 - Y_pred))
41
    loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=Y_gt, logits=logits)
42
    loss = tf.reduce_mean(loss)
43
    return loss
44
45
46
def weighted_binary_crossentroy(Y_pred, Y_gt, beta):
47
    """
48
    Weighted cross entropy (WCE) is a variant of CE where all positive examples get weighted by some coefficient.
49
    It is used in the case of class imbalance.
50
    For example, when you have an image with 10% black pixels and 90% white pixels, regular CE won’t work very well.
51
    WCE define:wce(p',p)=-(b*p*log(p')+(1-p)*log(1-p'))
52
    :param Y_pred:A tensor with (batchsize,z,x,y,channel)from a sigmod,probability distribution.
53
    :param Y_gt:A tensor with the same shape as `output`.
54
    :param beta: To decrease the number of false negatives, setβ>1. To decrease the number of false positives, set β<1.
55
    :return:weighted_binary_crossentroy
56
    """
57
    epsilon = 1.e-5
58
    Y_pred = tf.clip_by_value(Y_pred, epsilon, 1. - epsilon)
59
    logits = tf.log(Y_pred / (1 - Y_pred))
60
    loss = tf.nn.weighted_cross_entropy_with_logits(targets=Y_gt, logits=logits, pos_weight=beta)
61
    loss = tf.reduce_mean(loss)
62
    return loss
63
64
65
def balanced_binary_crossentroy(Y_pred, Y_gt, beta):
66
    """
67
    Balanced cross entropy (BCE) is similar to WCE. The only difference is that we weight also the negative examples
68
    bce define:bce(p',p)=-(b*p*log(p')+(1-b)*(1-p)*log(1-p'))
69
    :param Y_pred:A tensor with (batchsize,z,x,y,channel)from a sigmod,probability distribution.
70
    :param Y_gt:A tensor with the same shape as `output`.
71
    :param beta: β!=1,the denominator in pos_weight is not defined
72
    such as:beta = tf.reduce_sum(1 - y_true) / (BATCH_SIZE * HEIGHT * WIDTH)
73
    :return:
74
    """
75
    epsilon = 1.e-5
76
    Y_pred = tf.clip_by_value(Y_pred, epsilon, 1. - epsilon)
77
    logits = tf.log(Y_pred / (1 - Y_pred))
78
    beta = tf.clip_by_value(beta, epsilon, 1 - epsilon)
79
    pos_weight = beta / (1 - beta)
80
    loss = tf.nn.weighted_cross_entropy_with_logits(targets=Y_gt, logits=logits, pos_weight=pos_weight)
81
    loss = tf.reduce_mean(loss * (1 - beta))
82
    return loss
83
84
85
def binary_dice(Y_pred, Y_gt):
86
    """
87
    binary dice loss
88
    loss=2*(p&p')/(p+p')
89
    :param Y_pred: A tensor resulting from a sigmod
90
    :param Y_gt:  A tensor of the same shape as `output`
91
    :return: binary dice loss
92
    """
93
    smooth = 1.e-5
94
    smooth_tf = tf.constant(smooth, tf.float32)
95
    pred_flat = tf.cast(Y_pred, tf.float32)
96
    true_flat = tf.cast(Y_gt, tf.float32)
97
    intersection = 2 * tf.reduce_sum(pred_flat * true_flat, axis=-1) + smooth_tf
98
    denominator = tf.reduce_sum(pred_flat, axis=-1) + tf.reduce_sum(true_flat, axis=-1) + smooth_tf
99
    loss = -tf.reduce_mean(intersection / denominator)
100
    return loss
101
102
103
def binary_tversky(Y_pred, Y_gt, beta):
104
    """
105
    Tversky loss (TL) is a generalization of Dice loss. TL adds a weight to FP and FN.
106
    define:TL(p,p')=(p&p')/(p&p'+b*((1-p)&p')+(1-b)*(p&(1-p')))
107
    :param Y_pred:A tensor resulting from a sigmod
108
    :param Y_gt:A tensor of the same shape as `output`
109
    :param beta:beta=1/2,just Dice loss,beta must(0,1)
110
    :return:
111
    """
112
    smooth = 1.e-5
113
    smooth_tf = tf.constant(smooth, tf.float32)
114
    pred_flat = tf.cast(Y_pred, tf.float32)
115
    true_flat = tf.cast(Y_gt, tf.float32)
116
    intersection = tf.reduce_sum(pred_flat * true_flat, axis=-1) + smooth_tf
117
    denominator = intersection + tf.reduce_sum(beta * pred_flat * (1 - true_flat), axis=-1) + \
118
                  tf.reduce_sum((1 - beta) * true_flat * (1 - pred_flat), axis=-1) + smooth_tf
119
    loss = -tf.reduce_mean(intersection / denominator)
120
    return loss
121
122
123
def binary_dicePcrossentroy(Y_pred, Y_gt, lamda=1):
124
    """
125
    plus dice and crossentroy loss
126
    :param Y_pred:A tensor resulting from a sigmod
127
    :param Y_gt:A tensor of the same shape as `output`
128
    :param lamda:can set 0.1,0.5,1
129
    :return:dice+crossentroy
130
    """
131
    # step 1,calculate binary crossentroy
132
    loss1 = binary_crossentropy(Y_pred, Y_gt)
133
    # step 2,calculate binary dice
134
    loss2 = 1 - binary_dice(Y_pred, Y_gt)
135
    # step 3,calculate all loss mean
136
    loss = lamda * loss1 + tf.log1p(loss2)
137
    return loss
138
139
140
def binary_focalloss(Y_pred, Y_gt, alpha=0.25, gamma=2.):
141
    """
142
    Binary focal loss.
143
    FL(p_t) = -alpha * (1 - p_t)**gamma * log(p_t)
144
    where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
145
    :param Y_gt: A tensor of the same shape as `y_pred`
146
    :param Y_pred: A tensor resulting from a sigmoid
147
    :param alpha: Sample category weight
148
    :param gamma: Difficult sample weight
149
    :return: Binary focal loss.
150
    """
151
    epsilon = 1.e-5
152
    pt_1 = tf.where(tf.equal(Y_gt, 1), Y_pred, tf.ones_like(Y_pred))
153
    pt_0 = tf.where(tf.equal(Y_gt, 0), Y_pred, tf.zeros_like(Y_pred))
154
    # clip to prevent NaN's and Inf's
155
    pt_1 = tf.clip_by_value(pt_1, epsilon, 1. - epsilon)
156
    pt_0 = tf.clip_by_value(pt_0, epsilon, 1. - epsilon)
157
    loss_1 = alpha * tf.pow(1. - pt_1, gamma) * tf.log(pt_1)
158
    loss_0 = (1 - alpha) * tf.pow(pt_0, gamma) * tf.log(1. - pt_0)
159
    loss = -tf.reduce_sum(loss_1 + loss_0)
160
    loss = tf.reduce_mean(loss)
161
    return loss
162
163
164
def binary_distanceloss(Y_pred, Y_gt):
165
    """
166
    distance loss,make segmentation network more sensitive to the boundaries
167
    can use with cross entroy ,dice together,but should have mutilfy weighting coefficient
168
    :param Y_pred:A tensor resulting from a sigmoid
169
    :param Y_gt:A tensor of the same shape as `y_pred`
170
    :return:
171
    """
172
    pred_flat = tf.cast(Y_pred, tf.float32)
173
    true_flat = tf.cast(Y_gt, tf.float32)
174
175
    def Edge_Extracted(y_pred):
176
        # Edge extracted by sobel filter
177
        min_x = tf.constant(0, tf.float32)
178
        max_x = tf.constant(1, tf.float32)
179
        sobel_x = tf.constant([[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
180
                               [[-2, 0, 2], [-4, 0, 4], [-2, 0, 2]],
181
                               [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]], tf.float32)
182
        sobel_x_filter = tf.reshape(sobel_x, [3, 3, 3, 1, 1])
183
        sobel_y_filter = tf.transpose(sobel_x_filter, [0, 2, 1, 3, 4])
184
        filters_x = tf.nn.conv3d(y_pred, sobel_x_filter, strides=[1, 1, 1, 1, 1], padding='SAME')
185
        filters_y = tf.nn.conv3d(y_pred, sobel_y_filter, strides=[1, 1, 1, 1, 1], padding='SAME')
186
        edge = tf.sqrt(filters_x * filters_x + filters_y * filters_y + 1e-16)
187
        edge = tf.clip_by_value(edge, min_x, max_x)
188
        return edge
189
190
    Y_pred_edge = Edge_Extracted(pred_flat)
191
    Y_gt_edge = Edge_Extracted(true_flat)
192
    distanceloss = tf.reduce_sum(Y_gt_edge * Y_pred_edge, axis=-1)
193
    return distanceloss
194
195
196
# --------------------------- MULTICLASS Evaluation ---------------------------
197
198
def mean_iou(Y_pred, Y_gt):
199
    """
200
    Mean Intersection-Over-Union is a common evaluation metric for
201
    semantic image segmentation, which first computes the IOU for each
202
    semantic class and then computes the average over classes,but label 0 is background,general background is more big,
203
    so mean iou calculate don't include background
204
    :param Y_pred: [None, self.image_depth, self.image_height, self.image_width,
205
                                                       self.numclass],Y_pred is softmax result
206
    :param Y_gt: [None, self.image_depth, self.image_height, self.image_width,
207
                                                       self.numclass],Y_gt is one hot result
208
    :return: mean_iou
209
    """
210
    num_class = Y_pred.get_shape().as_list()[-1]
211
    Y_pred_part = tf.one_hot(tf.argmax(Y_pred, axis=-1), num_class)
212
    Y_pred_part = tf.cast(Y_pred_part, tf.float32)
213
    Y_pred_part = Y_pred_part[:, :, :, :, 1:num_class]
214
    Y_gt_part = tf.cast(Y_gt, tf.float32)
215
    Y_gt_part = Y_gt_part[:, :, :, :, 1:num_class]
216
    Z, H, W, C = Y_gt.get_shape().as_list()[1:]
217
    smooth = 1.e-5
218
    smooth_tf = tf.constant(smooth, tf.float32)
219
    pred_flat = tf.reshape(Y_pred_part, [-1, H * W * Z])
220
    true_flat = tf.reshape(Y_gt_part, [-1, H * W * Z])
221
    intersection = tf.reduce_sum(pred_flat * true_flat, axis=-1)
222
    union = tf.reduce_sum(pred_flat, axis=-1) + tf.reduce_sum(true_flat, axis=-1) - intersection
223
    metric = tf.reduce_mean((intersection + smooth_tf) / (union + smooth_tf))
224
    metric = tf.cond(tf.is_inf(metric), lambda: smooth_tf, lambda: metric)
225
    return metric
226
227
228
def mean_dice(Y_pred, Y_gt):
229
    """
230
    Mean dice is a common evaluation metric for
231
    semantic image segmentation, which first computes the dice for each
232
    semantic class and then computes the average over classes,but label 0 is background,general background is more big,
233
    so mean dice calculate don't include background
234
    :param Y_pred: [None, self.image_depth, self.image_height, self.image_width,
235
                                                       self.numclass],Y_pred is softmax result
236
    :param Y_gt: [None, self.image_depth, self.image_height, self.image_width,
237
                                                       self.numclass],Y_gt is one hot result
238
    :return: mean_iou
239
    """
240
    num_class = Y_pred.get_shape().as_list()[-1]
241
    Y_pred_part = tf.one_hot(tf.argmax(Y_pred, axis=-1), num_class)
242
    Y_pred_part = tf.cast(Y_pred_part, tf.float32)
243
    Y_pred_part = Y_pred_part[:, :, :, :, 1:num_class]
244
    Y_gt_part = tf.cast(Y_gt, tf.float32)
245
    Y_gt_part = Y_gt_part[:, :, :, :, 1:num_class]
246
    Z, H, W, C = Y_gt.get_shape().as_list()[1:]
247
    smooth = 1.e-5
248
    smooth_tf = tf.constant(smooth, tf.float32)
249
    pred_flat = tf.reshape(Y_pred_part, [-1, H * W * Z])
250
    true_flat = tf.reshape(Y_gt_part, [-1, H * W * Z])
251
    intersection = 2 * tf.reduce_sum(pred_flat * true_flat, axis=-1)
252
    union = tf.reduce_sum(pred_flat, axis=-1) + tf.reduce_sum(true_flat, axis=-1)
253
    metric = tf.reduce_mean((intersection + smooth_tf) / (union + smooth_tf))
254
    metric = tf.cond(tf.is_inf(metric), lambda: smooth_tf, lambda: metric)
255
    return metric
256
257
258
# --------------------------- MULTICLASS LOSSES ---------------------------
259
def categorical_crossentropy(Y_pred, Y_gt):
260
    """
261
    Categorical crossentropy between an output and a target
262
    loss=-y*log(y')
263
    :param Y_pred: A tensor resulting from a softmax
264
    :param Y_gt:  A tensor of the same shape as `output`
265
    :return:categorical_crossentropy loss
266
    """
267
    epsilon = 1.e-5
268
    # scale preds so that the class probas of each sample sum to 1
269
    output = Y_pred / tf.reduce_sum(Y_pred, axis=- 1, keep_dims=True)
270
    # manual computation of crossentropy
271
    output = tf.clip_by_value(output, epsilon, 1. - epsilon)
272
    loss = -Y_gt * tf.log(output)
273
    loss = tf.reduce_sum(loss, axis=(1, 2, 3))
274
    loss = tf.reduce_mean(loss, axis=0)
275
    loss = tf.reduce_mean(loss)
276
    return loss
277
278
279
def weighted_categorical_crossentropy(Y_pred, Y_gt, weights):
280
    """
281
    weighted_categorical_crossentropy between an output and a target
282
    loss=-weight*y*log(y')
283
    :param Y_pred:A tensor resulting from a softmax
284
    :param Y_gt:A tensor of the same shape as `output`
285
    :param weights:numpy array of shape (C,) where C is the number of classes
286
    :return:categorical_crossentropy loss
287
    Usage:
288
    weights = np.array([0.5,2,10]) # Class one at 0.5, class 2 twice the normal weights, class 3 10x.
289
    """
290
    weights = np.array(weights)
291
    epsilon = 1.e-5
292
    # scale preds so that the class probas of each sample sum to 1
293
    output = Y_pred / tf.reduce_sum(Y_pred, axis=- 1, keep_dims=True)
294
    # manual computation of crossentropy
295
    output = tf.clip_by_value(output, epsilon, 1. - epsilon)
296
    loss = - Y_gt * tf.log(output)
297
    loss = tf.reduce_sum(loss, axis=(1, 2, 3))
298
    loss = tf.reduce_mean(loss, axis=0)
299
    loss = tf.reduce_mean(weights * loss)
300
    return loss
301
302
303
def categorical_dice(Y_pred, Y_gt, weight_loss):
304
    """
305
    multi label dice loss with weighted
306
    WDL=1-2*(sum(w*sum(r&p))/sum((w*sum(r+p)))),w=array of shape (C,)
307
    :param Y_pred: [None, self.image_depth, self.image_height, self.image_width,
308
                                                       self.numclass],Y_pred is softmax result
309
    :param Y_gt:[None, self.image_depth, self.image_height, self.image_width,
310
                                                       self.numclass],Y_gt is one hot result
311
    :param weight_loss: numpy array of shape (C,) where C is the number of classes
312
    :return:
313
    """
314
    weight_loss = np.array(weight_loss)
315
    smooth = 1.e-5
316
    smooth_tf = tf.constant(smooth, tf.float32)
317
    Y_pred = tf.cast(Y_pred, tf.float32)
318
    Y_gt = tf.cast(Y_gt, tf.float32)
319
    # Compute gen dice coef:
320
    numerator = Y_gt * Y_pred
321
    numerator = tf.reduce_sum(numerator, axis=(1, 2, 3))
322
    denominator = Y_gt + Y_pred
323
    denominator = tf.reduce_sum(denominator, axis=(1, 2, 3))
324
    gen_dice_coef = tf.reduce_mean(2. * (numerator + smooth_tf) / (denominator + smooth_tf), axis=0)
325
    loss = -tf.reduce_mean(weight_loss * gen_dice_coef)
326
    return loss
327
328
329
def categorical_tversky(Y_pred, Y_gt, beta, weight_loss):
330
    """
331
    multi label tversky with weighted
332
    Tversky loss (TL) is a generalization of Dice loss. TL adds a weight to FP and FN.
333
    define:TL(p,p')=(p&p')/(p&p'+b*((1-p)&p')+(1-b)*(p&(1-p')))
334
    :param Y_pred: [None, self.image_depth, self.image_height, self.image_width,
335
                                                       self.numclass],Y_pred is softmax result
336
    :param Y_gt:[None, self.image_depth, self.image_height, self.image_width,
337
                                                       self.numclass],Y_gt is one hot result
338
    :param beta:beta=1/2,just Dice loss,beta must(0,1)
339
    :return:
340
    """
341
    weight_loss = np.array(weight_loss)
342
    smooth = 1.e-5
343
    smooth_tf = tf.constant(smooth, tf.float32)
344
    Y_pred = tf.cast(Y_pred, tf.float32)
345
    Y_gt = tf.cast(Y_gt, tf.float32)
346
    p0 = Y_pred
347
    p1 = 1 - Y_pred
348
    g0 = Y_gt
349
    g1 = 1 - Y_gt
350
    # Compute gen dice coef:
351
    numerator = p0 * g0
352
    numerator = tf.reduce_sum(numerator, axis=(1, 2, 3))
353
    denominator = tf.reduce_sum(beta * p0 * g1, axis=(1, 2, 3)) + tf.reduce_sum((1 - beta) * p1 * g0,
354
                                                                                axis=(1, 2, 3)) + numerator
355
    gen_dice_coef = tf.reduce_mean((numerator + smooth_tf) / (denominator + smooth_tf), axis=0)
356
    loss = -tf.reduce_mean(weight_loss * gen_dice_coef)
357
    return loss
358
359
360
def generalized_dice_loss_w(Y_pred, Y_gt):
361
    """
362
    Generalized Dice Loss with class weights
363
    GDL=1-2*(sum(w*sum(r*p))/sum((w*sum(r+p)))),w=1/sum(r)*sum(r)
364
    rln为类别l在第n个像素的标准值(GT),而pln为相应的预测概率值。此处最关键的是wl,为每个类别的权重
365
    :param Y_gt:[None, self.image_depth, self.image_height, self.image_width,
366
                                                       self.numclass],Y_gt is one hot result
367
    :param Y_pred:[None, self.image_depth, self.image_height, self.image_width,
368
                                                       self.numclass],Y_pred is softmax result
369
    :return:
370
    """
371
    smooth = 1.e-5
372
    smooth_tf = tf.constant(smooth, tf.float32)
373
    Y_pred = tf.cast(Y_pred, tf.float32)
374
    Y_gt = tf.cast(Y_gt, tf.float32)
375
    # Compute weights: "the contribution of each label is corrected by the inverse of its volume"
376
    weight_loss = tf.reduce_sum(Y_gt, axis=(0, 1, 2, 3))
377
    weight_loss = 1 / (tf.pow(weight_loss, 2) + smooth_tf)
378
    # Compute gen dice coef:
379
    numerator = Y_gt * Y_pred
380
    numerator = weight_loss * tf.reduce_sum(numerator, axis=(0, 1, 2, 3))
381
    numerator = tf.reduce_sum(numerator)
382
    denominator = Y_gt + Y_pred
383
    denominator = weight_loss * tf.reduce_sum(denominator, axis=(0, 1, 2, 3))
384
    denominator = tf.reduce_sum(denominator)
385
    loss = -2 * (numerator + smooth_tf) / (denominator + smooth_tf)
386
    return loss
387
388
389
def categorical_focal_loss(Y_pred, Y_gt, gamma, alpha):
390
    """
391
     Categorical focal_loss between an output and a target
392
    :param Y_pred: A tensor of the same shape as `y_pred`
393
    :param Y_gt: A tensor resulting from a softmax(-1,z,h,w,numclass)
394
    :param alpha: Sample category weight,which is shape (C,) where C is the number of classes
395
    :param gamma: Difficult sample weight
396
    :return:
397
    """
398
    weight_loss = np.array(alpha)
399
    epsilon = 1.e-5
400
    # Scale predictions so that the class probas of each sample sum to 1
401
    output = Y_pred / tf.reduce_sum(Y_pred, axis=- 1, keepdims=True)
402
    # Clip the prediction value to prevent NaN's and Inf's
403
    output = tf.clip_by_value(output, epsilon, 1. - epsilon)
404
    # Calculate Cross Entropy
405
    cross_entropy = -Y_gt * tf.log(output)
406
    # Calculate Focal Loss
407
    loss = tf.pow(1 - output, gamma) * cross_entropy
408
    loss = tf.reduce_sum(loss, axis=(1, 2, 3))
409
    loss = tf.reduce_mean(loss, axis=0)
410
    loss = tf.reduce_mean(weight_loss * loss)
411
    return loss
412
413
414
def categorical_dicePcrossentroy(Y_pred, Y_gt, weight, lamda=0.5):
415
    """
416
    hybrid loss function from dice loss and crossentroy
417
    loss=Ldice+lamda*Lfocalloss
418
    :param Y_pred:A tensor resulting from a softmax(-1,z,h,w,numclass)
419
    :param Y_gt: A tensor of the same shape as `y_pred`
420
    :param gamma:Difficult sample weight
421
    :param alpha:Sample category weight,which is shape (C,) where C is the number of classes
422
    :param lamda:trade-off between dice loss and focal loss,can set 0.1,0.5,1
423
    :return:diceplusfocalloss
424
    """
425
    weight_loss = np.array(weight)
426
    smooth = 1.e-5
427
    smooth_tf = tf.constant(smooth, tf.float32)
428
    Y_pred = tf.cast(Y_pred, tf.float32)
429
    Y_gt = tf.cast(Y_gt, tf.float32)
430
    # Compute gen dice coef:
431
    numerator = Y_gt * Y_pred
432
    numerator = tf.reduce_sum(numerator, axis=(1, 2, 3))
433
    denominator = Y_gt + Y_pred
434
    denominator = tf.reduce_sum(denominator, axis=(1, 2, 3))
435
    gen_dice_coef = tf.reduce_sum(2. * (numerator + smooth_tf) / (denominator + smooth_tf), axis=0)
436
    loss1 = tf.reduce_mean(weight_loss * gen_dice_coef)
437
    epsilon = 1.e-5
438
    # scale preds so that the class probas of each sample sum to 1
439
    output = Y_pred / tf.reduce_sum(Y_pred, axis=- 1, keep_dims=True)
440
    # manual computation of crossentropy
441
    output = tf.clip_by_value(output, epsilon, 1. - epsilon)
442
    loss = -Y_gt * tf.log(output)
443
    loss = tf.reduce_mean(loss, axis=(1, 2, 3))
444
    loss = tf.reduce_mean(loss, axis=0)
445
    loss2 = tf.reduce_mean(weight_loss * loss)
446
    total_loss = (1 - lamda) * (1 - loss1) + lamda * loss2
447
    return total_loss
448
449
450
def categorical_dicePfocalloss(Y_pred, Y_gt, alpha, lamda=0.5, gamma=2.):
451
    """
452
    hybrid loss function from dice loss and focalloss
453
    loss=Ldice+lamda*Lfocalloss
454
    :param Y_pred:A tensor resulting from a softmax(-1,z,h,w,numclass)
455
    :param Y_gt: A tensor of the same shape as `y_pred`
456
    :param gamma:Difficult sample weight
457
    :param alpha:Sample category weight,which is shape (C,) where C is the number of classes
458
    :param lamda:trade-off between dice loss and focal loss,can set 0.1,0.5,1
459
    :return:dicePfocalloss
460
    """
461
    weight_loss = np.array(alpha)
462
    smooth = 1.e-5
463
    smooth_tf = tf.constant(smooth, tf.float32)
464
    Y_pred = tf.cast(Y_pred, tf.float32)
465
    Y_gt = tf.cast(Y_gt, tf.float32)
466
    # Compute gen dice coef:
467
    numerator = Y_gt * Y_pred
468
    numerator = tf.reduce_sum(numerator, axis=(1, 2, 3))
469
    denominator = Y_gt + Y_pred
470
    denominator = tf.reduce_sum(denominator, axis=(1, 2, 3))
471
    gen_dice_coef = tf.reduce_sum(2. * (numerator + smooth_tf) / (denominator + smooth_tf), axis=0)
472
    loss1 = tf.reduce_mean(weight_loss * gen_dice_coef)
473
    epsilon = 1.e-5
474
    # Scale predictions so that the class probas of each sample sum to 1
475
    output = Y_pred / tf.reduce_sum(Y_pred, axis=- 1, keepdims=True)
476
    # Clip the prediction value to prevent NaN's and Inf's
477
    output = tf.clip_by_value(output, epsilon, 1. - epsilon)
478
    # Calculate Cross Entropy
479
    cross_entropy = -Y_gt * tf.log(output)
480
    # Calculate Focal Loss
481
    loss = tf.pow(1 - output, gamma) * cross_entropy
482
    loss = tf.reduce_mean(loss, axis=(1, 2, 3))
483
    loss = tf.reduce_mean(loss, axis=0)
484
    loss2 = tf.reduce_mean(weight_loss * loss)
485
    total_loss = (1 - lamda) * (1 - loss1) + lamda * loss2
486
    return total_loss
487
488
489
def ssim2d_loss(Y_pred, Y_gt, maxlabel):
490
    """
491
    Computes SSIM index between Y_pred and Y_gt.only calculate 2d image,3d image can use it,but not actual ssim3d
492
    :param Y_pred:A tensor resulting from a softmax(-1,z,h,w,numclass)
493
    :param Y_gt:A tensor of the same shape as `y_pred`
494
    :param maxlabel:maxlabelvalue
495
    :return:ssim_loss
496
    """
497
    loss = tf.image.ssim(Y_pred, Y_gt, maxlabel)
498
    loss = tf.reduce_mean(loss)
499
    return loss
500
501
502
def multiscalessim2d_loss(Y_pred, Y_gt, maxlabel, downsampledfactor=4):
503
    """
504
    Computes the MS-SSIM between Y_pred and Y_gt.only calculate 2d image,3d image can use it,but not actual multiscalessim3d
505
    :param Y_pred:A tensor resulting from a softmax(-1,z,h,w,numclass)
506
    :param Y_gt:A tensor of the same shape as `y_pred`
507
    :param maxlabel:maxlabelvalue
508
    :param downsampledfactor:downsample factor depend on input imagesize
509
    :return:multiscalessim_loss
510
    """
511
    if downsampledfactor >= 5:
512
        _MSSSIM_WEIGHTS = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)
513
    if downsampledfactor == 4:
514
        _MSSSIM_WEIGHTS = (0.0448, 0.2856, 0.3001, 0.2363)
515
    if downsampledfactor == 3:
516
        _MSSSIM_WEIGHTS = (0.0448, 0.2856, 0.3001)
517
    if downsampledfactor == 2:
518
        _MSSSIM_WEIGHTS = (0.0448, 0.2856)
519
    if downsampledfactor <= 1:
520
        _MSSSIM_WEIGHTS = (0.0448)
521
    loss = tf.image.ssim_multiscale(Y_pred, Y_gt, maxlabel, power_factors=_MSSSIM_WEIGHTS)
522
    loss = tf.reduce_mean(loss)
523
    return loss