--- a +++ b/bc-count/model.py @@ -0,0 +1,294 @@ +############################################## +# # +# DO-U-Net # +# and # +# DO-SegNet # +# # +# Author: Amine Neggazi # +# Email: neggazimedlamine@gmail/com # +# Nick: nemo256 # +# # +# Please read bc-count/LICENSE # +# # +############################################## + +import tensorflow as tf +import tensorflow_addons as tfa + +# custom imports +from config import * + + +def conv_bn(filters, + model, + model_type, + kernel=(3, 3), + activation='relu', + strides=(1, 1), + padding='valid', + type='normal'): + ''' + This is a custom convolution function: + :param filters --> number of filters for each convolution + :param kernel --> the kernel size + :param activation --> the general activation function (relu) + :param strides --> number of strides + :param padding --> model padding (can be valid or same) + :param type --> to indicate if it is a transpose or normal convolution + + :return --> returns the output after the convolution and batch normalization and activation. + ''' + if model_type == 'segnet': + kernel=3 + activation='relu' + strides=(1, 1) + padding='same' + type='normal' + + if type == 'transpose': + kernel = (2, 2) + strides = 2 + conv = tf.keras.layers.Conv2DTranspose(filters, kernel, strides, padding)(model) + else: + conv = tf.keras.layers.Conv2D(filters, kernel, strides, padding)(model) + + conv = tf.keras.layers.BatchNormalization()(conv) + conv = tf.keras.layers.Activation(activation)(conv) + + return conv + + +def max_pool(input): + ''' + This is a general max pool function with custom parameters. + ''' + return tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(input) + + +def concatenate(input1, input2, crop): + ''' + This is a general concatenation function with custom parameters. + ''' + return tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(crop)(input1), input2]) + + +def get_callbacks(name): + ''' + This is a custom function to save only the best checkpoint. + :param name --> the input model name + ''' + return [ + tf.keras.callbacks.ModelCheckpoint(f'models/{name}.h5', + save_best_only=True, + save_weights_only=True, + verbose=1) + ] + + +# loss functions +@tf.function +def dsc(y_true, y_pred): + smooth = 1.0 + y_true_f = tf.reshape(y_true, [-1]) + y_pred_f = tf.reshape(y_pred, [-1]) + intersection = tf.reduce_sum(y_true_f * y_pred_f) + return (2.0 * intersection + smooth) / (tf.reduce_sum(y_true_f) + + tf.reduce_sum(y_pred_f) + + smooth) + + +@tf.function +def dice_loss(y_true, y_pred): + return 1 - dsc(y_true, y_pred) + + +@tf.function +def tversky(y_true, y_pred): + alpha = 0.7 + smooth = 1.0 + y_true_pos = tf.reshape(y_true, [-1]) + y_pred_pos = tf.reshape(y_pred, [-1]) + true_pos = tf.reduce_sum(y_true_pos * y_pred_pos) + false_neg = tf.reduce_sum(y_true_pos * (1 - y_pred_pos)) + false_pos = tf.reduce_sum((1 - y_true_pos) * y_pred_pos) + return (true_pos + smooth) / (true_pos + alpha * false_neg + (1 - alpha) * false_pos + smooth) + + +@tf.function +def tversky_loss(y_true, y_pred): + return 1 - tversky(y_true, y_pred) + + +@tf.function +def focal_tversky(y_true, y_pred): + return tf.pow((1 - tversky(y_true, y_pred)), 0.75) + + +@tf.function +def iou(y_true, y_pred): + intersect = tf.reduce_sum(y_true * y_pred, axis=(1, 2)) + union = tf.reduce_sum(y_true + y_pred, axis=(1, 2)) + return tf.reduce_mean(tf.math.divide_no_nan(intersect, (union - intersect)), axis=1) + + +@tf.function +def mean_iou(y_true, y_pred): + y_true_32 = tf.cast(y_true, tf.float32) + y_pred_32 = tf.cast(y_pred, tf.float32) + score = tf.map_fn(lambda x: iou(y_true_32, tf.cast(y_pred_32 > x, tf.float32)), + tf.range(0.5, 1.0, 0.05, tf.float32), + tf.float32) + return tf.reduce_mean(score) + + +@tf.function +def iou_loss(y_true, y_pred): + return -1*mean_iou(y_true, y_pred) + + +def do_unet(): + ''' + This is the dual output U-Net model. + It is a custom U-Net with optimized number of layers. + Please read model.summary() + ''' + inputs = tf.keras.layers.Input((188, 188, 3)) + + # encoder + filters = 32 + encoder1 = conv_bn(3*filters, inputs, model_type) + encoder1 = conv_bn(filters, encoder1, model_type, kernel=(1, 1)) + encoder1 = conv_bn(filters, encoder1, model_type) + pool1 = max_pool(encoder1) + + filters *= 2 + encoder2 = conv_bn(filters, pool1, model_type) + encoder2 = conv_bn(filters, encoder2, model_type) + pool2 = max_pool(encoder2) + + filters *= 2 + encoder3 = conv_bn(filters, pool2, model_type) + encoder3 = conv_bn(filters, encoder3, model_type) + pool3 = max_pool(encoder3) + + filters *= 2 + encoder4 = conv_bn(filters, pool3, model_type) + encoder4 = conv_bn(filters, encoder4, model_type) + + # decoder + filters /= 2 + decoder1 = conv_bn(filters, encoder4, model_type, type='transpose') + decoder1 = concatenate(encoder3, decoder1, 4) + decoder1 = conv_bn(filters, decoder1, model_type) + decoder1 = conv_bn(filters, decoder1, model_type) + + filters /= 2 + decoder2 = conv_bn(filters, decoder1, model_type, type='transpose') + decoder2 = concatenate(encoder2, decoder2, 16) + decoder2 = conv_bn(filters, decoder2, model_type) + decoder2 = conv_bn(filters, decoder2, model_type) + + filters /= 2 + decoder3 = conv_bn(filters, decoder2, model_type, type='transpose') + decoder3 = concatenate(encoder1, decoder3, 40) + decoder3 = conv_bn(filters, decoder3, model_type) + decoder3 = conv_bn(filters, decoder3, model_type) + + out_mask = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid', name='mask')(decoder3) + + if cell_type == 'rbc': + out_edge = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid', name='edge')(decoder3) + model = tf.keras.models.Model(inputs=inputs, outputs=(out_mask, out_edge)) + elif cell_type == 'wbc' or cell_type == 'plt': + model = tf.keras.models.Model(inputs=inputs, outputs=(out_mask)) + + opt = tf.optimizers.Adam(learning_rate=0.0001) + + if cell_type == 'rbc': + model.compile(loss='mse', + loss_weights=[0.1, 0.9], + optimizer=opt, + metrics=['accuracy']) + elif cell_type == 'wbc' or cell_type == 'plt': + model.compile(loss='mse', + optimizer=opt, + metrics='accuracy') + return model + +def segnet(): + inputs = tf.keras.layers.Input((128, 128, 3)) + + # encoder + filters = 64 + encoder1 = conv_bn(filters, inputs, model_type) + encoder1 = conv_bn(filters, encoder1, model_type) + pool1, mask1 = tf.nn.max_pool_with_argmax(encoder1, 3, 2, padding="SAME") + + filters *= 2 + encoder2 = conv_bn(filters, pool1, model_type) + encoder2 = conv_bn(filters, encoder2, model_type) + pool2, mask2 = tf.nn.max_pool_with_argmax(encoder2, 3, 2, padding="SAME") + + filters *= 2 + encoder3 = conv_bn(filters, pool2, model_type) + encoder3 = conv_bn(filters, encoder3, model_type) + encoder3 = conv_bn(filters, encoder3, model_type) + pool3, mask3 = tf.nn.max_pool_with_argmax(encoder3, 3, 2, padding="SAME") + + filters *= 2 + encoder4 = conv_bn(filters, pool3, model_type) + encoder4 = conv_bn(filters, encoder4, model_type) + encoder4 = conv_bn(filters, encoder4, model_type) + pool4, mask4 = tf.nn.max_pool_with_argmax(encoder4, 3, 2, padding="SAME") + + encoder5 = conv_bn(filters, pool4, model_type) + encoder5 = conv_bn(filters, encoder5, model_type) + encoder5 = conv_bn(filters, encoder5, model_type) + pool5, mask5 = tf.nn.max_pool_with_argmax(encoder5, 3, 2, padding="SAME") + + # decoder + unpool1 = tfa.layers.MaxUnpooling2D()(pool5, mask5) + decoder1 = conv_bn(filters, unpool1, model_type) + decoder1 = conv_bn(filters, decoder1, model_type) + decoder1 = conv_bn(filters, decoder1, model_type) + + unpool2 = tfa.layers.MaxUnpooling2D()(decoder1, mask4) + decoder2 = conv_bn(filters, unpool2, model_type) + decoder2 = conv_bn(filters, decoder2, model_type) + decoder2 = conv_bn(filters/2, decoder2, model_type) + + filters /= 2 + unpool3 = tfa.layers.MaxUnpooling2D()(decoder2, mask3) + decoder3 = conv_bn(filters, unpool3, model_type) + decoder3 = conv_bn(filters, decoder3, model_type) + decoder3 = conv_bn(filters/2, decoder3, model_type) + + filters /= 2 + unpool4 = tfa.layers.MaxUnpooling2D()(decoder3, mask2) + decoder4 = conv_bn(filters, unpool4, model_type) + decoder4 = conv_bn(filters/2, decoder4, model_type) + + filters /= 2 + unpool5 = tfa.layers.MaxUnpooling2D()(decoder4, mask1) + decoder5 = conv_bn(filters, unpool5, model_type) + + out_mask = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid', name='mask')(decoder5) + + if cell_type == 'rbc': + out_edge = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid', name='edge')(decoder5) + model = tf.keras.models.Model(inputs=inputs, outputs=(out_mask, out_edge)) + elif cell_type == 'wbc' or cell_type == 'plt': + model = tf.keras.models.Model(inputs=inputs, outputs=(out_mask)) + + opt = tf.optimizers.Adam(learning_rate=0.0001) + + if cell_type == 'rbc': + model.compile(loss='mse', + loss_weights=[0.1, 0.9], + optimizer=opt, + metrics=[mean_iou, dsc, tversky, 'accuracy']) + elif cell_type == 'wbc' or cell_type == 'plt': + model.compile(loss='mse', + optimizer=opt, + metrics=[mean_iou, dsc, tversky, 'accuracy']) + return model