Diff of /model.py [000000] .. [fcace9]

Switch to unified view

a b/model.py
1
# Keras implementation of the paper:
2
# 3D MRI Brain Tumor Segmentation Using Autoencoder Regularization
3
# by Myronenko A. (https://arxiv.org/pdf/1810.11654.pdf)
4
# Author of this code: Suyog Jadhav (https://github.com/IAmSUyogJadhav)
5
6
import keras.backend as K
7
from keras.losses import mse
8
from keras.layers import Conv3D, Activation, Add, UpSampling3D, Lambda, Dense
9
from keras.layers import Input, Reshape, Flatten, Dropout, SpatialDropout3D
10
from keras.optimizers import adam
11
from keras.models import Model
12
try:
13
    from group_norm import GroupNormalization
14
except ImportError:
15
    import urllib.request
16
    print('Downloading group_norm.py in the current directory...')
17
    url = 'https://raw.githubusercontent.com/titu1994/Keras-Group-Normalization/master/group_norm.py'
18
    urllib.request.urlretrieve(url, "group_norm.py")
19
    from group_norm import GroupNormalization
20
21
22
def green_block(inp, filters, data_format='channels_first', name=None):
23
    """
24
    green_block(inp, filters, name=None)
25
    ------------------------------------
26
    Implementation of the special residual block used in the paper. The block
27
    consists of two (GroupNorm --> ReLu --> 3x3x3 non-strided Convolution)
28
    units, with a residual connection from the input `inp` to the output. Used
29
    internally in the model. Can be used independently as well.
30
31
    Parameters
32
    ----------
33
    `inp`: An keras.layers.layer instance, required
34
        The keras layer just preceding the green block.
35
    `filters`: integer, required
36
        No. of filters to use in the 3D convolutional block. The output
37
        layer of this green block will have this many no. of channels.
38
    `data_format`: string, optional
39
        The format of the input data. Must be either 'chanels_first' or
40
        'channels_last'. Defaults to `channels_first`, as used in the paper.
41
    `name`: string, optional
42
        The name to be given to this green block. Defaults to None, in which
43
        case, keras uses generated names for the involved layers. If a string
44
        is provided, the names of individual layers are generated by attaching
45
        a relevant prefix from [GroupNorm_, Res_, Conv3D_, Relu_, ], followed
46
        by _1 or _2.
47
48
    Returns
49
    -------
50
    `out`: A keras.layers.Layer instance
51
        The output of the green block. Has no. of channels equal to `filters`.
52
        The size of the rest of the dimensions remains same as in `inp`.
53
    """
54
    inp_res = Conv3D(
55
        filters=filters,
56
        kernel_size=(1, 1, 1),
57
        strides=1,
58
        data_format=data_format,
59
        name=f'Res_{name}' if name else None)(inp)
60
61
    # axis=1 for channels_first data format
62
    # No. of groups = 8, as given in the paper
63
    x = GroupNormalization(
64
        groups=8,
65
        axis=1 if data_format == 'channels_first' else 0,
66
        name=f'GroupNorm_1_{name}' if name else None)(inp)
67
    x = Activation('relu', name=f'Relu_1_{name}' if name else None)(x)
68
    x = Conv3D(
69
        filters=filters,
70
        kernel_size=(3, 3, 3),
71
        strides=1,
72
        padding='same',
73
        data_format=data_format,
74
        name=f'Conv3D_1_{name}' if name else None)(x)
75
76
    x = GroupNormalization(
77
        groups=8,
78
        axis=1 if data_format == 'channels_first' else 0,
79
        name=f'GroupNorm_2_{name}' if name else None)(x)
80
    x = Activation('relu', name=f'Relu_2_{name}' if name else None)(x)
81
    x = Conv3D(
82
        filters=filters,
83
        kernel_size=(3, 3, 3),
84
        strides=1,
85
        padding='same',
86
        data_format=data_format,
87
        name=f'Conv3D_2_{name}' if name else None)(x)
88
89
    out = Add(name=f'Out_{name}' if name else None)([x, inp_res])
90
    return out
91
92
93
# From keras-team/keras/blob/master/examples/variational_autoencoder.py
94
def sampling(args):
95
    """Reparameterization trick by sampling from an isotropic unit Gaussian.
96
    # Arguments
97
        args (tensor): mean and log of variance of Q(z|X)
98
    # Returns
99
        z (tensor): sampled latent vector
100
    """
101
    z_mean, z_var = args
102
    batch = K.shape(z_mean)[0]
103
    dim = K.int_shape(z_mean)[1]
104
    # by default, random_normal has mean = 0 and std = 1.0
105
    epsilon = K.random_normal(shape=(batch, dim))
106
    return z_mean + K.exp(0.5 * z_var) * epsilon
107
108
109
def dice_coefficient(y_true, y_pred):
110
    intersection = K.sum(K.abs(y_true * y_pred), axis=[-3,-2,-1])
111
    dn = K.sum(K.square(y_true) + K.square(y_pred), axis=[-3,-2,-1]) + 1e-8
112
    return K.mean(2 * intersection / dn, axis=[0,1])
113
114
115
def loss_gt(e=1e-8):
116
    """
117
    loss_gt(e=1e-8)
118
    ------------------------------------------------------
119
    Since keras does not allow custom loss functions to have arguments
120
    other than the true and predicted labels, this function acts as a wrapper
121
    that allows us to implement the custom loss used in the paper. This function
122
    only calculates - L<dice> term of the following equation. (i.e. GT Decoder part loss)
123
    
124
    L = - L<dice> + weight_L2 ∗ L<L2> + weight_KL ∗ L<KL>
125
    
126
    Parameters
127
    ----------
128
    `e`: Float, optional
129
        A small epsilon term to add in the denominator to avoid dividing by
130
        zero and possible gradient explosion.
131
        
132
    Returns
133
    -------
134
    loss_gt_(y_true, y_pred): A custom keras loss function
135
        This function takes as input the predicted and ground labels, uses them
136
        to calculate the dice loss.
137
        
138
    """
139
    def loss_gt_(y_true, y_pred):
140
        intersection = K.sum(K.abs(y_true * y_pred), axis=[-3,-2,-1])
141
        dn = K.sum(K.square(y_true) + K.square(y_pred), axis=[-3,-2,-1]) + e
142
        
143
        return - K.mean(2 * intersection / dn, axis=[0,1])
144
    
145
    return loss_gt_
146
147
def loss_VAE(input_shape, z_mean, z_var, weight_L2=0.1, weight_KL=0.1):
148
    """
149
    loss_VAE(input_shape, z_mean, z_var, weight_L2=0.1, weight_KL=0.1)
150
    ------------------------------------------------------
151
    Since keras does not allow custom loss functions to have arguments
152
    other than the true and predicted labels, this function acts as a wrapper
153
    that allows us to implement the custom loss used in the paper. This function
154
    calculates the following equation, except for -L<dice> term. (i.e. VAE decoder part loss)
155
    
156
    L = - L<dice> + weight_L2 ∗ L<L2> + weight_KL ∗ L<KL>
157
    
158
    Parameters
159
    ----------
160
     `input_shape`: A 4-tuple, required
161
        The shape of an image as the tuple (c, H, W, D), where c is
162
        the no. of channels; H, W and D is the height, width and depth of the
163
        input image, respectively.
164
    `z_mean`: An keras.layers.Layer instance, required
165
        The vector representing values of mean for the learned distribution
166
        in the VAE part. Used internally.
167
    `z_var`: An keras.layers.Layer instance, required
168
        The vector representing values of variance for the learned distribution
169
        in the VAE part. Used internally.
170
    `weight_L2`: A real number, optional
171
        The weight to be given to the L2 loss term in the loss function. Adjust to get best
172
        results for your task. Defaults to 0.1.
173
    `weight_KL`: A real number, optional
174
        The weight to be given to the KL loss term in the loss function. Adjust to get best
175
        results for your task. Defaults to 0.1.
176
        
177
    Returns
178
    -------
179
    loss_VAE_(y_true, y_pred): A custom keras loss function
180
        This function takes as input the predicted and ground labels, uses them
181
        to calculate the L2 and KL loss.
182
        
183
    """
184
    def loss_VAE_(y_true, y_pred):
185
        c, H, W, D = input_shape
186
        n = c * H * W * D
187
        
188
        loss_L2 = K.mean(K.square(y_true - y_pred), axis=(1, 2, 3, 4)) # original axis value is (1,2,3,4).
189
190
        loss_KL = (1 / n) * K.sum(
191
            K.exp(z_var) + K.square(z_mean) - 1. - z_var,
192
            axis=-1
193
        )
194
195
        return weight_L2 * loss_L2 + weight_KL * loss_KL
196
197
    return loss_VAE_
198
199
def build_model(input_shape=(4, 160, 192, 128), output_channels=3, weight_L2=0.1, weight_KL=0.1, dice_e=1e-8):
200
    """
201
    build_model(input_shape=(4, 160, 192, 128), output_channels=3, weight_L2=0.1, weight_KL=0.1)
202
    -------------------------------------------
203
    Creates the model used in the BRATS2018 winning solution
204
    by Myronenko A. (https://arxiv.org/pdf/1810.11654.pdf)
205
206
    Parameters
207
    ----------
208
    `input_shape`: A 4-tuple, optional.
209
        Shape of the input image. Must be a 4D image of shape (c, H, W, D),
210
        where, each of H, W and D are divisible by 2^4, and c is divisible by 4.
211
        Defaults to the crop size used in the paper, i.e., (4, 160, 192, 128).
212
    `output_channels`: An integer, optional.
213
        The no. of channels in the output. Defaults to 3 (BraTS 2018 format).
214
    `weight_L2`: A real number, optional
215
        The weight to be given to the L2 loss term in the loss function. Adjust to get best
216
        results for your task. Defaults to 0.1.
217
    `weight_KL`: A real number, optional
218
        The weight to be given to the KL loss term in the loss function. Adjust to get best
219
        results for your task. Defaults to 0.1.
220
    `dice_e`: Float, optional
221
        A small epsilon term to add in the denominator of dice loss to avoid dividing by
222
        zero and possible gradient explosion. This argument will be passed to loss_gt function.
223
224
225
    Returns
226
    -------
227
    `model`: A keras.models.Model instance
228
        The created model.
229
    """
230
    c, H, W, D = input_shape
231
    assert len(input_shape) == 4, "Input shape must be a 4-tuple"
232
    assert (c % 4) == 0, "The no. of channels must be divisible by 4"
233
    assert (H % 16) == 0 and (W % 16) == 0 and (D % 16) == 0, \
234
        "All the input dimensions must be divisible by 16"
235
236
237
    # -------------------------------------------------------------------------
238
    # Encoder
239
    # -------------------------------------------------------------------------
240
241
    ## Input Layer
242
    inp = Input(input_shape)
243
244
    ## The Initial Block
245
    x = Conv3D(
246
        filters=32,
247
        kernel_size=(3, 3, 3),
248
        strides=1,
249
        padding='same',
250
        data_format='channels_first',
251
        name='Input_x1')(inp)
252
253
    ## Dropout (0.2)
254
    x = SpatialDropout3D(0.2, data_format='channels_first')(x)
255
256
    ## Green Block x1 (output filters = 32)
257
    x1 = green_block(x, 32, name='x1')
258
    x = Conv3D(
259
        filters=32,
260
        kernel_size=(3, 3, 3),
261
        strides=2,
262
        padding='same',
263
        data_format='channels_first',
264
        name='Enc_DownSample_32')(x1)
265
266
    ## Green Block x2 (output filters = 64)
267
    x = green_block(x, 64, name='Enc_64_1')
268
    x2 = green_block(x, 64, name='x2')
269
    x = Conv3D(
270
        filters=64,
271
        kernel_size=(3, 3, 3),
272
        strides=2,
273
        padding='same',
274
        data_format='channels_first',
275
        name='Enc_DownSample_64')(x2)
276
277
    ## Green Blocks x2 (output filters = 128)
278
    x = green_block(x, 128, name='Enc_128_1')
279
    x3 = green_block(x, 128, name='x3')
280
    x = Conv3D(
281
        filters=128,
282
        kernel_size=(3, 3, 3),
283
        strides=2,
284
        padding='same',
285
        data_format='channels_first',
286
        name='Enc_DownSample_128')(x3)
287
288
    ## Green Blocks x4 (output filters = 256)
289
    x = green_block(x, 256, name='Enc_256_1')
290
    x = green_block(x, 256, name='Enc_256_2')
291
    x = green_block(x, 256, name='Enc_256_3')
292
    x4 = green_block(x, 256, name='x4')
293
294
    # -------------------------------------------------------------------------
295
    # Decoder
296
    # -------------------------------------------------------------------------
297
298
    ## GT (Groud Truth) Part
299
    # -------------------------------------------------------------------------
300
301
    ### Green Block x1 (output filters=128)
302
    x = Conv3D(
303
        filters=128,
304
        kernel_size=(1, 1, 1),
305
        strides=1,
306
        data_format='channels_first',
307
        name='Dec_GT_ReduceDepth_128')(x4)
308
    x = UpSampling3D(
309
        size=2,
310
        data_format='channels_first',
311
        name='Dec_GT_UpSample_128')(x)
312
    x = Add(name='Input_Dec_GT_128')([x, x3])
313
    x = green_block(x, 128, name='Dec_GT_128')
314
315
    ### Green Block x1 (output filters=64)
316
    x = Conv3D(
317
        filters=64,
318
        kernel_size=(1, 1, 1),
319
        strides=1,
320
        data_format='channels_first',
321
        name='Dec_GT_ReduceDepth_64')(x)
322
    x = UpSampling3D(
323
        size=2,
324
        data_format='channels_first',
325
        name='Dec_GT_UpSample_64')(x)
326
    x = Add(name='Input_Dec_GT_64')([x, x2])
327
    x = green_block(x, 64, name='Dec_GT_64')
328
329
    ### Green Block x1 (output filters=32)
330
    x = Conv3D(
331
        filters=32,
332
        kernel_size=(1, 1, 1),
333
        strides=1,
334
        data_format='channels_first',
335
        name='Dec_GT_ReduceDepth_32')(x)
336
    x = UpSampling3D(
337
        size=2,
338
        data_format='channels_first',
339
        name='Dec_GT_UpSample_32')(x)
340
    x = Add(name='Input_Dec_GT_32')([x, x1])
341
    x = green_block(x, 32, name='Dec_GT_32')
342
343
    ### Blue Block x1 (output filters=32)
344
    x = Conv3D(
345
        filters=32,
346
        kernel_size=(3, 3, 3),
347
        strides=1,
348
        padding='same',
349
        data_format='channels_first',
350
        name='Input_Dec_GT_Output')(x)
351
352
    ### Output Block
353
    out_GT = Conv3D(
354
        filters=output_channels,  # No. of tumor classes is 3
355
        kernel_size=(1, 1, 1),
356
        strides=1,
357
        data_format='channels_first',
358
        activation='sigmoid',
359
        name='Dec_GT_Output')(x)
360
361
    ## VAE (Variational Auto Encoder) Part
362
    # -------------------------------------------------------------------------
363
364
    ### VD Block (Reducing dimensionality of the data)
365
    x = GroupNormalization(groups=8, axis=1, name='Dec_VAE_VD_GN')(x4)
366
    x = Activation('relu', name='Dec_VAE_VD_relu')(x)
367
    x = Conv3D(
368
        filters=16,
369
        kernel_size=(3, 3, 3),
370
        strides=2,
371
        padding='same',
372
        data_format='channels_first',
373
        name='Dec_VAE_VD_Conv3D')(x)
374
375
    # Not mentioned in the paper, but the author used a Flattening layer here.
376
    x = Flatten(name='Dec_VAE_VD_Flatten')(x)
377
    x = Dense(256, name='Dec_VAE_VD_Dense')(x)
378
379
    ### VDraw Block (Sampling)
380
    z_mean = Dense(128, name='Dec_VAE_VDraw_Mean')(x)
381
    z_var = Dense(128, name='Dec_VAE_VDraw_Var')(x)
382
    x = Lambda(sampling, name='Dec_VAE_VDraw_Sampling')([z_mean, z_var])
383
384
    ### VU Block (Upsizing back to a depth of 256)
385
    x = Dense((c//4) * (H//16) * (W//16) * (D//16))(x)
386
    x = Activation('relu')(x)
387
    x = Reshape(((c//4), (H//16), (W//16), (D//16)))(x)
388
    x = Conv3D(
389
        filters=256,
390
        kernel_size=(1, 1, 1),
391
        strides=1,
392
        data_format='channels_first',
393
        name='Dec_VAE_ReduceDepth_256')(x)
394
    x = UpSampling3D(
395
        size=2,
396
        data_format='channels_first',
397
        name='Dec_VAE_UpSample_256')(x)
398
399
    ### Green Block x1 (output filters=128)
400
    x = Conv3D(
401
        filters=128,
402
        kernel_size=(1, 1, 1),
403
        strides=1,
404
        data_format='channels_first',
405
        name='Dec_VAE_ReduceDepth_128')(x)
406
    x = UpSampling3D(
407
        size=2,
408
        data_format='channels_first',
409
        name='Dec_VAE_UpSample_128')(x)
410
    x = green_block(x, 128, name='Dec_VAE_128')
411
412
    ### Green Block x1 (output filters=64)
413
    x = Conv3D(
414
        filters=64,
415
        kernel_size=(1, 1, 1),
416
        strides=1,
417
        data_format='channels_first',
418
        name='Dec_VAE_ReduceDepth_64')(x)
419
    x = UpSampling3D(
420
        size=2,
421
        data_format='channels_first',
422
        name='Dec_VAE_UpSample_64')(x)
423
    x = green_block(x, 64, name='Dec_VAE_64')
424
425
    ### Green Block x1 (output filters=32)
426
    x = Conv3D(
427
        filters=32,
428
        kernel_size=(1, 1, 1),
429
        strides=1,
430
        data_format='channels_first',
431
        name='Dec_VAE_ReduceDepth_32')(x)
432
    x = UpSampling3D(
433
        size=2,
434
        data_format='channels_first',
435
        name='Dec_VAE_UpSample_32')(x)
436
    x = green_block(x, 32, name='Dec_VAE_32')
437
438
    ### Blue Block x1 (output filters=32)
439
    x = Conv3D(
440
        filters=32,
441
        kernel_size=(3, 3, 3),
442
        strides=1,
443
        padding='same',
444
        data_format='channels_first',
445
        name='Input_Dec_VAE_Output')(x)
446
447
    ### Output Block
448
    out_VAE = Conv3D(
449
        filters=4,
450
        kernel_size=(1, 1, 1),
451
        strides=1,
452
        data_format='channels_first',
453
        name='Dec_VAE_Output')(x) 
454
455
    # Build and Compile the model
456
    out = out_GT
457
    model = Model(inp, outputs=[out, out_VAE])  # Create the model
458
    model.compile(
459
        adam(lr=1e-4),
460
        [loss_gt(dice_e), loss_VAE(input_shape, z_mean, z_var, weight_L2=weight_L2, weight_KL=weight_KL)],
461
        metrics=[dice_coefficient]
462
    )
463
464
    return model