Diff of /bc-count/model.py [000000] .. [0be6a8]

Switch to side-by-side view

--- a
+++ b/bc-count/model.py
@@ -0,0 +1,294 @@
+##############################################
+#                                            #
+#                 DO-U-Net                   #
+#                   and                      #
+#                 DO-SegNet                  #
+#                                            #
+# Author: Amine Neggazi                      #
+# Email: neggazimedlamine@gmail/com          #
+# Nick: nemo256                              #
+#                                            #
+# Please read bc-count/LICENSE               #
+#                                            #
+##############################################
+
+import tensorflow as tf
+import tensorflow_addons as tfa
+
+# custom imports
+from config import *
+
+
+def conv_bn(filters,
+            model,
+            model_type,
+            kernel=(3, 3),
+            activation='relu', 
+            strides=(1, 1),
+            padding='valid',
+            type='normal'):
+    '''
+    This is a custom convolution function:
+    :param filters --> number of filters for each convolution
+    :param kernel --> the kernel size
+    :param activation --> the general activation function (relu)
+    :param strides --> number of strides
+    :param padding --> model padding (can be valid or same)
+    :param type --> to indicate if it is a transpose or normal convolution
+
+    :return --> returns the output after the convolution and batch normalization and activation.
+    '''
+    if model_type == 'segnet':
+        kernel=3
+        activation='relu'
+        strides=(1, 1)
+        padding='same'
+        type='normal'
+
+    if type == 'transpose':
+        kernel = (2, 2)
+        strides = 2
+        conv = tf.keras.layers.Conv2DTranspose(filters, kernel, strides, padding)(model)
+    else:
+        conv = tf.keras.layers.Conv2D(filters, kernel, strides, padding)(model)
+
+    conv = tf.keras.layers.BatchNormalization()(conv)
+    conv = tf.keras.layers.Activation(activation)(conv)
+
+    return conv
+
+
+def max_pool(input):
+    '''
+    This is a general max pool function with custom parameters.
+    '''
+    return tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(input)
+
+
+def concatenate(input1, input2, crop):
+    '''
+    This is a general concatenation function with custom parameters.
+    '''
+    return tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(crop)(input1), input2])
+
+
+def get_callbacks(name):
+    '''
+    This is a custom function to save only the best checkpoint.
+    :param name --> the input model name
+    '''
+    return [
+        tf.keras.callbacks.ModelCheckpoint(f'models/{name}.h5',
+                                           save_best_only=True,
+                                           save_weights_only=True,
+                                           verbose=1)
+    ]
+
+
+# loss functions
+@tf.function
+def dsc(y_true, y_pred):
+    smooth = 1.0
+    y_true_f = tf.reshape(y_true, [-1])
+    y_pred_f = tf.reshape(y_pred, [-1])
+    intersection = tf.reduce_sum(y_true_f * y_pred_f)
+    return (2.0 * intersection + smooth) / (tf.reduce_sum(y_true_f) +
+                                            tf.reduce_sum(y_pred_f) +
+                                            smooth)
+
+
+@tf.function
+def dice_loss(y_true, y_pred):
+    return 1 - dsc(y_true, y_pred)
+
+
+@tf.function
+def tversky(y_true, y_pred):
+    alpha = 0.7
+    smooth = 1.0
+    y_true_pos = tf.reshape(y_true, [-1])
+    y_pred_pos = tf.reshape(y_pred, [-1])
+    true_pos = tf.reduce_sum(y_true_pos * y_pred_pos)
+    false_neg = tf.reduce_sum(y_true_pos * (1 - y_pred_pos))
+    false_pos = tf.reduce_sum((1 - y_true_pos) * y_pred_pos)
+    return (true_pos + smooth) / (true_pos + alpha * false_neg + (1 - alpha) * false_pos + smooth)
+
+
+@tf.function
+def tversky_loss(y_true, y_pred):
+    return 1 - tversky(y_true, y_pred)
+
+
+@tf.function
+def focal_tversky(y_true, y_pred):
+    return tf.pow((1 - tversky(y_true, y_pred)), 0.75)
+
+
+@tf.function
+def iou(y_true, y_pred):
+    intersect = tf.reduce_sum(y_true * y_pred, axis=(1, 2))
+    union = tf.reduce_sum(y_true + y_pred, axis=(1, 2))
+    return tf.reduce_mean(tf.math.divide_no_nan(intersect, (union - intersect)), axis=1)
+
+
+@tf.function
+def mean_iou(y_true, y_pred):
+    y_true_32 = tf.cast(y_true, tf.float32)
+    y_pred_32 = tf.cast(y_pred, tf.float32)
+    score = tf.map_fn(lambda x: iou(y_true_32, tf.cast(y_pred_32 > x, tf.float32)),
+                      tf.range(0.5, 1.0, 0.05, tf.float32),
+                      tf.float32)
+    return tf.reduce_mean(score)
+
+
+@tf.function
+def iou_loss(y_true, y_pred):
+    return -1*mean_iou(y_true, y_pred)
+
+
+def do_unet():
+    '''
+    This is the dual output U-Net model.
+    It is a custom U-Net with optimized number of layers.
+    Please read model.summary()
+    '''
+    inputs = tf.keras.layers.Input((188, 188, 3))
+
+    # encoder
+    filters = 32
+    encoder1 = conv_bn(3*filters, inputs, model_type)
+    encoder1 = conv_bn(filters, encoder1, model_type, kernel=(1, 1))
+    encoder1 = conv_bn(filters, encoder1, model_type)
+    pool1 = max_pool(encoder1)
+
+    filters *= 2
+    encoder2 = conv_bn(filters, pool1, model_type)
+    encoder2 = conv_bn(filters, encoder2, model_type)
+    pool2 = max_pool(encoder2)
+
+    filters *= 2
+    encoder3 = conv_bn(filters, pool2, model_type)
+    encoder3 = conv_bn(filters, encoder3, model_type)
+    pool3 = max_pool(encoder3)
+
+    filters *= 2
+    encoder4 = conv_bn(filters, pool3, model_type)
+    encoder4 = conv_bn(filters, encoder4, model_type)
+
+    # decoder
+    filters /= 2
+    decoder1 = conv_bn(filters, encoder4, model_type, type='transpose')
+    decoder1 = concatenate(encoder3, decoder1, 4)
+    decoder1 = conv_bn(filters, decoder1, model_type)
+    decoder1 = conv_bn(filters, decoder1, model_type)
+
+    filters /= 2
+    decoder2 = conv_bn(filters, decoder1, model_type, type='transpose')
+    decoder2 = concatenate(encoder2, decoder2, 16)
+    decoder2 = conv_bn(filters, decoder2, model_type)
+    decoder2 = conv_bn(filters, decoder2, model_type)
+
+    filters /= 2
+    decoder3 = conv_bn(filters, decoder2, model_type, type='transpose')
+    decoder3 = concatenate(encoder1, decoder3, 40)
+    decoder3 = conv_bn(filters, decoder3, model_type)
+    decoder3 = conv_bn(filters, decoder3, model_type)
+
+    out_mask = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid', name='mask')(decoder3)
+
+    if cell_type == 'rbc':
+        out_edge = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid', name='edge')(decoder3)
+        model = tf.keras.models.Model(inputs=inputs, outputs=(out_mask, out_edge))
+    elif cell_type == 'wbc' or cell_type == 'plt':
+        model = tf.keras.models.Model(inputs=inputs, outputs=(out_mask))
+
+    opt = tf.optimizers.Adam(learning_rate=0.0001)
+
+    if cell_type == 'rbc':
+        model.compile(loss='mse',
+                      loss_weights=[0.1, 0.9],
+                      optimizer=opt,
+                      metrics=['accuracy'])
+    elif cell_type == 'wbc' or cell_type == 'plt':
+        model.compile(loss='mse',
+                      optimizer=opt,
+                      metrics='accuracy')
+    return model
+
+def segnet():
+    inputs = tf.keras.layers.Input((128, 128, 3))
+
+    # encoder
+    filters = 64
+    encoder1 = conv_bn(filters, inputs, model_type)
+    encoder1 = conv_bn(filters, encoder1, model_type)
+    pool1, mask1 = tf.nn.max_pool_with_argmax(encoder1, 3, 2, padding="SAME")
+
+    filters *= 2
+    encoder2 = conv_bn(filters, pool1, model_type)
+    encoder2 = conv_bn(filters, encoder2, model_type)
+    pool2, mask2 = tf.nn.max_pool_with_argmax(encoder2, 3, 2, padding="SAME")
+
+    filters *= 2
+    encoder3 = conv_bn(filters, pool2, model_type)
+    encoder3 = conv_bn(filters, encoder3, model_type)
+    encoder3 = conv_bn(filters, encoder3, model_type)
+    pool3, mask3 = tf.nn.max_pool_with_argmax(encoder3, 3, 2, padding="SAME")
+
+    filters *= 2
+    encoder4 = conv_bn(filters, pool3, model_type)
+    encoder4 = conv_bn(filters, encoder4, model_type)
+    encoder4 = conv_bn(filters, encoder4, model_type)
+    pool4, mask4 = tf.nn.max_pool_with_argmax(encoder4, 3, 2, padding="SAME")
+
+    encoder5 = conv_bn(filters, pool4, model_type)
+    encoder5 = conv_bn(filters, encoder5, model_type)
+    encoder5 = conv_bn(filters, encoder5, model_type)
+    pool5, mask5 = tf.nn.max_pool_with_argmax(encoder5, 3, 2, padding="SAME")
+
+    # decoder
+    unpool1 = tfa.layers.MaxUnpooling2D()(pool5, mask5)
+    decoder1 = conv_bn(filters, unpool1, model_type)
+    decoder1 = conv_bn(filters, decoder1, model_type)
+    decoder1 = conv_bn(filters, decoder1, model_type)
+
+    unpool2 = tfa.layers.MaxUnpooling2D()(decoder1, mask4)
+    decoder2 = conv_bn(filters, unpool2, model_type)
+    decoder2 = conv_bn(filters, decoder2, model_type)
+    decoder2 = conv_bn(filters/2, decoder2, model_type)
+
+    filters /= 2
+    unpool3 = tfa.layers.MaxUnpooling2D()(decoder2, mask3)
+    decoder3 = conv_bn(filters, unpool3, model_type)
+    decoder3 = conv_bn(filters, decoder3, model_type)
+    decoder3 = conv_bn(filters/2, decoder3, model_type)
+
+    filters /= 2
+    unpool4 = tfa.layers.MaxUnpooling2D()(decoder3, mask2)
+    decoder4 = conv_bn(filters, unpool4, model_type)
+    decoder4 = conv_bn(filters/2, decoder4, model_type)
+
+    filters /= 2
+    unpool5 = tfa.layers.MaxUnpooling2D()(decoder4, mask1)
+    decoder5 = conv_bn(filters, unpool5, model_type)
+
+    out_mask = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid', name='mask')(decoder5)
+
+    if cell_type == 'rbc':
+        out_edge = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid', name='edge')(decoder5)
+        model = tf.keras.models.Model(inputs=inputs, outputs=(out_mask, out_edge))
+    elif cell_type == 'wbc' or cell_type == 'plt':
+        model = tf.keras.models.Model(inputs=inputs, outputs=(out_mask))
+
+    opt = tf.optimizers.Adam(learning_rate=0.0001)
+
+    if cell_type == 'rbc':
+        model.compile(loss='mse',
+                      loss_weights=[0.1, 0.9],
+                      optimizer=opt,
+                      metrics=[mean_iou, dsc, tversky, 'accuracy'])
+    elif cell_type == 'wbc' or cell_type == 'plt':
+        model.compile(loss='mse',
+                      optimizer=opt,
+                      metrics=[mean_iou, dsc, tversky, 'accuracy'])
+    return model