a b/Segmentation/model/unet.py
1
import tensorflow as tf
2
import tensorflow.keras.layers as tfkl
3
from Segmentation.model.unet_build_blocks import Conv_Block, Up_Conv
4
from Segmentation.model.unet_build_blocks import Attention_Gate
5
from Segmentation.model.unet_build_blocks import Recurrent_ResConv_block
6
from Segmentation.model.backbone import Encoder
7
8
9
class UNet(tf.keras.Model):
10
    """ Tensorflow 2 Implementation of 'U-Net: Convolutional Networks for
11
    Biomedical Image Segmentation' https://arxiv.org/abs/1505.04597."""
12
13
    def __init__(self,
14
                 num_channels,
15
                 num_classes,
16
                 use_2d=True,
17
                 backbone_name='default',
18
                 num_conv_layers=2,
19
                 kernel_size=3,
20
                 nonlinearity='relu',
21
                 use_attention=False,
22
                 use_batchnorm=True,
23
                 use_bias=True,
24
                 use_dropout=False,
25
                 dropout_rate=0.25,
26
                 use_spatial_dropout=True,
27
                 data_format='channels_last',
28
                 **kwargs):
29
30
        super(UNet, self).__init__(**kwargs)
31
32
        self.backbone_name = backbone_name
33
        self.contracting_path = []
34
        self.upsampling_path = []
35
36
        if self.backbone_name == 'default':
37
            for i in range(len(num_channels)):
38
                output = num_channels[i]
39
                self.contracting_path.append(Conv_Block(num_channels=output,
40
                                                        use_2d=use_2d,
41
                                                        num_conv_layers=num_conv_layers,
42
                                                        kernel_size=kernel_size,
43
                                                        nonlinearity=nonlinearity,
44
                                                        use_batchnorm=use_batchnorm,
45
                                                        use_bias=use_bias,
46
                                                        use_dropout=use_dropout,
47
                                                        dropout_rate=dropout_rate,
48
                                                        use_spatial_dropout=use_spatial_dropout,
49
                                                        data_format=data_format))
50
                if i != len(num_channels) - 1:
51
                    if use_2d:
52
                        self.contracting_path.append(tfkl.MaxPooling2D())
53
                    else:
54
                        self.contracting_path.append(tfkl.MaxPooling3D())
55
        else:
56
            assert use_2d is True
57
            encoder = Encoder(weights_init='imagenet', model_architecture=backbone_name)
58
            encoder.freeze_pretrained_layers()
59
            self.backbone = encoder.construct_backbone()
60
61
        n = len(self.num_channels) - 2
62
        for i in range(n, -1, -1):
63
            output = num_channels[i]
64
            self.upsampling_path.append(Up_Conv(output,
65
                                                use_2d=use_2d,
66
                                                kernel_size=2,
67
                                                nonlinearity=nonlinearity,
68
                                                use_attention=use_attention,
69
                                                use_batchnorm=use_batchnorm,
70
                                                use_transpose=False,
71
                                                use_bias=use_bias,
72
                                                strides=2,
73
                                                data_format=data_format))
74
75
        if use_2d:
76
            self.conv_1x1 = tfkl.Conv2D(num_classes,
77
                                        (1, 1),
78
                                        activation='sigmoid' if num_classes == 1 else 'softmax',
79
                                        padding='same',
80
                                        data_format=data_format)
81
        else:
82
            self.conv_1x1 = tfkl.Conv3D(num_classes,
83
                                        (1, 1, 1),
84
                                        activation='linear' if num_classes == 1 else 'softmax',
85
                                        padding='same',
86
                                        data_format=data_format)
87
88
    def call(self, x, training=False):
89
        blocks = []
90
        if self.backbone_name == 'default':
91
            for i, down in enumerate(self.contracting_path):
92
                x = down(x, training=training)
93
                if i != len(self.contracting_path) - 1:
94
                    blocks.append(x)
95
        else:
96
            bridge_1, bridge_2, bridge_3, bridge_4, x = self.backbone(x, training=training)
97
            blocks.extend([bridge_1, bridge_2, bridge_3, bridge_4])
98
99
        for j, up in enumerate(self.upsampling_path):
100
            if self.backbone_name in ['default']:
101
                x = up(x, blocks[-2 * j - 2], training=training)
102
            else:
103
                x = up(x, blocks[-j - 1], training=training)
104
105
        del blocks
106
107
        if self.backbone_name not in ['default', 'vgg16', 'vgg19']:
108
            x = tfkl.UpSampling2D()(x)
109
110
        output = self.conv_1x1(x)
111
        return output
112
113
class R2_UNet(tf.keras.Model):
114
    """ Tensorflow 2 Implementation of 'Recurrent Residual Convolutional
115
    Neural Network based on U-Net (R2U-Net) for Medical Image Segmentation'
116
    https://arxiv.org/ftp/arxiv/papers/1802/1802.06955.pdf."""
117
118
    def __init__(self,
119
                 num_channels,
120
                 num_classes,
121
                 use_2d=True,
122
                 num_conv_layers=2,
123
                 kernel_size=3,
124
                 nonlinearity='relu',
125
                 t=2,
126
                 use_attention=False,
127
                 use_batchnorm=True,
128
                 use_bias=True,
129
                 data_format='channels_last',
130
                 **kwargs):
131
132
        super(R2_UNet, self).__init__(**kwargs)
133
134
        self.contracting_path = []
135
        self.upsampling_path = []
136
137
        for i in range(len(num_channels)):
138
            output = num_channels[i]
139
            self.contracting_path.append(Recurrent_ResConv_block(num_channels=output,
140
                                                                 use_2d=use_2d,
141
                                                                 kernel_size=kernel_size,
142
                                                                 nonlinearity=nonlinearity,
143
                                                                 padding='same',
144
                                                                 strides=1,
145
                                                                 t=t,
146
                                                                 use_batchnorm=use_batchnorm,
147
                                                                 data_format=data_format))
148
            if i != len(num_channels) - 1:
149
                if use_2d:
150
                    self.contracting_path.append(tfkl.MaxPooling2D())
151
                else:
152
                    self.contracting_path.append(tfkl.MaxPooling3D())
153
154
        n = len(num_channels) - 2
155
        for i in range(n, -1, -1):
156
            output = num_channels[i]
157
            up_conv = Up_Conv(output,
158
                              use_2d,
159
                              kernel_size=2,
160
                              nonlinearity=nonlinearity,
161
                              use_attention=use_attention,
162
                              use_batchnorm=use_batchnorm,
163
                              use_transpose=False,
164
                              use_bias=use_bias,
165
                              strides=2,
166
                              data_format=data_format)
167
168
            # override default conv block with recurrent-residual conv block
169
            up_conv.conv_block = Recurrent_ResConv_block(num_channels=output,
170
                                                         use_2d=use_2d,
171
                                                         kernel_size=kernel_size,
172
                                                         nonlinearity=nonlinearity,
173
                                                         padding='same',
174
                                                         strides=1,
175
                                                         t=t,
176
                                                         use_batchnorm=use_batchnorm,
177
                                                         data_format=data_format)
178
179
            self.upsampling_path.append(up_conv)
180
181
        if use_2d:
182
            self.conv_1x1 = tfkl.Conv2D(filters=num_classes,
183
                                        kernel_size=(1, 1),
184
                                        activation='sigmoid' if num_classes == 1 else 'softmax',
185
                                        padding='same',
186
                                        data_format=data_format)
187
        else:
188
            self.conv_1x1 = tfkl.Conv3D(filters=num_classes,
189
                                        kernel_size=(1, 1, 1),
190
                                        activation='sigmoid' if num_classes == 1 else 'softmax',
191
                                        padding='same',
192
                                        data_format=data_format)
193
194
    def call(self, x, training=False):
195
        blocks = []
196
        for i, down in enumerate(self.contracting_path):
197
            x = down(x, training=training)
198
            if i != len(self.contracting_path) - 1:
199
                blocks.append(x)
200
201
        for j, up in enumerate(self.upsampling_path):
202
            x = up(x, blocks[-2 * j - 2], training=training)
203
204
        del blocks
205
206
        output = self.conv_1x1(x)
207
208
        return output
209
210
class Nested_UNet(tf.keras.Model):
211
    """ Tensorflow 2 Implementation of 'UNet++: A Nested
212
    U-Net Architecture for Medical Image Segmentation'
213
    https://arxiv.org/pdf/1807.10165.pdf """
214
215
    def __init__(self,
216
                 num_channels,
217
                 num_classes,
218
                 use_2d=True,
219
                 num_conv_layers=2,
220
                 kernel_size=(3, 3),
221
                 nonlinearity='relu',
222
                 use_batchnorm=True,
223
                 use_bias=True,
224
                 data_format='channels_last',
225
                 **kwargs):
226
227
        super(Nested_UNet, self).__init__(**kwargs)
228
229
        self.conv_block_lists = []
230
        self.pool = tfkl.MaxPooling2D() if use_2d else tfkl.MaxPooling3D()
231
        self.up = tfkl.UpSampling2D() if use_2d else tfkl.UpSampling3D()
232
233
        for i in range(len(num_channels)):
234
            output_ch = num_channels[i]
235
            conv_layer_lists = []
236
            num_conv_blocks = len(num_channels) - i
237
238
            for _ in range(num_conv_blocks):
239
                conv_layer_lists.append(Conv_Block(num_channels=output_ch,
240
                                                   use_2d=use_2d,
241
                                                   num_conv_layers=num_conv_layers,
242
                                                   kernel_size=kernel_size,
243
                                                   nonlinearity=nonlinearity,
244
                                                   use_batchnorm=use_batchnorm,
245
                                                   use_bias=use_bias,
246
                                                   data_format=data_format))
247
248
            self.conv_block_lists.append(conv_layer_lists)
249
250
        if use_2d:
251
            self.conv_1x1 = tfkl.Conv2D(num_classes,
252
                                        (1, 1),
253
                                        activation='sigmoid' if self.num_classes == 1 else 'softmax',
254
                                        padding='same',
255
                                        data_format=data_format)
256
        else:
257
            self.conv_1x1 = tfkl.Conv3D(num_classes,
258
                                        (1, 1, 1),
259
                                        activation='sigmoid' if self.num_classes == 1 else 'softmax',
260
                                        padding='same',
261
                                        data_format=data_format)
262
263
    def call(self, input, training=False):
264
265
        block_list = []
266
        x = self.conv_block_lists[0][0](input, training=training)
267
        block_list.append([x])
268
        for sum_idx in range(1, len(self.conv_block_lists)):
269
            left_idx = sum_idx
270
            right_idx = 0
271
            layer_list = []
272
            while right_idx <= sum_idx:
273
                if left_idx == sum_idx:
274
                    x = self.conv_block_lists[left_idx][right_idx](self.pool(block_list[left_idx - 1][right_idx]),
275
                                                                   training=training)
276
                else:
277
                    concat_list = [self.up(x)]
278
                    for idx in range(1, right_idx + 1):
279
                        concat_list.append(block_list[left_idx + idx - 1][-1 + idx])
280
                    x = self.conv_block_lists[left_idx][right_idx](tfkl.concatenate(concat_list),
281
                                                                   training=training)
282
                left_idx -= 1
283
                right_idx += 1
284
                layer_list.append(x)
285
            block_list.append(layer_list)
286
        output = self.conv_1x1(x)
287
288
        return output
289
290
class Nested_UNet_v2(tf.keras.Model):
291
292
    def __init__(self,
293
                 num_channels,
294
                 num_classes,
295
                 use_2d=True,
296
                 num_conv_layers=2,
297
                 kernel_size=(3, 3),
298
                 nonlinearity='relu',
299
                 use_batchnorm=True,
300
                 use_bias=True,
301
                 data_format='channels_last',
302
                 **kwargs):
303
304
        super(Nested_UNet, self).__init__(**kwargs)
305
306
        self.conv_block_lists = []
307
        self.pool = tfkl.MaxPooling2D() if use_2d else tfkl.MaxPooling3D()
308
        self.up = tfkl.UpSampling2D() if use_2d else tfkl.UpSampling3D()
309
310
        for i in range(len(num_channels)):
311
            output_ch = num_channels[i]
312
            conv_layer_lists = []
313
            num_conv_blocks = len(num_channels) - i
314
315
            for _ in range(num_conv_blocks):
316
                conv_layer_lists.append(Conv_Block(num_channels=output_ch,
317
                                                   use_2d=use_2d,
318
                                                   num_conv_layers=num_conv_layers,
319
                                                   kernel_size=kernel_size,
320
                                                   nonlinearity=nonlinearity,
321
                                                   use_batchnorm=use_batchnorm,
322
                                                   use_bias=use_bias,
323
                                                   data_format=data_format))
324
325
            self.conv_block_lists.append(conv_layer_lists)
326
327
        if use_2d:
328
            self.conv_1x1 = tfkl.Conv2D(num_classes,
329
                                        (1, 1),
330
                                        activation='sigmoid' if self.num_classes == 1 else 'softmax',
331
                                        padding='same',
332
                                        data_format=data_format)
333
        else:
334
            self.conv_1x1 = tfkl.Conv3D(num_classes,
335
                                        (1, 1, 1),
336
                                        activation='sigmoid' if self.num_classes == 1 else 'softmax',
337
                                        padding='same',
338
                                        data_format=data_format)
339
340
    def call(self, input, training=False):
341
342
        x = dict()
343
        use_x = list()
344
        x['0_0'] = self.conv_block_lists[0][0](input, training=training)
345
        last_0_name = '0_0'
346
        last_name = last_0_name
347
348
        for sum in range(1, len(self.conv_block_lists)):
349
            i, j = sum, 0
350
            while j <= sum:
351
352
                name = str(i) + '_' + str(j)
353
354
                if i == sum:
355
                    x[name] = self.conv_block_lists[i][j](self.pool(x[last_0_name]), training=training)
356
                    last_0_name = name
357
358
                else:
359
                    for temp_right in range(0, j):
360
                        string = str(i) + '_' + str(temp_right)
361
                        use_x.append(x[string])
362
363
                    use_x.append(self.up(x[last_name]))
364
                    x[name] = self.conv_block_lists[i][j](tfkl.concatenate(use_x), training=training)
365
366
                use_x.clear()
367
                last = (i, j)
368
                last_name = name
369
                i = i - 1
370
                j = j + 1
371
372
        output = self.conv_1x1(x[last_name])
373
374
        return output