a b/tool/Code/utilities/loss.py
1
2
# Copyright 2019 Population Health Sciences and Image Analysis, German Center for Neurodegenerative Diseases(DZNE)
3
#
4
#    Licensed under the Apache License, Version 2.0 (the "License");
5
#    you may not use this file except in compliance with the License.
6
#    You may obtain a copy of the License at
7
#
8
#        http://www.apache.org/licenses/LICENSE-2.0
9
#
10
#    Unless required by applicable law or agreed to in writing, software
11
#    distributed under the License is distributed on an "AS IS" BASIS,
12
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
#    See the License for the specific language governing permissions and
14
#    limitations under the License.
15
16
import numpy as np
17
from keras import backend as K
18
from keras import metrics
19
20
21
# %%1.DICE LOSS
22
smooth = 1
23
w_dice = 0.5
24
25
K.set_epsilon(1e-7)
26
np.set_printoptions(threshold=np.inf)
27
K.set_image_data_format('channels_last')
28
29
def average_dice_coef(y_true,y_pred):
30
    avg_dice=0
31
    for i in range(y_pred.shape[-1]):
32
        avg_dice += dice_coef_axis(y_true,y_pred,i)
33
    return avg_dice/(i+1)
34
35
def dice_coef(y_true, y_pred):
36
        intersection = 0
37
        union = 0
38
        if len(y_pred.shape)==5:
39
            for i in range(y_pred.shape[-1]):
40
                intersection += (K.sum(y_true[:, :, :,:, i] * y_pred[:, :, :,:, i]))
41
                union += (K.sum(y_true[:, :, :,:, i] + y_pred[:, :, :,:, i]))
42
            return (2. * intersection + smooth) / (union + smooth)
43
        elif len(y_pred.shape)==4:
44
            for i in range(y_pred.shape[-1]):
45
                intersection +=  (K.sum(y_true[:, :, :, i] * y_pred[:, :, :, i]))
46
                union += (K.sum(y_true[:, :, :, i] + y_pred[:, :, :, i]))
47
            return (2. * intersection + smooth) / (union + smooth)
48
49
50
# %%CLASS-WISE-DICE
51
def dice_coef_axis(y_true, y_pred, i):
52
53
    intersection = 0
54
    #med_bal_factor = [1, 1, 1, 1]  # TODO_ remove it. After testing
55
    union = 0
56
    if len(y_pred.shape)==4:
57
        intersection += (K.sum(y_true[:, :, :, i] * y_pred[:, :, :, i]))
58
        union +=(K.sum(y_true[:, :, :, i] + y_pred[:, :, :, i]))
59
        return (2. * intersection + smooth) / (union + smooth)
60
    elif len(y_pred.shape)==5:
61
        intersection += (K.sum(y_true[:, :, :, :, i] * y_pred[:, :, :, :, i]))
62
        union += (K.sum(y_true[:, :, :, :, i] + y_pred[:, :, :, :, i]))
63
        return (2. * intersection + smooth) / (union + smooth)
64
65
def dice_coef_0(y_true, y_pred):
66
    return dice_coef_axis(y_true, y_pred, 0)
67
68
69
def dice_coef_1(y_true, y_pred):
70
    return dice_coef_axis(y_true, y_pred, 1)
71
72
73
def dice_coef_2(y_true, y_pred):
74
    return dice_coef_axis(y_true, y_pred, 2)
75
76
77
def dice_coef_3(y_true, y_pred):
78
    return dice_coef_axis(y_true, y_pred, 3)
79
80
def dice_coef_4(y_true, y_pred):
81
    return dice_coef_axis(y_true, y_pred, 4)
82
83
def dice_coef_loss(y_true, y_pred):
84
        return -dice_coef(y_true, y_pred)
85
86
def jaccard_coef(y_true,y_pred):
87
    y_true = K.clip(y_true, K.epsilon(), 1. - K.epsilon())
88
    y_pred = K.clip(y_pred, K.epsilon(), 1. - K.epsilon())
89
    intersection = K.tf.reduce_sum(y_pred * y_true) + smooth
90
    sum_=(K.tf.reduce_sum(y_true) + K.tf.reduce_sum(y_pred))
91
    union=sum_-intersection+smooth
92
    jac=intersection/union
93
94
    return jac
95
96
97
def custom_loss(MedBalFactor,sigma=3,loss_type='Dice'):
98
    n_classes=len(MedBalFactor)
99
100
    def get_gauss_kernel_3D(sigma):
101
        ker=np.zeros(shape=(3, 3, 3, n_classes, n_classes), dtype='float32')
102
        ind = np.linspace(-np.floor(ker.shape[1]), np.floor(ker.shape[1]), ker.shape[1])
103
        ind2 = np.linspace(-np.floor(ker.shape[2]), np.floor(ker.shape[2]), ker.shape[2])
104
        x, y = np.meshgrid(ind, ind2)
105
        G=np.zeros((3,3,3))
106
        for i in range(ker.shape[0]):
107
            G[i,:,:] = (np.exp((-1 / (2 * sigma ** 2)) * (x ** 2 + y ** 2)))
108
        G = G / np.sum(G)
109
        for i in range(n_classes):
110
            ker[:,:, :, i, i] = G
111
        ker = K.constant(ker)
112
        return ker
113
114
115
    def get_gauss_kernel(sigma):
116
        ker = np.zeros(shape=(3, 3, n_classes, n_classes), dtype='float32')
117
118
        ind = np.linspace(-np.floor(ker.shape[0]), np.floor(ker.shape[0]), ker.shape[0])
119
        ind2 = np.linspace(-np.floor(ker.shape[1]), np.floor(ker.shape[1]), ker.shape[1])
120
        x, y = np.meshgrid(ind, ind2)
121
        G = (np.exp((-1 / (2 * sigma ** 2)) * (x ** 2 + y ** 2)))
122
        G = G / np.sum(G)
123
        for i in range(n_classes):
124
            ker[:, :, i, i] = G
125
        ker = K.constant(ker)
126
        return ker
127
128
    def get_sobel_kernel_3D(axis):
129
        ker = np.zeros(shape=(3,3, 3, n_classes, n_classes), dtype='float32')
130
131
        if axis == 'z':
132
            S=np.array([[[1,2,1],[2,4,2],[1,2,1]],[[0,0,0],[0,0,0],[0,0,0]],[[-1,-2,-1],[-2,-4,-2],[-1,-2,-1]]])
133
        else:
134
            s = np.array([[1, 2, 1],
135
                          [0, 0, 0],
136
                          [-1, -2, -1]], dtype='float32')
137
            if axis == 'y':
138
                pass
139
            elif axis == 'x':
140
                s = np.transpose(s, )
141
142
            S = np.zeros((3, 3, 3))
143
            for i in range(ker.shape[0]):
144
                S[i,:,:] = s[:]
145
146
        for i in range(n_classes):
147
            ker[:,:, :, i, i] = S
148
        ker = K.constant(ker)
149
        return ker
150
151
    def get_sobel_kernel(axis):
152
        s = np.array([[1, 2, 1],
153
                      [0, 0, 0],
154
                      [-1, -2, -1]], dtype='float32')
155
        if axis == 'y':
156
            pass
157
        elif axis == 'x':
158
            s = np.transpose(s, )
159
        ker = np.zeros(shape=(3, 3, n_classes, n_classes), dtype='float32')
160
        for i in range(n_classes):
161
            ker[:, :, i, i] = s
162
        ker = K.constant(ker)
163
        return ker
164
165
    GAUSS_KERNEL_3D=get_gauss_kernel_3D(sigma)
166
    GAUSS_KERNEL = get_gauss_kernel(sigma)
167
    SOBEL_X = get_sobel_kernel('x')
168
    SOBEL_Y = get_sobel_kernel('y')
169
    SOBEL_X_3D=get_sobel_kernel_3D('x')
170
    SOBEL_Y_3D = get_sobel_kernel_3D('y')
171
    SOBEL_Z_3D = get_sobel_kernel_3D('z')
172
173
    def get_grad_tensor_3d(img_tensor,apply_gauss=True):
174
175
        grad_x = K.conv3d(img_tensor, SOBEL_X_3D, padding='same')
176
        grad_y = K.conv3d(img_tensor, SOBEL_Y_3D, padding='same')
177
        grad_z= K.conv3d(img_tensor, SOBEL_Z_3D, padding='same')
178
        grad_tensor = K.sqrt(grad_x * grad_x + grad_y * grad_y + grad_z*grad_z)
179
        grad_tensor = K.greater(grad_tensor, 100.0 * K.epsilon())
180
        grad_tensor = K.cast(grad_tensor, K.floatx())
181
        grad_tensor = K.clip(grad_tensor, K.epsilon(), 1.0)
182
        grad_map = K.sum(grad_tensor, axis=-1, keepdims=True)
183
        for i in range(n_classes):
184
            if i ==0:
185
                grad_tensor=grad_map[:]
186
            else:
187
                grad_tensor = K.concatenate([grad_tensor,grad_map], axis=-1)
188
        # del grad_map
189
        # grad_tensor = K.concatenate([grad_tensor, grad_tensor], axis=CHANNEL_AXIS)
190
        grad_tensor = K.greater(grad_tensor, 100.0 * K.epsilon())
191
        grad_tensor = K.cast(grad_tensor, K.floatx())
192
        if apply_gauss:
193
            grad_tensor = K.conv3d(grad_tensor, GAUSS_KERNEL_3D, padding='same')
194
        return grad_tensor
195
196
197
    def get_grad_tensor(img_tensor, apply_gauss=True):
198
        grad_x = K.conv2d(img_tensor, SOBEL_X, padding='same')
199
        grad_y = K.conv2d(img_tensor, SOBEL_Y, padding='same')
200
201
        grad_tensor = K.sqrt(grad_x * grad_x + grad_y * grad_y)
202
        grad_tensor = K.greater(grad_tensor, 100.0 * K.epsilon())
203
        grad_tensor = K.cast(grad_tensor, K.floatx())
204
        grad_tensor = K.clip(grad_tensor, K.epsilon(), 1.0)
205
        grad_map = K.sum(grad_tensor, axis=-1, keepdims=True)
206
        for i in range(n_classes):
207
            if i ==0:
208
                grad_tensor=grad_map[:]
209
            else:
210
                grad_tensor = K.concatenate([grad_tensor,grad_map], axis=-1)
211
        # del grad_map
212
        # grad_tensor = K.concatenate([grad_tensor, grad_tensor], axis=CHANNEL_AXIS)
213
        grad_tensor = K.greater(grad_tensor, 100.0 * K.epsilon())
214
        grad_tensor = K.cast(grad_tensor, K.floatx())
215
        if apply_gauss:
216
            grad_tensor = K.conv2d(grad_tensor, GAUSS_KERNEL, padding='same')
217
        return grad_tensor
218
219
    def weighted_gradient_loss(y_true,y_pred):
220
        y_true = K.clip(y_true, K.epsilon(), 1. - K.epsilon())
221
        y_pred = K.clip(y_pred, K.epsilon(), 1. - K.epsilon())
222
223
        weights = []
224
225
        if len(y_pred.shape)==4:
226
            axis = [0, 1, 2]
227
            if np.max(MedBalFactor)> 5:
228
                edge_weights=10*get_grad_tensor(y_true,True)
229
            else:
230
                edge_weights = 2 * np.max(MedBalFactor) * get_grad_tensor(y_true, True)
231
232
            for i in range(len(MedBalFactor)):
233
                weights.append(MedBalFactor[i] * K.ones_like(y_true[:, :, :, i:i+1]))
234
235
        elif len(y_pred.shape) == 5:
236
            axis = [0, 1, 2, 3]
237
238
            if np.max(MedBalFactor) > 5:
239
                edge_weights = 10 * get_grad_tensor_3d(y_true, True)
240
            else:
241
                edge_weights = 2 * np.max(MedBalFactor) * get_grad_tensor_3d(y_true, True)
242
243
            for i in range(len(MedBalFactor)):
244
                weights.append(MedBalFactor[i] * K.ones_like(y_true[:, :, :, :, i:i + 1]))
245
246
        class_weights = K.concatenate(weights, axis=-1)
247
        class_weights=K.tf.add(class_weights,edge_weights)
248
        cross_entropy_part=-1.0 * K.tf.reduce_sum(K.tf.reduce_mean(K.tf.multiply(y_true * K.tf.log(y_pred),class_weights),axis=axis,keepdims=True))
249
        return cross_entropy_part
250
251
    def weighted_logistic_loss(y_true,y_pred):
252
        y_true = K.clip(y_true, K.epsilon(), 1. - K.epsilon())
253
        y_pred = K.clip(y_pred, K.epsilon(), 1. - K.epsilon())
254
255
        weights = []
256
        if len(y_pred.shape)==4:
257
            axis=[0,1,2]
258
            for i in range(len(MedBalFactor)):
259
                weights.append(MedBalFactor[i] * K.ones_like(y_true[:,:,:, i:i + 1]))
260
261
        elif len(y_pred.shape)==5:
262
            axis=[0,1,2,3]
263
            for i in range(len(MedBalFactor)):
264
                weights.append(MedBalFactor[i] * K.ones_like(y_true[:,:,:,:, i:i + 1]))
265
266
        class_weights = K.concatenate(weights, axis=-1)
267
        cross_entropy_part=-1.0 * K.tf.reduce_sum(K.tf.reduce_mean(K.tf.multiply(y_true * K.tf.log(y_pred),class_weights),axis=axis,keepdims=True))
268
        return cross_entropy_part
269
270
    def logistic_loss(y_true, y_pred):
271
        y_true = K.clip(y_true, K.epsilon(), 1. - K.epsilon())
272
        y_pred = K.clip(y_pred, K.epsilon(), 1. - K.epsilon())
273
        if len(y_pred.shape)==4:
274
            axis=[0,1,2]
275
        elif len(y_pred.shape)==5:
276
            axis=[0,1,2,3]
277
        cross_entropy_part=-1.0 * K.tf.reduce_sum(K.tf.reduce_mean((y_true * K.tf.log(y_pred)),axis=axis,keepdims=True))
278
        return cross_entropy_part
279
280
    def dice_loss(y_true,y_pred):
281
        y_true = K.clip(y_true, K.epsilon(), 1. - K.epsilon())
282
        y_pred = K.clip(y_pred, K.epsilon(), 1. - K.epsilon())
283
        intersection = K.tf.reduce_sum(y_pred * y_true) + smooth
284
        union = (K.tf.reduce_sum(y_true) + K.tf.reduce_sum(y_pred)) + smooth
285
        dice_part = -2.0 * (intersection / union)
286
287
        return dice_part
288
289
    def mixed_loss(y_true,y_pred):
290
        if loss_type == 'Dice':
291
            return dice_loss(y_true,y_pred)
292
        elif loss_type == 'Logistic':
293
            return logistic_loss(y_true,y_pred)
294
        elif loss_type == 'Weighted_Logistic':
295
            return weighted_logistic_loss(y_true,y_pred)
296
        elif loss_type == 'Weighted_Grad_Logistic':
297
            return weighted_gradient_loss(y_true,y_pred)
298
        elif loss_type == 'Mixed_Grad_Weighted':
299
            dice_part=dice_loss(y_true,y_pred)
300
            cross_entropy_part = weighted_gradient_loss(y_true, y_pred)
301
            return cross_entropy_part + dice_part
302
        elif loss_type== 'Mixed':
303
            dice_part=dice_loss(y_true,y_pred)
304
            cross_entropy_part=logistic_loss(y_true,y_pred)
305
            return cross_entropy_part + dice_part
306
        elif loss_type == 'Mixed_Weighted':
307
            dice_part = dice_loss(y_true, y_pred)
308
            cross_entropy_part=weighted_logistic_loss(y_true,y_pred)
309
            return cross_entropy_part + dice_part
310
311
    return mixed_loss
312
313
314
315
316
317
318
319
320
321