Diff of /networks/vnet.py [000000] .. [903821]

Switch to unified view

a b/networks/vnet.py
1
import torch
2
from torch import nn
3
4
5
class ConvBlock(nn.Module):
6
    def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
7
        super(ConvBlock, self).__init__()
8
9
        ops = []
10
        for i in range(n_stages):
11
            if i==0:
12
                input_channel = n_filters_in
13
            else:
14
                input_channel = n_filters_out
15
16
            ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
17
            if normalization == 'batchnorm':
18
                ops.append(nn.BatchNorm3d(n_filters_out))
19
            elif normalization == 'groupnorm':
20
                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
21
            elif normalization == 'instancenorm':
22
                ops.append(nn.InstanceNorm3d(n_filters_out))
23
            elif normalization != 'none':
24
                assert False
25
            ops.append(nn.ReLU(inplace=True))
26
27
        self.conv = nn.Sequential(*ops)
28
29
    def forward(self, x):
30
        x = self.conv(x)
31
        return x
32
33
34
class ResidualConvBlock(nn.Module):
35
    def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
36
        super(ResidualConvBlock, self).__init__()
37
38
        ops = []
39
        for i in range(n_stages):
40
            if i == 0:
41
                input_channel = n_filters_in
42
            else:
43
                input_channel = n_filters_out
44
45
            ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
46
            if normalization == 'batchnorm':
47
                ops.append(nn.BatchNorm3d(n_filters_out))
48
            elif normalization == 'groupnorm':
49
                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
50
            elif normalization == 'instancenorm':
51
                ops.append(nn.InstanceNorm3d(n_filters_out))
52
            elif normalization != 'none':
53
                assert False
54
55
            if i != n_stages-1:
56
                ops.append(nn.ReLU(inplace=True))
57
58
        self.conv = nn.Sequential(*ops)
59
        self.relu = nn.ReLU(inplace=True)
60
61
    def forward(self, x):
62
        x = (self.conv(x) + x)
63
        x = self.relu(x)
64
        return x
65
66
67
class DownsamplingConvBlock(nn.Module):
68
    def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
69
        super(DownsamplingConvBlock, self).__init__()
70
71
        ops = []
72
        if normalization != 'none':
73
            ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
74
            if normalization == 'batchnorm':
75
                ops.append(nn.BatchNorm3d(n_filters_out))
76
            elif normalization == 'groupnorm':
77
                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
78
            elif normalization == 'instancenorm':
79
                ops.append(nn.InstanceNorm3d(n_filters_out))
80
            else:
81
                assert False
82
        else:
83
            ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
84
85
        ops.append(nn.ReLU(inplace=True))
86
87
        self.conv = nn.Sequential(*ops)
88
89
    def forward(self, x):
90
        x = self.conv(x)
91
        return x
92
93
94
class UpsamplingDeconvBlock(nn.Module):
95
    def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
96
        super(UpsamplingDeconvBlock, self).__init__()
97
98
        ops = []
99
        if normalization != 'none':
100
            ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
101
            if normalization == 'batchnorm':
102
                ops.append(nn.BatchNorm3d(n_filters_out))
103
            elif normalization == 'groupnorm':
104
                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
105
            elif normalization == 'instancenorm':
106
                ops.append(nn.InstanceNorm3d(n_filters_out))
107
            else:
108
                assert False
109
        else:
110
            ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
111
112
        ops.append(nn.ReLU(inplace=True))
113
114
        self.conv = nn.Sequential(*ops)
115
116
    def forward(self, x):
117
        x = self.conv(x)
118
        return x
119
120
121
class Upsampling(nn.Module):
122
    def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
123
        super(Upsampling, self).__init__()
124
125
        ops = []
126
        ops.append(nn.Upsample(scale_factor=stride, mode='trilinear',align_corners=False))
127
        ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1))
128
        if normalization == 'batchnorm':
129
            ops.append(nn.BatchNorm3d(n_filters_out))
130
        elif normalization == 'groupnorm':
131
            ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
132
        elif normalization == 'instancenorm':
133
            ops.append(nn.InstanceNorm3d(n_filters_out))
134
        elif normalization != 'none':
135
            assert False
136
        ops.append(nn.ReLU(inplace=True))
137
138
        self.conv = nn.Sequential(*ops)
139
140
    def forward(self, x):
141
        x = self.conv(x)
142
        return x
143
144
145
class VNet(nn.Module):
146
    def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False):
147
        super(VNet, self).__init__()
148
        self.has_dropout = has_dropout
149
150
        self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
151
        self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
152
153
        self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
154
        self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
155
156
        self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
157
        self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
158
159
        self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
160
        self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, stride=(2,2,1), normalization=normalization)
161
162
        self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
163
        self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, stride=(2,2,1), normalization=normalization)
164
165
        self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
166
        self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
167
168
        self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
169
        self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
170
171
        self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
172
        self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
173
174
        self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
175
        self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
176
177
        self.dropout = nn.Dropout3d(p=0.5, inplace=False)
178
        # self.__init_weight()
179
180
    def encoder(self, input):
181
        x1 = self.block_one(input)
182
        x1_dw = self.block_one_dw(x1)
183
184
        x2 = self.block_two(x1_dw)
185
        x2_dw = self.block_two_dw(x2)
186
187
        x3 = self.block_three(x2_dw)
188
        x3_dw = self.block_three_dw(x3)
189
190
        x4 = self.block_four(x3_dw)
191
        x4_dw = self.block_four_dw(x4)
192
193
        x5 = self.block_five(x4_dw)
194
        # x5 = F.dropout3d(x5, p=0.5, training=True)
195
        if self.has_dropout:
196
            x5 = self.dropout(x5)
197
198
        res = [x1, x2, x3, x4, x5]
199
        # print(x5.shape)
200
201
        return res
202
203
    def decoder(self, features):
204
        x1 = features[0]
205
        x2 = features[1]
206
        x3 = features[2]
207
        x4 = features[3]
208
        x5 = features[4]
209
210
        x5_up = self.block_five_up(x5)
211
        # print(x5_up.shape)
212
        x5_up = x5_up + x4
213
214
        x6 = self.block_six(x5_up)
215
        x6_up = self.block_six_up(x6)
216
        x6_up = x6_up + x3
217
218
        x7 = self.block_seven(x6_up)
219
        x7_up = self.block_seven_up(x7)
220
        x7_up = x7_up + x2
221
222
        x8 = self.block_eight(x7_up)
223
        x8_up = self.block_eight_up(x8)
224
        x8_up = x8_up + x1
225
        x9 = self.block_nine(x8_up)
226
        # x9 = F.dropout3d(x9, p=0.5, training=True)
227
        if self.has_dropout:
228
            x9 = self.dropout(x9)
229
        out = self.out_conv(x9)
230
        return out
231
232
    def forward(self, input, turnoff_drop=False):
233
        if turnoff_drop:
234
            has_dropout = self.has_dropout
235
            self.has_dropout = False
236
        features = self.encoder(input)
237
        out = self.decoder(features)
238
        if turnoff_drop:
239
            self.has_dropout = has_dropout
240
        return out
241
242
243
if __name__ == '__main__':
244
    x = torch.randn(2,1,112,112,80)
245
    model = VNet(n_channels=1,n_classes=2, normalization='batchnorm', has_dropout=True)
246
    y = model(x)
247
    print(y.shape)