Diff of /bc-count/model.py [000000] .. [0be6a8]

Switch to unified view

a b/bc-count/model.py
1
##############################################
2
#                                            #
3
#                 DO-U-Net                   #
4
#                   and                      #
5
#                 DO-SegNet                  #
6
#                                            #
7
# Author: Amine Neggazi                      #
8
# Email: neggazimedlamine@gmail/com          #
9
# Nick: nemo256                              #
10
#                                            #
11
# Please read bc-count/LICENSE               #
12
#                                            #
13
##############################################
14
15
import tensorflow as tf
16
import tensorflow_addons as tfa
17
18
# custom imports
19
from config import *
20
21
22
def conv_bn(filters,
23
            model,
24
            model_type,
25
            kernel=(3, 3),
26
            activation='relu', 
27
            strides=(1, 1),
28
            padding='valid',
29
            type='normal'):
30
    '''
31
    This is a custom convolution function:
32
    :param filters --> number of filters for each convolution
33
    :param kernel --> the kernel size
34
    :param activation --> the general activation function (relu)
35
    :param strides --> number of strides
36
    :param padding --> model padding (can be valid or same)
37
    :param type --> to indicate if it is a transpose or normal convolution
38
39
    :return --> returns the output after the convolution and batch normalization and activation.
40
    '''
41
    if model_type == 'segnet':
42
        kernel=3
43
        activation='relu'
44
        strides=(1, 1)
45
        padding='same'
46
        type='normal'
47
48
    if type == 'transpose':
49
        kernel = (2, 2)
50
        strides = 2
51
        conv = tf.keras.layers.Conv2DTranspose(filters, kernel, strides, padding)(model)
52
    else:
53
        conv = tf.keras.layers.Conv2D(filters, kernel, strides, padding)(model)
54
55
    conv = tf.keras.layers.BatchNormalization()(conv)
56
    conv = tf.keras.layers.Activation(activation)(conv)
57
58
    return conv
59
60
61
def max_pool(input):
62
    '''
63
    This is a general max pool function with custom parameters.
64
    '''
65
    return tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(input)
66
67
68
def concatenate(input1, input2, crop):
69
    '''
70
    This is a general concatenation function with custom parameters.
71
    '''
72
    return tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(crop)(input1), input2])
73
74
75
def get_callbacks(name):
76
    '''
77
    This is a custom function to save only the best checkpoint.
78
    :param name --> the input model name
79
    '''
80
    return [
81
        tf.keras.callbacks.ModelCheckpoint(f'models/{name}.h5',
82
                                           save_best_only=True,
83
                                           save_weights_only=True,
84
                                           verbose=1)
85
    ]
86
87
88
# loss functions
89
@tf.function
90
def dsc(y_true, y_pred):
91
    smooth = 1.0
92
    y_true_f = tf.reshape(y_true, [-1])
93
    y_pred_f = tf.reshape(y_pred, [-1])
94
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
95
    return (2.0 * intersection + smooth) / (tf.reduce_sum(y_true_f) +
96
                                            tf.reduce_sum(y_pred_f) +
97
                                            smooth)
98
99
100
@tf.function
101
def dice_loss(y_true, y_pred):
102
    return 1 - dsc(y_true, y_pred)
103
104
105
@tf.function
106
def tversky(y_true, y_pred):
107
    alpha = 0.7
108
    smooth = 1.0
109
    y_true_pos = tf.reshape(y_true, [-1])
110
    y_pred_pos = tf.reshape(y_pred, [-1])
111
    true_pos = tf.reduce_sum(y_true_pos * y_pred_pos)
112
    false_neg = tf.reduce_sum(y_true_pos * (1 - y_pred_pos))
113
    false_pos = tf.reduce_sum((1 - y_true_pos) * y_pred_pos)
114
    return (true_pos + smooth) / (true_pos + alpha * false_neg + (1 - alpha) * false_pos + smooth)
115
116
117
@tf.function
118
def tversky_loss(y_true, y_pred):
119
    return 1 - tversky(y_true, y_pred)
120
121
122
@tf.function
123
def focal_tversky(y_true, y_pred):
124
    return tf.pow((1 - tversky(y_true, y_pred)), 0.75)
125
126
127
@tf.function
128
def iou(y_true, y_pred):
129
    intersect = tf.reduce_sum(y_true * y_pred, axis=(1, 2))
130
    union = tf.reduce_sum(y_true + y_pred, axis=(1, 2))
131
    return tf.reduce_mean(tf.math.divide_no_nan(intersect, (union - intersect)), axis=1)
132
133
134
@tf.function
135
def mean_iou(y_true, y_pred):
136
    y_true_32 = tf.cast(y_true, tf.float32)
137
    y_pred_32 = tf.cast(y_pred, tf.float32)
138
    score = tf.map_fn(lambda x: iou(y_true_32, tf.cast(y_pred_32 > x, tf.float32)),
139
                      tf.range(0.5, 1.0, 0.05, tf.float32),
140
                      tf.float32)
141
    return tf.reduce_mean(score)
142
143
144
@tf.function
145
def iou_loss(y_true, y_pred):
146
    return -1*mean_iou(y_true, y_pred)
147
148
149
def do_unet():
150
    '''
151
    This is the dual output U-Net model.
152
    It is a custom U-Net with optimized number of layers.
153
    Please read model.summary()
154
    '''
155
    inputs = tf.keras.layers.Input((188, 188, 3))
156
157
    # encoder
158
    filters = 32
159
    encoder1 = conv_bn(3*filters, inputs, model_type)
160
    encoder1 = conv_bn(filters, encoder1, model_type, kernel=(1, 1))
161
    encoder1 = conv_bn(filters, encoder1, model_type)
162
    pool1 = max_pool(encoder1)
163
164
    filters *= 2
165
    encoder2 = conv_bn(filters, pool1, model_type)
166
    encoder2 = conv_bn(filters, encoder2, model_type)
167
    pool2 = max_pool(encoder2)
168
169
    filters *= 2
170
    encoder3 = conv_bn(filters, pool2, model_type)
171
    encoder3 = conv_bn(filters, encoder3, model_type)
172
    pool3 = max_pool(encoder3)
173
174
    filters *= 2
175
    encoder4 = conv_bn(filters, pool3, model_type)
176
    encoder4 = conv_bn(filters, encoder4, model_type)
177
178
    # decoder
179
    filters /= 2
180
    decoder1 = conv_bn(filters, encoder4, model_type, type='transpose')
181
    decoder1 = concatenate(encoder3, decoder1, 4)
182
    decoder1 = conv_bn(filters, decoder1, model_type)
183
    decoder1 = conv_bn(filters, decoder1, model_type)
184
185
    filters /= 2
186
    decoder2 = conv_bn(filters, decoder1, model_type, type='transpose')
187
    decoder2 = concatenate(encoder2, decoder2, 16)
188
    decoder2 = conv_bn(filters, decoder2, model_type)
189
    decoder2 = conv_bn(filters, decoder2, model_type)
190
191
    filters /= 2
192
    decoder3 = conv_bn(filters, decoder2, model_type, type='transpose')
193
    decoder3 = concatenate(encoder1, decoder3, 40)
194
    decoder3 = conv_bn(filters, decoder3, model_type)
195
    decoder3 = conv_bn(filters, decoder3, model_type)
196
197
    out_mask = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid', name='mask')(decoder3)
198
199
    if cell_type == 'rbc':
200
        out_edge = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid', name='edge')(decoder3)
201
        model = tf.keras.models.Model(inputs=inputs, outputs=(out_mask, out_edge))
202
    elif cell_type == 'wbc' or cell_type == 'plt':
203
        model = tf.keras.models.Model(inputs=inputs, outputs=(out_mask))
204
205
    opt = tf.optimizers.Adam(learning_rate=0.0001)
206
207
    if cell_type == 'rbc':
208
        model.compile(loss='mse',
209
                      loss_weights=[0.1, 0.9],
210
                      optimizer=opt,
211
                      metrics=['accuracy'])
212
    elif cell_type == 'wbc' or cell_type == 'plt':
213
        model.compile(loss='mse',
214
                      optimizer=opt,
215
                      metrics='accuracy')
216
    return model
217
218
def segnet():
219
    inputs = tf.keras.layers.Input((128, 128, 3))
220
221
    # encoder
222
    filters = 64
223
    encoder1 = conv_bn(filters, inputs, model_type)
224
    encoder1 = conv_bn(filters, encoder1, model_type)
225
    pool1, mask1 = tf.nn.max_pool_with_argmax(encoder1, 3, 2, padding="SAME")
226
227
    filters *= 2
228
    encoder2 = conv_bn(filters, pool1, model_type)
229
    encoder2 = conv_bn(filters, encoder2, model_type)
230
    pool2, mask2 = tf.nn.max_pool_with_argmax(encoder2, 3, 2, padding="SAME")
231
232
    filters *= 2
233
    encoder3 = conv_bn(filters, pool2, model_type)
234
    encoder3 = conv_bn(filters, encoder3, model_type)
235
    encoder3 = conv_bn(filters, encoder3, model_type)
236
    pool3, mask3 = tf.nn.max_pool_with_argmax(encoder3, 3, 2, padding="SAME")
237
238
    filters *= 2
239
    encoder4 = conv_bn(filters, pool3, model_type)
240
    encoder4 = conv_bn(filters, encoder4, model_type)
241
    encoder4 = conv_bn(filters, encoder4, model_type)
242
    pool4, mask4 = tf.nn.max_pool_with_argmax(encoder4, 3, 2, padding="SAME")
243
244
    encoder5 = conv_bn(filters, pool4, model_type)
245
    encoder5 = conv_bn(filters, encoder5, model_type)
246
    encoder5 = conv_bn(filters, encoder5, model_type)
247
    pool5, mask5 = tf.nn.max_pool_with_argmax(encoder5, 3, 2, padding="SAME")
248
249
    # decoder
250
    unpool1 = tfa.layers.MaxUnpooling2D()(pool5, mask5)
251
    decoder1 = conv_bn(filters, unpool1, model_type)
252
    decoder1 = conv_bn(filters, decoder1, model_type)
253
    decoder1 = conv_bn(filters, decoder1, model_type)
254
255
    unpool2 = tfa.layers.MaxUnpooling2D()(decoder1, mask4)
256
    decoder2 = conv_bn(filters, unpool2, model_type)
257
    decoder2 = conv_bn(filters, decoder2, model_type)
258
    decoder2 = conv_bn(filters/2, decoder2, model_type)
259
260
    filters /= 2
261
    unpool3 = tfa.layers.MaxUnpooling2D()(decoder2, mask3)
262
    decoder3 = conv_bn(filters, unpool3, model_type)
263
    decoder3 = conv_bn(filters, decoder3, model_type)
264
    decoder3 = conv_bn(filters/2, decoder3, model_type)
265
266
    filters /= 2
267
    unpool4 = tfa.layers.MaxUnpooling2D()(decoder3, mask2)
268
    decoder4 = conv_bn(filters, unpool4, model_type)
269
    decoder4 = conv_bn(filters/2, decoder4, model_type)
270
271
    filters /= 2
272
    unpool5 = tfa.layers.MaxUnpooling2D()(decoder4, mask1)
273
    decoder5 = conv_bn(filters, unpool5, model_type)
274
275
    out_mask = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid', name='mask')(decoder5)
276
277
    if cell_type == 'rbc':
278
        out_edge = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid', name='edge')(decoder5)
279
        model = tf.keras.models.Model(inputs=inputs, outputs=(out_mask, out_edge))
280
    elif cell_type == 'wbc' or cell_type == 'plt':
281
        model = tf.keras.models.Model(inputs=inputs, outputs=(out_mask))
282
283
    opt = tf.optimizers.Adam(learning_rate=0.0001)
284
285
    if cell_type == 'rbc':
286
        model.compile(loss='mse',
287
                      loss_weights=[0.1, 0.9],
288
                      optimizer=opt,
289
                      metrics=[mean_iou, dsc, tversky, 'accuracy'])
290
    elif cell_type == 'wbc' or cell_type == 'plt':
291
        model.compile(loss='mse',
292
                      optimizer=opt,
293
                      metrics=[mean_iou, dsc, tversky, 'accuracy'])
294
    return model