a b/Segmentation/model/vnet.py
1
import tensorflow as tf
2
import tensorflow.keras.layers as tfkl
3
import inspect
4
from Segmentation.model.vnet_build_blocks import Conv_ResBlock, Up_ResBlock
5
6
class VNet(tf.keras.Model):
7
8
    def __init__(self,
9
                 num_channels,
10
                 num_classes,
11
                 use_2d=False,
12
                 num_conv_layers=2,
13
                 kernel_size=(3, 3, 3),
14
                 activation='prelu',
15
                 use_batchnorm=True,
16
                 noise=0.0,
17
                 dropout_rate=0.25,
18
                 use_spatial_dropout=True,
19
                 predict_slice=False,
20
                 slice_format="mean",
21
                 **kwargs):
22
23
        self.params = str(inspect.currentframe().f_locals)
24
        super(VNet, self).__init__(**kwargs)
25
        self.noise = noise
26
        self.predict_slice = predict_slice
27
        self.slice_format = slice_format
28
29
        block_args = {
30
            'use_2d': use_2d,
31
            'num_conv_layers': num_conv_layers,
32
            'kernel_size': kernel_size,
33
            'activation': activation,
34
            'use_batchnorm': use_batchnorm,
35
            'dropout_rate': dropout_rate,
36
            'use_spatial_dropout': use_spatial_dropout,
37
        }
38
39
        self.contracting_path = []
40
41
        for i in range(len(num_channels)):
42
            output_ch = num_channels[i]
43
            self.contracting_path.append(Conv_ResBlock(output_ch,
44
                                                       **block_args,
45
                                                       **kwargs))
46
47
        self.upsampling_path = []
48
        n = len(num_channels) - 1
49
        for i in range(n, -1, -1):
50
            output_ch = num_channels[i]
51
            self.upsampling_path.append(Up_ResBlock(output_ch,
52
                                                    **block_args,
53
                                                    **kwargs))
54
55
        # convolution num_channels at the output
56
        if use_2d:
57
            self.conv_output = tfkl.Conv2D(filters=num_channels,
58
                                           kernel_size=kernel_size,
59
                                           activation=None,
60
                                           padding='same')
61
        else:
62
            self.conv_output = tfkl.Conv3D(filters=num_classes,
63
                                           kernel_size=kernel_size,
64
                                           activation=None,
65
                                           padding='same')
66
        if activation == 'prelu':
67
            self.activation = tfkl.PReLU()  # alpha_initializer=tf.keras.initializers.Constant(value=0.25))
68
        else:
69
            self.activation = tfkl.Activation(activation)
70
71
        if use_2d:
72
            self.conv_1x1 = tfkl.Conv2D(filters=num_classes,
73
                                        kernel_size=(1, 1),
74
                                        padding='same')
75
        else:
76
            self.conv_1x1 = tfkl.Conv3D(filters=num_classes,
77
                                        kernel_size=(1, 1, 1),
78
                                        padding='same')
79
80
        self.output_act = tfkl.Activation('sigmoid' if num_classes == 1 else 'softmax')
81
82
    def call(self, x, training):
83
84
        if self.noise and training:
85
            x = tfkl.GaussianNoise(self.noise)(x)
86
87
        blocks = []
88
        # encoder blocks
89
        for _, down in enumerate(self.contracting_path):
90
            x, x_before = down(x, training=training)
91
            blocks.append(x_before)
92
93
        # decoder blocks
94
        for j, up in enumerate(self.upsampling_path):
95
            x = up([x, blocks[-j - 1]], training=training)
96
97
        output = self.conv_output(x)
98
        output = self.activation(output)
99
100
        output = self.conv_1x1(output)
101
        if self.predict_slice:
102
            if self.slice_format == "mean":
103
                output = tf.reduce_mean(output, -4)
104
                output = tf.expand_dims(output, 1)
105
            if self.slice_format == "sum":
106
                output = tf.reduce_sum(output, -4)
107
                output = tf.expand_dims(output, 1)
108
        output = self.output_act(output)
109
        return output