--- a +++ b/model.py @@ -0,0 +1,464 @@ +# Keras implementation of the paper: +# 3D MRI Brain Tumor Segmentation Using Autoencoder Regularization +# by Myronenko A. (https://arxiv.org/pdf/1810.11654.pdf) +# Author of this code: Suyog Jadhav (https://github.com/IAmSUyogJadhav) + +import keras.backend as K +from keras.losses import mse +from keras.layers import Conv3D, Activation, Add, UpSampling3D, Lambda, Dense +from keras.layers import Input, Reshape, Flatten, Dropout, SpatialDropout3D +from keras.optimizers import adam +from keras.models import Model +try: + from group_norm import GroupNormalization +except ImportError: + import urllib.request + print('Downloading group_norm.py in the current directory...') + url = 'https://raw.githubusercontent.com/titu1994/Keras-Group-Normalization/master/group_norm.py' + urllib.request.urlretrieve(url, "group_norm.py") + from group_norm import GroupNormalization + + +def green_block(inp, filters, data_format='channels_first', name=None): + """ + green_block(inp, filters, name=None) + ------------------------------------ + Implementation of the special residual block used in the paper. The block + consists of two (GroupNorm --> ReLu --> 3x3x3 non-strided Convolution) + units, with a residual connection from the input `inp` to the output. Used + internally in the model. Can be used independently as well. + + Parameters + ---------- + `inp`: An keras.layers.layer instance, required + The keras layer just preceding the green block. + `filters`: integer, required + No. of filters to use in the 3D convolutional block. The output + layer of this green block will have this many no. of channels. + `data_format`: string, optional + The format of the input data. Must be either 'chanels_first' or + 'channels_last'. Defaults to `channels_first`, as used in the paper. + `name`: string, optional + The name to be given to this green block. Defaults to None, in which + case, keras uses generated names for the involved layers. If a string + is provided, the names of individual layers are generated by attaching + a relevant prefix from [GroupNorm_, Res_, Conv3D_, Relu_, ], followed + by _1 or _2. + + Returns + ------- + `out`: A keras.layers.Layer instance + The output of the green block. Has no. of channels equal to `filters`. + The size of the rest of the dimensions remains same as in `inp`. + """ + inp_res = Conv3D( + filters=filters, + kernel_size=(1, 1, 1), + strides=1, + data_format=data_format, + name=f'Res_{name}' if name else None)(inp) + + # axis=1 for channels_first data format + # No. of groups = 8, as given in the paper + x = GroupNormalization( + groups=8, + axis=1 if data_format == 'channels_first' else 0, + name=f'GroupNorm_1_{name}' if name else None)(inp) + x = Activation('relu', name=f'Relu_1_{name}' if name else None)(x) + x = Conv3D( + filters=filters, + kernel_size=(3, 3, 3), + strides=1, + padding='same', + data_format=data_format, + name=f'Conv3D_1_{name}' if name else None)(x) + + x = GroupNormalization( + groups=8, + axis=1 if data_format == 'channels_first' else 0, + name=f'GroupNorm_2_{name}' if name else None)(x) + x = Activation('relu', name=f'Relu_2_{name}' if name else None)(x) + x = Conv3D( + filters=filters, + kernel_size=(3, 3, 3), + strides=1, + padding='same', + data_format=data_format, + name=f'Conv3D_2_{name}' if name else None)(x) + + out = Add(name=f'Out_{name}' if name else None)([x, inp_res]) + return out + + +# From keras-team/keras/blob/master/examples/variational_autoencoder.py +def sampling(args): + """Reparameterization trick by sampling from an isotropic unit Gaussian. + # Arguments + args (tensor): mean and log of variance of Q(z|X) + # Returns + z (tensor): sampled latent vector + """ + z_mean, z_var = args + batch = K.shape(z_mean)[0] + dim = K.int_shape(z_mean)[1] + # by default, random_normal has mean = 0 and std = 1.0 + epsilon = K.random_normal(shape=(batch, dim)) + return z_mean + K.exp(0.5 * z_var) * epsilon + + +def dice_coefficient(y_true, y_pred): + intersection = K.sum(K.abs(y_true * y_pred), axis=[-3,-2,-1]) + dn = K.sum(K.square(y_true) + K.square(y_pred), axis=[-3,-2,-1]) + 1e-8 + return K.mean(2 * intersection / dn, axis=[0,1]) + + +def loss_gt(e=1e-8): + """ + loss_gt(e=1e-8) + ------------------------------------------------------ + Since keras does not allow custom loss functions to have arguments + other than the true and predicted labels, this function acts as a wrapper + that allows us to implement the custom loss used in the paper. This function + only calculates - L<dice> term of the following equation. (i.e. GT Decoder part loss) + + L = - L<dice> + weight_L2 ∗ L<L2> + weight_KL ∗ L<KL> + + Parameters + ---------- + `e`: Float, optional + A small epsilon term to add in the denominator to avoid dividing by + zero and possible gradient explosion. + + Returns + ------- + loss_gt_(y_true, y_pred): A custom keras loss function + This function takes as input the predicted and ground labels, uses them + to calculate the dice loss. + + """ + def loss_gt_(y_true, y_pred): + intersection = K.sum(K.abs(y_true * y_pred), axis=[-3,-2,-1]) + dn = K.sum(K.square(y_true) + K.square(y_pred), axis=[-3,-2,-1]) + e + + return - K.mean(2 * intersection / dn, axis=[0,1]) + + return loss_gt_ + +def loss_VAE(input_shape, z_mean, z_var, weight_L2=0.1, weight_KL=0.1): + """ + loss_VAE(input_shape, z_mean, z_var, weight_L2=0.1, weight_KL=0.1) + ------------------------------------------------------ + Since keras does not allow custom loss functions to have arguments + other than the true and predicted labels, this function acts as a wrapper + that allows us to implement the custom loss used in the paper. This function + calculates the following equation, except for -L<dice> term. (i.e. VAE decoder part loss) + + L = - L<dice> + weight_L2 ∗ L<L2> + weight_KL ∗ L<KL> + + Parameters + ---------- + `input_shape`: A 4-tuple, required + The shape of an image as the tuple (c, H, W, D), where c is + the no. of channels; H, W and D is the height, width and depth of the + input image, respectively. + `z_mean`: An keras.layers.Layer instance, required + The vector representing values of mean for the learned distribution + in the VAE part. Used internally. + `z_var`: An keras.layers.Layer instance, required + The vector representing values of variance for the learned distribution + in the VAE part. Used internally. + `weight_L2`: A real number, optional + The weight to be given to the L2 loss term in the loss function. Adjust to get best + results for your task. Defaults to 0.1. + `weight_KL`: A real number, optional + The weight to be given to the KL loss term in the loss function. Adjust to get best + results for your task. Defaults to 0.1. + + Returns + ------- + loss_VAE_(y_true, y_pred): A custom keras loss function + This function takes as input the predicted and ground labels, uses them + to calculate the L2 and KL loss. + + """ + def loss_VAE_(y_true, y_pred): + c, H, W, D = input_shape + n = c * H * W * D + + loss_L2 = K.mean(K.square(y_true - y_pred), axis=(1, 2, 3, 4)) # original axis value is (1,2,3,4). + + loss_KL = (1 / n) * K.sum( + K.exp(z_var) + K.square(z_mean) - 1. - z_var, + axis=-1 + ) + + return weight_L2 * loss_L2 + weight_KL * loss_KL + + return loss_VAE_ + +def build_model(input_shape=(4, 160, 192, 128), output_channels=3, weight_L2=0.1, weight_KL=0.1, dice_e=1e-8): + """ + build_model(input_shape=(4, 160, 192, 128), output_channels=3, weight_L2=0.1, weight_KL=0.1) + ------------------------------------------- + Creates the model used in the BRATS2018 winning solution + by Myronenko A. (https://arxiv.org/pdf/1810.11654.pdf) + + Parameters + ---------- + `input_shape`: A 4-tuple, optional. + Shape of the input image. Must be a 4D image of shape (c, H, W, D), + where, each of H, W and D are divisible by 2^4, and c is divisible by 4. + Defaults to the crop size used in the paper, i.e., (4, 160, 192, 128). + `output_channels`: An integer, optional. + The no. of channels in the output. Defaults to 3 (BraTS 2018 format). + `weight_L2`: A real number, optional + The weight to be given to the L2 loss term in the loss function. Adjust to get best + results for your task. Defaults to 0.1. + `weight_KL`: A real number, optional + The weight to be given to the KL loss term in the loss function. Adjust to get best + results for your task. Defaults to 0.1. + `dice_e`: Float, optional + A small epsilon term to add in the denominator of dice loss to avoid dividing by + zero and possible gradient explosion. This argument will be passed to loss_gt function. + + + Returns + ------- + `model`: A keras.models.Model instance + The created model. + """ + c, H, W, D = input_shape + assert len(input_shape) == 4, "Input shape must be a 4-tuple" + assert (c % 4) == 0, "The no. of channels must be divisible by 4" + assert (H % 16) == 0 and (W % 16) == 0 and (D % 16) == 0, \ + "All the input dimensions must be divisible by 16" + + + # ------------------------------------------------------------------------- + # Encoder + # ------------------------------------------------------------------------- + + ## Input Layer + inp = Input(input_shape) + + ## The Initial Block + x = Conv3D( + filters=32, + kernel_size=(3, 3, 3), + strides=1, + padding='same', + data_format='channels_first', + name='Input_x1')(inp) + + ## Dropout (0.2) + x = SpatialDropout3D(0.2, data_format='channels_first')(x) + + ## Green Block x1 (output filters = 32) + x1 = green_block(x, 32, name='x1') + x = Conv3D( + filters=32, + kernel_size=(3, 3, 3), + strides=2, + padding='same', + data_format='channels_first', + name='Enc_DownSample_32')(x1) + + ## Green Block x2 (output filters = 64) + x = green_block(x, 64, name='Enc_64_1') + x2 = green_block(x, 64, name='x2') + x = Conv3D( + filters=64, + kernel_size=(3, 3, 3), + strides=2, + padding='same', + data_format='channels_first', + name='Enc_DownSample_64')(x2) + + ## Green Blocks x2 (output filters = 128) + x = green_block(x, 128, name='Enc_128_1') + x3 = green_block(x, 128, name='x3') + x = Conv3D( + filters=128, + kernel_size=(3, 3, 3), + strides=2, + padding='same', + data_format='channels_first', + name='Enc_DownSample_128')(x3) + + ## Green Blocks x4 (output filters = 256) + x = green_block(x, 256, name='Enc_256_1') + x = green_block(x, 256, name='Enc_256_2') + x = green_block(x, 256, name='Enc_256_3') + x4 = green_block(x, 256, name='x4') + + # ------------------------------------------------------------------------- + # Decoder + # ------------------------------------------------------------------------- + + ## GT (Groud Truth) Part + # ------------------------------------------------------------------------- + + ### Green Block x1 (output filters=128) + x = Conv3D( + filters=128, + kernel_size=(1, 1, 1), + strides=1, + data_format='channels_first', + name='Dec_GT_ReduceDepth_128')(x4) + x = UpSampling3D( + size=2, + data_format='channels_first', + name='Dec_GT_UpSample_128')(x) + x = Add(name='Input_Dec_GT_128')([x, x3]) + x = green_block(x, 128, name='Dec_GT_128') + + ### Green Block x1 (output filters=64) + x = Conv3D( + filters=64, + kernel_size=(1, 1, 1), + strides=1, + data_format='channels_first', + name='Dec_GT_ReduceDepth_64')(x) + x = UpSampling3D( + size=2, + data_format='channels_first', + name='Dec_GT_UpSample_64')(x) + x = Add(name='Input_Dec_GT_64')([x, x2]) + x = green_block(x, 64, name='Dec_GT_64') + + ### Green Block x1 (output filters=32) + x = Conv3D( + filters=32, + kernel_size=(1, 1, 1), + strides=1, + data_format='channels_first', + name='Dec_GT_ReduceDepth_32')(x) + x = UpSampling3D( + size=2, + data_format='channels_first', + name='Dec_GT_UpSample_32')(x) + x = Add(name='Input_Dec_GT_32')([x, x1]) + x = green_block(x, 32, name='Dec_GT_32') + + ### Blue Block x1 (output filters=32) + x = Conv3D( + filters=32, + kernel_size=(3, 3, 3), + strides=1, + padding='same', + data_format='channels_first', + name='Input_Dec_GT_Output')(x) + + ### Output Block + out_GT = Conv3D( + filters=output_channels, # No. of tumor classes is 3 + kernel_size=(1, 1, 1), + strides=1, + data_format='channels_first', + activation='sigmoid', + name='Dec_GT_Output')(x) + + ## VAE (Variational Auto Encoder) Part + # ------------------------------------------------------------------------- + + ### VD Block (Reducing dimensionality of the data) + x = GroupNormalization(groups=8, axis=1, name='Dec_VAE_VD_GN')(x4) + x = Activation('relu', name='Dec_VAE_VD_relu')(x) + x = Conv3D( + filters=16, + kernel_size=(3, 3, 3), + strides=2, + padding='same', + data_format='channels_first', + name='Dec_VAE_VD_Conv3D')(x) + + # Not mentioned in the paper, but the author used a Flattening layer here. + x = Flatten(name='Dec_VAE_VD_Flatten')(x) + x = Dense(256, name='Dec_VAE_VD_Dense')(x) + + ### VDraw Block (Sampling) + z_mean = Dense(128, name='Dec_VAE_VDraw_Mean')(x) + z_var = Dense(128, name='Dec_VAE_VDraw_Var')(x) + x = Lambda(sampling, name='Dec_VAE_VDraw_Sampling')([z_mean, z_var]) + + ### VU Block (Upsizing back to a depth of 256) + x = Dense((c//4) * (H//16) * (W//16) * (D//16))(x) + x = Activation('relu')(x) + x = Reshape(((c//4), (H//16), (W//16), (D//16)))(x) + x = Conv3D( + filters=256, + kernel_size=(1, 1, 1), + strides=1, + data_format='channels_first', + name='Dec_VAE_ReduceDepth_256')(x) + x = UpSampling3D( + size=2, + data_format='channels_first', + name='Dec_VAE_UpSample_256')(x) + + ### Green Block x1 (output filters=128) + x = Conv3D( + filters=128, + kernel_size=(1, 1, 1), + strides=1, + data_format='channels_first', + name='Dec_VAE_ReduceDepth_128')(x) + x = UpSampling3D( + size=2, + data_format='channels_first', + name='Dec_VAE_UpSample_128')(x) + x = green_block(x, 128, name='Dec_VAE_128') + + ### Green Block x1 (output filters=64) + x = Conv3D( + filters=64, + kernel_size=(1, 1, 1), + strides=1, + data_format='channels_first', + name='Dec_VAE_ReduceDepth_64')(x) + x = UpSampling3D( + size=2, + data_format='channels_first', + name='Dec_VAE_UpSample_64')(x) + x = green_block(x, 64, name='Dec_VAE_64') + + ### Green Block x1 (output filters=32) + x = Conv3D( + filters=32, + kernel_size=(1, 1, 1), + strides=1, + data_format='channels_first', + name='Dec_VAE_ReduceDepth_32')(x) + x = UpSampling3D( + size=2, + data_format='channels_first', + name='Dec_VAE_UpSample_32')(x) + x = green_block(x, 32, name='Dec_VAE_32') + + ### Blue Block x1 (output filters=32) + x = Conv3D( + filters=32, + kernel_size=(3, 3, 3), + strides=1, + padding='same', + data_format='channels_first', + name='Input_Dec_VAE_Output')(x) + + ### Output Block + out_VAE = Conv3D( + filters=4, + kernel_size=(1, 1, 1), + strides=1, + data_format='channels_first', + name='Dec_VAE_Output')(x) + + # Build and Compile the model + out = out_GT + model = Model(inp, outputs=[out, out_VAE]) # Create the model + model.compile( + adam(lr=1e-4), + [loss_gt(dice_e), loss_VAE(input_shape, z_mean, z_var, weight_L2=weight_L2, weight_KL=weight_KL)], + metrics=[dice_coefficient] + ) + + return model