a b/model.py
1
from tkinter import Y
2
import numpy as np
3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6
from torch.nn.functional import relu
7
8
from models.pointnet_utils import PointNetEncoder
9
from models.pointnet2_utils import PointNetSetAbstraction,PointNetFeaturePropagation
10
11
class ECGnet(nn.Module):
12
    def __init__(self, in_ch=3+4, out_ch=3, num_input=1024, z_dims=16):
13
        super(ECGnet, self).__init__()
14
15
16
        self.encoder_signal = CRNN()
17
18
        # decode for signal
19
        self.elu = nn.ELU(inplace=True)
20
        self.fc1 = nn.Linear(z_dims, 256*2)
21
        self.fc2 = nn.Linear(256*2, 512*2)
22
        self.up = nn.Upsample(size=(8, 512), mode='bilinear')
23
        self.deconv = DoubleDeConv(1, 1)
24
25
        self.decoder_MI = nn.Sequential(
26
            nn.Linear(z_dims, 128),
27
            nn.ReLU(),
28
            nn.Linear(128, 64),
29
            nn.ReLU(),
30
            nn.Linear(64, out_ch),
31
        )
32
33
        
34
    def reparameterize(self, mu, log_var):
35
        """
36
        :param mu: mean from the encoder's latent space
37
        :param log_var: log variance from the encoder's latent space
38
        """
39
        std = torch.exp(0.5*log_var) # standard deviation
40
        eps = torch.randn(log_var.shape).to(std.device) # `randn_like` as we need the same size
41
        sample = mu + (eps * std) # sampling as if coming from the input space
42
        return sample
43
44
    def decode_signal(self, latent_z): # P(x|z, c)
45
        '''
46
        z: (bs, latent_size)
47
        '''
48
        inputs = latent_z
49
        f = self.elu(self.fc1(inputs))
50
        f = self.elu(self.fc2(f))
51
        u = self.up(f.reshape(f.shape[0], 1, 8, -1))
52
        dc = self.deconv(u)
53
54
        return dc
55
    
56
    def forward(self, partial_input, signal_input):   
57
58
        mu_signal, std_signal = self.encoder_signal(signal_input)
59
        latent_z_signal = self.reparameterize(mu_signal, std_signal)
60
        y_ECG = self.decode_signal(latent_z_signal)
61
        y_MI = self.decoder_MI(latent_z_signal)
62
        y_MI = nn.Softmax(dim=1)(y_MI)
63
64
        return y_MI, y_ECG, mu_signal, std_signal
65
66
class InferenceNet(nn.Module):
67
    def __init__(self, in_ch=3+4, out_ch=3, num_input=1024, z_dims=16):
68
        super(InferenceNet, self).__init__()
69
70
        self.z_dims = z_dims
71
72
        # PointNet++ Encoder
73
        self.sa1 = PointNetSetAbstraction(npoint=num_input, radius=0.2, nsample=64, in_channel=in_ch, mlp=[64, 64, 128], group_all=False)
74
        self.sa2 = PointNetSetAbstraction(128, 0.4, 64, 128 + 3, [128, 128, 256], False)
75
        self.sa3 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 512, 1024], False)
76
        self.fc11 = nn.Linear(1024*16, z_dims*2)
77
78
        # PointNet++ Decoder
79
        self.fc12 = nn.Linear(z_dims*2, 1024) # feat_ECG = H*feat_MI + epsilon
80
        self.fp3 = PointNetFeaturePropagation(1280, [256, 256])
81
        self.fp2 = PointNetFeaturePropagation(384, [256, 128])
82
        self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128])
83
        self.conv1 = nn.Conv1d(128, 128, 1)
84
        self.bn1 = nn.BatchNorm1d(128)
85
        self.drop1 = nn.Dropout(0.5)
86
        self.conv2 = nn.Conv1d(128, out_ch, 1) 
87
88
        self.decoder_geometry = BetaVAE_Decoder(num_input, num_input//4, in_ch, z_dims) # in_ch -> out_ch*3
89
        
90
        self.encoder_signal = CRNN()
91
92
        # decode for signal
93
        self.elu = nn.ELU(inplace=True)
94
        self.fc1 = nn.Linear(z_dims, 256*2)
95
        self.fc2 = nn.Linear(256*2, 512*2)
96
        self.up = nn.Upsample(size=(8, 512), mode='bilinear')
97
        self.deconv = DoubleDeConv(1, 1)
98
99
    def reparameterize(self, mu, log_var):
100
        """
101
        :param mu: mean from the encoder's latent space
102
        :param log_var: log variance from the encoder's latent space
103
        """
104
        std = torch.exp(0.5*log_var) # standard deviation
105
        eps = torch.randn(log_var.shape).to(std.device) # `randn_like` as we need the same size
106
        sample = mu + (eps * std) # sampling as if coming from the input space
107
        return sample
108
109
    def decode_signal(self, latent_z): # P(x|z, c)
110
        '''
111
        z: (bs, latent_size)
112
        '''
113
        inputs = latent_z
114
        f = self.elu(self.fc1(inputs))
115
        f = self.elu(self.fc2(f))
116
        u = self.up(f.reshape(f.shape[0], 1, 8, -1))
117
        dc = self.deconv(u)
118
119
        return dc
120
    
121
    def forward(self, partial_input, signal_input):  
122
        num_points = partial_input.shape[-1]
123
        # extract ecg features
124
        mu_signal, std_signal = self.encoder_signal(signal_input)
125
        latent_z_signal = mu_signal # self.reparameterize(mu_signal, std_signal)
126
127
        # extract point cloud features      
128
        l0_xyz = partial_input[:,:3,:] 
129
        l0_points = partial_input[:,3:,:] 
130
        l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
131
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
132
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
133
        features = self.fc11(l3_points.view(-1, 1024*16))
134
        mu_geometry = features[:, : self.z_dims]
135
        std_geometry = features[:, self.z_dims:] + 1e-6
136
        latent_geometry = self.reparameterize(mu_signal, std_signal)
137
        # latent_geometry = self.fc11(l3_points.view(-1, 1024*16))
138
139
        # fuse two features 
140
        # mu = torch.cat((mu_geometry, mu_signal), dim=1)
141
        # log_var = torch.cat((std_geometry, std_signal), dim=1)
142
        # latent_z = self.reparameterize(mu, log_var)
143
        latent_z = torch.cat((latent_z_signal, latent_geometry), dim=1)
144
145
        # segment point cloud
146
        anatomy_signal_feat = F.relu(self.fc12(latent_z))
147
        anatomy_signal_feat = anatomy_signal_feat.view(-1, 1024, 1).repeat(1, 1, num_points)      
148
        l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, anatomy_signal_feat)
149
        l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
150
        l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points)
151
        y_seg = self.drop1(F.relu(self.bn1(self.conv1(l0_points))))
152
        y_seg = self.conv2(y_seg)   
153
        y_seg = nn.Softmax(dim=1)(y_seg)     
154
155
        # reconstruct point cloud and ecg
156
        y_coarse, y_detail = self.decoder_geometry(latent_geometry)
157
        y_coarse, y_detail = nn.Sigmoid()(y_coarse), nn.Sigmoid()(y_detail)  
158
        y_ECG = self.decode_signal(latent_z_signal)
159
160
        return y_seg, y_coarse, y_detail, y_ECG, mu_signal, std_signal
161
162
class CRNN(nn.Module):
163
    '''
164
    nh: default=256, 'size of the LSTM hidden state'
165
    imgH: default=8, 'the height of the input image to network'
166
    imgW: default=256, 'the width of the input image to network'
167
168
    :param class_labels: list[n_class]
169
    :return: (n_batch, n_class)
170
    '''
171
172
    def __init__(self, n_lead=8, z_dims=16):
173
        super(CRNN, self).__init__()
174
175
        n_out = 128
176
        self.z_dims = z_dims
177
178
        self.cnn = nn.Sequential(
179
            nn.Conv1d(n_lead, n_out, kernel_size=16, stride=2, padding=2),
180
            nn.BatchNorm1d(n_out),
181
            nn.LeakyReLU(0.2, inplace=True),
182
            nn.Conv1d(n_out, n_out*2, kernel_size=16, stride=2, padding=2),
183
            nn.BatchNorm1d(n_out*2),
184
            nn.LeakyReLU(0.2, inplace=True)
185
            )
186
            
187
188
        self.rnn = BidirectionalLSTM(256, z_dims*4, z_dims*2)
189
        # self.rnn = nn.Sequential(
190
        #     BidirectionalLSTM(512, nh, nh),
191
        #     BidirectionalLSTM(nh, nh, 1))
192
193
194
    def forward(self, input):
195
        # conv features
196
        conv = self.cnn(input)
197
        b, c, w = conv.size()
198
        conv = conv.permute(2, 0, 1)  # [w, b, c]
199
200
        # rnn features
201
        output = self.rnn(conv).permute(1, 0, 2)
202
        features = torch.max(output, 1)[0]
203
        mean = features[:, : self.z_dims]
204
        std = features[:, self.z_dims:] + 1e-6
205
    
206
        return mean, std
207
208
209
    def backward_hook(self, module, grad_input, grad_output):
210
        for g in grad_input:
211
            g[g != g] = 0   # replace all nan/inf in gradients to zero
212
213
class BidirectionalLSTM(nn.Module):
214
215
    def __init__(self, nIn, nHidden, nOut):
216
        super(BidirectionalLSTM, self).__init__()
217
218
        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
219
        self.embedding = nn.Linear(nHidden * 2, nOut)
220
221
    def forward(self, input):
222
        recurrent, _ = self.rnn(input)
223
        T, b, h = recurrent.size()
224
        t_rec = recurrent.view(T * b, h)
225
226
        output = self.embedding(t_rec)  # [T * b, nOut]
227
        output = output.view(T, b, -1)
228
229
        return output
230
231
class PointNet(nn.Module):
232
    def __init__(self, num_classes=10, n_signal=10, n_param=4, n_ECG=128):
233
        super(PointNet, self).__init__()
234
        self.k = num_classes
235
        self.n_signal = n_signal
236
        self.feat = PointNetEncoder(global_feat=False, feature_transform=True, channel=4)
237
        self.conv1 = torch.nn.Conv1d(1024+64+n_ECG, 512, 1)
238
        self.conv2 = torch.nn.Conv1d(512, 256, 1)
239
        self.conv3 = torch.nn.Conv1d(256, 128, 1)
240
        self.conv4 = torch.nn.Conv1d(128, self.k, 1)
241
        self.bn1 = nn.BatchNorm1d(512)
242
        self.bn2 = nn.BatchNorm1d(256)
243
        self.bn3 = nn.BatchNorm1d(128)
244
245
        self.ECG_model = CRNN()
246
247
        self.inference_model = nn.Sequential(
248
            nn.Linear(1024+n_ECG, 512),
249
            nn.Dropout(0.5),
250
            nn.ReLU(),
251
            nn.Linear(512, 256),
252
            nn.Dropout(0.5),
253
            nn.ReLU(),
254
            nn.Linear(256, self.n_signal*n_param),
255
            nn.Sigmoid()
256
            )
257
258
259
    def forward(self, x, signal):
260
        n_pts = x.size()[2]
261
        anatomy_signal_feature, global_feature, trans_feat = self.feat(x)
262
        ECG_feature = self.ECG_model(signal)
263
        ECG_feature_extend = ECG_feature.repeat(1, 1, n_pts)
264
        
265
        anatomy_signal_feat = torch.cat([anatomy_signal_feature, ECG_feature_extend], 1)
266
        y1 = F.relu(self.bn1(self.conv1(anatomy_signal_feat)))
267
        y1 = F.relu(self.bn2(self.conv2(y1)))
268
        y1 = F.relu(self.bn3(self.conv3(y1)))
269
        y1 = self.conv4(y1)
270
        y1 = y1.transpose(2,1).contiguous()
271
        out_ATM = y1 #nn.Sigmoid()(y1)
272
273
        return out_ATM
274
275
class PointNet_plusplus(nn.Module):
276
    def __init__(self, num_classes=10, n_signal=10, n_param=4, n_ECG=128):
277
        super(PointNet_plusplus, self).__init__()
278
        self.n_signal = n_signal
279
        self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=64, in_channel= 3 + 4, mlp=[64, 64, 128], group_all=False)
280
        self.sa2 = PointNetSetAbstraction(128, 0.4, 64, 128 + 3, [128, 128, 256], False)
281
        self.sa3 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 512, 1024], False)
282
        self.fp3 = PointNetFeaturePropagation(1280+n_ECG, [256, 256])
283
        self.fp2 = PointNetFeaturePropagation(384, [256, 128])
284
        self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128])
285
        self.conv1 = nn.Conv1d(128, 128, 1)
286
        self.bn1 = nn.BatchNorm1d(128)
287
        self.drop1 = nn.Dropout(0.5)
288
        self.conv2 = nn.Conv1d(128, num_classes, 1)
289
290
        self.ECG_model = CRNN()
291
        self.inference_model = nn.Sequential(
292
            nn.Linear(1024+n_ECG, 512),
293
            nn.Dropout(0.5),
294
            nn.ReLU(),
295
            nn.Linear(512, 256),
296
            nn.Dropout(0.5),
297
            nn.ReLU(),
298
            nn.Linear(256, self.n_signal*n_param),
299
            nn.Sigmoid())
300
301
    def forward(self, x, signal):
302
        l0_points = x
303
        l0_xyz = x[:,:3,:] 
304
305
        ECG_feature = self.ECG_model(signal)
306
 
307
        # Set Abstraction layers
308
        l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
309
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
310
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
311
312
        ECG_feature_extend = ECG_feature.repeat(1, 1, l3_points.size()[2])  
313
        anatomy_signal_feat = torch.cat([l3_points, ECG_feature_extend], 1)
314
315
        # Feature Propagation layers
316
        l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, anatomy_signal_feat)
317
        l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
318
        l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points)
319
320
        y1 = self.drop1(F.relu(self.bn1(self.conv1(l0_points))))
321
        y1 = self.conv2(y1)        
322
        out_ATM = y1 #nn.Sigmoid()(y1)
323
        out_ATM = out_ATM.permute(0, 2, 1)
324
325
        return out_ATM
326
327
class BetaVAE(nn.Module):
328
    def __init__(self, in_ch=4, num_input=1024, num_class=2, z_dims=16):
329
        super(BetaVAE, self).__init__()
330
331
        self.encoder = BetaVAE_Encoder(in_ch, z_dims)
332
        self.decoder = BetaVAE_Decoder_new(num_input, num_class)
333
334
    def forward(self, x):
335
        latent_z = self.encoder(x)
336
        y = self.decoder(latent_z)
337
        return y
338
339
class BetaVAE_Encoder(nn.Module):
340
    def __init__(self, in_ch, z_dims):
341
        super(BetaVAE_Encoder, self).__init__()
342
        self.z_dims = z_dims
343
        self.mlp_conv1 = mlp_conv(in_ch, layer_dims=[128, 256])
344
        self.mlp_conv2 = mlp_conv(512, layer_dims=[512, 1024])
345
346
        self.fc1 = nn.Linear(1024, 1024)
347
        self.fc2 = nn.Linear(1024, 256)
348
        self.fc3 = nn.Linear(256, z_dims*2) 
349
350
    def forward(self, inputs):
351
        num_points = [inputs.shape[2]]
352
        features = self.mlp_conv1(inputs)
353
        features_global = point_maxpool(features, num_points, keepdim=True)
354
        features_global = point_unpool(features_global, num_points)
355
        features = torch.cat([features, features_global], dim=1)
356
        features = self.mlp_conv2(features)
357
        features = point_maxpool(features, num_points)
358
359
        features = features.view(features.size()[0], -1)
360
        features = self.fc1(features)
361
        features = self.fc2(features)
362
        features = self.fc3(features)
363
        mean = features[:, : self.z_dims]
364
        std = features[:, self.z_dims:] + 1e-6
365
366
        return mean, std
367
368
class BetaVAE_Decoder_new(nn.Module):
369
        def __init__(self, num_input, num_class=2, z_dims=16*2):
370
            super(BetaVAE_Decoder_new, self).__init__()
371
            self.out_ch = num_class
372
            self.n_pts = num_input 
373
            self.mlp = mlp(in_channels=z_dims, layer_dims=[128, 256, 512, 1024,  self.n_pts * self.out_ch]) 
374
375
        def forward(self, features):
376
            y = self.mlp(features).reshape(-1, self.out_ch, self.n_pts)
377
            
378
            return nn.Softmax(dim=1)(y)
379
380
class BetaVAE_Decoder_plus(nn.Module):
381
        def __init__(self, num_dense, num_coarse, out_ch, z_dims):
382
            super(BetaVAE_Decoder_plus, self).__init__()
383
            self.out_ch = out_ch
384
            self.num_coarse = num_coarse
385
            self.grid_size = int(np.sqrt(num_dense//num_coarse))
386
            self.num_fine = num_dense
387
388
            # PointNet++ Decoder
389
            self.fc12 = nn.Linear(z_dims*2, 1024)
390
            self.fp3 = PointNetFeaturePropagation(1280, [256, 256])
391
            self.fp2 = PointNetFeaturePropagation(384, [256, 128])
392
            self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128])
393
            self.conv1 = nn.Conv1d(128, 128, 1)
394
            self.bn1 = nn.BatchNorm1d(128)
395
            self.drop1 = nn.Dropout(0.5)
396
            self.conv2 = nn.Conv1d(128, out_ch, 1) 
397
398
399
        def forward(self, latent_z, l0_xyz, l1_xyz, l2_xyz, l3_xyz):
400
            anatomy_signal_feat = F.relu(self.fc12(latent_z))
401
            coarse = anatomy_signal_feat.view(-1, 1024, 1).repeat(1, 1, self.num_coarse)      
402
            l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, coarse)
403
            l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
404
            l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points)
405
            fine = self.drop1(F.relu(self.bn1(self.conv1(l0_points))))
406
            fine = self.conv2(fine)   
407
408
            return coarse, fine
409
410
class BetaVAE_Decoder(nn.Module):
411
        def __init__(self, num_dense, num_coarse, out_ch, z_dims):
412
            super(BetaVAE_Decoder, self).__init__()
413
            self.out_ch = out_ch
414
            self.num_coarse = num_coarse
415
            self.grid_size = int(np.sqrt(num_dense//num_coarse))
416
            self.num_fine = num_dense
417
418
            self.mlp = mlp(in_channels=z_dims, layer_dims=[256, 512, 1024, 2048,  self.num_coarse * self.out_ch])
419
            x = torch.linspace(-0.05, 0.05, self.grid_size)
420
            y = torch.linspace(-0.05, 0.05, self.grid_size)
421
            self.grid = torch.cat(torch.meshgrid(x, y), dim=0).view(1, 2, self.grid_size ** 2)
422
            # self.grid = torch.stack(torch.meshgrid(x, y), dim=2)
423
            # self.grid = torch.reshape(self.grid.transpose(1, 0), [-1, 2]).unsqueeze(0)
424
425
            self.mlp_conv3 = mlp_conv(z_dims+2+out_ch, layer_dims=[512, 512, out_ch]) # here "+2" refers  to the two axes of grid
426
427
        def forward(self, latent_z):
428
            features = latent_z
429
            coarse = self.mlp(features).reshape(-1, self.num_coarse, self.out_ch)
430
            point_feat = coarse.unsqueeze(2).repeat(1, 1, self.grid_size * 2, 1)
431
            point_feat = point_feat.reshape(-1, self.out_ch, self.num_fine)
432
433
            grid_feat = self.grid.unsqueeze(2).repeat(features.shape[0], 1, self.num_coarse, 1).to(features.device)
434
            grid_feat = grid_feat.reshape(features.shape[0], -1, self.num_fine) 
435
            global_feat = features.unsqueeze(2).repeat(1, 1, self.num_fine)
436
            feat = torch.cat([grid_feat, point_feat, global_feat], dim=1)
437
          
438
            center = point_feat.reshape(-1, self.num_fine, self.out_ch)
439
            fine = self.mlp_conv3(feat).transpose(1, 2) + center
440
441
            return coarse, fine
442
443
def point_maxpool(features, npts, keepdim=True):
444
    splitted = torch.split(features, npts[0], dim=1)
445
    outputs = [torch.max(f, dim=2, keepdim=keepdim)[0] for f in splitted] # modified by Lei in 2022/02/10
446
    return torch.cat(outputs, dim=0)
447
    # return torch.max(features, dim=2, keepdims=keepdims)[0]
448
449
def point_unpool(features, npts):
450
    features = torch.split(features, features.shape[0], dim=0)
451
    outputs = [f.repeat(1, 1, npts[i]) for i, f in enumerate(features)]
452
    # outputs = [torch.tile(f, [1, 1, npts[i]]) for i, f in enumerate(features)]
453
    return torch.cat(outputs, dim=0)
454
    # return features.repeat([1, 1, 256])
455
456
class mlp_conv(nn.Module):
457
    def __init__(self, in_channels, layer_dims):
458
        super(mlp_conv, self).__init__()
459
        self.layer_dims = layer_dims
460
        for i, out_channels in enumerate(self.layer_dims):
461
            layer = torch.nn.Conv1d(in_channels, out_channels, kernel_size=1)
462
            setattr(self, 'conv_' + str(i), layer)
463
            in_channels = out_channels
464
465
    def __call__(self, inputs):
466
        outputs = inputs
467
        dims = len(self.layer_dims)
468
        for i in range(dims):
469
            layer = getattr(self, 'conv_' + str(i))
470
            if i == dims - 1:
471
                outputs = layer(outputs)
472
            else:
473
                outputs = relu(layer(outputs))
474
        return outputs
475
476
class mlp(nn.Module):
477
    def __init__(self, in_channels, layer_dims):
478
        super(mlp, self).__init__()
479
        self.layer_dims = layer_dims
480
        for i, out_channels in enumerate(layer_dims):
481
            layer = torch.nn.Linear(in_channels, out_channels)
482
            setattr(self, 'fc_' + str(i), layer)
483
            in_channels = out_channels
484
485
    def __call__(self, inputs):
486
        outputs = inputs
487
        dims = len(self.layer_dims)
488
        for i in range(dims):
489
            layer = getattr(self, 'fc_' + str(i))
490
            if i == dims - 1:
491
                outputs = layer(outputs)
492
            else:
493
                outputs = relu(layer(outputs))
494
        return outputs
495
496
class DoubleDeConv(nn.Module):
497
    def __init__(self, in_ch, out_ch):
498
        super(DoubleDeConv, self).__init__()
499
        self.conv = nn.Sequential(
500
            nn.ConvTranspose2d(in_ch, out_ch, kernel_size=(3, 3), padding=1),
501
            nn.BatchNorm2d(out_ch),
502
            nn.ELU(inplace=True),
503
            nn.ConvTranspose2d(out_ch, out_ch, kernel_size=(3, 3), padding=1),
504
            nn.BatchNorm2d(out_ch),
505
            nn.ELU(inplace=True)
506
        )
507
508
    def forward(self, input):
509
        return self.conv(input)
510
511
class DoubleConv(nn.Module):
512
    def __init__(self, in_ch, out_ch):
513
        super(DoubleConv, self).__init__()
514
        self.conv = nn.Sequential(
515
            nn.Conv2d(in_ch, out_ch, kernel_size=(3, 3), padding=1),
516
            nn.BatchNorm2d(out_ch),
517
            nn.ELU(inplace=True),
518
            nn.Conv2d(out_ch, out_ch, kernel_size=(3, 3), padding=1),
519
            nn.BatchNorm2d(out_ch),
520
            # nn.ELU(inplace=True)
521
        )
522
523
    def forward(self, input):
524
        return self.conv(input)
525
526
527
if __name__ == "__main__":
528
    x = torch.rand(3, 4, 2048)
529
    conditions = torch.rand(3, 2, 1)
530
531
    network = BetaVAE()
532
    y_coarse, y_detail = network(x, conditions)
533
    print(y_coarse.size(), y_detail.size())