|
a |
|
b/networks/vnet_sdf.py |
|
|
1 |
import torch |
|
|
2 |
from torch import nn |
|
|
3 |
import torch.nn.functional as F |
|
|
4 |
|
|
|
5 |
""" |
|
|
6 |
Differences with V-Net |
|
|
7 |
Adding nn.Tanh in the end of the conv. to make the outputs in [-1, 1]. |
|
|
8 |
""" |
|
|
9 |
|
|
|
10 |
class ConvBlock(nn.Module): |
|
|
11 |
def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): |
|
|
12 |
super(ConvBlock, self).__init__() |
|
|
13 |
|
|
|
14 |
ops = [] |
|
|
15 |
for i in range(n_stages): |
|
|
16 |
if i==0: |
|
|
17 |
input_channel = n_filters_in |
|
|
18 |
else: |
|
|
19 |
input_channel = n_filters_out |
|
|
20 |
|
|
|
21 |
ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) |
|
|
22 |
if normalization == 'batchnorm': |
|
|
23 |
ops.append(nn.BatchNorm3d(n_filters_out)) |
|
|
24 |
elif normalization == 'groupnorm': |
|
|
25 |
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) |
|
|
26 |
elif normalization == 'instancenorm': |
|
|
27 |
ops.append(nn.InstanceNorm3d(n_filters_out)) |
|
|
28 |
elif normalization != 'none': |
|
|
29 |
assert False |
|
|
30 |
ops.append(nn.ReLU(inplace=True)) |
|
|
31 |
|
|
|
32 |
self.conv = nn.Sequential(*ops) |
|
|
33 |
|
|
|
34 |
def forward(self, x): |
|
|
35 |
x = self.conv(x) |
|
|
36 |
return x |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
class ResidualConvBlock(nn.Module): |
|
|
40 |
def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): |
|
|
41 |
super(ResidualConvBlock, self).__init__() |
|
|
42 |
|
|
|
43 |
ops = [] |
|
|
44 |
for i in range(n_stages): |
|
|
45 |
if i == 0: |
|
|
46 |
input_channel = n_filters_in |
|
|
47 |
else: |
|
|
48 |
input_channel = n_filters_out |
|
|
49 |
|
|
|
50 |
ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) |
|
|
51 |
if normalization == 'batchnorm': |
|
|
52 |
ops.append(nn.BatchNorm3d(n_filters_out)) |
|
|
53 |
elif normalization == 'groupnorm': |
|
|
54 |
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) |
|
|
55 |
elif normalization == 'instancenorm': |
|
|
56 |
ops.append(nn.InstanceNorm3d(n_filters_out)) |
|
|
57 |
elif normalization != 'none': |
|
|
58 |
assert False |
|
|
59 |
|
|
|
60 |
if i != n_stages-1: |
|
|
61 |
ops.append(nn.ReLU(inplace=True)) |
|
|
62 |
|
|
|
63 |
self.conv = nn.Sequential(*ops) |
|
|
64 |
self.relu = nn.ReLU(inplace=True) |
|
|
65 |
|
|
|
66 |
def forward(self, x): |
|
|
67 |
x = (self.conv(x) + x) |
|
|
68 |
x = self.relu(x) |
|
|
69 |
return x |
|
|
70 |
|
|
|
71 |
|
|
|
72 |
class DownsamplingConvBlock(nn.Module): |
|
|
73 |
def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): |
|
|
74 |
super(DownsamplingConvBlock, self).__init__() |
|
|
75 |
|
|
|
76 |
ops = [] |
|
|
77 |
if normalization != 'none': |
|
|
78 |
ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) |
|
|
79 |
if normalization == 'batchnorm': |
|
|
80 |
ops.append(nn.BatchNorm3d(n_filters_out)) |
|
|
81 |
elif normalization == 'groupnorm': |
|
|
82 |
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) |
|
|
83 |
elif normalization == 'instancenorm': |
|
|
84 |
ops.append(nn.InstanceNorm3d(n_filters_out)) |
|
|
85 |
else: |
|
|
86 |
assert False |
|
|
87 |
else: |
|
|
88 |
ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) |
|
|
89 |
|
|
|
90 |
ops.append(nn.ReLU(inplace=True)) |
|
|
91 |
|
|
|
92 |
self.conv = nn.Sequential(*ops) |
|
|
93 |
|
|
|
94 |
def forward(self, x): |
|
|
95 |
x = self.conv(x) |
|
|
96 |
return x |
|
|
97 |
|
|
|
98 |
|
|
|
99 |
class UpsamplingDeconvBlock(nn.Module): |
|
|
100 |
def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): |
|
|
101 |
super(UpsamplingDeconvBlock, self).__init__() |
|
|
102 |
|
|
|
103 |
ops = [] |
|
|
104 |
if normalization != 'none': |
|
|
105 |
ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) |
|
|
106 |
if normalization == 'batchnorm': |
|
|
107 |
ops.append(nn.BatchNorm3d(n_filters_out)) |
|
|
108 |
elif normalization == 'groupnorm': |
|
|
109 |
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) |
|
|
110 |
elif normalization == 'instancenorm': |
|
|
111 |
ops.append(nn.InstanceNorm3d(n_filters_out)) |
|
|
112 |
else: |
|
|
113 |
assert False |
|
|
114 |
else: |
|
|
115 |
ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) |
|
|
116 |
|
|
|
117 |
ops.append(nn.ReLU(inplace=True)) |
|
|
118 |
|
|
|
119 |
self.conv = nn.Sequential(*ops) |
|
|
120 |
|
|
|
121 |
def forward(self, x): |
|
|
122 |
x = self.conv(x) |
|
|
123 |
return x |
|
|
124 |
|
|
|
125 |
|
|
|
126 |
class Upsampling(nn.Module): |
|
|
127 |
def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): |
|
|
128 |
super(Upsampling, self).__init__() |
|
|
129 |
|
|
|
130 |
ops = [] |
|
|
131 |
ops.append(nn.Upsample(scale_factor=stride, mode='trilinear',align_corners=False)) |
|
|
132 |
ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1)) |
|
|
133 |
if normalization == 'batchnorm': |
|
|
134 |
ops.append(nn.BatchNorm3d(n_filters_out)) |
|
|
135 |
elif normalization == 'groupnorm': |
|
|
136 |
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) |
|
|
137 |
elif normalization == 'instancenorm': |
|
|
138 |
ops.append(nn.InstanceNorm3d(n_filters_out)) |
|
|
139 |
elif normalization != 'none': |
|
|
140 |
assert False |
|
|
141 |
ops.append(nn.ReLU(inplace=True)) |
|
|
142 |
|
|
|
143 |
self.conv = nn.Sequential(*ops) |
|
|
144 |
|
|
|
145 |
def forward(self, x): |
|
|
146 |
x = self.conv(x) |
|
|
147 |
return x |
|
|
148 |
|
|
|
149 |
|
|
|
150 |
class VNet(nn.Module): |
|
|
151 |
def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, has_residual=False): |
|
|
152 |
super(VNet, self).__init__() |
|
|
153 |
self.has_dropout = has_dropout |
|
|
154 |
convBlock = ConvBlock if not has_residual else ResidualConvBlock |
|
|
155 |
|
|
|
156 |
self.block_one = convBlock(1, n_channels, n_filters, normalization=normalization) |
|
|
157 |
self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) |
|
|
158 |
|
|
|
159 |
self.block_two = convBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) |
|
|
160 |
self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) |
|
|
161 |
|
|
|
162 |
self.block_three = convBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) |
|
|
163 |
self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) |
|
|
164 |
|
|
|
165 |
self.block_four = convBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) |
|
|
166 |
self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) |
|
|
167 |
|
|
|
168 |
self.block_five = convBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) |
|
|
169 |
self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) |
|
|
170 |
|
|
|
171 |
self.block_six = convBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) |
|
|
172 |
self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) |
|
|
173 |
|
|
|
174 |
self.block_seven = convBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) |
|
|
175 |
self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) |
|
|
176 |
|
|
|
177 |
self.block_eight = convBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) |
|
|
178 |
self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) |
|
|
179 |
|
|
|
180 |
self.block_nine = convBlock(1, n_filters, n_filters, normalization=normalization) |
|
|
181 |
self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) |
|
|
182 |
self.out_conv2 = nn.Conv3d(n_filters, n_classes, 1, padding=0) |
|
|
183 |
self.tanh = nn.Tanh() |
|
|
184 |
|
|
|
185 |
self.dropout = nn.Dropout3d(p=0.5, inplace=False) |
|
|
186 |
# self.__init_weight() |
|
|
187 |
|
|
|
188 |
def encoder(self, input): |
|
|
189 |
x1 = self.block_one(input) |
|
|
190 |
x1_dw = self.block_one_dw(x1) |
|
|
191 |
|
|
|
192 |
x2 = self.block_two(x1_dw) |
|
|
193 |
x2_dw = self.block_two_dw(x2) |
|
|
194 |
|
|
|
195 |
x3 = self.block_three(x2_dw) |
|
|
196 |
x3_dw = self.block_three_dw(x3) |
|
|
197 |
|
|
|
198 |
x4 = self.block_four(x3_dw) |
|
|
199 |
x4_dw = self.block_four_dw(x4) |
|
|
200 |
|
|
|
201 |
x5 = self.block_five(x4_dw) |
|
|
202 |
# x5 = F.dropout3d(x5, p=0.5, training=True) |
|
|
203 |
if self.has_dropout: |
|
|
204 |
x5 = self.dropout(x5) |
|
|
205 |
|
|
|
206 |
res = [x1, x2, x3, x4, x5] |
|
|
207 |
|
|
|
208 |
return res |
|
|
209 |
|
|
|
210 |
def decoder(self, features): |
|
|
211 |
x1 = features[0] |
|
|
212 |
x2 = features[1] |
|
|
213 |
x3 = features[2] |
|
|
214 |
x4 = features[3] |
|
|
215 |
x5 = features[4] |
|
|
216 |
|
|
|
217 |
x5_up = self.block_five_up(x5) |
|
|
218 |
x5_up = x5_up + x4 |
|
|
219 |
|
|
|
220 |
x6 = self.block_six(x5_up) |
|
|
221 |
x6_up = self.block_six_up(x6) |
|
|
222 |
x6_up = x6_up + x3 |
|
|
223 |
|
|
|
224 |
x7 = self.block_seven(x6_up) |
|
|
225 |
x7_up = self.block_seven_up(x7) |
|
|
226 |
x7_up = x7_up + x2 |
|
|
227 |
|
|
|
228 |
x8 = self.block_eight(x7_up) |
|
|
229 |
x8_up = self.block_eight_up(x8) |
|
|
230 |
x8_up = x8_up + x1 |
|
|
231 |
x9 = self.block_nine(x8_up) |
|
|
232 |
# x9 = F.dropout3d(x9, p=0.5, training=True) |
|
|
233 |
if self.has_dropout: |
|
|
234 |
x9 = self.dropout(x9) |
|
|
235 |
out = self.out_conv(x9) |
|
|
236 |
out_tanh = self.tanh(out) |
|
|
237 |
out_seg = self.out_conv2(x9) |
|
|
238 |
return out_tanh, out_seg |
|
|
239 |
|
|
|
240 |
|
|
|
241 |
def forward(self, input, turnoff_drop=False): |
|
|
242 |
if turnoff_drop: |
|
|
243 |
has_dropout = self.has_dropout |
|
|
244 |
self.has_dropout = False |
|
|
245 |
features = self.encoder(input) |
|
|
246 |
out_tanh, out_seg = self.decoder(features) |
|
|
247 |
if turnoff_drop: |
|
|
248 |
self.has_dropout = has_dropout |
|
|
249 |
return out_tanh,out_seg |
|
|
250 |
|
|
|
251 |
# def __init_weight(self): |
|
|
252 |
# for m in self.modules(): |
|
|
253 |
# if isinstance(m, nn.Conv3d): |
|
|
254 |
# torch.nn.init.kaiming_normal_(m.weight) |
|
|
255 |
# elif isinstance(m, nn.BatchNorm3d): |
|
|
256 |
# m.weight.data.fill_(1) |
|
|
257 |
|
|
|
258 |
# if __name__ == '__main__': |
|
|
259 |
# # compute FLOPS & PARAMETERS |
|
|
260 |
# # from thop import profile |
|
|
261 |
# # from thop import clever_format |
|
|
262 |
# # model = VNet(n_channels=1, n_classes=2) |
|
|
263 |
# # input = torch.randn(4, 1, 112, 112, 80) |
|
|
264 |
# # flops, params = profile(model, inputs=(input,)) |
|
|
265 |
# # macs, params = clever_format([flops, params], "%.3f") |
|
|
266 |
# # print(macs, params) |