|
a |
|
b/Segmentation/model/vnet_build_blocks.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
import tensorflow.keras.layers as tfkl |
|
|
3 |
from Segmentation.model.unet_build_blocks import Conv_Block, Up_Conv |
|
|
4 |
|
|
|
5 |
class Conv_ResBlock(tf.keras.Model): |
|
|
6 |
def __init__(self, |
|
|
7 |
num_channels, |
|
|
8 |
use_2d=False, |
|
|
9 |
num_conv_layers=2, |
|
|
10 |
kernel_size=3, |
|
|
11 |
strides=2, |
|
|
12 |
res_activation='relu', |
|
|
13 |
data_format='channels_last', |
|
|
14 |
**kwargs): |
|
|
15 |
|
|
|
16 |
super(Conv_ResBlock, self).__init__(**kwargs) |
|
|
17 |
|
|
|
18 |
self.num_channels = num_channels |
|
|
19 |
self.use_2d = use_2d |
|
|
20 |
self.num_conv_layers = num_conv_layers |
|
|
21 |
self.kernel_size = kernel_size |
|
|
22 |
self.strides = strides |
|
|
23 |
self.res_activation = res_activation |
|
|
24 |
self.data_format = data_format |
|
|
25 |
|
|
|
26 |
self.conv_block = Conv_Block(num_channels=self.num_channels, |
|
|
27 |
use_2d=self.use_2d, |
|
|
28 |
num_conv_layers=self.num_conv_layers, |
|
|
29 |
kernel_size=self.kernel_size, |
|
|
30 |
data_format=self.data_format, |
|
|
31 |
**kwargs) |
|
|
32 |
if self.use_2d: |
|
|
33 |
self.conv_stride = tfkl.Conv2D(num_channels * 2, |
|
|
34 |
kernel_size=(2, 2), |
|
|
35 |
strides=strides, |
|
|
36 |
padding='same') |
|
|
37 |
|
|
|
38 |
else: |
|
|
39 |
self.conv_stride = tfkl.Conv3D(num_channels * 2, |
|
|
40 |
kernel_size=(2, 2, 2), |
|
|
41 |
strides=strides, |
|
|
42 |
padding='same') |
|
|
43 |
if res_activation == 'prelu': |
|
|
44 |
self.res_activation = tfkl.PReLU() |
|
|
45 |
else: |
|
|
46 |
self.res_activation = tfkl.Activation(res_activation) |
|
|
47 |
|
|
|
48 |
def call(self, inputs, training): |
|
|
49 |
x = inputs |
|
|
50 |
x = self.conv_block(x, training=training) |
|
|
51 |
x = tfkl.add([x, inputs]) |
|
|
52 |
down_x = self.conv_stride(x) |
|
|
53 |
down_x = self.res_activation(down_x) |
|
|
54 |
return down_x, x |
|
|
55 |
|
|
|
56 |
class Up_ResBlock(tf.keras.Model): |
|
|
57 |
def __init__(self, |
|
|
58 |
num_channels, |
|
|
59 |
use_2d=False, |
|
|
60 |
kernel_size=3, |
|
|
61 |
**kwargs): |
|
|
62 |
super(Up_ResBlock, self).__init__(**kwargs) |
|
|
63 |
|
|
|
64 |
self.num_channels = num_channels |
|
|
65 |
self.use_2d = use_2d |
|
|
66 |
self.kernel_size = kernel_size |
|
|
67 |
self.up_conv = Up_Conv(num_channels=self.num_channels, |
|
|
68 |
use_2d=self.use_2d, |
|
|
69 |
kernel_size=self.kernel_size, |
|
|
70 |
**kwargs) |
|
|
71 |
|
|
|
72 |
def call(self, inputs, training): |
|
|
73 |
x, x_highway = inputs |
|
|
74 |
x_res_start = self.up_conv(x, x_highway, training=training) |
|
|
75 |
x = tfkl.add([x, x_res_start]) |
|
|
76 |
return x |