a b/Segmentation/model/segnet.py
1
import tensorflow as tf
2
import tensorflow.keras.layers as tfkl
3
4
class SegNet (tf.keras.Model):
5
    """ Tensorflow 2 Implementation of 'SegNet: A Deep Convolutional Encoder-Decoder
6
    Architecture for Image Segmentation' https://arxiv.org/abs/1611.09326 """
7
8
    def __init__(self,
9
                 num_channels,
10
                 num_classes,
11
                 backbone='default',
12
                 kernel_size=(3, 3),
13
                 pool_size=(2, 2),
14
                 nonlinearity='relu',
15
                 use_batchnorm=True,
16
                 use_bias=True,
17
                 use_transpose=False,
18
                 use_dropout=False,
19
                 dropout_rate=0.25,
20
                 use_spatial_dropout=True,
21
                 data_format='channels_last',
22
                 **kwargs):
23
24
        super(SegNet, self).__init__(**kwargs)
25
26
        self.num_classes = num_classes
27
        self.num_channels = num_channels
28
        self.backbone = backbone
29
        self.kernel_size = kernel_size
30
        self.pool_size = pool_size
31
        self.nonlinearity = nonlinearity
32
        self.use_batchnorm = use_batchnorm
33
        self.use_bias = use_bias
34
        self.use_transpose = use_transpose
35
        self.use_dropout = use_dropout
36
        self.dropout_rate = dropout_rate
37
        self.use_spatial_dropout = use_spatial_dropout
38
        self.data_format = data_format
39
40
        self.conv_list = tf.keras.Sequential()
41
        for i in range(len(self.num_channels)):
42
            output_ch = self.num_channels[i]
43
            if i == 0 or i == 1:
44
                num_conv = 2
45
            else:
46
                num_conv = 3
47
48
            self.conv_list.add(SegNet_Conv2D_Block(output_ch,
49
                                                   num_conv,
50
                                                   self.kernel_size,
51
                                                   self.pool_size,
52
                                                   self.nonlinearity,
53
                                                   self.use_batchnorm,
54
                                                   self.use_bias,
55
                                                   self.use_dropout,
56
                                                   self.dropout_rate,
57
                                                   self.use_spatial_dropout,
58
                                                   self.data_format))
59
60
        self.up_conv_list = tf.keras.Sequential()
61
        n = len(self.num_channels) - 1
62
63
        for j in range(n, -1, -1):
64
            output = self.num_channels[j]
65
            if j in [n, n - 1, n - 2]:
66
                num_conv = 3
67
            else:
68
                num_conv = 2
69
            self.up_conv_list.add(segnet_Up_Conv2D_block(output,
70
                                                         num_conv_layers=num_conv,
71
                                                         kernel_size=(2, 2),
72
                                                         upsampling_size=(2, 2),
73
                                                         nonlinearity=self.nonlinearity,
74
                                                         use_batchnorm=self.use_batchnorm,
75
                                                         use_transpose=self.use_transpose,
76
                                                         use_bias=self.use_bias,
77
                                                         strides=(2, 2),
78
                                                         data_format=self.data_format))
79
80
        self.conv_1x1 = tfkl.Conv2D(num_classes,
81
                                    (1, 1),
82
                                    activation='linear',
83
                                    padding='same',
84
                                    data_format=data_format)
85
86
    def call(self, x, training=False):
87
88
        encoded = self.conv_list(x, training=training)
89
        decoded = self.up_conv_list(encoded, training=training)
90
        output = self.conv_1x1(decoded)
91
92
        if self.num_classes == 1:
93
            output = tfkl.Activation('sigmoid')(output)
94
        else:
95
            output = tfkl.Activation('softmax')(output)
96
        return output
97
98
99
class SegNet_Conv2D_Block(tf.keras.Sequential):
100
101
    def __init__(self,
102
                 num_channels,
103
                 num_conv_layers=2,
104
                 kernel_size=(3, 3),
105
                 pool_size=(2, 2),
106
                 nonlinearity='relu',
107
                 use_batchnorm=True,
108
                 use_bias=True,
109
                 use_dropout=False,
110
                 dropout_rate=0.25,
111
                 use_spatial_dropout=True,
112
                 data_format='channels_last',
113
                 **kwargs):
114
115
        super(SegNet_Conv2D_Block, self).__init__(**kwargs)
116
117
        for _ in range(num_conv_layers):
118
            self.add(tfkl.Conv2D(num_channels,
119
                                 kernel_size,
120
                                 padding='same',
121
                                 use_bias=use_bias,
122
                                 data_format=data_format))
123
            if use_batchnorm:
124
                self.add(tfkl.BatchNormalization(axis=-1,
125
                                                 momentum=0.95,
126
                                                 epsilon=0.001))
127
            self.add(tfkl.Activation(nonlinearity))
128
129
        if use_dropout:
130
            if use_spatial_dropout:
131
                self.add(tfkl.SpatialDropout2D(rate=dropout_rate))
132
            else:
133
                self.add(tfkl.Dropout(rate=dropout_rate))
134
135
        self.add(tfkl.MaxPool2D(pool_size))
136
137
    def call(self, x, training=False):
138
139
        output = super(SegNet_Conv2D_Block, self).call(x, training=training)
140
        return output
141
142
143
class segnet_Up_Conv2D_block(tf.keras.Sequential):
144
145
    def __init__(self,
146
                 num_channels,
147
                 num_conv_layers,
148
                 kernel_size=(3, 3),
149
                 upsampling_size=(2, 2),
150
                 nonlinearity='relu',
151
                 use_batchnorm=True,
152
                 use_transpose=False,
153
                 use_bias=True,
154
                 strides=(2, 2),
155
                 data_format='channels_last',
156
                 **kwargs):
157
158
        super(segnet_Up_Conv2D_block, self).__init__(**kwargs)
159
160
        if use_transpose:
161
            self.add(tfkl.Conv2DTranspose(num_channels,
162
                                          kernel_size,
163
                                          padding='same',
164
                                          strides=strides,
165
                                          data_format=data_format))
166
        else:
167
            self.add(tf.keras.layers.UpSampling2D(size=upsampling_size))
168
169
        for _ in range(num_conv_layers):
170
            self.add(tfkl.Conv2D(num_channels,
171
                                 kernel_size,
172
                                 padding='same',
173
                                 data_format=data_format))
174
            if use_batchnorm:
175
                self.add(tfkl.BatchNormalization(axis=-1,
176
                                                 momentum=0.95,
177
                                                 epsilon=0.001))
178
            self.add(tfkl.Activation(nonlinearity))
179
180
    def call(self, x, training=False):
181
182
        output = super(segnet_Up_Conv2D_block, self).call(x, training=training)
183
        return output