--- a +++ b/model.py @@ -0,0 +1,170 @@ +import tensorflow.keras.backend as K +from tensorflow.keras.models import Model +from tensorflow.keras import Input +from tensorflow.keras.layers import Conv2D, PReLU, UpSampling2D, concatenate , Reshape, Dense, Permute, MaxPool2D +from tensorflow.keras.layers import GlobalAveragePooling2D, Activation, add, GaussianNoise, BatchNormalization, multiply +from tensorflow.keras.optimizers import SGD +from loss import custom_loss +K.set_image_data_format("channels_last") + + + +def unet_model(input_shape, modified_unet=True, learning_rate=0.01, start_channel=64, + number_of_levels=3, inc_rate=2, output_channels=4, saved_model_dir=None): + """ + Builds UNet model + + Parameters + ---------- + input_shape : tuple + Shape of the input data (height, width, channel) + modified_unet : bool + Whether to use modified UNet or the original UNet + learning_rate : float + Learning rate for the model. The default is 0.01. + start_channel : int + Number of channels of the first conv. The default is 64. + number_of_levels : int + The depth size of the U-structure. The default is 3. + inc_rate : int + Rate at which the conv channels will increase. The default is 2. + output_channels : int + The number of output layer channels. The default is 4 + saved_model_dir : str + If spesified, the model weights will be loaded from this path. The default is None. + + Returns + ------- + model : keras.model + The created keras model with respect to the input parameters + + """ + + + input_layer = Input(shape=input_shape, name='the_input_layer') + + if modified_unet: + x = GaussianNoise(0.01, name='Gaussian_Noise')(input_layer) + x = Conv2D(64, 2, padding='same')(x) + x = level_block_modified(x, start_channel, number_of_levels, inc_rate) + x = BatchNormalization(axis = -1)(x) + x = PReLU(shared_axes=[1, 2])(x) + else: + x = level_block(input_layer, start_channel, number_of_levels, inc_rate) + + x = Conv2D(output_channels, 1, padding='same')(x) + output_layer = Activation('softmax')(x) + + model = Model(inputs = input_layer, outputs = output_layer) + + if modified_unet: + print("The modified UNet was built!") + else: + print("The original UNet was built!") + + if saved_model_dir: + model.load_weights(saved_model_dir) + print("the model weights were successfully loaded!") + + sgd = SGD(lr=learning_rate, momentum=0.9, decay=0) + model.compile(optimizer=sgd, loss=custom_loss) + + return model + + +def se_block(x, ratio=16): + + """ + creates a squeeze and excitation block + https://arxiv.org/abs/1709.01507 + + Parameters + ---------- + x : tensor + Input keras tensor + ratio : int + The reduction ratio. The default is 16. + + Returns + ------- + x : tensor + A keras tensor + """ + + + channel_axis = 1 if K.image_data_format() == "channels_first" else -1 + filters = x.shape[channel_axis] + se_shape = (1, 1, filters) + + se = GlobalAveragePooling2D()(x) + se = Reshape(se_shape)(se) + se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se) + se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se) + + if K.image_data_format() == 'channels_first': + se = Permute((3, 1, 2))(se) + + x = multiply([x, se]) + return x + + +def level_block(x, dim, level, inc): + + if level > 0: + m = conv_layers(x, dim) + x = MaxPool2D(pool_size=(2, 2))(m) + x = level_block(x,int(inc*dim), level-1, inc) + x = UpSampling2D(size=(2, 2))(x) + x = Conv2D(dim, 2, padding='same')(x) + m = concatenate([m,x]) + x = conv_layers(m, dim) + else: + x = conv_layers(x, dim) + return x + + +def level_block_modified(x, dim, level, inc): + + if level > 0: + m = res_block(x, dim, encoder_path=True)########## + x = Conv2D(int(inc*dim), 2, strides=2, padding='same')(m) + x = level_block_modified(x, int(inc*dim), level-1, inc) + + x = UpSampling2D(size=(2, 2))(x) + x = Conv2D(dim, 2, padding='same')(x) + + m = concatenate([m,x]) + m = se_block(m, 8) + x = res_block(m, dim, encoder_path=False) + else: + x = res_block(x, dim, encoder_path=True) ############# + return x + + +def conv_layers(x, dim): + + x = Conv2D(dim, 3, padding='same')(x) + x = Activation("relu")(x) + + x = Conv2D(dim, 3, padding='same')(x) + x = Activation("relu")(x) + + return x + +def res_block(x, dim, encoder_path=True): + + m = BatchNormalization(axis = -1)(x) + m = PReLU(shared_axes = [1, 2])(m) + m = Conv2D(dim, 3, padding='same')(m) + + m = BatchNormalization(axis = -1)(m) + m = PReLU(shared_axes = [1, 2])(m) + m = Conv2D(dim, 3, padding='same')(m) + + if encoder_path: + x = add([x, m]) + else: + x = Conv2D(dim, 1, padding='same', use_bias=False)(x) + x = add([x,m]) + return x +