a b/Segmentation/model/unet_build_blocks.py
1
import tensorflow as tf
2
import tensorflow.keras.layers as tfkl
3
4
5
class Conv_Block(tf.keras.Sequential):
6
7
    def __init__(self,
8
                 num_channels,
9
                 use_2d=True,
10
                 num_conv_layers=2,
11
                 kernel_size=3,
12
                 nonlinearity='relu',
13
                 use_batchnorm=False,
14
                 use_bias=True,
15
                 use_dropout=False,
16
                 dropout_rate=0.25,
17
                 use_spatial_dropout=True,
18
                 data_format='channels_last',
19
                 **kwargs):
20
21
        super(Conv_Block, self).__init__(**kwargs)
22
23
        for _ in range(self.num_conv_layers):
24
            if use_2d:
25
                self.add(tfkl.Conv2D(num_channels,
26
                                     kernel_size,
27
                                     padding='same',
28
                                     use_bias=use_bias,
29
                                     data_format=data_format))
30
            else:
31
                self.add(tfkl.Conv3D(num_channels,
32
                                     kernel_size,
33
                                     padding='same',
34
                                     use_bias=use_bias,
35
                                     data_format=data_format))
36
            if use_batchnorm:
37
                self.add(tfkl.BatchNormalization(axis=-1 if data_format == 'channels_last' else 1,
38
                                                 momentum=0.95,
39
                                                 epsilon=0.001))
40
            self.add(tfkl.Activation(nonlinearity))
41
42
        if use_dropout:
43
            if use_spatial_dropout:
44
                if use_2d:
45
                    self.add(tfkl.SpatialDropout2D(rate=dropout_rate))
46
                else:
47
                    self.add(tfkl.SpatialDropout3D(rate=dropout_rate))
48
            else:
49
                self.add(tfkl.Dropout(rate=dropout_rate))
50
51
    def call(self, inputs, training=False):
52
53
        outputs = super(Conv_Block, self).call(inputs, training=training)
54
55
        return outputs
56
57
58
class Up_Conv(tf.keras.Model):
59
60
    def __init__(self,
61
                 num_channels,
62
                 use_2d=True,
63
                 kernel_size=2,
64
                 nonlinearity='relu',
65
                 use_attention=False,
66
                 use_batchnorm=False,
67
                 use_transpose=False,
68
                 use_bias=True,
69
                 strides=2,
70
                 data_format='channels_last',
71
                 **kwargs):
72
73
        super(Up_Conv, self).__init__(**kwargs)
74
75
        self.data_format = data_format
76
77
        if use_transpose:
78
            if use_2d:
79
                self.upconv_layer = tfkl.Conv2DTranspose(num_channels,
80
                                                         kernel_size,
81
                                                         padding='same',
82
                                                         strides=strides,
83
                                                         data_format=self.data_format)
84
            else:
85
                self.upconv_layer = tfkl.Conv3DTranspose(num_channels,
86
                                                         kernel_size,
87
                                                         padding='same',
88
                                                         strides=strides,
89
                                                         data_format=self.data_format)
90
        else:
91
            if use_2d:
92
                self.upconv_layer = tfkl.UpSampling2D(size=strides)
93
            else:
94
                self.upconv_layer = tfkl.UpSampling3D(size=strides)
95
96
        if self.use_attention:
97
            self.attention = Attention_Gate(num_channels=num_channels,
98
                                            use_2d=use_2d,
99
                                            kernel_size=1,
100
                                            nonlinearity=nonlinearity,
101
                                            padding='same',
102
                                            strides=strides,
103
                                            use_bias=use_bias,
104
                                            data_format=self.data_format)
105
106
        self.conv = Conv_Block(num_channels=num_channels,
107
                               use_2d=use_2d,
108
                               num_conv_layers=1,
109
                               kernel_size=kernel_size,
110
                               nonlinearity=nonlinearity,
111
                               use_batchnorm=use_batchnorm,
112
                               use_dropout=False,
113
                               data_format=self.data_format)
114
115
        self.conv_block = Conv_Block(num_channels=num_channels,
116
                                     use_2d=use_2d,
117
                                     num_conv_layers=2,
118
                                     kernel_size=3,
119
                                     nonlinearity=nonlinearity,
120
                                     use_batchnorm=use_batchnorm,
121
                                     use_dropout=False,
122
                                     data_format=self.data_format)
123
124
    def call(self, inputs, bridge, training=False):
125
126
        up = self.upconv_layer(inputs)
127
        up = self.conv(up, training=training)
128
        if self.use_attention:
129
            up = self.attention(bridge, up, training=training)
130
        out = tfkl.concatenate([up, bridge], axis=-1 if self.data_format == 'channels_last' else 1)
131
        out = self.conv_block(out, training=training)
132
133
        return out
134
135
136
class Attention_Gate(tf.keras.Model):
137
138
    def __init__(self,
139
                 num_channels,
140
                 use_2d=True,
141
                 kernel_size=1,
142
                 nonlinearity='relu',
143
                 padding='same',
144
                 strides=1,
145
                 use_bias=True,
146
                 use_batchnorm=True,
147
                 data_format='channels_last',
148
                 **kwargs):
149
150
        super(Attention_Gate, self).__init__(**kwargs)
151
152
        self.conv_blocks = []
153
        self.data_format = data_format
154
155
        for _ in range(3):
156
            self.conv_blocks.append(Conv_Block(num_channels,
157
                                               use_2d=use_2d,
158
                                               num_conv_layers=1,
159
                                               kernel_size=kernel_size,
160
                                               nonlinearity=nonlinearity,
161
                                               use_batchnorm=use_batchnorm,
162
                                               use_dropout=False,
163
                                               data_format=self.data_format))
164
165
    def call(self, input_x, input_g, training=False):
166
167
        x_g = self.conv_blocks[0](input_g, training=training)
168
        x_l = self.conv_blocks[1](input_x, training=training)
169
170
        x = tfkl.concatenate([x_g, x_l], axis=-1 if self.data_format == 'channels_last' else 1)
171
        x = tfkl.Activation('relu')(x)
172
173
        x = self.conv_blocks[2](x, training=training)
174
        alpha = tfkl.Activation('sigmoid')(x)
175
176
        outputs = tf.math.multiply(alpha, input_x)
177
178
        return outputs
179
180
181
class Recurrent_Block(tf.keras.Model):
182
183
    def __init__(self,
184
                 num_channels,
185
                 use_2d=True,
186
                 kernel_size=3,
187
                 nonlinearity='relu',
188
                 padding='same',
189
                 strides=1,
190
                 t=2,
191
                 use_batchnorm=True,
192
                 data_format='channels_last',
193
                 **kwargs):
194
195
        super(Recurrent_Block, self).__init__(**kwargs)
196
197
        self.conv = Conv_Block(num_channels=num_channels,
198
                               use_2d=use_2d,
199
                               num_conv_layers=1,
200
                               kernel_size=kernel_size,
201
                               nonlinearity=nonlinearity,
202
                               use_batchnorm=use_batchnorm,
203
                               data_format=data_format)
204
205
    def call(self, x, training=False):
206
207
        for i in range(self.t):
208
209
            if i == 0:
210
                x1 = self.conv(x, training=training)
211
212
            x1 = tfkl.Add()([x, x1])
213
            x1 = self.conv(x1, training=training)
214
215
        return x1
216
217
218
class Recurrent_ResConv_block(tf.keras.Model):
219
    def __init__(self,
220
                 num_channels,
221
                 use_2d=True,
222
                 kernel_size=3,
223
                 nonlinearity='relu',
224
                 padding='same',
225
                 strides=1,
226
                 t=2,
227
                 use_batchnorm=True,
228
                 data_format='channels_last',
229
                 **kwargs):
230
231
        super(Recurrent_ResConv_block, self).__init__(**kwargs)
232
233
        self.Recurrent_CNN = tf.keras.Sequential([
234
            Recurrent_Block(num_channels,
235
                            use_2d,
236
                            kernel_size,
237
                            nonlinearity,
238
                            padding,
239
                            strides,
240
                            t,
241
                            use_batchnorm,
242
                            data_format),
243
            Recurrent_Block(num_channels,
244
                            use_2d,
245
                            kernel_size,
246
                            nonlinearity,
247
                            padding,
248
                            strides,
249
                            t,
250
                            use_batchnorm,
251
                            data_format)])
252
253
        if use_2d:
254
            self.Conv_1x1 = tf.keras.layers.Conv2D(num_channels,
255
                                                   kernel_size=(1, 1),
256
                                                   strides=strides,
257
                                                   padding=padding,
258
                                                   data_format=data_format)
259
        else:
260
            self.Conv_1x1 = tf.keras.layers.Conv3D(num_channels,
261
                                                   kernel_size=(1, 1, 1),
262
                                                   strides=strides,
263
                                                   padding=padding,
264
                                                   data_format=data_format)
265
266
    def call(self, x):
267
        x = self.Conv_1x1(x)
268
        x1 = self.Recurrent_CNN(x)
269
        output = tfkl.Add()([x, x1])
270
271
        return output