|
a |
|
b/model.py |
|
|
1 |
from tkinter import Y |
|
|
2 |
import numpy as np |
|
|
3 |
import torch |
|
|
4 |
import torch.nn as nn |
|
|
5 |
import torch.nn.functional as F |
|
|
6 |
from torch.nn.functional import relu |
|
|
7 |
|
|
|
8 |
from models.pointnet_utils import PointNetEncoder |
|
|
9 |
from models.pointnet2_utils import PointNetSetAbstraction,PointNetFeaturePropagation |
|
|
10 |
|
|
|
11 |
class ECGnet(nn.Module): |
|
|
12 |
def __init__(self, in_ch=3+4, out_ch=3, num_input=1024, z_dims=16): |
|
|
13 |
super(ECGnet, self).__init__() |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
self.encoder_signal = CRNN() |
|
|
17 |
|
|
|
18 |
# decode for signal |
|
|
19 |
self.elu = nn.ELU(inplace=True) |
|
|
20 |
self.fc1 = nn.Linear(z_dims, 256*2) |
|
|
21 |
self.fc2 = nn.Linear(256*2, 512*2) |
|
|
22 |
self.up = nn.Upsample(size=(8, 512), mode='bilinear') |
|
|
23 |
self.deconv = DoubleDeConv(1, 1) |
|
|
24 |
|
|
|
25 |
self.decoder_MI = nn.Sequential( |
|
|
26 |
nn.Linear(z_dims, 128), |
|
|
27 |
nn.ReLU(), |
|
|
28 |
nn.Linear(128, 64), |
|
|
29 |
nn.ReLU(), |
|
|
30 |
nn.Linear(64, out_ch), |
|
|
31 |
) |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
def reparameterize(self, mu, log_var): |
|
|
35 |
""" |
|
|
36 |
:param mu: mean from the encoder's latent space |
|
|
37 |
:param log_var: log variance from the encoder's latent space |
|
|
38 |
""" |
|
|
39 |
std = torch.exp(0.5*log_var) # standard deviation |
|
|
40 |
eps = torch.randn(log_var.shape).to(std.device) # `randn_like` as we need the same size |
|
|
41 |
sample = mu + (eps * std) # sampling as if coming from the input space |
|
|
42 |
return sample |
|
|
43 |
|
|
|
44 |
def decode_signal(self, latent_z): # P(x|z, c) |
|
|
45 |
''' |
|
|
46 |
z: (bs, latent_size) |
|
|
47 |
''' |
|
|
48 |
inputs = latent_z |
|
|
49 |
f = self.elu(self.fc1(inputs)) |
|
|
50 |
f = self.elu(self.fc2(f)) |
|
|
51 |
u = self.up(f.reshape(f.shape[0], 1, 8, -1)) |
|
|
52 |
dc = self.deconv(u) |
|
|
53 |
|
|
|
54 |
return dc |
|
|
55 |
|
|
|
56 |
def forward(self, partial_input, signal_input): |
|
|
57 |
|
|
|
58 |
mu_signal, std_signal = self.encoder_signal(signal_input) |
|
|
59 |
latent_z_signal = self.reparameterize(mu_signal, std_signal) |
|
|
60 |
y_ECG = self.decode_signal(latent_z_signal) |
|
|
61 |
y_MI = self.decoder_MI(latent_z_signal) |
|
|
62 |
y_MI = nn.Softmax(dim=1)(y_MI) |
|
|
63 |
|
|
|
64 |
return y_MI, y_ECG, mu_signal, std_signal |
|
|
65 |
|
|
|
66 |
class InferenceNet(nn.Module): |
|
|
67 |
def __init__(self, in_ch=3+4, out_ch=3, num_input=1024, z_dims=16): |
|
|
68 |
super(InferenceNet, self).__init__() |
|
|
69 |
|
|
|
70 |
self.z_dims = z_dims |
|
|
71 |
|
|
|
72 |
# PointNet++ Encoder |
|
|
73 |
self.sa1 = PointNetSetAbstraction(npoint=num_input, radius=0.2, nsample=64, in_channel=in_ch, mlp=[64, 64, 128], group_all=False) |
|
|
74 |
self.sa2 = PointNetSetAbstraction(128, 0.4, 64, 128 + 3, [128, 128, 256], False) |
|
|
75 |
self.sa3 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 512, 1024], False) |
|
|
76 |
self.fc11 = nn.Linear(1024*16, z_dims*2) |
|
|
77 |
|
|
|
78 |
# PointNet++ Decoder |
|
|
79 |
self.fc12 = nn.Linear(z_dims*2, 1024) # feat_ECG = H*feat_MI + epsilon |
|
|
80 |
self.fp3 = PointNetFeaturePropagation(1280, [256, 256]) |
|
|
81 |
self.fp2 = PointNetFeaturePropagation(384, [256, 128]) |
|
|
82 |
self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128]) |
|
|
83 |
self.conv1 = nn.Conv1d(128, 128, 1) |
|
|
84 |
self.bn1 = nn.BatchNorm1d(128) |
|
|
85 |
self.drop1 = nn.Dropout(0.5) |
|
|
86 |
self.conv2 = nn.Conv1d(128, out_ch, 1) |
|
|
87 |
|
|
|
88 |
self.decoder_geometry = BetaVAE_Decoder(num_input, num_input//4, in_ch, z_dims) # in_ch -> out_ch*3 |
|
|
89 |
|
|
|
90 |
self.encoder_signal = CRNN() |
|
|
91 |
|
|
|
92 |
# decode for signal |
|
|
93 |
self.elu = nn.ELU(inplace=True) |
|
|
94 |
self.fc1 = nn.Linear(z_dims, 256*2) |
|
|
95 |
self.fc2 = nn.Linear(256*2, 512*2) |
|
|
96 |
self.up = nn.Upsample(size=(8, 512), mode='bilinear') |
|
|
97 |
self.deconv = DoubleDeConv(1, 1) |
|
|
98 |
|
|
|
99 |
def reparameterize(self, mu, log_var): |
|
|
100 |
""" |
|
|
101 |
:param mu: mean from the encoder's latent space |
|
|
102 |
:param log_var: log variance from the encoder's latent space |
|
|
103 |
""" |
|
|
104 |
std = torch.exp(0.5*log_var) # standard deviation |
|
|
105 |
eps = torch.randn(log_var.shape).to(std.device) # `randn_like` as we need the same size |
|
|
106 |
sample = mu + (eps * std) # sampling as if coming from the input space |
|
|
107 |
return sample |
|
|
108 |
|
|
|
109 |
def decode_signal(self, latent_z): # P(x|z, c) |
|
|
110 |
''' |
|
|
111 |
z: (bs, latent_size) |
|
|
112 |
''' |
|
|
113 |
inputs = latent_z |
|
|
114 |
f = self.elu(self.fc1(inputs)) |
|
|
115 |
f = self.elu(self.fc2(f)) |
|
|
116 |
u = self.up(f.reshape(f.shape[0], 1, 8, -1)) |
|
|
117 |
dc = self.deconv(u) |
|
|
118 |
|
|
|
119 |
return dc |
|
|
120 |
|
|
|
121 |
def forward(self, partial_input, signal_input): |
|
|
122 |
num_points = partial_input.shape[-1] |
|
|
123 |
# extract ecg features |
|
|
124 |
mu_signal, std_signal = self.encoder_signal(signal_input) |
|
|
125 |
latent_z_signal = mu_signal # self.reparameterize(mu_signal, std_signal) |
|
|
126 |
|
|
|
127 |
# extract point cloud features |
|
|
128 |
l0_xyz = partial_input[:,:3,:] |
|
|
129 |
l0_points = partial_input[:,3:,:] |
|
|
130 |
l1_xyz, l1_points = self.sa1(l0_xyz, l0_points) |
|
|
131 |
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) |
|
|
132 |
l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) |
|
|
133 |
features = self.fc11(l3_points.view(-1, 1024*16)) |
|
|
134 |
mu_geometry = features[:, : self.z_dims] |
|
|
135 |
std_geometry = features[:, self.z_dims:] + 1e-6 |
|
|
136 |
latent_geometry = self.reparameterize(mu_signal, std_signal) |
|
|
137 |
# latent_geometry = self.fc11(l3_points.view(-1, 1024*16)) |
|
|
138 |
|
|
|
139 |
# fuse two features |
|
|
140 |
# mu = torch.cat((mu_geometry, mu_signal), dim=1) |
|
|
141 |
# log_var = torch.cat((std_geometry, std_signal), dim=1) |
|
|
142 |
# latent_z = self.reparameterize(mu, log_var) |
|
|
143 |
latent_z = torch.cat((latent_z_signal, latent_geometry), dim=1) |
|
|
144 |
|
|
|
145 |
# segment point cloud |
|
|
146 |
anatomy_signal_feat = F.relu(self.fc12(latent_z)) |
|
|
147 |
anatomy_signal_feat = anatomy_signal_feat.view(-1, 1024, 1).repeat(1, 1, num_points) |
|
|
148 |
l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, anatomy_signal_feat) |
|
|
149 |
l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) |
|
|
150 |
l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points) |
|
|
151 |
y_seg = self.drop1(F.relu(self.bn1(self.conv1(l0_points)))) |
|
|
152 |
y_seg = self.conv2(y_seg) |
|
|
153 |
y_seg = nn.Softmax(dim=1)(y_seg) |
|
|
154 |
|
|
|
155 |
# reconstruct point cloud and ecg |
|
|
156 |
y_coarse, y_detail = self.decoder_geometry(latent_geometry) |
|
|
157 |
y_coarse, y_detail = nn.Sigmoid()(y_coarse), nn.Sigmoid()(y_detail) |
|
|
158 |
y_ECG = self.decode_signal(latent_z_signal) |
|
|
159 |
|
|
|
160 |
return y_seg, y_coarse, y_detail, y_ECG, mu_signal, std_signal |
|
|
161 |
|
|
|
162 |
class CRNN(nn.Module): |
|
|
163 |
''' |
|
|
164 |
nh: default=256, 'size of the LSTM hidden state' |
|
|
165 |
imgH: default=8, 'the height of the input image to network' |
|
|
166 |
imgW: default=256, 'the width of the input image to network' |
|
|
167 |
|
|
|
168 |
:param class_labels: list[n_class] |
|
|
169 |
:return: (n_batch, n_class) |
|
|
170 |
''' |
|
|
171 |
|
|
|
172 |
def __init__(self, n_lead=8, z_dims=16): |
|
|
173 |
super(CRNN, self).__init__() |
|
|
174 |
|
|
|
175 |
n_out = 128 |
|
|
176 |
self.z_dims = z_dims |
|
|
177 |
|
|
|
178 |
self.cnn = nn.Sequential( |
|
|
179 |
nn.Conv1d(n_lead, n_out, kernel_size=16, stride=2, padding=2), |
|
|
180 |
nn.BatchNorm1d(n_out), |
|
|
181 |
nn.LeakyReLU(0.2, inplace=True), |
|
|
182 |
nn.Conv1d(n_out, n_out*2, kernel_size=16, stride=2, padding=2), |
|
|
183 |
nn.BatchNorm1d(n_out*2), |
|
|
184 |
nn.LeakyReLU(0.2, inplace=True) |
|
|
185 |
) |
|
|
186 |
|
|
|
187 |
|
|
|
188 |
self.rnn = BidirectionalLSTM(256, z_dims*4, z_dims*2) |
|
|
189 |
# self.rnn = nn.Sequential( |
|
|
190 |
# BidirectionalLSTM(512, nh, nh), |
|
|
191 |
# BidirectionalLSTM(nh, nh, 1)) |
|
|
192 |
|
|
|
193 |
|
|
|
194 |
def forward(self, input): |
|
|
195 |
# conv features |
|
|
196 |
conv = self.cnn(input) |
|
|
197 |
b, c, w = conv.size() |
|
|
198 |
conv = conv.permute(2, 0, 1) # [w, b, c] |
|
|
199 |
|
|
|
200 |
# rnn features |
|
|
201 |
output = self.rnn(conv).permute(1, 0, 2) |
|
|
202 |
features = torch.max(output, 1)[0] |
|
|
203 |
mean = features[:, : self.z_dims] |
|
|
204 |
std = features[:, self.z_dims:] + 1e-6 |
|
|
205 |
|
|
|
206 |
return mean, std |
|
|
207 |
|
|
|
208 |
|
|
|
209 |
def backward_hook(self, module, grad_input, grad_output): |
|
|
210 |
for g in grad_input: |
|
|
211 |
g[g != g] = 0 # replace all nan/inf in gradients to zero |
|
|
212 |
|
|
|
213 |
class BidirectionalLSTM(nn.Module): |
|
|
214 |
|
|
|
215 |
def __init__(self, nIn, nHidden, nOut): |
|
|
216 |
super(BidirectionalLSTM, self).__init__() |
|
|
217 |
|
|
|
218 |
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) |
|
|
219 |
self.embedding = nn.Linear(nHidden * 2, nOut) |
|
|
220 |
|
|
|
221 |
def forward(self, input): |
|
|
222 |
recurrent, _ = self.rnn(input) |
|
|
223 |
T, b, h = recurrent.size() |
|
|
224 |
t_rec = recurrent.view(T * b, h) |
|
|
225 |
|
|
|
226 |
output = self.embedding(t_rec) # [T * b, nOut] |
|
|
227 |
output = output.view(T, b, -1) |
|
|
228 |
|
|
|
229 |
return output |
|
|
230 |
|
|
|
231 |
class PointNet(nn.Module): |
|
|
232 |
def __init__(self, num_classes=10, n_signal=10, n_param=4, n_ECG=128): |
|
|
233 |
super(PointNet, self).__init__() |
|
|
234 |
self.k = num_classes |
|
|
235 |
self.n_signal = n_signal |
|
|
236 |
self.feat = PointNetEncoder(global_feat=False, feature_transform=True, channel=4) |
|
|
237 |
self.conv1 = torch.nn.Conv1d(1024+64+n_ECG, 512, 1) |
|
|
238 |
self.conv2 = torch.nn.Conv1d(512, 256, 1) |
|
|
239 |
self.conv3 = torch.nn.Conv1d(256, 128, 1) |
|
|
240 |
self.conv4 = torch.nn.Conv1d(128, self.k, 1) |
|
|
241 |
self.bn1 = nn.BatchNorm1d(512) |
|
|
242 |
self.bn2 = nn.BatchNorm1d(256) |
|
|
243 |
self.bn3 = nn.BatchNorm1d(128) |
|
|
244 |
|
|
|
245 |
self.ECG_model = CRNN() |
|
|
246 |
|
|
|
247 |
self.inference_model = nn.Sequential( |
|
|
248 |
nn.Linear(1024+n_ECG, 512), |
|
|
249 |
nn.Dropout(0.5), |
|
|
250 |
nn.ReLU(), |
|
|
251 |
nn.Linear(512, 256), |
|
|
252 |
nn.Dropout(0.5), |
|
|
253 |
nn.ReLU(), |
|
|
254 |
nn.Linear(256, self.n_signal*n_param), |
|
|
255 |
nn.Sigmoid() |
|
|
256 |
) |
|
|
257 |
|
|
|
258 |
|
|
|
259 |
def forward(self, x, signal): |
|
|
260 |
n_pts = x.size()[2] |
|
|
261 |
anatomy_signal_feature, global_feature, trans_feat = self.feat(x) |
|
|
262 |
ECG_feature = self.ECG_model(signal) |
|
|
263 |
ECG_feature_extend = ECG_feature.repeat(1, 1, n_pts) |
|
|
264 |
|
|
|
265 |
anatomy_signal_feat = torch.cat([anatomy_signal_feature, ECG_feature_extend], 1) |
|
|
266 |
y1 = F.relu(self.bn1(self.conv1(anatomy_signal_feat))) |
|
|
267 |
y1 = F.relu(self.bn2(self.conv2(y1))) |
|
|
268 |
y1 = F.relu(self.bn3(self.conv3(y1))) |
|
|
269 |
y1 = self.conv4(y1) |
|
|
270 |
y1 = y1.transpose(2,1).contiguous() |
|
|
271 |
out_ATM = y1 #nn.Sigmoid()(y1) |
|
|
272 |
|
|
|
273 |
return out_ATM |
|
|
274 |
|
|
|
275 |
class PointNet_plusplus(nn.Module): |
|
|
276 |
def __init__(self, num_classes=10, n_signal=10, n_param=4, n_ECG=128): |
|
|
277 |
super(PointNet_plusplus, self).__init__() |
|
|
278 |
self.n_signal = n_signal |
|
|
279 |
self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=64, in_channel= 3 + 4, mlp=[64, 64, 128], group_all=False) |
|
|
280 |
self.sa2 = PointNetSetAbstraction(128, 0.4, 64, 128 + 3, [128, 128, 256], False) |
|
|
281 |
self.sa3 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 512, 1024], False) |
|
|
282 |
self.fp3 = PointNetFeaturePropagation(1280+n_ECG, [256, 256]) |
|
|
283 |
self.fp2 = PointNetFeaturePropagation(384, [256, 128]) |
|
|
284 |
self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128]) |
|
|
285 |
self.conv1 = nn.Conv1d(128, 128, 1) |
|
|
286 |
self.bn1 = nn.BatchNorm1d(128) |
|
|
287 |
self.drop1 = nn.Dropout(0.5) |
|
|
288 |
self.conv2 = nn.Conv1d(128, num_classes, 1) |
|
|
289 |
|
|
|
290 |
self.ECG_model = CRNN() |
|
|
291 |
self.inference_model = nn.Sequential( |
|
|
292 |
nn.Linear(1024+n_ECG, 512), |
|
|
293 |
nn.Dropout(0.5), |
|
|
294 |
nn.ReLU(), |
|
|
295 |
nn.Linear(512, 256), |
|
|
296 |
nn.Dropout(0.5), |
|
|
297 |
nn.ReLU(), |
|
|
298 |
nn.Linear(256, self.n_signal*n_param), |
|
|
299 |
nn.Sigmoid()) |
|
|
300 |
|
|
|
301 |
def forward(self, x, signal): |
|
|
302 |
l0_points = x |
|
|
303 |
l0_xyz = x[:,:3,:] |
|
|
304 |
|
|
|
305 |
ECG_feature = self.ECG_model(signal) |
|
|
306 |
|
|
|
307 |
# Set Abstraction layers |
|
|
308 |
l1_xyz, l1_points = self.sa1(l0_xyz, l0_points) |
|
|
309 |
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) |
|
|
310 |
l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) |
|
|
311 |
|
|
|
312 |
ECG_feature_extend = ECG_feature.repeat(1, 1, l3_points.size()[2]) |
|
|
313 |
anatomy_signal_feat = torch.cat([l3_points, ECG_feature_extend], 1) |
|
|
314 |
|
|
|
315 |
# Feature Propagation layers |
|
|
316 |
l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, anatomy_signal_feat) |
|
|
317 |
l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) |
|
|
318 |
l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points) |
|
|
319 |
|
|
|
320 |
y1 = self.drop1(F.relu(self.bn1(self.conv1(l0_points)))) |
|
|
321 |
y1 = self.conv2(y1) |
|
|
322 |
out_ATM = y1 #nn.Sigmoid()(y1) |
|
|
323 |
out_ATM = out_ATM.permute(0, 2, 1) |
|
|
324 |
|
|
|
325 |
return out_ATM |
|
|
326 |
|
|
|
327 |
class BetaVAE(nn.Module): |
|
|
328 |
def __init__(self, in_ch=4, num_input=1024, num_class=2, z_dims=16): |
|
|
329 |
super(BetaVAE, self).__init__() |
|
|
330 |
|
|
|
331 |
self.encoder = BetaVAE_Encoder(in_ch, z_dims) |
|
|
332 |
self.decoder = BetaVAE_Decoder_new(num_input, num_class) |
|
|
333 |
|
|
|
334 |
def forward(self, x): |
|
|
335 |
latent_z = self.encoder(x) |
|
|
336 |
y = self.decoder(latent_z) |
|
|
337 |
return y |
|
|
338 |
|
|
|
339 |
class BetaVAE_Encoder(nn.Module): |
|
|
340 |
def __init__(self, in_ch, z_dims): |
|
|
341 |
super(BetaVAE_Encoder, self).__init__() |
|
|
342 |
self.z_dims = z_dims |
|
|
343 |
self.mlp_conv1 = mlp_conv(in_ch, layer_dims=[128, 256]) |
|
|
344 |
self.mlp_conv2 = mlp_conv(512, layer_dims=[512, 1024]) |
|
|
345 |
|
|
|
346 |
self.fc1 = nn.Linear(1024, 1024) |
|
|
347 |
self.fc2 = nn.Linear(1024, 256) |
|
|
348 |
self.fc3 = nn.Linear(256, z_dims*2) |
|
|
349 |
|
|
|
350 |
def forward(self, inputs): |
|
|
351 |
num_points = [inputs.shape[2]] |
|
|
352 |
features = self.mlp_conv1(inputs) |
|
|
353 |
features_global = point_maxpool(features, num_points, keepdim=True) |
|
|
354 |
features_global = point_unpool(features_global, num_points) |
|
|
355 |
features = torch.cat([features, features_global], dim=1) |
|
|
356 |
features = self.mlp_conv2(features) |
|
|
357 |
features = point_maxpool(features, num_points) |
|
|
358 |
|
|
|
359 |
features = features.view(features.size()[0], -1) |
|
|
360 |
features = self.fc1(features) |
|
|
361 |
features = self.fc2(features) |
|
|
362 |
features = self.fc3(features) |
|
|
363 |
mean = features[:, : self.z_dims] |
|
|
364 |
std = features[:, self.z_dims:] + 1e-6 |
|
|
365 |
|
|
|
366 |
return mean, std |
|
|
367 |
|
|
|
368 |
class BetaVAE_Decoder_new(nn.Module): |
|
|
369 |
def __init__(self, num_input, num_class=2, z_dims=16*2): |
|
|
370 |
super(BetaVAE_Decoder_new, self).__init__() |
|
|
371 |
self.out_ch = num_class |
|
|
372 |
self.n_pts = num_input |
|
|
373 |
self.mlp = mlp(in_channels=z_dims, layer_dims=[128, 256, 512, 1024, self.n_pts * self.out_ch]) |
|
|
374 |
|
|
|
375 |
def forward(self, features): |
|
|
376 |
y = self.mlp(features).reshape(-1, self.out_ch, self.n_pts) |
|
|
377 |
|
|
|
378 |
return nn.Softmax(dim=1)(y) |
|
|
379 |
|
|
|
380 |
class BetaVAE_Decoder_plus(nn.Module): |
|
|
381 |
def __init__(self, num_dense, num_coarse, out_ch, z_dims): |
|
|
382 |
super(BetaVAE_Decoder_plus, self).__init__() |
|
|
383 |
self.out_ch = out_ch |
|
|
384 |
self.num_coarse = num_coarse |
|
|
385 |
self.grid_size = int(np.sqrt(num_dense//num_coarse)) |
|
|
386 |
self.num_fine = num_dense |
|
|
387 |
|
|
|
388 |
# PointNet++ Decoder |
|
|
389 |
self.fc12 = nn.Linear(z_dims*2, 1024) |
|
|
390 |
self.fp3 = PointNetFeaturePropagation(1280, [256, 256]) |
|
|
391 |
self.fp2 = PointNetFeaturePropagation(384, [256, 128]) |
|
|
392 |
self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128]) |
|
|
393 |
self.conv1 = nn.Conv1d(128, 128, 1) |
|
|
394 |
self.bn1 = nn.BatchNorm1d(128) |
|
|
395 |
self.drop1 = nn.Dropout(0.5) |
|
|
396 |
self.conv2 = nn.Conv1d(128, out_ch, 1) |
|
|
397 |
|
|
|
398 |
|
|
|
399 |
def forward(self, latent_z, l0_xyz, l1_xyz, l2_xyz, l3_xyz): |
|
|
400 |
anatomy_signal_feat = F.relu(self.fc12(latent_z)) |
|
|
401 |
coarse = anatomy_signal_feat.view(-1, 1024, 1).repeat(1, 1, self.num_coarse) |
|
|
402 |
l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, coarse) |
|
|
403 |
l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) |
|
|
404 |
l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points) |
|
|
405 |
fine = self.drop1(F.relu(self.bn1(self.conv1(l0_points)))) |
|
|
406 |
fine = self.conv2(fine) |
|
|
407 |
|
|
|
408 |
return coarse, fine |
|
|
409 |
|
|
|
410 |
class BetaVAE_Decoder(nn.Module): |
|
|
411 |
def __init__(self, num_dense, num_coarse, out_ch, z_dims): |
|
|
412 |
super(BetaVAE_Decoder, self).__init__() |
|
|
413 |
self.out_ch = out_ch |
|
|
414 |
self.num_coarse = num_coarse |
|
|
415 |
self.grid_size = int(np.sqrt(num_dense//num_coarse)) |
|
|
416 |
self.num_fine = num_dense |
|
|
417 |
|
|
|
418 |
self.mlp = mlp(in_channels=z_dims, layer_dims=[256, 512, 1024, 2048, self.num_coarse * self.out_ch]) |
|
|
419 |
x = torch.linspace(-0.05, 0.05, self.grid_size) |
|
|
420 |
y = torch.linspace(-0.05, 0.05, self.grid_size) |
|
|
421 |
self.grid = torch.cat(torch.meshgrid(x, y), dim=0).view(1, 2, self.grid_size ** 2) |
|
|
422 |
# self.grid = torch.stack(torch.meshgrid(x, y), dim=2) |
|
|
423 |
# self.grid = torch.reshape(self.grid.transpose(1, 0), [-1, 2]).unsqueeze(0) |
|
|
424 |
|
|
|
425 |
self.mlp_conv3 = mlp_conv(z_dims+2+out_ch, layer_dims=[512, 512, out_ch]) # here "+2" refers to the two axes of grid |
|
|
426 |
|
|
|
427 |
def forward(self, latent_z): |
|
|
428 |
features = latent_z |
|
|
429 |
coarse = self.mlp(features).reshape(-1, self.num_coarse, self.out_ch) |
|
|
430 |
point_feat = coarse.unsqueeze(2).repeat(1, 1, self.grid_size * 2, 1) |
|
|
431 |
point_feat = point_feat.reshape(-1, self.out_ch, self.num_fine) |
|
|
432 |
|
|
|
433 |
grid_feat = self.grid.unsqueeze(2).repeat(features.shape[0], 1, self.num_coarse, 1).to(features.device) |
|
|
434 |
grid_feat = grid_feat.reshape(features.shape[0], -1, self.num_fine) |
|
|
435 |
global_feat = features.unsqueeze(2).repeat(1, 1, self.num_fine) |
|
|
436 |
feat = torch.cat([grid_feat, point_feat, global_feat], dim=1) |
|
|
437 |
|
|
|
438 |
center = point_feat.reshape(-1, self.num_fine, self.out_ch) |
|
|
439 |
fine = self.mlp_conv3(feat).transpose(1, 2) + center |
|
|
440 |
|
|
|
441 |
return coarse, fine |
|
|
442 |
|
|
|
443 |
def point_maxpool(features, npts, keepdim=True): |
|
|
444 |
splitted = torch.split(features, npts[0], dim=1) |
|
|
445 |
outputs = [torch.max(f, dim=2, keepdim=keepdim)[0] for f in splitted] # modified by Lei in 2022/02/10 |
|
|
446 |
return torch.cat(outputs, dim=0) |
|
|
447 |
# return torch.max(features, dim=2, keepdims=keepdims)[0] |
|
|
448 |
|
|
|
449 |
def point_unpool(features, npts): |
|
|
450 |
features = torch.split(features, features.shape[0], dim=0) |
|
|
451 |
outputs = [f.repeat(1, 1, npts[i]) for i, f in enumerate(features)] |
|
|
452 |
# outputs = [torch.tile(f, [1, 1, npts[i]]) for i, f in enumerate(features)] |
|
|
453 |
return torch.cat(outputs, dim=0) |
|
|
454 |
# return features.repeat([1, 1, 256]) |
|
|
455 |
|
|
|
456 |
class mlp_conv(nn.Module): |
|
|
457 |
def __init__(self, in_channels, layer_dims): |
|
|
458 |
super(mlp_conv, self).__init__() |
|
|
459 |
self.layer_dims = layer_dims |
|
|
460 |
for i, out_channels in enumerate(self.layer_dims): |
|
|
461 |
layer = torch.nn.Conv1d(in_channels, out_channels, kernel_size=1) |
|
|
462 |
setattr(self, 'conv_' + str(i), layer) |
|
|
463 |
in_channels = out_channels |
|
|
464 |
|
|
|
465 |
def __call__(self, inputs): |
|
|
466 |
outputs = inputs |
|
|
467 |
dims = len(self.layer_dims) |
|
|
468 |
for i in range(dims): |
|
|
469 |
layer = getattr(self, 'conv_' + str(i)) |
|
|
470 |
if i == dims - 1: |
|
|
471 |
outputs = layer(outputs) |
|
|
472 |
else: |
|
|
473 |
outputs = relu(layer(outputs)) |
|
|
474 |
return outputs |
|
|
475 |
|
|
|
476 |
class mlp(nn.Module): |
|
|
477 |
def __init__(self, in_channels, layer_dims): |
|
|
478 |
super(mlp, self).__init__() |
|
|
479 |
self.layer_dims = layer_dims |
|
|
480 |
for i, out_channels in enumerate(layer_dims): |
|
|
481 |
layer = torch.nn.Linear(in_channels, out_channels) |
|
|
482 |
setattr(self, 'fc_' + str(i), layer) |
|
|
483 |
in_channels = out_channels |
|
|
484 |
|
|
|
485 |
def __call__(self, inputs): |
|
|
486 |
outputs = inputs |
|
|
487 |
dims = len(self.layer_dims) |
|
|
488 |
for i in range(dims): |
|
|
489 |
layer = getattr(self, 'fc_' + str(i)) |
|
|
490 |
if i == dims - 1: |
|
|
491 |
outputs = layer(outputs) |
|
|
492 |
else: |
|
|
493 |
outputs = relu(layer(outputs)) |
|
|
494 |
return outputs |
|
|
495 |
|
|
|
496 |
class DoubleDeConv(nn.Module): |
|
|
497 |
def __init__(self, in_ch, out_ch): |
|
|
498 |
super(DoubleDeConv, self).__init__() |
|
|
499 |
self.conv = nn.Sequential( |
|
|
500 |
nn.ConvTranspose2d(in_ch, out_ch, kernel_size=(3, 3), padding=1), |
|
|
501 |
nn.BatchNorm2d(out_ch), |
|
|
502 |
nn.ELU(inplace=True), |
|
|
503 |
nn.ConvTranspose2d(out_ch, out_ch, kernel_size=(3, 3), padding=1), |
|
|
504 |
nn.BatchNorm2d(out_ch), |
|
|
505 |
nn.ELU(inplace=True) |
|
|
506 |
) |
|
|
507 |
|
|
|
508 |
def forward(self, input): |
|
|
509 |
return self.conv(input) |
|
|
510 |
|
|
|
511 |
class DoubleConv(nn.Module): |
|
|
512 |
def __init__(self, in_ch, out_ch): |
|
|
513 |
super(DoubleConv, self).__init__() |
|
|
514 |
self.conv = nn.Sequential( |
|
|
515 |
nn.Conv2d(in_ch, out_ch, kernel_size=(3, 3), padding=1), |
|
|
516 |
nn.BatchNorm2d(out_ch), |
|
|
517 |
nn.ELU(inplace=True), |
|
|
518 |
nn.Conv2d(out_ch, out_ch, kernel_size=(3, 3), padding=1), |
|
|
519 |
nn.BatchNorm2d(out_ch), |
|
|
520 |
# nn.ELU(inplace=True) |
|
|
521 |
) |
|
|
522 |
|
|
|
523 |
def forward(self, input): |
|
|
524 |
return self.conv(input) |
|
|
525 |
|
|
|
526 |
|
|
|
527 |
if __name__ == "__main__": |
|
|
528 |
x = torch.rand(3, 4, 2048) |
|
|
529 |
conditions = torch.rand(3, 2, 1) |
|
|
530 |
|
|
|
531 |
network = BetaVAE() |
|
|
532 |
y_coarse, y_detail = network(x, conditions) |
|
|
533 |
print(y_coarse.size(), y_detail.size()) |