Diff of /losses.py [000000] .. [beb348]

Switch to unified view

a b/losses.py
1
import numpy as np
2
import keras.backend as K
3
4
5
def dice(y_true, y_pred):
6
    #computes the dice score on two tensors
7
8
    sum_p=K.sum(y_pred,axis=0)
9
    sum_r=K.sum(y_true,axis=0)
10
    sum_pr=K.sum(y_true * y_pred,axis=0)
11
    dice_numerator =2*sum_pr
12
    dice_denominator =sum_r+sum_p
13
    dice_score =(dice_numerator+K.epsilon() )/(dice_denominator+K.epsilon())
14
    return dice_score
15
16
17
def dice_whole_metric(y_true, y_pred):
18
    #computes the dice for the whole tumor
19
20
    y_true_f = K.reshape(y_true,shape=(-1,4))
21
    y_pred_f = K.reshape(y_pred,shape=(-1,4))
22
    y_whole=K.sum(y_true_f[:,1:],axis=1)
23
    p_whole=K.sum(y_pred_f[:,1:],axis=1)
24
    dice_whole=dice(y_whole,p_whole)
25
    return dice_whole
26
27
def dice_en_metric(y_true, y_pred):
28
    #computes the dice for the enhancing region
29
30
    y_true_f = K.reshape(y_true,shape=(-1,4))
31
    y_pred_f = K.reshape(y_pred,shape=(-1,4))
32
    y_enh=y_true_f[:,-1]
33
    p_enh=y_pred_f[:,-1]
34
    dice_en=dice(y_enh,p_enh)
35
    return dice_en
36
37
def dice_core_metric(y_true, y_pred):
38
    ##computes the dice for the core region
39
40
    y_true_f = K.reshape(y_true,shape=(-1,4))
41
    y_pred_f = K.reshape(y_pred,shape=(-1,4))
42
    
43
    #workaround for tf
44
    #y_core=K.sum(tf.gather(y_true_f, [1,3],axis =1),axis=1)
45
    #p_core=K.sum(tf.gather(y_pred_f, [1,3],axis =1),axis=1)
46
    
47
    y_core=K.sum(y_true_f[:,[1,3]],axis=1)
48
    p_core=K.sum(y_pred_f[:,[1,3]],axis=1)
49
    dice_core=dice(y_core,p_core)
50
    return dice_core
51
52
53
54
def weighted_log_loss(y_true, y_pred):
55
    # scale predictions so that the class probas of each sample sum to 1
56
    y_pred /= K.sum(y_pred, axis=-1, keepdims=True)
57
    # clip to prevent NaN's and Inf's
58
    y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())
59
    # weights are assigned in this order : normal,necrotic,edema,enhancing 
60
    weights=np.array([1,5,2,4])
61
    weights = K.variable(weights)
62
    loss = y_true * K.log(y_pred) * weights
63
    loss = K.mean(-K.sum(loss, -1))
64
    return loss
65
66
def gen_dice_loss(y_true, y_pred):
67
    '''
68
    computes the sum of two losses : generalised dice loss and weighted cross entropy
69
    '''
70
71
    #generalised dice score is calculated as in this paper : https://arxiv.org/pdf/1707.03237
72
    y_true_f = K.reshape(y_true,shape=(-1,4))
73
    y_pred_f = K.reshape(y_pred,shape=(-1,4))
74
    sum_p=K.sum(y_pred_f,axis=-2)
75
    sum_r=K.sum(y_true_f,axis=-2)
76
    sum_pr=K.sum(y_true_f * y_pred_f,axis=-2)
77
    weights=K.pow(K.square(sum_r)+K.epsilon(),-1)
78
    generalised_dice_numerator =2*K.sum(weights*sum_pr)
79
    generalised_dice_denominator =K.sum(weights*(sum_r+sum_p))
80
    generalised_dice_score =generalised_dice_numerator /generalised_dice_denominator
81
    GDL=1-generalised_dice_score
82
    del sum_p,sum_r,sum_pr,weights
83
84
    return GDL+weighted_log_loss(y_true,y_pred)