a b/loss.py
1
import tensorflow.keras.backend as K
2
from tensorflow.keras.losses import categorical_crossentropy
3
4
5
def generalized_dice(y_true, y_pred):
6
    
7
    """
8
    Generalized Dice Score
9
    https://arxiv.org/pdf/1707.03237
10
    
11
    """
12
    
13
    y_true    = K.reshape(y_true,shape=(-1,4))
14
    y_pred    = K.reshape(y_pred,shape=(-1,4))
15
    sum_p     = K.sum(y_pred, -2)
16
    sum_r     = K.sum(y_true, -2)
17
    sum_pr    = K.sum(y_true * y_pred, -2)
18
    weights   = K.pow(K.square(sum_r) + K.epsilon(), -1)
19
    generalized_dice = (2 * K.sum(weights * sum_pr)) / (K.sum(weights * (sum_r + sum_p)))
20
    
21
    return generalized_dice
22
23
def generalized_dice_loss(y_true, y_pred):   
24
    return 1-generalized_dice(y_true, y_pred)
25
    
26
    
27
def custom_loss(y_true, y_pred):
28
    
29
    """
30
    The final loss function consists of the summation of two losses "GDL" and "CE"
31
    with a regularization term.
32
    """
33
    
34
    return generalized_dice_loss(y_true, y_pred) + 1.25 * categorical_crossentropy(y_true, y_pred)
35
    
36
    
37
    
38
    
39
    
40