Diff of /model.py [000000] .. [fcace9]

Switch to side-by-side view

--- a
+++ b/model.py
@@ -0,0 +1,464 @@
+# Keras implementation of the paper:
+# 3D MRI Brain Tumor Segmentation Using Autoencoder Regularization
+# by Myronenko A. (https://arxiv.org/pdf/1810.11654.pdf)
+# Author of this code: Suyog Jadhav (https://github.com/IAmSUyogJadhav)
+
+import keras.backend as K
+from keras.losses import mse
+from keras.layers import Conv3D, Activation, Add, UpSampling3D, Lambda, Dense
+from keras.layers import Input, Reshape, Flatten, Dropout, SpatialDropout3D
+from keras.optimizers import adam
+from keras.models import Model
+try:
+    from group_norm import GroupNormalization
+except ImportError:
+    import urllib.request
+    print('Downloading group_norm.py in the current directory...')
+    url = 'https://raw.githubusercontent.com/titu1994/Keras-Group-Normalization/master/group_norm.py'
+    urllib.request.urlretrieve(url, "group_norm.py")
+    from group_norm import GroupNormalization
+
+
+def green_block(inp, filters, data_format='channels_first', name=None):
+    """
+    green_block(inp, filters, name=None)
+    ------------------------------------
+    Implementation of the special residual block used in the paper. The block
+    consists of two (GroupNorm --> ReLu --> 3x3x3 non-strided Convolution)
+    units, with a residual connection from the input `inp` to the output. Used
+    internally in the model. Can be used independently as well.
+
+    Parameters
+    ----------
+    `inp`: An keras.layers.layer instance, required
+        The keras layer just preceding the green block.
+    `filters`: integer, required
+        No. of filters to use in the 3D convolutional block. The output
+        layer of this green block will have this many no. of channels.
+    `data_format`: string, optional
+        The format of the input data. Must be either 'chanels_first' or
+        'channels_last'. Defaults to `channels_first`, as used in the paper.
+    `name`: string, optional
+        The name to be given to this green block. Defaults to None, in which
+        case, keras uses generated names for the involved layers. If a string
+        is provided, the names of individual layers are generated by attaching
+        a relevant prefix from [GroupNorm_, Res_, Conv3D_, Relu_, ], followed
+        by _1 or _2.
+
+    Returns
+    -------
+    `out`: A keras.layers.Layer instance
+        The output of the green block. Has no. of channels equal to `filters`.
+        The size of the rest of the dimensions remains same as in `inp`.
+    """
+    inp_res = Conv3D(
+        filters=filters,
+        kernel_size=(1, 1, 1),
+        strides=1,
+        data_format=data_format,
+        name=f'Res_{name}' if name else None)(inp)
+
+    # axis=1 for channels_first data format
+    # No. of groups = 8, as given in the paper
+    x = GroupNormalization(
+        groups=8,
+        axis=1 if data_format == 'channels_first' else 0,
+        name=f'GroupNorm_1_{name}' if name else None)(inp)
+    x = Activation('relu', name=f'Relu_1_{name}' if name else None)(x)
+    x = Conv3D(
+        filters=filters,
+        kernel_size=(3, 3, 3),
+        strides=1,
+        padding='same',
+        data_format=data_format,
+        name=f'Conv3D_1_{name}' if name else None)(x)
+
+    x = GroupNormalization(
+        groups=8,
+        axis=1 if data_format == 'channels_first' else 0,
+        name=f'GroupNorm_2_{name}' if name else None)(x)
+    x = Activation('relu', name=f'Relu_2_{name}' if name else None)(x)
+    x = Conv3D(
+        filters=filters,
+        kernel_size=(3, 3, 3),
+        strides=1,
+        padding='same',
+        data_format=data_format,
+        name=f'Conv3D_2_{name}' if name else None)(x)
+
+    out = Add(name=f'Out_{name}' if name else None)([x, inp_res])
+    return out
+
+
+# From keras-team/keras/blob/master/examples/variational_autoencoder.py
+def sampling(args):
+    """Reparameterization trick by sampling from an isotropic unit Gaussian.
+    # Arguments
+        args (tensor): mean and log of variance of Q(z|X)
+    # Returns
+        z (tensor): sampled latent vector
+    """
+    z_mean, z_var = args
+    batch = K.shape(z_mean)[0]
+    dim = K.int_shape(z_mean)[1]
+    # by default, random_normal has mean = 0 and std = 1.0
+    epsilon = K.random_normal(shape=(batch, dim))
+    return z_mean + K.exp(0.5 * z_var) * epsilon
+
+
+def dice_coefficient(y_true, y_pred):
+    intersection = K.sum(K.abs(y_true * y_pred), axis=[-3,-2,-1])
+    dn = K.sum(K.square(y_true) + K.square(y_pred), axis=[-3,-2,-1]) + 1e-8
+    return K.mean(2 * intersection / dn, axis=[0,1])
+
+
+def loss_gt(e=1e-8):
+    """
+    loss_gt(e=1e-8)
+    ------------------------------------------------------
+    Since keras does not allow custom loss functions to have arguments
+    other than the true and predicted labels, this function acts as a wrapper
+    that allows us to implement the custom loss used in the paper. This function
+    only calculates - L<dice> term of the following equation. (i.e. GT Decoder part loss)
+    
+    L = - L<dice> + weight_L2 ∗ L<L2> + weight_KL ∗ L<KL>
+    
+    Parameters
+    ----------
+    `e`: Float, optional
+        A small epsilon term to add in the denominator to avoid dividing by
+        zero and possible gradient explosion.
+        
+    Returns
+    -------
+    loss_gt_(y_true, y_pred): A custom keras loss function
+        This function takes as input the predicted and ground labels, uses them
+        to calculate the dice loss.
+        
+    """
+    def loss_gt_(y_true, y_pred):
+        intersection = K.sum(K.abs(y_true * y_pred), axis=[-3,-2,-1])
+        dn = K.sum(K.square(y_true) + K.square(y_pred), axis=[-3,-2,-1]) + e
+        
+        return - K.mean(2 * intersection / dn, axis=[0,1])
+    
+    return loss_gt_
+
+def loss_VAE(input_shape, z_mean, z_var, weight_L2=0.1, weight_KL=0.1):
+    """
+    loss_VAE(input_shape, z_mean, z_var, weight_L2=0.1, weight_KL=0.1)
+    ------------------------------------------------------
+    Since keras does not allow custom loss functions to have arguments
+    other than the true and predicted labels, this function acts as a wrapper
+    that allows us to implement the custom loss used in the paper. This function
+    calculates the following equation, except for -L<dice> term. (i.e. VAE decoder part loss)
+    
+    L = - L<dice> + weight_L2 ∗ L<L2> + weight_KL ∗ L<KL>
+    
+    Parameters
+    ----------
+     `input_shape`: A 4-tuple, required
+        The shape of an image as the tuple (c, H, W, D), where c is
+        the no. of channels; H, W and D is the height, width and depth of the
+        input image, respectively.
+    `z_mean`: An keras.layers.Layer instance, required
+        The vector representing values of mean for the learned distribution
+        in the VAE part. Used internally.
+    `z_var`: An keras.layers.Layer instance, required
+        The vector representing values of variance for the learned distribution
+        in the VAE part. Used internally.
+    `weight_L2`: A real number, optional
+        The weight to be given to the L2 loss term in the loss function. Adjust to get best
+        results for your task. Defaults to 0.1.
+    `weight_KL`: A real number, optional
+        The weight to be given to the KL loss term in the loss function. Adjust to get best
+        results for your task. Defaults to 0.1.
+        
+    Returns
+    -------
+    loss_VAE_(y_true, y_pred): A custom keras loss function
+        This function takes as input the predicted and ground labels, uses them
+        to calculate the L2 and KL loss.
+        
+    """
+    def loss_VAE_(y_true, y_pred):
+        c, H, W, D = input_shape
+        n = c * H * W * D
+        
+        loss_L2 = K.mean(K.square(y_true - y_pred), axis=(1, 2, 3, 4)) # original axis value is (1,2,3,4).
+
+        loss_KL = (1 / n) * K.sum(
+            K.exp(z_var) + K.square(z_mean) - 1. - z_var,
+            axis=-1
+        )
+
+        return weight_L2 * loss_L2 + weight_KL * loss_KL
+
+    return loss_VAE_
+
+def build_model(input_shape=(4, 160, 192, 128), output_channels=3, weight_L2=0.1, weight_KL=0.1, dice_e=1e-8):
+    """
+    build_model(input_shape=(4, 160, 192, 128), output_channels=3, weight_L2=0.1, weight_KL=0.1)
+    -------------------------------------------
+    Creates the model used in the BRATS2018 winning solution
+    by Myronenko A. (https://arxiv.org/pdf/1810.11654.pdf)
+
+    Parameters
+    ----------
+    `input_shape`: A 4-tuple, optional.
+        Shape of the input image. Must be a 4D image of shape (c, H, W, D),
+        where, each of H, W and D are divisible by 2^4, and c is divisible by 4.
+        Defaults to the crop size used in the paper, i.e., (4, 160, 192, 128).
+    `output_channels`: An integer, optional.
+        The no. of channels in the output. Defaults to 3 (BraTS 2018 format).
+    `weight_L2`: A real number, optional
+        The weight to be given to the L2 loss term in the loss function. Adjust to get best
+        results for your task. Defaults to 0.1.
+    `weight_KL`: A real number, optional
+        The weight to be given to the KL loss term in the loss function. Adjust to get best
+        results for your task. Defaults to 0.1.
+    `dice_e`: Float, optional
+        A small epsilon term to add in the denominator of dice loss to avoid dividing by
+        zero and possible gradient explosion. This argument will be passed to loss_gt function.
+
+
+    Returns
+    -------
+    `model`: A keras.models.Model instance
+        The created model.
+    """
+    c, H, W, D = input_shape
+    assert len(input_shape) == 4, "Input shape must be a 4-tuple"
+    assert (c % 4) == 0, "The no. of channels must be divisible by 4"
+    assert (H % 16) == 0 and (W % 16) == 0 and (D % 16) == 0, \
+        "All the input dimensions must be divisible by 16"
+
+
+    # -------------------------------------------------------------------------
+    # Encoder
+    # -------------------------------------------------------------------------
+
+    ## Input Layer
+    inp = Input(input_shape)
+
+    ## The Initial Block
+    x = Conv3D(
+        filters=32,
+        kernel_size=(3, 3, 3),
+        strides=1,
+        padding='same',
+        data_format='channels_first',
+        name='Input_x1')(inp)
+
+    ## Dropout (0.2)
+    x = SpatialDropout3D(0.2, data_format='channels_first')(x)
+
+    ## Green Block x1 (output filters = 32)
+    x1 = green_block(x, 32, name='x1')
+    x = Conv3D(
+        filters=32,
+        kernel_size=(3, 3, 3),
+        strides=2,
+        padding='same',
+        data_format='channels_first',
+        name='Enc_DownSample_32')(x1)
+
+    ## Green Block x2 (output filters = 64)
+    x = green_block(x, 64, name='Enc_64_1')
+    x2 = green_block(x, 64, name='x2')
+    x = Conv3D(
+        filters=64,
+        kernel_size=(3, 3, 3),
+        strides=2,
+        padding='same',
+        data_format='channels_first',
+        name='Enc_DownSample_64')(x2)
+
+    ## Green Blocks x2 (output filters = 128)
+    x = green_block(x, 128, name='Enc_128_1')
+    x3 = green_block(x, 128, name='x3')
+    x = Conv3D(
+        filters=128,
+        kernel_size=(3, 3, 3),
+        strides=2,
+        padding='same',
+        data_format='channels_first',
+        name='Enc_DownSample_128')(x3)
+
+    ## Green Blocks x4 (output filters = 256)
+    x = green_block(x, 256, name='Enc_256_1')
+    x = green_block(x, 256, name='Enc_256_2')
+    x = green_block(x, 256, name='Enc_256_3')
+    x4 = green_block(x, 256, name='x4')
+
+    # -------------------------------------------------------------------------
+    # Decoder
+    # -------------------------------------------------------------------------
+
+    ## GT (Groud Truth) Part
+    # -------------------------------------------------------------------------
+
+    ### Green Block x1 (output filters=128)
+    x = Conv3D(
+        filters=128,
+        kernel_size=(1, 1, 1),
+        strides=1,
+        data_format='channels_first',
+        name='Dec_GT_ReduceDepth_128')(x4)
+    x = UpSampling3D(
+        size=2,
+        data_format='channels_first',
+        name='Dec_GT_UpSample_128')(x)
+    x = Add(name='Input_Dec_GT_128')([x, x3])
+    x = green_block(x, 128, name='Dec_GT_128')
+
+    ### Green Block x1 (output filters=64)
+    x = Conv3D(
+        filters=64,
+        kernel_size=(1, 1, 1),
+        strides=1,
+        data_format='channels_first',
+        name='Dec_GT_ReduceDepth_64')(x)
+    x = UpSampling3D(
+        size=2,
+        data_format='channels_first',
+        name='Dec_GT_UpSample_64')(x)
+    x = Add(name='Input_Dec_GT_64')([x, x2])
+    x = green_block(x, 64, name='Dec_GT_64')
+
+    ### Green Block x1 (output filters=32)
+    x = Conv3D(
+        filters=32,
+        kernel_size=(1, 1, 1),
+        strides=1,
+        data_format='channels_first',
+        name='Dec_GT_ReduceDepth_32')(x)
+    x = UpSampling3D(
+        size=2,
+        data_format='channels_first',
+        name='Dec_GT_UpSample_32')(x)
+    x = Add(name='Input_Dec_GT_32')([x, x1])
+    x = green_block(x, 32, name='Dec_GT_32')
+
+    ### Blue Block x1 (output filters=32)
+    x = Conv3D(
+        filters=32,
+        kernel_size=(3, 3, 3),
+        strides=1,
+        padding='same',
+        data_format='channels_first',
+        name='Input_Dec_GT_Output')(x)
+
+    ### Output Block
+    out_GT = Conv3D(
+        filters=output_channels,  # No. of tumor classes is 3
+        kernel_size=(1, 1, 1),
+        strides=1,
+        data_format='channels_first',
+        activation='sigmoid',
+        name='Dec_GT_Output')(x)
+
+    ## VAE (Variational Auto Encoder) Part
+    # -------------------------------------------------------------------------
+
+    ### VD Block (Reducing dimensionality of the data)
+    x = GroupNormalization(groups=8, axis=1, name='Dec_VAE_VD_GN')(x4)
+    x = Activation('relu', name='Dec_VAE_VD_relu')(x)
+    x = Conv3D(
+        filters=16,
+        kernel_size=(3, 3, 3),
+        strides=2,
+        padding='same',
+        data_format='channels_first',
+        name='Dec_VAE_VD_Conv3D')(x)
+
+    # Not mentioned in the paper, but the author used a Flattening layer here.
+    x = Flatten(name='Dec_VAE_VD_Flatten')(x)
+    x = Dense(256, name='Dec_VAE_VD_Dense')(x)
+
+    ### VDraw Block (Sampling)
+    z_mean = Dense(128, name='Dec_VAE_VDraw_Mean')(x)
+    z_var = Dense(128, name='Dec_VAE_VDraw_Var')(x)
+    x = Lambda(sampling, name='Dec_VAE_VDraw_Sampling')([z_mean, z_var])
+
+    ### VU Block (Upsizing back to a depth of 256)
+    x = Dense((c//4) * (H//16) * (W//16) * (D//16))(x)
+    x = Activation('relu')(x)
+    x = Reshape(((c//4), (H//16), (W//16), (D//16)))(x)
+    x = Conv3D(
+        filters=256,
+        kernel_size=(1, 1, 1),
+        strides=1,
+        data_format='channels_first',
+        name='Dec_VAE_ReduceDepth_256')(x)
+    x = UpSampling3D(
+        size=2,
+        data_format='channels_first',
+        name='Dec_VAE_UpSample_256')(x)
+
+    ### Green Block x1 (output filters=128)
+    x = Conv3D(
+        filters=128,
+        kernel_size=(1, 1, 1),
+        strides=1,
+        data_format='channels_first',
+        name='Dec_VAE_ReduceDepth_128')(x)
+    x = UpSampling3D(
+        size=2,
+        data_format='channels_first',
+        name='Dec_VAE_UpSample_128')(x)
+    x = green_block(x, 128, name='Dec_VAE_128')
+
+    ### Green Block x1 (output filters=64)
+    x = Conv3D(
+        filters=64,
+        kernel_size=(1, 1, 1),
+        strides=1,
+        data_format='channels_first',
+        name='Dec_VAE_ReduceDepth_64')(x)
+    x = UpSampling3D(
+        size=2,
+        data_format='channels_first',
+        name='Dec_VAE_UpSample_64')(x)
+    x = green_block(x, 64, name='Dec_VAE_64')
+
+    ### Green Block x1 (output filters=32)
+    x = Conv3D(
+        filters=32,
+        kernel_size=(1, 1, 1),
+        strides=1,
+        data_format='channels_first',
+        name='Dec_VAE_ReduceDepth_32')(x)
+    x = UpSampling3D(
+        size=2,
+        data_format='channels_first',
+        name='Dec_VAE_UpSample_32')(x)
+    x = green_block(x, 32, name='Dec_VAE_32')
+
+    ### Blue Block x1 (output filters=32)
+    x = Conv3D(
+        filters=32,
+        kernel_size=(3, 3, 3),
+        strides=1,
+        padding='same',
+        data_format='channels_first',
+        name='Input_Dec_VAE_Output')(x)
+
+    ### Output Block
+    out_VAE = Conv3D(
+        filters=4,
+        kernel_size=(1, 1, 1),
+        strides=1,
+        data_format='channels_first',
+        name='Dec_VAE_Output')(x) 
+
+    # Build and Compile the model
+    out = out_GT
+    model = Model(inp, outputs=[out, out_VAE])  # Create the model
+    model.compile(
+        adam(lr=1e-4),
+        [loss_gt(dice_e), loss_VAE(input_shape, z_mean, z_var, weight_L2=weight_L2, weight_KL=weight_KL)],
+        metrics=[dice_coefficient]
+    )
+
+    return model