|
a |
|
b/Segmentation/model/deeplabv3.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
import tensorflow.keras.layers as tfkl |
|
|
3 |
|
|
|
4 |
class Deeplabv3_plus(tf.keras.Model): |
|
|
5 |
def __init__(self, |
|
|
6 |
num_classes, |
|
|
7 |
kernel_size_initial_conv, |
|
|
8 |
num_channels_atrous=512, |
|
|
9 |
num_channels_DCNN=[256, 512, 1024], |
|
|
10 |
num_channels_ASPP=256, |
|
|
11 |
kernel_size_atrous=3, |
|
|
12 |
kernel_size_DCNN=[1, 3], |
|
|
13 |
kernel_size_ASPP=[1, 3, 3, 3], |
|
|
14 |
num_filters_final_encoder=512, |
|
|
15 |
num_channels_from_backcone=[128, 96], |
|
|
16 |
num_channels_UpConv=[512, 256, 128], |
|
|
17 |
kernel_size_UpConv=3, |
|
|
18 |
stride_UpConv=(2, 2), |
|
|
19 |
use_batchnorm_UpConv=False, |
|
|
20 |
use_transpose_UpConv=False, |
|
|
21 |
padding='same', |
|
|
22 |
nonlinearity='relu', |
|
|
23 |
use_batchnorm=True, |
|
|
24 |
use_bias=True, |
|
|
25 |
data_format='channels_last', |
|
|
26 |
MultiGrid=[1, 2, 4], |
|
|
27 |
rate_ASPP=[1, 6, 12, 18], |
|
|
28 |
atrous_output_stride=16, |
|
|
29 |
# Not adapted code for any other out stride |
|
|
30 |
**kwargs): |
|
|
31 |
|
|
|
32 |
""" Arguments: |
|
|
33 |
kernel_size_initial_conv: the size of the kernel for the |
|
|
34 |
first convolution |
|
|
35 |
num_channels_DCNN: touple with the number of channels for the |
|
|
36 |
first three blocks of the DCNN |
|
|
37 |
kernel_size_DCNN: two element touple with the kernel size of the |
|
|
38 |
first and last convolution of the resnet_block |
|
|
39 |
(First element) and the middle convolution |
|
|
40 |
of the resnet_block (Second element) """ |
|
|
41 |
|
|
|
42 |
super(Deeplabv3_plus, self).__init__(**kwargs) |
|
|
43 |
|
|
|
44 |
self.num_classes = num_classes |
|
|
45 |
|
|
|
46 |
#ResNet backbone |
|
|
47 |
self.first_conv = tfkl.Conv2D(num_channels_DCNN[0], |
|
|
48 |
kernel_size_initial_conv, |
|
|
49 |
strides=2, |
|
|
50 |
padding=padding, |
|
|
51 |
use_bias=use_bias, |
|
|
52 |
data_format=data_format) |
|
|
53 |
|
|
|
54 |
self.block1 = resnet_block(False, |
|
|
55 |
num_channels_DCNN[0], |
|
|
56 |
kernel_size_DCNN, |
|
|
57 |
padding, |
|
|
58 |
nonlinearity, |
|
|
59 |
use_batchnorm, |
|
|
60 |
use_bias, |
|
|
61 |
data_format) |
|
|
62 |
|
|
|
63 |
self.block2 = resnet_block(True, |
|
|
64 |
num_channels_DCNN[1], |
|
|
65 |
kernel_size_DCNN, |
|
|
66 |
padding, |
|
|
67 |
nonlinearity, |
|
|
68 |
use_batchnorm, |
|
|
69 |
use_bias, |
|
|
70 |
data_format) |
|
|
71 |
|
|
|
72 |
self.block3 = resnet_block(True, |
|
|
73 |
num_channels_DCNN[2], |
|
|
74 |
kernel_size_DCNN, |
|
|
75 |
padding, |
|
|
76 |
nonlinearity, |
|
|
77 |
use_batchnorm, |
|
|
78 |
use_bias, |
|
|
79 |
data_format) |
|
|
80 |
|
|
|
81 |
#Atrous components |
|
|
82 |
self.atrous_conv = Atrous_conv(num_channels_atrous, |
|
|
83 |
kernel_size_atrous, |
|
|
84 |
MultiGrid, |
|
|
85 |
padding, |
|
|
86 |
use_batchnorm, |
|
|
87 |
'linear', |
|
|
88 |
use_bias, |
|
|
89 |
data_format, |
|
|
90 |
atrous_output_stride) |
|
|
91 |
|
|
|
92 |
self.aspp_term = atrous_spatial_pyramid_pooling(num_channels_ASPP, |
|
|
93 |
kernel_size_ASPP, |
|
|
94 |
rate_ASPP, |
|
|
95 |
padding, |
|
|
96 |
use_batchnorm, |
|
|
97 |
'linear', |
|
|
98 |
use_bias, |
|
|
99 |
data_format) |
|
|
100 |
|
|
|
101 |
#Final convolution of encoder |
|
|
102 |
self.final_encoder_conv = aspp_block(1, |
|
|
103 |
1, |
|
|
104 |
num_filters_final_encoder, |
|
|
105 |
padding, |
|
|
106 |
use_batchnorm, |
|
|
107 |
'linear', |
|
|
108 |
use_bias, |
|
|
109 |
data_format) |
|
|
110 |
|
|
|
111 |
#Decoder |
|
|
112 |
self.decoder_term = Decoder(num_channels_from_backcone=num_channels_from_backcone, |
|
|
113 |
num_channels_UpConv=num_channels_UpConv, |
|
|
114 |
kernel_size_UpConv=kernel_size_UpConv, |
|
|
115 |
stride_UpConv=stride_UpConv, |
|
|
116 |
use_batchnorm_UpConv=use_batchnorm_UpConv, |
|
|
117 |
use_transpose_UpConv=use_transpose_UpConv, |
|
|
118 |
use_bias=use_bias, |
|
|
119 |
padding=padding, |
|
|
120 |
data_format=data_format) |
|
|
121 |
|
|
|
122 |
self.output_conv = tfkl.Conv2D(num_classes, |
|
|
123 |
1, |
|
|
124 |
activation='linear', |
|
|
125 |
padding='same', |
|
|
126 |
data_format=data_format) |
|
|
127 |
|
|
|
128 |
def call(self, x, training=False): |
|
|
129 |
|
|
|
130 |
###Encoder |
|
|
131 |
before_final_stride = self.first_conv(x, training=training) # output stride 2 |
|
|
132 |
|
|
|
133 |
before_final_stride = self.block1(before_final_stride, training=training) # output stride 2 |
|
|
134 |
before_final_stride = self.block2(before_final_stride, training=training) # output stride 4 |
|
|
135 |
atrous_out = self.block3(before_final_stride, training=training) # output stride 8 |
|
|
136 |
|
|
|
137 |
atrous_out = self.atrous_conv(atrous_out, training=training) |
|
|
138 |
out = self.aspp_term(atrous_out, training=training) |
|
|
139 |
out = self.final_encoder_conv(out, training=training) |
|
|
140 |
|
|
|
141 |
###Decoder |
|
|
142 |
out = self.decoder_term(atrous_out ,out, before_final_stride, training=training) |
|
|
143 |
|
|
|
144 |
out = self.output_conv(out, training=training) |
|
|
145 |
if self.num_classes == 1: |
|
|
146 |
out = tfkl.Activation('sigmoid')(out) |
|
|
147 |
else: |
|
|
148 |
out = tfkl.Activation('softmax')(out) |
|
|
149 |
|
|
|
150 |
# Upsample to same size as the input |
|
|
151 |
# print(f"Input Shape: {x.shape}, Out Shape: {decoder_out.shape}") |
|
|
152 |
# input_size = tf.shape(x)[1:3] |
|
|
153 |
# decoder_out = tf.image.resize(decoder_out, input_size) |
|
|
154 |
|
|
|
155 |
return out |
|
|
156 |
|
|
|
157 |
|
|
|
158 |
class Deeplabv3(tf.keras.Sequential): |
|
|
159 |
""" Tensorflow 2 Implementation of """ |
|
|
160 |
def __init__(self, |
|
|
161 |
num_classes, |
|
|
162 |
kernel_size_initial_conv, |
|
|
163 |
num_channels_atrous, |
|
|
164 |
num_channels_DCNN=[256, 512, 1024], |
|
|
165 |
num_channels_ASPP=256, |
|
|
166 |
kernel_size_atrous=3, |
|
|
167 |
kernel_size_DCNN=[1, 3], |
|
|
168 |
kernel_size_ASPP=[1, 3, 3, 3], |
|
|
169 |
padding='same', |
|
|
170 |
nonlinearity='relu', |
|
|
171 |
use_batchnorm=True, |
|
|
172 |
use_bias=True, |
|
|
173 |
data_format='channels_last', |
|
|
174 |
MultiGrid=[1, 2, 4], |
|
|
175 |
rate_ASPP=[1, 6, 12, 18], |
|
|
176 |
atrous_output_stride=16, |
|
|
177 |
# Not adapted code for any other out stride |
|
|
178 |
**kwargs): |
|
|
179 |
|
|
|
180 |
""" Arguments: |
|
|
181 |
kernel_size_initial_conv: the size of the kernel for the |
|
|
182 |
first convolution |
|
|
183 |
num_channels_DCNN: touple with the number of channels for the |
|
|
184 |
first three blocks of the DCNN |
|
|
185 |
kernel_size_DCNN: two element touple with the kernel size of the |
|
|
186 |
first and last convolution of the resnet_block |
|
|
187 |
(First element) and the middle convolution |
|
|
188 |
of the resnet_block (Second element) """ |
|
|
189 |
|
|
|
190 |
super(Deeplabv3, self).__init__(**kwargs) |
|
|
191 |
|
|
|
192 |
self.num_classes = num_classes |
|
|
193 |
|
|
|
194 |
self.add(ResNet_Backbone(kernel_size_initial_conv, |
|
|
195 |
num_channels_DCNN, |
|
|
196 |
kernel_size_DCNN, |
|
|
197 |
padding, |
|
|
198 |
nonlinearity, |
|
|
199 |
use_batchnorm, |
|
|
200 |
use_bias, |
|
|
201 |
False, |
|
|
202 |
data_format)) |
|
|
203 |
|
|
|
204 |
self.add(Atrous_conv(num_channels_atrous, |
|
|
205 |
kernel_size_atrous, |
|
|
206 |
MultiGrid, |
|
|
207 |
padding, |
|
|
208 |
use_batchnorm, |
|
|
209 |
'linear', |
|
|
210 |
use_bias, |
|
|
211 |
data_format, |
|
|
212 |
atrous_output_stride)) |
|
|
213 |
|
|
|
214 |
self.add(atrous_spatial_pyramid_pooling(num_channels_ASPP, |
|
|
215 |
kernel_size_ASPP, |
|
|
216 |
rate_ASPP, |
|
|
217 |
padding, |
|
|
218 |
use_batchnorm, |
|
|
219 |
'linear', |
|
|
220 |
use_bias, |
|
|
221 |
data_format)) |
|
|
222 |
|
|
|
223 |
self.add(aspp_block(1, |
|
|
224 |
1, |
|
|
225 |
num_classes, |
|
|
226 |
padding, |
|
|
227 |
use_batchnorm, |
|
|
228 |
'linear', |
|
|
229 |
use_bias, |
|
|
230 |
data_format)) |
|
|
231 |
|
|
|
232 |
def call(self, x, training=False): |
|
|
233 |
|
|
|
234 |
out = super(Deeplabv3, self).call(x, training=training) |
|
|
235 |
if self.num_classes == 1: |
|
|
236 |
out = tfkl.Activation('sigmoid')(out) |
|
|
237 |
else: |
|
|
238 |
out = tfkl.Activation('softmax')(out) |
|
|
239 |
|
|
|
240 |
# Upsample to same size as the input |
|
|
241 |
input_size = tf.shape(x)[1:3] |
|
|
242 |
out = tf.image.resize(out, input_size) |
|
|
243 |
|
|
|
244 |
return out |
|
|
245 |
|
|
|
246 |
class ResNet_Backbone(tf.keras.Model): |
|
|
247 |
def __init__(self, |
|
|
248 |
kernel_size_initial_conv, |
|
|
249 |
num_channels=[256, 512, 1024], |
|
|
250 |
kernel_size_blocks=[1, 3], |
|
|
251 |
padding='same', |
|
|
252 |
nonlinearity='relu', |
|
|
253 |
use_batchnorm=True, |
|
|
254 |
use_bias=True, |
|
|
255 |
use_pooling=False, |
|
|
256 |
data_format='channels_last', |
|
|
257 |
**kwargs): |
|
|
258 |
|
|
|
259 |
super(ResNet_Backbone, self).__init__(**kwargs) |
|
|
260 |
self.first_conv = tfkl.Conv2D(num_channels[0], |
|
|
261 |
kernel_size_initial_conv, |
|
|
262 |
strides=2, |
|
|
263 |
padding=padding, |
|
|
264 |
use_bias=use_bias, |
|
|
265 |
data_format=data_format) |
|
|
266 |
|
|
|
267 |
self.max_pool = tfkl.MaxPool2D(pool_size=(2, 2), |
|
|
268 |
padding='valid') |
|
|
269 |
|
|
|
270 |
self.use_pooling = use_pooling |
|
|
271 |
|
|
|
272 |
self.block1 = resnet_block(False, |
|
|
273 |
num_channels[0], |
|
|
274 |
kernel_size_blocks, |
|
|
275 |
padding, |
|
|
276 |
nonlinearity, |
|
|
277 |
use_batchnorm, |
|
|
278 |
use_bias, |
|
|
279 |
data_format) |
|
|
280 |
self.block2 = resnet_block(True, |
|
|
281 |
num_channels[1], |
|
|
282 |
kernel_size_blocks, |
|
|
283 |
padding, |
|
|
284 |
nonlinearity, |
|
|
285 |
use_batchnorm, |
|
|
286 |
use_bias, |
|
|
287 |
data_format) |
|
|
288 |
|
|
|
289 |
self.block3 = resnet_block(True, |
|
|
290 |
num_channels[2], |
|
|
291 |
kernel_size_blocks, |
|
|
292 |
padding, |
|
|
293 |
nonlinearity, |
|
|
294 |
use_batchnorm, |
|
|
295 |
use_bias, |
|
|
296 |
data_format) |
|
|
297 |
|
|
|
298 |
def call(self, x, training=False): |
|
|
299 |
|
|
|
300 |
x = self.first_conv(x, training=training) # output stride 2 |
|
|
301 |
if self.use_pooling: |
|
|
302 |
x = self.max_pool(x) # output stride 4 |
|
|
303 |
|
|
|
304 |
x = self.block1(x, training=training) # output stride 2 or 4 |
|
|
305 |
x = self.block2(x, training=training) # output stride 4 or 8 |
|
|
306 |
x = self.block3(x, training=training) # output stride 8 or 16 |
|
|
307 |
return x |
|
|
308 |
|
|
|
309 |
|
|
|
310 |
# full pre-activation residual unit |
|
|
311 |
class resnet_block(tf.keras.Model): |
|
|
312 |
|
|
|
313 |
def __init__(self, |
|
|
314 |
use_stride, |
|
|
315 |
num_channels, |
|
|
316 |
kernel_size=[1, 3], |
|
|
317 |
padding='same', |
|
|
318 |
nonlinearity='relu', |
|
|
319 |
use_batchnorm=True, |
|
|
320 |
use_bias=True, |
|
|
321 |
data_format='channels_last', |
|
|
322 |
**kwargs): |
|
|
323 |
|
|
|
324 |
super(resnet_block, self).__init__(**kwargs) |
|
|
325 |
self.use_stride = use_stride |
|
|
326 |
inner_num_channels = num_channels // 4 |
|
|
327 |
|
|
|
328 |
if use_stride: |
|
|
329 |
self.input_conv = basic_conv_block(num_channels, |
|
|
330 |
1, |
|
|
331 |
2, |
|
|
332 |
padding, |
|
|
333 |
nonlinearity, |
|
|
334 |
use_batchnorm, |
|
|
335 |
use_bias, |
|
|
336 |
data_format) |
|
|
337 |
stride = 2 |
|
|
338 |
|
|
|
339 |
else: |
|
|
340 |
stride = 1 |
|
|
341 |
|
|
|
342 |
self.first_conv = basic_conv_block(inner_num_channels, |
|
|
343 |
kernel_size[0], |
|
|
344 |
stride, |
|
|
345 |
padding, |
|
|
346 |
nonlinearity, |
|
|
347 |
use_batchnorm, |
|
|
348 |
use_bias, |
|
|
349 |
data_format) |
|
|
350 |
|
|
|
351 |
self.second_conv = basic_conv_block(inner_num_channels, |
|
|
352 |
kernel_size[1], |
|
|
353 |
1, |
|
|
354 |
padding, |
|
|
355 |
nonlinearity, |
|
|
356 |
use_batchnorm, |
|
|
357 |
use_bias, |
|
|
358 |
data_format) |
|
|
359 |
|
|
|
360 |
self.third_conv = basic_conv_block(num_channels, |
|
|
361 |
kernel_size[0], |
|
|
362 |
1, |
|
|
363 |
padding, |
|
|
364 |
nonlinearity, |
|
|
365 |
use_batchnorm, |
|
|
366 |
use_bias, |
|
|
367 |
data_format) |
|
|
368 |
|
|
|
369 |
def call(self, x, training=False): |
|
|
370 |
|
|
|
371 |
residual = self.first_conv(x, training=training) |
|
|
372 |
|
|
|
373 |
if self.use_stride: |
|
|
374 |
x = self.input_conv(x, training=training) |
|
|
375 |
|
|
|
376 |
residual = self.second_conv(residual, training=training) |
|
|
377 |
residual = self.third_conv(residual, training=training) |
|
|
378 |
|
|
|
379 |
output = tfkl.Add()([residual, x]) |
|
|
380 |
return output |
|
|
381 |
|
|
|
382 |
|
|
|
383 |
class basic_conv_block(tf.keras.Sequential): |
|
|
384 |
|
|
|
385 |
def __init__(self, |
|
|
386 |
num_channels, |
|
|
387 |
kernel_size, |
|
|
388 |
stride=1, |
|
|
389 |
padding='same', |
|
|
390 |
nonlinearity='relu', |
|
|
391 |
use_batchnorm=True, |
|
|
392 |
use_bias=True, |
|
|
393 |
data_format='channels_last', |
|
|
394 |
rate=1, |
|
|
395 |
**kwargs): |
|
|
396 |
|
|
|
397 |
super(basic_conv_block, self).__init__(**kwargs) |
|
|
398 |
|
|
|
399 |
if use_batchnorm: |
|
|
400 |
self.add(tfkl.BatchNormalization(axis=-1, |
|
|
401 |
momentum=0.95, |
|
|
402 |
epsilon=0.001)) |
|
|
403 |
self.add(tfkl.Activation(nonlinearity)) |
|
|
404 |
|
|
|
405 |
self.add(tfkl.Conv2D(num_channels, |
|
|
406 |
kernel_size, |
|
|
407 |
strides=stride, |
|
|
408 |
padding=padding, |
|
|
409 |
use_bias=use_bias, |
|
|
410 |
data_format=data_format, |
|
|
411 |
dilation_rate=rate)) |
|
|
412 |
|
|
|
413 |
def call(self, x, training=False): |
|
|
414 |
|
|
|
415 |
output = super(basic_conv_block, self).call(x, training=training) |
|
|
416 |
return output |
|
|
417 |
|
|
|
418 |
# ####################### Atrous Convolution ####################### # |
|
|
419 |
class Atrous_conv(tf.keras.Model): |
|
|
420 |
|
|
|
421 |
def __init__(self, |
|
|
422 |
num_channels, |
|
|
423 |
kernel_size=3, |
|
|
424 |
MultiGrid=[1, 2, 4], |
|
|
425 |
padding='same', |
|
|
426 |
use_batchnorm=True, |
|
|
427 |
nonlinearity='linear', |
|
|
428 |
use_bias=True, |
|
|
429 |
data_format='channels_last', |
|
|
430 |
output_stride=16, |
|
|
431 |
**kwargs): |
|
|
432 |
|
|
|
433 |
super(Atrous_conv, self).__init__(**kwargs) |
|
|
434 |
|
|
|
435 |
if output_stride == 16: |
|
|
436 |
multiplier = 2 |
|
|
437 |
else: |
|
|
438 |
multiplier = 1 |
|
|
439 |
|
|
|
440 |
self.first_conv = basic_conv_block(num_channels, |
|
|
441 |
kernel_size, |
|
|
442 |
1, |
|
|
443 |
padding, |
|
|
444 |
nonlinearity, |
|
|
445 |
use_batchnorm, |
|
|
446 |
use_bias, |
|
|
447 |
data_format, |
|
|
448 |
rate=int(multiplier * MultiGrid[0])) |
|
|
449 |
|
|
|
450 |
self.second_conv = basic_conv_block(num_channels, |
|
|
451 |
kernel_size, |
|
|
452 |
1, |
|
|
453 |
padding, |
|
|
454 |
nonlinearity, |
|
|
455 |
use_batchnorm, |
|
|
456 |
use_bias, |
|
|
457 |
data_format, |
|
|
458 |
rate=int(multiplier * MultiGrid[1])) |
|
|
459 |
|
|
|
460 |
self.third_conv = basic_conv_block(num_channels, |
|
|
461 |
kernel_size, |
|
|
462 |
1, |
|
|
463 |
padding, |
|
|
464 |
nonlinearity, |
|
|
465 |
use_batchnorm, |
|
|
466 |
use_bias, |
|
|
467 |
data_format, |
|
|
468 |
rate=int(multiplier * MultiGrid[2])) |
|
|
469 |
|
|
|
470 |
def call(self, x, training=False): |
|
|
471 |
|
|
|
472 |
x = self.first_conv(x, training) |
|
|
473 |
x = self.second_conv(x, training) |
|
|
474 |
x = self.third_conv(x, training) |
|
|
475 |
return x |
|
|
476 |
|
|
|
477 |
|
|
|
478 |
# ####################### ASPP ####################### # |
|
|
479 |
class atrous_spatial_pyramid_pooling(tf.keras.Model): |
|
|
480 |
|
|
|
481 |
def __init__(self, |
|
|
482 |
num_channels=256, |
|
|
483 |
kernel_size=[1, 3, 3, 3], |
|
|
484 |
rate=[1, 6, 12, 18], |
|
|
485 |
padding='same', |
|
|
486 |
use_batchnorm=True, |
|
|
487 |
nonlinearity='linear', |
|
|
488 |
use_bias=True, |
|
|
489 |
data_format='channels_last', |
|
|
490 |
**kwargs): |
|
|
491 |
|
|
|
492 |
super(atrous_spatial_pyramid_pooling, self).__init__(**kwargs) |
|
|
493 |
self.block_list = [] |
|
|
494 |
|
|
|
495 |
self.basic_conv1 = tfkl.Conv2D(num_channels, |
|
|
496 |
kernel_size=1, |
|
|
497 |
padding=padding) |
|
|
498 |
|
|
|
499 |
self.basic_conv2 = tfkl.Conv2D(num_channels, |
|
|
500 |
kernel_size=1, |
|
|
501 |
padding=padding) |
|
|
502 |
|
|
|
503 |
for i in range(len(kernel_size)): |
|
|
504 |
self.block_list.append(aspp_block(kernel_size[i], |
|
|
505 |
rate[i], |
|
|
506 |
num_channels, |
|
|
507 |
padding, |
|
|
508 |
use_batchnorm, |
|
|
509 |
nonlinearity, |
|
|
510 |
use_bias, |
|
|
511 |
data_format)) |
|
|
512 |
|
|
|
513 |
def call(self, x, training=False): |
|
|
514 |
|
|
|
515 |
feature_map_size = tf.shape(x) |
|
|
516 |
output_list = [] |
|
|
517 |
|
|
|
518 |
# Non diluted convolution |
|
|
519 |
y = tf.math.reduce_mean(x, axis=[1, 2], keepdims=True) # ~ Average Pooling |
|
|
520 |
y = self.basic_conv1(y, training=training) |
|
|
521 |
output_list.append(tf.image.resize(y, (feature_map_size[1], feature_map_size[2]))) # ~ Upsampling |
|
|
522 |
|
|
|
523 |
# Series of diluted convolutions with rates (1, 6, 12, 18) |
|
|
524 |
for i, block in enumerate(self.block_list): |
|
|
525 |
output_list.append(block(x, training=training)) |
|
|
526 |
|
|
|
527 |
# concatenate all outputs |
|
|
528 |
out = tf.concat(output_list, axis=3) |
|
|
529 |
out = self.basic_conv2(out, training=training) |
|
|
530 |
return out |
|
|
531 |
|
|
|
532 |
|
|
|
533 |
class aspp_block(tf.keras.Sequential): |
|
|
534 |
|
|
|
535 |
def __init__(self, |
|
|
536 |
kernel_size, |
|
|
537 |
rate, |
|
|
538 |
num_channels=256, |
|
|
539 |
padding='same', |
|
|
540 |
use_batchnorm=True, |
|
|
541 |
nonlinearity='linear', |
|
|
542 |
use_bias=True, |
|
|
543 |
data_format='channels_last', |
|
|
544 |
**kwargs): |
|
|
545 |
|
|
|
546 |
super(aspp_block, self).__init__(**kwargs) |
|
|
547 |
|
|
|
548 |
self.add(tfkl.Conv2D(num_channels, |
|
|
549 |
kernel_size, |
|
|
550 |
padding=padding, |
|
|
551 |
use_bias=use_bias, |
|
|
552 |
data_format=data_format, |
|
|
553 |
dilation_rate=rate)) |
|
|
554 |
|
|
|
555 |
if use_batchnorm: |
|
|
556 |
self.add(tfkl.BatchNormalization(axis=-1, |
|
|
557 |
momentum=0.95, |
|
|
558 |
epsilon=0.001)) |
|
|
559 |
|
|
|
560 |
self.add(tfkl.Activation(nonlinearity)) |
|
|
561 |
|
|
|
562 |
def call(self, x, training=False): |
|
|
563 |
|
|
|
564 |
output = super(aspp_block, self).call(x, training=training) |
|
|
565 |
return output |
|
|
566 |
|
|
|
567 |
# ####################### Decoder ####################### # |
|
|
568 |
class Decoder(tf.keras.Model): |
|
|
569 |
|
|
|
570 |
def __init__(self, |
|
|
571 |
num_channels_from_backcone=[48], |
|
|
572 |
num_channels_UpConv=[512, 256, 128], |
|
|
573 |
kernel_size_UpConv=3, |
|
|
574 |
stride_UpConv=(2, 2), |
|
|
575 |
use_batchnorm_UpConv=False, |
|
|
576 |
use_transpose_UpConv=False, |
|
|
577 |
use_bias=True, |
|
|
578 |
padding='same', |
|
|
579 |
data_format='channels_last', |
|
|
580 |
**kwargs): |
|
|
581 |
|
|
|
582 |
super(Decoder, self).__init__(**kwargs) |
|
|
583 |
|
|
|
584 |
self.first_conv1x1 = tfkl.Conv2D(num_channels_from_backcone[0], |
|
|
585 |
kernel_size=1, |
|
|
586 |
padding=padding, |
|
|
587 |
data_format=data_format) |
|
|
588 |
|
|
|
589 |
self.second_conv1x1 = tfkl.Conv2D(num_channels_from_backcone[1], |
|
|
590 |
kernel_size=1, |
|
|
591 |
padding=padding, |
|
|
592 |
data_format=data_format) |
|
|
593 |
|
|
|
594 |
self.conv1 = Up_Conv2D(num_channels_conv=num_channels_UpConv[0], |
|
|
595 |
kernel_size=kernel_size_UpConv, |
|
|
596 |
use_batchnorm=use_batchnorm_UpConv, |
|
|
597 |
use_transpose=use_transpose_UpConv, |
|
|
598 |
strides=stride_UpConv) |
|
|
599 |
|
|
|
600 |
self.conv2 = Up_Conv2D(num_channels_conv=num_channels_UpConv[1], |
|
|
601 |
kernel_size=kernel_size_UpConv, |
|
|
602 |
use_batchnorm=use_batchnorm_UpConv, |
|
|
603 |
use_transpose=use_transpose_UpConv, |
|
|
604 |
strides=stride_UpConv) |
|
|
605 |
|
|
|
606 |
self.conv3 = Up_Conv2D(num_channels_conv=num_channels_UpConv[2], |
|
|
607 |
kernel_size=kernel_size_UpConv, |
|
|
608 |
use_batchnorm=use_batchnorm_UpConv, |
|
|
609 |
use_transpose=use_transpose_UpConv, |
|
|
610 |
strides=stride_UpConv) |
|
|
611 |
|
|
|
612 |
def call(self, in_atrous, in_encoder, in_DCNN, training=False): |
|
|
613 |
|
|
|
614 |
in_atrous = self.first_conv1x1(in_atrous, training=training) |
|
|
615 |
in_DCNN = self.second_conv1x1(in_DCNN, training=training) |
|
|
616 |
|
|
|
617 |
out = tf.concat([in_atrous, in_encoder], axis=3) |
|
|
618 |
out = self.conv1(out, training=training) |
|
|
619 |
|
|
|
620 |
out = tf.concat([in_DCNN, out], axis=3) |
|
|
621 |
out = self.conv2(out, training=training) |
|
|
622 |
out = self.conv3(out, training=training) |
|
|
623 |
|
|
|
624 |
return out |
|
|
625 |
|
|
|
626 |
|
|
|
627 |
class Up_Conv2D(tf.keras.Sequential): |
|
|
628 |
|
|
|
629 |
def __init__(self, |
|
|
630 |
num_channels_conv, |
|
|
631 |
num_channels_UpConv=256, |
|
|
632 |
kernel_size=3, |
|
|
633 |
nonlinearity='relu', |
|
|
634 |
use_batchnorm=False, |
|
|
635 |
use_transpose=False, |
|
|
636 |
use_bias=True, |
|
|
637 |
strides=(2, 2), |
|
|
638 |
padding='same', |
|
|
639 |
data_format='channels_last', |
|
|
640 |
**kwargs): |
|
|
641 |
|
|
|
642 |
super(Up_Conv2D, self).__init__(**kwargs) |
|
|
643 |
|
|
|
644 |
if use_transpose: |
|
|
645 |
self.add(tfkl.Conv2DTranspose(num_channels_UpConv, |
|
|
646 |
kernel_size, |
|
|
647 |
padding='same', |
|
|
648 |
strides=strides, |
|
|
649 |
data_format=data_format)) |
|
|
650 |
else: |
|
|
651 |
self.add(tfkl.UpSampling2D(size=strides)) |
|
|
652 |
|
|
|
653 |
self.add(aspp_block(kernel_size=kernel_size, |
|
|
654 |
rate=1, |
|
|
655 |
num_channels=num_channels_conv, |
|
|
656 |
padding=padding, |
|
|
657 |
use_batchnorm=use_batchnorm, |
|
|
658 |
nonlinearity=nonlinearity, |
|
|
659 |
use_bias=use_bias, |
|
|
660 |
data_format=data_format)) |
|
|
661 |
|
|
|
662 |
def call(self, x, training=False): |
|
|
663 |
|
|
|
664 |
out = super(Up_Conv2D, self).call(x, training=training) |
|
|
665 |
return out |
|
|
666 |
|
|
|
667 |
|