a b/Segmentation/model/deeplabv3.py
1
import tensorflow as tf
2
import tensorflow.keras.layers as tfkl
3
4
class Deeplabv3_plus(tf.keras.Model):
5
    def __init__(self,
6
                 num_classes,
7
                 kernel_size_initial_conv,
8
                 num_channels_atrous=512,
9
                 num_channels_DCNN=[256, 512, 1024],
10
                 num_channels_ASPP=256,
11
                 kernel_size_atrous=3,
12
                 kernel_size_DCNN=[1, 3],
13
                 kernel_size_ASPP=[1, 3, 3, 3],
14
                 num_filters_final_encoder=512,
15
                 num_channels_from_backcone=[128, 96],
16
                 num_channels_UpConv=[512, 256, 128],
17
                 kernel_size_UpConv=3,
18
                 stride_UpConv=(2, 2),
19
                 use_batchnorm_UpConv=False,
20
                 use_transpose_UpConv=False,
21
                 padding='same',
22
                 nonlinearity='relu',
23
                 use_batchnorm=True,
24
                 use_bias=True,
25
                 data_format='channels_last',
26
                 MultiGrid=[1, 2, 4],
27
                 rate_ASPP=[1, 6, 12, 18],
28
                 atrous_output_stride=16,
29
                 # Not adapted code for any other out stride
30
                 **kwargs):
31
        
32
        """ Arguments:
33
            kernel_size_initial_conv: the size of the kernel for the
34
                                      first convolution
35
            num_channels_DCNN: touple with the number of channels for the
36
                               first three blocks of the DCNN
37
            kernel_size_DCNN: two element touple with the kernel size of the
38
                              first and last convolution of the resnet_block
39
                              (First element) and the middle convolution
40
                              of the resnet_block (Second element)  """
41
42
        super(Deeplabv3_plus, self).__init__(**kwargs)
43
44
        self.num_classes = num_classes
45
46
        #ResNet backbone
47
        self.first_conv = tfkl.Conv2D(num_channels_DCNN[0],
48
                                      kernel_size_initial_conv,
49
                                      strides=2,
50
                                      padding=padding,
51
                                      use_bias=use_bias,
52
                                      data_format=data_format)
53
54
        self.block1 = resnet_block(False,
55
                                   num_channels_DCNN[0],
56
                                   kernel_size_DCNN,
57
                                   padding,
58
                                   nonlinearity,
59
                                   use_batchnorm,
60
                                   use_bias,
61
                                   data_format)
62
63
        self.block2 = resnet_block(True,
64
                                   num_channels_DCNN[1],
65
                                   kernel_size_DCNN,
66
                                   padding,
67
                                   nonlinearity,
68
                                   use_batchnorm,
69
                                   use_bias,
70
                                   data_format)
71
72
        self.block3 = resnet_block(True,
73
                                   num_channels_DCNN[2],
74
                                   kernel_size_DCNN,
75
                                   padding,
76
                                   nonlinearity,
77
                                   use_batchnorm,
78
                                   use_bias,
79
                                   data_format)
80
81
        #Atrous components
82
        self.atrous_conv = Atrous_conv(num_channels_atrous,
83
                                       kernel_size_atrous,
84
                                       MultiGrid,
85
                                       padding,
86
                                       use_batchnorm,
87
                                       'linear',
88
                                       use_bias,
89
                                       data_format,
90
                                       atrous_output_stride)
91
92
        self.aspp_term = atrous_spatial_pyramid_pooling(num_channels_ASPP,
93
                                                        kernel_size_ASPP,
94
                                                        rate_ASPP,
95
                                                        padding,
96
                                                        use_batchnorm,
97
                                                        'linear',
98
                                                        use_bias,
99
                                                        data_format)
100
101
        #Final convolution of encoder
102
        self.final_encoder_conv = aspp_block(1,
103
                                             1,
104
                                             num_filters_final_encoder,
105
                                             padding,
106
                                             use_batchnorm,
107
                                             'linear',
108
                                             use_bias,
109
                                             data_format)
110
111
        #Decoder
112
        self.decoder_term = Decoder(num_channels_from_backcone=num_channels_from_backcone,
113
                                    num_channels_UpConv=num_channels_UpConv,
114
                                    kernel_size_UpConv=kernel_size_UpConv,
115
                                    stride_UpConv=stride_UpConv,
116
                                    use_batchnorm_UpConv=use_batchnorm_UpConv,
117
                                    use_transpose_UpConv=use_transpose_UpConv,
118
                                    use_bias=use_bias,
119
                                    padding=padding,
120
                                    data_format=data_format)
121
122
        self.output_conv = tfkl.Conv2D(num_classes,
123
                                       1,
124
                                       activation='linear',
125
                                       padding='same',
126
                                       data_format=data_format)
127
128
    def call(self, x, training=False):
129
        
130
        ###Encoder
131
        before_final_stride = self.first_conv(x, training=training)  # output stride 2
132
133
        before_final_stride = self.block1(before_final_stride, training=training)  # output stride 2
134
        before_final_stride = self.block2(before_final_stride, training=training)  # output stride 4
135
        atrous_out = self.block3(before_final_stride, training=training)  # output stride 8
136
137
        atrous_out = self.atrous_conv(atrous_out, training=training)
138
        out = self.aspp_term(atrous_out, training=training)
139
        out = self.final_encoder_conv(out, training=training)
140
141
        ###Decoder
142
        out = self.decoder_term(atrous_out ,out, before_final_stride, training=training)
143
144
        out = self.output_conv(out, training=training)
145
        if self.num_classes == 1:
146
            out = tfkl.Activation('sigmoid')(out)
147
        else:
148
            out = tfkl.Activation('softmax')(out)
149
        
150
        # Upsample to same size as the input
151
        # print(f"Input Shape: {x.shape}, Out Shape: {decoder_out.shape}")
152
        # input_size = tf.shape(x)[1:3]
153
        # decoder_out = tf.image.resize(decoder_out, input_size)
154
155
        return out
156
157
158
class Deeplabv3(tf.keras.Sequential):
159
    """ Tensorflow 2 Implementation of """
160
    def __init__(self,
161
                 num_classes,
162
                 kernel_size_initial_conv,
163
                 num_channels_atrous,
164
                 num_channels_DCNN=[256, 512, 1024],
165
                 num_channels_ASPP=256,
166
                 kernel_size_atrous=3,
167
                 kernel_size_DCNN=[1, 3],
168
                 kernel_size_ASPP=[1, 3, 3, 3],
169
                 padding='same',
170
                 nonlinearity='relu',
171
                 use_batchnorm=True,
172
                 use_bias=True,
173
                 data_format='channels_last',
174
                 MultiGrid=[1, 2, 4],
175
                 rate_ASPP=[1, 6, 12, 18],
176
                 atrous_output_stride=16,
177
                 # Not adapted code for any other out stride
178
                 **kwargs):
179
        
180
        """ Arguments:
181
            kernel_size_initial_conv: the size of the kernel for the
182
                                      first convolution
183
            num_channels_DCNN: touple with the number of channels for the
184
                               first three blocks of the DCNN
185
            kernel_size_DCNN: two element touple with the kernel size of the
186
                              first and last convolution of the resnet_block
187
                              (First element) and the middle convolution
188
                              of the resnet_block (Second element)  """
189
190
        super(Deeplabv3, self).__init__(**kwargs)
191
192
        self.num_classes = num_classes
193
194
        self.add(ResNet_Backbone(kernel_size_initial_conv,
195
                                 num_channels_DCNN,
196
                                 kernel_size_DCNN,
197
                                 padding,
198
                                 nonlinearity,
199
                                 use_batchnorm,
200
                                 use_bias,
201
                                 False,
202
                                 data_format))
203
204
        self.add(Atrous_conv(num_channels_atrous,
205
                             kernel_size_atrous,
206
                             MultiGrid,
207
                             padding,
208
                             use_batchnorm,
209
                             'linear',
210
                             use_bias,
211
                             data_format,
212
                             atrous_output_stride))
213
214
        self.add(atrous_spatial_pyramid_pooling(num_channels_ASPP,
215
                                                kernel_size_ASPP,
216
                                                rate_ASPP,
217
                                                padding,
218
                                                use_batchnorm,
219
                                                'linear',
220
                                                use_bias,
221
                                                data_format))
222
223
        self.add(aspp_block(1,
224
                            1,
225
                            num_classes,
226
                            padding,
227
                            use_batchnorm,
228
                            'linear',
229
                            use_bias,
230
                            data_format))
231
232
    def call(self, x, training=False):
233
234
        out = super(Deeplabv3, self).call(x, training=training)
235
        if self.num_classes == 1:
236
            out = tfkl.Activation('sigmoid')(out)
237
        else:
238
            out = tfkl.Activation('softmax')(out)
239
        
240
        # Upsample to same size as the input
241
        input_size = tf.shape(x)[1:3]
242
        out = tf.image.resize(out, input_size)
243
244
        return out
245
246
class ResNet_Backbone(tf.keras.Model):
247
    def __init__(self,
248
                 kernel_size_initial_conv,
249
                 num_channels=[256, 512, 1024],
250
                 kernel_size_blocks=[1, 3],
251
                 padding='same',
252
                 nonlinearity='relu',
253
                 use_batchnorm=True,
254
                 use_bias=True,
255
                 use_pooling=False,
256
                 data_format='channels_last',
257
                 **kwargs):
258
        
259
        super(ResNet_Backbone, self).__init__(**kwargs)
260
        self.first_conv = tfkl.Conv2D(num_channels[0],
261
                                      kernel_size_initial_conv,
262
                                      strides=2,
263
                                      padding=padding,
264
                                      use_bias=use_bias,
265
                                      data_format=data_format)
266
267
        self.max_pool = tfkl.MaxPool2D(pool_size=(2, 2),
268
                                       padding='valid')
269
        
270
        self.use_pooling = use_pooling
271
272
        self.block1 = resnet_block(False,
273
                                   num_channels[0],
274
                                   kernel_size_blocks,
275
                                   padding,
276
                                   nonlinearity,
277
                                   use_batchnorm,
278
                                   use_bias,
279
                                   data_format)
280
        self.block2 = resnet_block(True,
281
                                   num_channels[1],
282
                                   kernel_size_blocks,
283
                                   padding,
284
                                   nonlinearity,
285
                                   use_batchnorm,
286
                                   use_bias,
287
                                   data_format)
288
289
        self.block3 = resnet_block(True,
290
                                   num_channels[2],
291
                                   kernel_size_blocks,
292
                                   padding,
293
                                   nonlinearity,
294
                                   use_batchnorm,
295
                                   use_bias,
296
                                   data_format)
297
298
    def call(self, x, training=False):
299
300
        x = self.first_conv(x, training=training)  # output stride 2
301
        if self.use_pooling:
302
            x = self.max_pool(x)  # output stride 4
303
304
        x = self.block1(x, training=training)  # output stride 2 or 4
305
        x = self.block2(x, training=training)  # output stride 4 or 8
306
        x = self.block3(x, training=training)  # output stride 8 or 16
307
        return x
308
309
310
# full pre-activation residual unit
311
class resnet_block(tf.keras.Model):
312
313
    def __init__(self,
314
                 use_stride,
315
                 num_channels,
316
                 kernel_size=[1, 3],
317
                 padding='same',
318
                 nonlinearity='relu',
319
                 use_batchnorm=True,
320
                 use_bias=True,
321
                 data_format='channels_last',
322
                 **kwargs):
323
        
324
        super(resnet_block, self).__init__(**kwargs)
325
        self.use_stride = use_stride
326
        inner_num_channels = num_channels // 4
327
328
        if use_stride:
329
            self.input_conv = basic_conv_block(num_channels,
330
                                               1,
331
                                               2,
332
                                               padding,
333
                                               nonlinearity,
334
                                               use_batchnorm,
335
                                               use_bias,
336
                                               data_format)
337
            stride = 2
338
339
        else:
340
            stride = 1
341
342
        self.first_conv = basic_conv_block(inner_num_channels,
343
                                           kernel_size[0],
344
                                           stride,
345
                                           padding,
346
                                           nonlinearity,
347
                                           use_batchnorm,
348
                                           use_bias,
349
                                           data_format)
350
351
        self.second_conv = basic_conv_block(inner_num_channels,
352
                                            kernel_size[1],
353
                                            1,
354
                                            padding,
355
                                            nonlinearity,
356
                                            use_batchnorm,
357
                                            use_bias,
358
                                            data_format)
359
                             
360
        self.third_conv = basic_conv_block(num_channels,
361
                                           kernel_size[0],
362
                                           1,
363
                                           padding,
364
                                           nonlinearity,
365
                                           use_batchnorm,
366
                                           use_bias,
367
                                           data_format)
368
369
    def call(self, x, training=False):
370
        
371
        residual = self.first_conv(x, training=training)
372
373
        if self.use_stride:
374
            x = self.input_conv(x, training=training)
375
376
        residual = self.second_conv(residual, training=training)
377
        residual = self.third_conv(residual, training=training)
378
379
        output = tfkl.Add()([residual, x])
380
        return output
381
382
383
class basic_conv_block(tf.keras.Sequential):
384
385
    def __init__(self,
386
                 num_channels,
387
                 kernel_size,
388
                 stride=1,
389
                 padding='same',
390
                 nonlinearity='relu',
391
                 use_batchnorm=True,
392
                 use_bias=True,
393
                 data_format='channels_last',
394
                 rate=1,
395
                 **kwargs):
396
397
        super(basic_conv_block, self).__init__(**kwargs)
398
399
        if use_batchnorm:
400
            self.add(tfkl.BatchNormalization(axis=-1,
401
                                             momentum=0.95,
402
                                             epsilon=0.001))
403
        self.add(tfkl.Activation(nonlinearity))
404
405
        self.add(tfkl.Conv2D(num_channels,
406
                             kernel_size,
407
                             strides=stride,
408
                             padding=padding,
409
                             use_bias=use_bias,
410
                             data_format=data_format,
411
                             dilation_rate=rate))
412
    
413
    def call(self, x, training=False):
414
415
        output = super(basic_conv_block, self).call(x, training=training)
416
        return output
417
418
# ####################### Atrous Convolution ####################### #
419
class Atrous_conv(tf.keras.Model):
420
421
    def __init__(self,
422
                 num_channels,
423
                 kernel_size=3,
424
                 MultiGrid=[1, 2, 4],
425
                 padding='same',
426
                 use_batchnorm=True,
427
                 nonlinearity='linear',
428
                 use_bias=True,
429
                 data_format='channels_last',
430
                 output_stride=16,
431
                 **kwargs):
432
433
        super(Atrous_conv, self).__init__(**kwargs)
434
435
        if output_stride == 16:
436
            multiplier = 2
437
        else:
438
            multiplier = 1
439
        
440
        self.first_conv = basic_conv_block(num_channels,
441
                                           kernel_size,
442
                                           1,
443
                                           padding,
444
                                           nonlinearity,
445
                                           use_batchnorm,
446
                                           use_bias,
447
                                           data_format,
448
                                           rate=int(multiplier * MultiGrid[0]))
449
450
        self.second_conv = basic_conv_block(num_channels,
451
                                            kernel_size,
452
                                            1,
453
                                            padding,
454
                                            nonlinearity,
455
                                            use_batchnorm,
456
                                            use_bias,
457
                                            data_format,
458
                                            rate=int(multiplier * MultiGrid[1]))
459
460
        self.third_conv = basic_conv_block(num_channels,
461
                                           kernel_size,
462
                                           1,
463
                                           padding,
464
                                           nonlinearity,
465
                                           use_batchnorm,
466
                                           use_bias,
467
                                           data_format,
468
                                           rate=int(multiplier * MultiGrid[2]))
469
470
    def call(self, x, training=False):
471
472
        x = self.first_conv(x, training)
473
        x = self.second_conv(x, training)
474
        x = self.third_conv(x, training)
475
        return x
476
477
478
# ####################### ASPP ####################### #
479
class atrous_spatial_pyramid_pooling(tf.keras.Model):
480
481
    def __init__(self,
482
                 num_channels=256,
483
                 kernel_size=[1, 3, 3, 3],
484
                 rate=[1, 6, 12, 18],
485
                 padding='same',
486
                 use_batchnorm=True,
487
                 nonlinearity='linear',
488
                 use_bias=True,
489
                 data_format='channels_last',
490
                 **kwargs):
491
        
492
        super(atrous_spatial_pyramid_pooling, self).__init__(**kwargs)
493
        self.block_list = []
494
495
        self.basic_conv1 = tfkl.Conv2D(num_channels,
496
                                       kernel_size=1,
497
                                       padding=padding)
498
499
        self.basic_conv2 = tfkl.Conv2D(num_channels,
500
                                       kernel_size=1,
501
                                       padding=padding)
502
503
        for i in range(len(kernel_size)):
504
            self.block_list.append(aspp_block(kernel_size[i],
505
                                              rate[i],
506
                                              num_channels,
507
                                              padding,
508
                                              use_batchnorm,
509
                                              nonlinearity,
510
                                              use_bias,
511
                                              data_format))
512
            
513
    def call(self, x, training=False):
514
515
        feature_map_size = tf.shape(x)
516
        output_list = []
517
518
        # Non diluted convolution
519
        y = tf.math.reduce_mean(x, axis=[1, 2], keepdims=True)  # ~ Average Pooling
520
        y = self.basic_conv1(y, training=training)
521
        output_list.append(tf.image.resize(y, (feature_map_size[1], feature_map_size[2])))  # ~ Upsampling
522
523
        # Series of diluted convolutions with rates (1, 6, 12, 18)
524
        for i, block in enumerate(self.block_list):
525
            output_list.append(block(x, training=training))
526
527
        # concatenate all outputs
528
        out = tf.concat(output_list, axis=3)
529
        out = self.basic_conv2(out, training=training)
530
        return out
531
532
533
class aspp_block(tf.keras.Sequential):
534
535
    def __init__(self,
536
                 kernel_size,
537
                 rate,
538
                 num_channels=256,
539
                 padding='same',
540
                 use_batchnorm=True,
541
                 nonlinearity='linear',
542
                 use_bias=True,
543
                 data_format='channels_last',
544
                 **kwargs):
545
546
        super(aspp_block, self).__init__(**kwargs)
547
548
        self.add(tfkl.Conv2D(num_channels,
549
                             kernel_size,
550
                             padding=padding,
551
                             use_bias=use_bias,
552
                             data_format=data_format,
553
                             dilation_rate=rate))
554
        
555
        if use_batchnorm:
556
            self.add(tfkl.BatchNormalization(axis=-1,
557
                                             momentum=0.95,
558
                                             epsilon=0.001))
559
560
        self.add(tfkl.Activation(nonlinearity))
561
562
    def call(self, x, training=False):
563
564
        output = super(aspp_block, self).call(x, training=training)
565
        return output
566
567
# ####################### Decoder ####################### #
568
class Decoder(tf.keras.Model):
569
570
    def __init__(self,
571
                 num_channels_from_backcone=[48],
572
                 num_channels_UpConv=[512, 256, 128],
573
                 kernel_size_UpConv=3,
574
                 stride_UpConv=(2, 2),
575
                 use_batchnorm_UpConv=False,
576
                 use_transpose_UpConv=False,
577
                 use_bias=True,
578
                 padding='same',
579
                 data_format='channels_last',
580
                 **kwargs):
581
582
        super(Decoder, self).__init__(**kwargs)
583
584
        self.first_conv1x1 = tfkl.Conv2D(num_channels_from_backcone[0],
585
                                         kernel_size=1,
586
                                         padding=padding,
587
                                         data_format=data_format)
588
589
        self.second_conv1x1 = tfkl.Conv2D(num_channels_from_backcone[1],
590
                                          kernel_size=1,
591
                                          padding=padding,
592
                                          data_format=data_format)
593
594
        self.conv1 = Up_Conv2D(num_channels_conv=num_channels_UpConv[0],
595
                               kernel_size=kernel_size_UpConv,
596
                               use_batchnorm=use_batchnorm_UpConv,
597
                               use_transpose=use_transpose_UpConv,
598
                               strides=stride_UpConv)
599
600
        self.conv2 = Up_Conv2D(num_channels_conv=num_channels_UpConv[1],
601
                               kernel_size=kernel_size_UpConv,
602
                               use_batchnorm=use_batchnorm_UpConv,
603
                               use_transpose=use_transpose_UpConv,
604
                               strides=stride_UpConv)
605
606
        self.conv3 = Up_Conv2D(num_channels_conv=num_channels_UpConv[2],
607
                               kernel_size=kernel_size_UpConv,
608
                               use_batchnorm=use_batchnorm_UpConv,
609
                               use_transpose=use_transpose_UpConv,
610
                               strides=stride_UpConv)
611
612
    def call(self, in_atrous, in_encoder, in_DCNN, training=False):
613
614
        in_atrous = self.first_conv1x1(in_atrous, training=training)
615
        in_DCNN = self.second_conv1x1(in_DCNN, training=training)
616
617
        out = tf.concat([in_atrous, in_encoder], axis=3)
618
        out = self.conv1(out, training=training)
619
620
        out = tf.concat([in_DCNN, out], axis=3)
621
        out = self.conv2(out, training=training)
622
        out = self.conv3(out, training=training)
623
624
        return out
625
626
627
class Up_Conv2D(tf.keras.Sequential):
628
629
    def __init__(self,
630
                 num_channels_conv,
631
                 num_channels_UpConv=256,
632
                 kernel_size=3,
633
                 nonlinearity='relu',
634
                 use_batchnorm=False,
635
                 use_transpose=False,
636
                 use_bias=True,
637
                 strides=(2, 2),
638
                 padding='same',
639
                 data_format='channels_last',
640
                 **kwargs):
641
642
        super(Up_Conv2D, self).__init__(**kwargs)
643
644
        if use_transpose:
645
            self.add(tfkl.Conv2DTranspose(num_channels_UpConv,
646
                                          kernel_size,
647
                                          padding='same',
648
                                          strides=strides,
649
                                          data_format=data_format))
650
        else:
651
            self.add(tfkl.UpSampling2D(size=strides))
652
653
        self.add(aspp_block(kernel_size=kernel_size,
654
                            rate=1,
655
                            num_channels=num_channels_conv,
656
                            padding=padding,
657
                            use_batchnorm=use_batchnorm,
658
                            nonlinearity=nonlinearity,
659
                            use_bias=use_bias,
660
                            data_format=data_format))
661
662
    def call(self, x, training=False):
663
664
        out = super(Up_Conv2D, self).call(x, training=training)
665
        return out
666
667