|
a |
|
b/model.py |
|
|
1 |
import sys |
|
|
2 |
|
|
|
3 |
from torch.distributions.normal import Normal |
|
|
4 |
import torch |
|
|
5 |
import torch.nn as nn |
|
|
6 |
import torch.nn.functional as F |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
class Encoder(nn.Module): |
|
|
10 |
def __init__(self, |
|
|
11 |
in_channels, |
|
|
12 |
out_channels, |
|
|
13 |
kernel_size=3, |
|
|
14 |
stride=1, |
|
|
15 |
padding=1, |
|
|
16 |
bias=True, |
|
|
17 |
bn=False, |
|
|
18 |
num_groups=8): |
|
|
19 |
super(Encoder, self).__init__() |
|
|
20 |
self.in_channels = in_channels |
|
|
21 |
self.out_channels = out_channels |
|
|
22 |
self.relu = nn.ReLU() |
|
|
23 |
self.conv1 = nn.Conv3d(in_channels, |
|
|
24 |
out_channels, |
|
|
25 |
kernel_size, |
|
|
26 |
stride=stride, |
|
|
27 |
padding=padding, |
|
|
28 |
bias=bias) |
|
|
29 |
self.gn1 = nn.GroupNorm(num_groups, out_channels) |
|
|
30 |
self.conv2 = nn.Conv3d(out_channels, |
|
|
31 |
out_channels, |
|
|
32 |
kernel_size, |
|
|
33 |
stride=stride, |
|
|
34 |
padding=padding, |
|
|
35 |
bias=bias) |
|
|
36 |
self.gn2 = nn.GroupNorm(num_groups, out_channels) |
|
|
37 |
|
|
|
38 |
def forward(self, x): |
|
|
39 |
identity = x |
|
|
40 |
res = self.relu(x) |
|
|
41 |
res = self.conv1(res) |
|
|
42 |
res = self.gn1(res) |
|
|
43 |
res = self.relu(res) |
|
|
44 |
res = self.conv2(res) |
|
|
45 |
res = self.gn2(res) |
|
|
46 |
res = self.relu(res) |
|
|
47 |
if self.in_channels != self.out_channels: |
|
|
48 |
pad = [0] * (2 * len(identity.size())) |
|
|
49 |
pad[6] = (self.out_channels - self.in_channels) |
|
|
50 |
identity = F.pad(input=identity, pad=pad, mode='constant', value=0) |
|
|
51 |
return res + identity |
|
|
52 |
|
|
|
53 |
|
|
|
54 |
class UNet3D(nn.Module): |
|
|
55 |
def __init__(self, |
|
|
56 |
in_channel, |
|
|
57 |
n_classes, |
|
|
58 |
use_bias=True, |
|
|
59 |
inplanes=32, |
|
|
60 |
num_groups=8): |
|
|
61 |
self.in_channel = in_channel |
|
|
62 |
self.n_classes = n_classes |
|
|
63 |
self.inplanes = inplanes |
|
|
64 |
self.num_groups = num_groups |
|
|
65 |
planes = [inplanes * int(pow(2, i)) for i in range(0, 5)] |
|
|
66 |
super(UNet3D, self).__init__() |
|
|
67 |
self.ec0 = Encoder(in_channel, |
|
|
68 |
planes[1], |
|
|
69 |
bias=use_bias, |
|
|
70 |
num_groups=num_groups) |
|
|
71 |
self.ec1 = Encoder(planes[1], |
|
|
72 |
planes[2], |
|
|
73 |
bias=use_bias, |
|
|
74 |
num_groups=num_groups) |
|
|
75 |
self.ec1_2 = Encoder(planes[2], |
|
|
76 |
planes[2], |
|
|
77 |
bias=use_bias, |
|
|
78 |
num_groups=num_groups) |
|
|
79 |
self.ec2 = Encoder(planes[2], |
|
|
80 |
planes[3], |
|
|
81 |
bias=use_bias, |
|
|
82 |
num_groups=num_groups) |
|
|
83 |
self.ec2_2 = Encoder(planes[3], |
|
|
84 |
planes[3], |
|
|
85 |
bias=use_bias, |
|
|
86 |
num_groups=num_groups) |
|
|
87 |
self.ec3 = Encoder(planes[3], |
|
|
88 |
planes[4], |
|
|
89 |
bias=use_bias, |
|
|
90 |
num_groups=num_groups) |
|
|
91 |
self.ec3_2 = Encoder(planes[4], |
|
|
92 |
planes[4], |
|
|
93 |
bias=use_bias, |
|
|
94 |
num_groups=num_groups) |
|
|
95 |
self.maxpool = nn.MaxPool3d(2) |
|
|
96 |
self.dc3 = Encoder(planes[4], |
|
|
97 |
planes[4], |
|
|
98 |
bias=use_bias, |
|
|
99 |
num_groups=num_groups) |
|
|
100 |
self.dc3_2 = Encoder(planes[4], |
|
|
101 |
planes[4], |
|
|
102 |
bias=use_bias, |
|
|
103 |
num_groups=num_groups) |
|
|
104 |
self.up3 = self.decoder(planes[4], |
|
|
105 |
planes[3], |
|
|
106 |
kernel_size=2, |
|
|
107 |
stride=2, |
|
|
108 |
bias=use_bias) |
|
|
109 |
self.dc2 = Encoder(planes[4], |
|
|
110 |
planes[3], |
|
|
111 |
bias=use_bias, |
|
|
112 |
num_groups=num_groups) |
|
|
113 |
self.dc2_2 = Encoder(planes[3], |
|
|
114 |
planes[3], |
|
|
115 |
bias=use_bias, |
|
|
116 |
num_groups=num_groups) |
|
|
117 |
self.up2 = self.decoder(planes[3], |
|
|
118 |
planes[2], |
|
|
119 |
kernel_size=2, |
|
|
120 |
stride=2, |
|
|
121 |
bias=use_bias) |
|
|
122 |
self.dc1 = Encoder(planes[3], |
|
|
123 |
planes[2], |
|
|
124 |
bias=use_bias, |
|
|
125 |
num_groups=num_groups) |
|
|
126 |
self.dc1_2 = Encoder(planes[2], |
|
|
127 |
planes[2], |
|
|
128 |
bias=use_bias, |
|
|
129 |
num_groups=num_groups) |
|
|
130 |
self.up1 = self.decoder(planes[2], |
|
|
131 |
planes[1], |
|
|
132 |
kernel_size=2, |
|
|
133 |
stride=2, |
|
|
134 |
bias=use_bias) |
|
|
135 |
self.dc0a = Encoder(planes[2], |
|
|
136 |
planes[1], |
|
|
137 |
bias=use_bias, |
|
|
138 |
num_groups=num_groups) |
|
|
139 |
self.dc0b = self.decoder(planes[1], |
|
|
140 |
n_classes, |
|
|
141 |
kernel_size=1, |
|
|
142 |
stride=1, |
|
|
143 |
bias=use_bias, |
|
|
144 |
relu=False) |
|
|
145 |
for m in self.modules(): |
|
|
146 |
if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d): |
|
|
147 |
nn.init.kaiming_normal_(m.weight, |
|
|
148 |
mode='fan_out', |
|
|
149 |
nonlinearity='relu') |
|
|
150 |
elif isinstance(m, nn.GroupNorm): |
|
|
151 |
nn.init.constant_(m.weight, 1) |
|
|
152 |
nn.init.constant_(m.bias, 0) |
|
|
153 |
|
|
|
154 |
def decoder(self, |
|
|
155 |
in_channels, |
|
|
156 |
out_channels, |
|
|
157 |
kernel_size, |
|
|
158 |
stride=1, |
|
|
159 |
padding=0, |
|
|
160 |
output_padding=0, |
|
|
161 |
bias=True, |
|
|
162 |
relu=True): |
|
|
163 |
layer = [ |
|
|
164 |
nn.ConvTranspose3d(in_channels, |
|
|
165 |
out_channels, |
|
|
166 |
kernel_size, |
|
|
167 |
stride=stride, |
|
|
168 |
padding=padding, |
|
|
169 |
output_padding=output_padding, |
|
|
170 |
bias=bias), |
|
|
171 |
] |
|
|
172 |
if relu: |
|
|
173 |
layer.append(nn.GroupNorm(self.num_groups, out_channels)) |
|
|
174 |
layer.append(nn.ReLU()) |
|
|
175 |
layer = nn.Sequential(*layer) |
|
|
176 |
return layer |
|
|
177 |
|
|
|
178 |
def forward(self, x): |
|
|
179 |
e0 = self.ec0(x) |
|
|
180 |
e1 = self.ec1_2(self.ec1(self.maxpool(e0))) |
|
|
181 |
e2 = self.ec2_2(self.ec2(self.maxpool(e1))) |
|
|
182 |
e3 = self.ec3_2(self.ec3(self.maxpool(e2))) |
|
|
183 |
d3 = self.up3(self.dc3_2(self.dc3(e3))) |
|
|
184 |
if d3.size()[2:] != e2.size()[2:]: |
|
|
185 |
d3 = F.interpolate(d3, |
|
|
186 |
e2.size()[2:], |
|
|
187 |
mode='trilinear', |
|
|
188 |
align_corners=False) |
|
|
189 |
d3 = torch.cat((d3, e2), 1) |
|
|
190 |
d2 = self.up2(self.dc2_2(self.dc2(d3))) |
|
|
191 |
if d2.size()[2:] != e1.size()[2:]: |
|
|
192 |
d2 = F.interpolate(d2, |
|
|
193 |
e1.size()[2:], |
|
|
194 |
mode='trilinear', |
|
|
195 |
align_corners=False) |
|
|
196 |
d2 = torch.cat((d2, e1), 1) |
|
|
197 |
d1 = self.up1(self.dc1_2(self.dc1(d2))) |
|
|
198 |
if d1.size()[2:] != e0.size()[2:]: |
|
|
199 |
d1 = F.interpolate(d1, |
|
|
200 |
e0.size()[2:], |
|
|
201 |
mode='trilinear', |
|
|
202 |
align_corners=False) |
|
|
203 |
d1 = torch.cat((d1, e0), 1) |
|
|
204 |
d0 = self.dc0b(self.dc0a(d1)) |
|
|
205 |
return d0 |