a b/Segmentation/model/vnet_build_blocks.py
1
import tensorflow as tf
2
import tensorflow.keras.layers as tfkl
3
from Segmentation.model.unet_build_blocks import Conv_Block, Up_Conv
4
5
class Conv_ResBlock(tf.keras.Model):
6
    def __init__(self,
7
                 num_channels,
8
                 use_2d=False,
9
                 num_conv_layers=2,
10
                 kernel_size=3,
11
                 strides=2,
12
                 res_activation='relu',
13
                 data_format='channels_last',
14
                 **kwargs):
15
16
        super(Conv_ResBlock, self).__init__(**kwargs)
17
18
        self.num_channels = num_channels
19
        self.use_2d = use_2d
20
        self.num_conv_layers = num_conv_layers
21
        self.kernel_size = kernel_size
22
        self.strides = strides
23
        self.res_activation = res_activation
24
        self.data_format = data_format
25
26
        self.conv_block = Conv_Block(num_channels=self.num_channels,
27
                                     use_2d=self.use_2d,
28
                                     num_conv_layers=self.num_conv_layers,
29
                                     kernel_size=self.kernel_size,
30
                                     data_format=self.data_format,
31
                                     **kwargs)
32
        if self.use_2d:
33
            self.conv_stride = tfkl.Conv2D(num_channels * 2,
34
                                           kernel_size=(2, 2),
35
                                           strides=strides,
36
                                           padding='same')
37
38
        else:
39
            self.conv_stride = tfkl.Conv3D(num_channels * 2,
40
                                           kernel_size=(2, 2, 2),
41
                                           strides=strides,
42
                                           padding='same')
43
        if res_activation == 'prelu':
44
            self.res_activation = tfkl.PReLU()
45
        else:
46
            self.res_activation = tfkl.Activation(res_activation)
47
48
    def call(self, inputs, training):
49
        x = inputs
50
        x = self.conv_block(x, training=training)
51
        x = tfkl.add([x, inputs])
52
        down_x = self.conv_stride(x)
53
        down_x = self.res_activation(down_x)
54
        return down_x, x
55
56
class Up_ResBlock(tf.keras.Model):
57
    def __init__(self,
58
                 num_channels,
59
                 use_2d=False,
60
                 kernel_size=3,
61
                 **kwargs):
62
        super(Up_ResBlock, self).__init__(**kwargs)
63
64
        self.num_channels = num_channels
65
        self.use_2d = use_2d
66
        self.kernel_size = kernel_size
67
        self.up_conv = Up_Conv(num_channels=self.num_channels,
68
                               use_2d=self.use_2d,
69
                               kernel_size=self.kernel_size,
70
                               **kwargs)
71
72
    def call(self, inputs, training):
73
        x, x_highway = inputs
74
        x_res_start = self.up_conv(x, x_highway, training=training)
75
        x = tfkl.add([x, x_res_start])
76
        return x