a b/networks/vnet_sdf.py
1
import torch
2
from torch import nn
3
import torch.nn.functional as F
4
5
"""
6
Differences with V-Net
7
Adding nn.Tanh in the end of the conv. to make the outputs in [-1, 1].
8
"""
9
10
class ConvBlock(nn.Module):
11
    def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
12
        super(ConvBlock, self).__init__()
13
14
        ops = []
15
        for i in range(n_stages):
16
            if i==0:
17
                input_channel = n_filters_in
18
            else:
19
                input_channel = n_filters_out
20
21
            ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
22
            if normalization == 'batchnorm':
23
                ops.append(nn.BatchNorm3d(n_filters_out))
24
            elif normalization == 'groupnorm':
25
                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
26
            elif normalization == 'instancenorm':
27
                ops.append(nn.InstanceNorm3d(n_filters_out))
28
            elif normalization != 'none':
29
                assert False
30
            ops.append(nn.ReLU(inplace=True))
31
32
        self.conv = nn.Sequential(*ops)
33
34
    def forward(self, x):
35
        x = self.conv(x)
36
        return x
37
38
39
class ResidualConvBlock(nn.Module):
40
    def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
41
        super(ResidualConvBlock, self).__init__()
42
43
        ops = []
44
        for i in range(n_stages):
45
            if i == 0:
46
                input_channel = n_filters_in
47
            else:
48
                input_channel = n_filters_out
49
50
            ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
51
            if normalization == 'batchnorm':
52
                ops.append(nn.BatchNorm3d(n_filters_out))
53
            elif normalization == 'groupnorm':
54
                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
55
            elif normalization == 'instancenorm':
56
                ops.append(nn.InstanceNorm3d(n_filters_out))
57
            elif normalization != 'none':
58
                assert False
59
60
            if i != n_stages-1:
61
                ops.append(nn.ReLU(inplace=True))
62
63
        self.conv = nn.Sequential(*ops)
64
        self.relu = nn.ReLU(inplace=True)
65
66
    def forward(self, x):
67
        x = (self.conv(x) + x)
68
        x = self.relu(x)
69
        return x
70
71
72
class DownsamplingConvBlock(nn.Module):
73
    def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
74
        super(DownsamplingConvBlock, self).__init__()
75
76
        ops = []
77
        if normalization != 'none':
78
            ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
79
            if normalization == 'batchnorm':
80
                ops.append(nn.BatchNorm3d(n_filters_out))
81
            elif normalization == 'groupnorm':
82
                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
83
            elif normalization == 'instancenorm':
84
                ops.append(nn.InstanceNorm3d(n_filters_out))
85
            else:
86
                assert False
87
        else:
88
            ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
89
90
        ops.append(nn.ReLU(inplace=True))
91
92
        self.conv = nn.Sequential(*ops)
93
94
    def forward(self, x):
95
        x = self.conv(x)
96
        return x
97
98
99
class UpsamplingDeconvBlock(nn.Module):
100
    def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
101
        super(UpsamplingDeconvBlock, self).__init__()
102
103
        ops = []
104
        if normalization != 'none':
105
            ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
106
            if normalization == 'batchnorm':
107
                ops.append(nn.BatchNorm3d(n_filters_out))
108
            elif normalization == 'groupnorm':
109
                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
110
            elif normalization == 'instancenorm':
111
                ops.append(nn.InstanceNorm3d(n_filters_out))
112
            else:
113
                assert False
114
        else:
115
            ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
116
117
        ops.append(nn.ReLU(inplace=True))
118
119
        self.conv = nn.Sequential(*ops)
120
121
    def forward(self, x):
122
        x = self.conv(x)
123
        return x
124
125
126
class Upsampling(nn.Module):
127
    def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
128
        super(Upsampling, self).__init__()
129
130
        ops = []
131
        ops.append(nn.Upsample(scale_factor=stride, mode='trilinear',align_corners=False))
132
        ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1))
133
        if normalization == 'batchnorm':
134
            ops.append(nn.BatchNorm3d(n_filters_out))
135
        elif normalization == 'groupnorm':
136
            ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
137
        elif normalization == 'instancenorm':
138
            ops.append(nn.InstanceNorm3d(n_filters_out))
139
        elif normalization != 'none':
140
            assert False
141
        ops.append(nn.ReLU(inplace=True))
142
143
        self.conv = nn.Sequential(*ops)
144
145
    def forward(self, x):
146
        x = self.conv(x)
147
        return x
148
149
150
class VNet(nn.Module):
151
    def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, has_residual=False):
152
        super(VNet, self).__init__()
153
        self.has_dropout = has_dropout
154
        convBlock = ConvBlock if not has_residual else ResidualConvBlock
155
156
        self.block_one = convBlock(1, n_channels, n_filters, normalization=normalization)
157
        self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
158
159
        self.block_two = convBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
160
        self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
161
162
        self.block_three = convBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
163
        self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
164
165
        self.block_four = convBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
166
        self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
167
168
        self.block_five = convBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
169
        self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
170
171
        self.block_six = convBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
172
        self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
173
174
        self.block_seven = convBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
175
        self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
176
177
        self.block_eight = convBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
178
        self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
179
180
        self.block_nine = convBlock(1, n_filters, n_filters, normalization=normalization)
181
        self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
182
        self.out_conv2 = nn.Conv3d(n_filters, n_classes, 1, padding=0)
183
        self.tanh = nn.Tanh()
184
185
        self.dropout = nn.Dropout3d(p=0.5, inplace=False)
186
        # self.__init_weight()
187
188
    def encoder(self, input):
189
        x1 = self.block_one(input)
190
        x1_dw = self.block_one_dw(x1)
191
192
        x2 = self.block_two(x1_dw)
193
        x2_dw = self.block_two_dw(x2)
194
195
        x3 = self.block_three(x2_dw)
196
        x3_dw = self.block_three_dw(x3)
197
198
        x4 = self.block_four(x3_dw)
199
        x4_dw = self.block_four_dw(x4)
200
201
        x5 = self.block_five(x4_dw)
202
        # x5 = F.dropout3d(x5, p=0.5, training=True)
203
        if self.has_dropout:
204
            x5 = self.dropout(x5)
205
206
        res = [x1, x2, x3, x4, x5]
207
208
        return res
209
210
    def decoder(self, features):
211
        x1 = features[0]
212
        x2 = features[1]
213
        x3 = features[2]
214
        x4 = features[3]
215
        x5 = features[4]
216
217
        x5_up = self.block_five_up(x5)
218
        x5_up = x5_up + x4
219
220
        x6 = self.block_six(x5_up)
221
        x6_up = self.block_six_up(x6)
222
        x6_up = x6_up + x3
223
224
        x7 = self.block_seven(x6_up)
225
        x7_up = self.block_seven_up(x7)
226
        x7_up = x7_up + x2
227
228
        x8 = self.block_eight(x7_up)
229
        x8_up = self.block_eight_up(x8)
230
        x8_up = x8_up + x1
231
        x9 = self.block_nine(x8_up)
232
        # x9 = F.dropout3d(x9, p=0.5, training=True)
233
        if self.has_dropout:
234
            x9 = self.dropout(x9)
235
        out = self.out_conv(x9)
236
        out_tanh = self.tanh(out)
237
        out_seg = self.out_conv2(x9)
238
        return out_tanh, out_seg
239
240
241
    def forward(self, input, turnoff_drop=False):
242
        if turnoff_drop:
243
            has_dropout = self.has_dropout
244
            self.has_dropout = False
245
        features = self.encoder(input)
246
        out_tanh, out_seg = self.decoder(features)
247
        if turnoff_drop:
248
            self.has_dropout = has_dropout
249
        return out_tanh,out_seg
250
251
    # def __init_weight(self):
252
    #     for m in self.modules():
253
    #         if isinstance(m, nn.Conv3d):
254
    #             torch.nn.init.kaiming_normal_(m.weight)
255
    #         elif isinstance(m, nn.BatchNorm3d):
256
    #             m.weight.data.fill_(1)
257
258
# if __name__ == '__main__':
259
#     # compute FLOPS & PARAMETERS
260
#     # from thop import profile
261
#     # from thop import clever_format
262
#     # model = VNet(n_channels=1, n_classes=2)
263
#     # input = torch.randn(4, 1, 112, 112, 80)
264
#     # flops, params = profile(model, inputs=(input,))
265
#     # macs, params = clever_format([flops, params], "%.3f")
266
#     # print(macs, params)