|
a |
|
b/loss.py |
|
|
1 |
import tensorflow.keras.backend as K |
|
|
2 |
from tensorflow.keras.losses import categorical_crossentropy |
|
|
3 |
|
|
|
4 |
|
|
|
5 |
def generalized_dice(y_true, y_pred): |
|
|
6 |
|
|
|
7 |
""" |
|
|
8 |
Generalized Dice Score |
|
|
9 |
https://arxiv.org/pdf/1707.03237 |
|
|
10 |
|
|
|
11 |
""" |
|
|
12 |
|
|
|
13 |
y_true = K.reshape(y_true,shape=(-1,4)) |
|
|
14 |
y_pred = K.reshape(y_pred,shape=(-1,4)) |
|
|
15 |
sum_p = K.sum(y_pred, -2) |
|
|
16 |
sum_r = K.sum(y_true, -2) |
|
|
17 |
sum_pr = K.sum(y_true * y_pred, -2) |
|
|
18 |
weights = K.pow(K.square(sum_r) + K.epsilon(), -1) |
|
|
19 |
generalized_dice = (2 * K.sum(weights * sum_pr)) / (K.sum(weights * (sum_r + sum_p))) |
|
|
20 |
|
|
|
21 |
return generalized_dice |
|
|
22 |
|
|
|
23 |
def generalized_dice_loss(y_true, y_pred): |
|
|
24 |
return 1-generalized_dice(y_true, y_pred) |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
def custom_loss(y_true, y_pred): |
|
|
28 |
|
|
|
29 |
""" |
|
|
30 |
The final loss function consists of the summation of two losses "GDL" and "CE" |
|
|
31 |
with a regularization term. |
|
|
32 |
""" |
|
|
33 |
|
|
|
34 |
return generalized_dice_loss(y_true, y_pred) + 1.25 * categorical_crossentropy(y_true, y_pred) |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
|
|
|
38 |
|
|
|
39 |
|
|
|
40 |
|