Diff of /model.py [000000] .. [390c2f]

Switch to side-by-side view

--- a
+++ b/model.py
@@ -0,0 +1,533 @@
+from tkinter import Y
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.functional import relu
+
+from models.pointnet_utils import PointNetEncoder
+from models.pointnet2_utils import PointNetSetAbstraction,PointNetFeaturePropagation
+
+class ECGnet(nn.Module):
+    def __init__(self, in_ch=3+4, out_ch=3, num_input=1024, z_dims=16):
+        super(ECGnet, self).__init__()
+
+
+        self.encoder_signal = CRNN()
+
+        # decode for signal
+        self.elu = nn.ELU(inplace=True)
+        self.fc1 = nn.Linear(z_dims, 256*2)
+        self.fc2 = nn.Linear(256*2, 512*2)
+        self.up = nn.Upsample(size=(8, 512), mode='bilinear')
+        self.deconv = DoubleDeConv(1, 1)
+
+        self.decoder_MI = nn.Sequential(
+            nn.Linear(z_dims, 128),
+            nn.ReLU(),
+            nn.Linear(128, 64),
+            nn.ReLU(),
+            nn.Linear(64, out_ch),
+        )
+
+        
+    def reparameterize(self, mu, log_var):
+        """
+        :param mu: mean from the encoder's latent space
+        :param log_var: log variance from the encoder's latent space
+        """
+        std = torch.exp(0.5*log_var) # standard deviation
+        eps = torch.randn(log_var.shape).to(std.device) # `randn_like` as we need the same size
+        sample = mu + (eps * std) # sampling as if coming from the input space
+        return sample
+
+    def decode_signal(self, latent_z): # P(x|z, c)
+        '''
+        z: (bs, latent_size)
+        '''
+        inputs = latent_z
+        f = self.elu(self.fc1(inputs))
+        f = self.elu(self.fc2(f))
+        u = self.up(f.reshape(f.shape[0], 1, 8, -1))
+        dc = self.deconv(u)
+
+        return dc
+    
+    def forward(self, partial_input, signal_input):   
+
+        mu_signal, std_signal = self.encoder_signal(signal_input)
+        latent_z_signal = self.reparameterize(mu_signal, std_signal)
+        y_ECG = self.decode_signal(latent_z_signal)
+        y_MI = self.decoder_MI(latent_z_signal)
+        y_MI = nn.Softmax(dim=1)(y_MI)
+
+        return y_MI, y_ECG, mu_signal, std_signal
+
+class InferenceNet(nn.Module):
+    def __init__(self, in_ch=3+4, out_ch=3, num_input=1024, z_dims=16):
+        super(InferenceNet, self).__init__()
+
+        self.z_dims = z_dims
+
+        # PointNet++ Encoder
+        self.sa1 = PointNetSetAbstraction(npoint=num_input, radius=0.2, nsample=64, in_channel=in_ch, mlp=[64, 64, 128], group_all=False)
+        self.sa2 = PointNetSetAbstraction(128, 0.4, 64, 128 + 3, [128, 128, 256], False)
+        self.sa3 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 512, 1024], False)
+        self.fc11 = nn.Linear(1024*16, z_dims*2)
+
+        # PointNet++ Decoder
+        self.fc12 = nn.Linear(z_dims*2, 1024) # feat_ECG = H*feat_MI + epsilon
+        self.fp3 = PointNetFeaturePropagation(1280, [256, 256])
+        self.fp2 = PointNetFeaturePropagation(384, [256, 128])
+        self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128])
+        self.conv1 = nn.Conv1d(128, 128, 1)
+        self.bn1 = nn.BatchNorm1d(128)
+        self.drop1 = nn.Dropout(0.5)
+        self.conv2 = nn.Conv1d(128, out_ch, 1) 
+
+        self.decoder_geometry = BetaVAE_Decoder(num_input, num_input//4, in_ch, z_dims) # in_ch -> out_ch*3
+        
+        self.encoder_signal = CRNN()
+
+        # decode for signal
+        self.elu = nn.ELU(inplace=True)
+        self.fc1 = nn.Linear(z_dims, 256*2)
+        self.fc2 = nn.Linear(256*2, 512*2)
+        self.up = nn.Upsample(size=(8, 512), mode='bilinear')
+        self.deconv = DoubleDeConv(1, 1)
+
+    def reparameterize(self, mu, log_var):
+        """
+        :param mu: mean from the encoder's latent space
+        :param log_var: log variance from the encoder's latent space
+        """
+        std = torch.exp(0.5*log_var) # standard deviation
+        eps = torch.randn(log_var.shape).to(std.device) # `randn_like` as we need the same size
+        sample = mu + (eps * std) # sampling as if coming from the input space
+        return sample
+
+    def decode_signal(self, latent_z): # P(x|z, c)
+        '''
+        z: (bs, latent_size)
+        '''
+        inputs = latent_z
+        f = self.elu(self.fc1(inputs))
+        f = self.elu(self.fc2(f))
+        u = self.up(f.reshape(f.shape[0], 1, 8, -1))
+        dc = self.deconv(u)
+
+        return dc
+    
+    def forward(self, partial_input, signal_input):  
+        num_points = partial_input.shape[-1]
+        # extract ecg features
+        mu_signal, std_signal = self.encoder_signal(signal_input)
+        latent_z_signal = mu_signal # self.reparameterize(mu_signal, std_signal)
+
+        # extract point cloud features      
+        l0_xyz = partial_input[:,:3,:] 
+        l0_points = partial_input[:,3:,:] 
+        l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
+        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
+        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
+        features = self.fc11(l3_points.view(-1, 1024*16))
+        mu_geometry = features[:, : self.z_dims]
+        std_geometry = features[:, self.z_dims:] + 1e-6
+        latent_geometry = self.reparameterize(mu_signal, std_signal)
+        # latent_geometry = self.fc11(l3_points.view(-1, 1024*16))
+
+        # fuse two features 
+        # mu = torch.cat((mu_geometry, mu_signal), dim=1)
+        # log_var = torch.cat((std_geometry, std_signal), dim=1)
+        # latent_z = self.reparameterize(mu, log_var)
+        latent_z = torch.cat((latent_z_signal, latent_geometry), dim=1)
+
+        # segment point cloud
+        anatomy_signal_feat = F.relu(self.fc12(latent_z))
+        anatomy_signal_feat = anatomy_signal_feat.view(-1, 1024, 1).repeat(1, 1, num_points)      
+        l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, anatomy_signal_feat)
+        l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
+        l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points)
+        y_seg = self.drop1(F.relu(self.bn1(self.conv1(l0_points))))
+        y_seg = self.conv2(y_seg)   
+        y_seg = nn.Softmax(dim=1)(y_seg)     
+
+        # reconstruct point cloud and ecg
+        y_coarse, y_detail = self.decoder_geometry(latent_geometry)
+        y_coarse, y_detail = nn.Sigmoid()(y_coarse), nn.Sigmoid()(y_detail)  
+        y_ECG = self.decode_signal(latent_z_signal)
+
+        return y_seg, y_coarse, y_detail, y_ECG, mu_signal, std_signal
+
+class CRNN(nn.Module):
+    '''
+    nh: default=256, 'size of the LSTM hidden state'
+    imgH: default=8, 'the height of the input image to network'
+    imgW: default=256, 'the width of the input image to network'
+
+    :param class_labels: list[n_class]
+    :return: (n_batch, n_class)
+    '''
+
+    def __init__(self, n_lead=8, z_dims=16):
+        super(CRNN, self).__init__()
+
+        n_out = 128
+        self.z_dims = z_dims
+
+        self.cnn = nn.Sequential(
+            nn.Conv1d(n_lead, n_out, kernel_size=16, stride=2, padding=2),
+            nn.BatchNorm1d(n_out),
+            nn.LeakyReLU(0.2, inplace=True),
+            nn.Conv1d(n_out, n_out*2, kernel_size=16, stride=2, padding=2),
+            nn.BatchNorm1d(n_out*2),
+            nn.LeakyReLU(0.2, inplace=True)
+            )
+            
+
+        self.rnn = BidirectionalLSTM(256, z_dims*4, z_dims*2)
+        # self.rnn = nn.Sequential(
+        #     BidirectionalLSTM(512, nh, nh),
+        #     BidirectionalLSTM(nh, nh, 1))
+
+
+    def forward(self, input):
+        # conv features
+        conv = self.cnn(input)
+        b, c, w = conv.size()
+        conv = conv.permute(2, 0, 1)  # [w, b, c]
+
+        # rnn features
+        output = self.rnn(conv).permute(1, 0, 2)
+        features = torch.max(output, 1)[0]
+        mean = features[:, : self.z_dims]
+        std = features[:, self.z_dims:] + 1e-6
+    
+        return mean, std
+
+
+    def backward_hook(self, module, grad_input, grad_output):
+        for g in grad_input:
+            g[g != g] = 0   # replace all nan/inf in gradients to zero
+
+class BidirectionalLSTM(nn.Module):
+
+    def __init__(self, nIn, nHidden, nOut):
+        super(BidirectionalLSTM, self).__init__()
+
+        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
+        self.embedding = nn.Linear(nHidden * 2, nOut)
+
+    def forward(self, input):
+        recurrent, _ = self.rnn(input)
+        T, b, h = recurrent.size()
+        t_rec = recurrent.view(T * b, h)
+
+        output = self.embedding(t_rec)  # [T * b, nOut]
+        output = output.view(T, b, -1)
+
+        return output
+
+class PointNet(nn.Module):
+    def __init__(self, num_classes=10, n_signal=10, n_param=4, n_ECG=128):
+        super(PointNet, self).__init__()
+        self.k = num_classes
+        self.n_signal = n_signal
+        self.feat = PointNetEncoder(global_feat=False, feature_transform=True, channel=4)
+        self.conv1 = torch.nn.Conv1d(1024+64+n_ECG, 512, 1)
+        self.conv2 = torch.nn.Conv1d(512, 256, 1)
+        self.conv3 = torch.nn.Conv1d(256, 128, 1)
+        self.conv4 = torch.nn.Conv1d(128, self.k, 1)
+        self.bn1 = nn.BatchNorm1d(512)
+        self.bn2 = nn.BatchNorm1d(256)
+        self.bn3 = nn.BatchNorm1d(128)
+
+        self.ECG_model = CRNN()
+
+        self.inference_model = nn.Sequential(
+            nn.Linear(1024+n_ECG, 512),
+            nn.Dropout(0.5),
+            nn.ReLU(),
+            nn.Linear(512, 256),
+            nn.Dropout(0.5),
+            nn.ReLU(),
+            nn.Linear(256, self.n_signal*n_param),
+            nn.Sigmoid()
+            )
+
+
+    def forward(self, x, signal):
+        n_pts = x.size()[2]
+        anatomy_signal_feature, global_feature, trans_feat = self.feat(x)
+        ECG_feature = self.ECG_model(signal)
+        ECG_feature_extend = ECG_feature.repeat(1, 1, n_pts)
+        
+        anatomy_signal_feat = torch.cat([anatomy_signal_feature, ECG_feature_extend], 1)
+        y1 = F.relu(self.bn1(self.conv1(anatomy_signal_feat)))
+        y1 = F.relu(self.bn2(self.conv2(y1)))
+        y1 = F.relu(self.bn3(self.conv3(y1)))
+        y1 = self.conv4(y1)
+        y1 = y1.transpose(2,1).contiguous()
+        out_ATM = y1 #nn.Sigmoid()(y1)
+
+        return out_ATM
+
+class PointNet_plusplus(nn.Module):
+    def __init__(self, num_classes=10, n_signal=10, n_param=4, n_ECG=128):
+        super(PointNet_plusplus, self).__init__()
+        self.n_signal = n_signal
+        self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=64, in_channel= 3 + 4, mlp=[64, 64, 128], group_all=False)
+        self.sa2 = PointNetSetAbstraction(128, 0.4, 64, 128 + 3, [128, 128, 256], False)
+        self.sa3 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 512, 1024], False)
+        self.fp3 = PointNetFeaturePropagation(1280+n_ECG, [256, 256])
+        self.fp2 = PointNetFeaturePropagation(384, [256, 128])
+        self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128])
+        self.conv1 = nn.Conv1d(128, 128, 1)
+        self.bn1 = nn.BatchNorm1d(128)
+        self.drop1 = nn.Dropout(0.5)
+        self.conv2 = nn.Conv1d(128, num_classes, 1)
+
+        self.ECG_model = CRNN()
+        self.inference_model = nn.Sequential(
+            nn.Linear(1024+n_ECG, 512),
+            nn.Dropout(0.5),
+            nn.ReLU(),
+            nn.Linear(512, 256),
+            nn.Dropout(0.5),
+            nn.ReLU(),
+            nn.Linear(256, self.n_signal*n_param),
+            nn.Sigmoid())
+
+    def forward(self, x, signal):
+        l0_points = x
+        l0_xyz = x[:,:3,:] 
+
+        ECG_feature = self.ECG_model(signal)
+ 
+        # Set Abstraction layers
+        l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
+        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
+        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
+
+        ECG_feature_extend = ECG_feature.repeat(1, 1, l3_points.size()[2])  
+        anatomy_signal_feat = torch.cat([l3_points, ECG_feature_extend], 1)
+
+        # Feature Propagation layers
+        l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, anatomy_signal_feat)
+        l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
+        l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points)
+
+        y1 = self.drop1(F.relu(self.bn1(self.conv1(l0_points))))
+        y1 = self.conv2(y1)        
+        out_ATM = y1 #nn.Sigmoid()(y1)
+        out_ATM = out_ATM.permute(0, 2, 1)
+
+        return out_ATM
+
+class BetaVAE(nn.Module):
+    def __init__(self, in_ch=4, num_input=1024, num_class=2, z_dims=16):
+        super(BetaVAE, self).__init__()
+
+        self.encoder = BetaVAE_Encoder(in_ch, z_dims)
+        self.decoder = BetaVAE_Decoder_new(num_input, num_class)
+
+    def forward(self, x):
+        latent_z = self.encoder(x)
+        y = self.decoder(latent_z)
+        return y
+
+class BetaVAE_Encoder(nn.Module):
+    def __init__(self, in_ch, z_dims):
+        super(BetaVAE_Encoder, self).__init__()
+        self.z_dims = z_dims
+        self.mlp_conv1 = mlp_conv(in_ch, layer_dims=[128, 256])
+        self.mlp_conv2 = mlp_conv(512, layer_dims=[512, 1024])
+
+        self.fc1 = nn.Linear(1024, 1024)
+        self.fc2 = nn.Linear(1024, 256)
+        self.fc3 = nn.Linear(256, z_dims*2) 
+
+    def forward(self, inputs):
+        num_points = [inputs.shape[2]]
+        features = self.mlp_conv1(inputs)
+        features_global = point_maxpool(features, num_points, keepdim=True)
+        features_global = point_unpool(features_global, num_points)
+        features = torch.cat([features, features_global], dim=1)
+        features = self.mlp_conv2(features)
+        features = point_maxpool(features, num_points)
+
+        features = features.view(features.size()[0], -1)
+        features = self.fc1(features)
+        features = self.fc2(features)
+        features = self.fc3(features)
+        mean = features[:, : self.z_dims]
+        std = features[:, self.z_dims:] + 1e-6
+
+        return mean, std
+
+class BetaVAE_Decoder_new(nn.Module):
+        def __init__(self, num_input, num_class=2, z_dims=16*2):
+            super(BetaVAE_Decoder_new, self).__init__()
+            self.out_ch = num_class
+            self.n_pts = num_input 
+            self.mlp = mlp(in_channels=z_dims, layer_dims=[128, 256, 512, 1024,  self.n_pts * self.out_ch]) 
+
+        def forward(self, features):
+            y = self.mlp(features).reshape(-1, self.out_ch, self.n_pts)
+            
+            return nn.Softmax(dim=1)(y)
+
+class BetaVAE_Decoder_plus(nn.Module):
+        def __init__(self, num_dense, num_coarse, out_ch, z_dims):
+            super(BetaVAE_Decoder_plus, self).__init__()
+            self.out_ch = out_ch
+            self.num_coarse = num_coarse
+            self.grid_size = int(np.sqrt(num_dense//num_coarse))
+            self.num_fine = num_dense
+
+            # PointNet++ Decoder
+            self.fc12 = nn.Linear(z_dims*2, 1024)
+            self.fp3 = PointNetFeaturePropagation(1280, [256, 256])
+            self.fp2 = PointNetFeaturePropagation(384, [256, 128])
+            self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128])
+            self.conv1 = nn.Conv1d(128, 128, 1)
+            self.bn1 = nn.BatchNorm1d(128)
+            self.drop1 = nn.Dropout(0.5)
+            self.conv2 = nn.Conv1d(128, out_ch, 1) 
+
+
+        def forward(self, latent_z, l0_xyz, l1_xyz, l2_xyz, l3_xyz):
+            anatomy_signal_feat = F.relu(self.fc12(latent_z))
+            coarse = anatomy_signal_feat.view(-1, 1024, 1).repeat(1, 1, self.num_coarse)      
+            l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, coarse)
+            l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
+            l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points)
+            fine = self.drop1(F.relu(self.bn1(self.conv1(l0_points))))
+            fine = self.conv2(fine)   
+
+            return coarse, fine
+
+class BetaVAE_Decoder(nn.Module):
+        def __init__(self, num_dense, num_coarse, out_ch, z_dims):
+            super(BetaVAE_Decoder, self).__init__()
+            self.out_ch = out_ch
+            self.num_coarse = num_coarse
+            self.grid_size = int(np.sqrt(num_dense//num_coarse))
+            self.num_fine = num_dense
+
+            self.mlp = mlp(in_channels=z_dims, layer_dims=[256, 512, 1024, 2048,  self.num_coarse * self.out_ch])
+            x = torch.linspace(-0.05, 0.05, self.grid_size)
+            y = torch.linspace(-0.05, 0.05, self.grid_size)
+            self.grid = torch.cat(torch.meshgrid(x, y), dim=0).view(1, 2, self.grid_size ** 2)
+            # self.grid = torch.stack(torch.meshgrid(x, y), dim=2)
+            # self.grid = torch.reshape(self.grid.transpose(1, 0), [-1, 2]).unsqueeze(0)
+
+            self.mlp_conv3 = mlp_conv(z_dims+2+out_ch, layer_dims=[512, 512, out_ch]) # here "+2" refers  to the two axes of grid
+
+        def forward(self, latent_z):
+            features = latent_z
+            coarse = self.mlp(features).reshape(-1, self.num_coarse, self.out_ch)
+            point_feat = coarse.unsqueeze(2).repeat(1, 1, self.grid_size * 2, 1)
+            point_feat = point_feat.reshape(-1, self.out_ch, self.num_fine)
+
+            grid_feat = self.grid.unsqueeze(2).repeat(features.shape[0], 1, self.num_coarse, 1).to(features.device)
+            grid_feat = grid_feat.reshape(features.shape[0], -1, self.num_fine) 
+            global_feat = features.unsqueeze(2).repeat(1, 1, self.num_fine)
+            feat = torch.cat([grid_feat, point_feat, global_feat], dim=1)
+          
+            center = point_feat.reshape(-1, self.num_fine, self.out_ch)
+            fine = self.mlp_conv3(feat).transpose(1, 2) + center
+
+            return coarse, fine
+
+def point_maxpool(features, npts, keepdim=True):
+    splitted = torch.split(features, npts[0], dim=1)
+    outputs = [torch.max(f, dim=2, keepdim=keepdim)[0] for f in splitted] # modified by Lei in 2022/02/10
+    return torch.cat(outputs, dim=0)
+    # return torch.max(features, dim=2, keepdims=keepdims)[0]
+
+def point_unpool(features, npts):
+    features = torch.split(features, features.shape[0], dim=0)
+    outputs = [f.repeat(1, 1, npts[i]) for i, f in enumerate(features)]
+    # outputs = [torch.tile(f, [1, 1, npts[i]]) for i, f in enumerate(features)]
+    return torch.cat(outputs, dim=0)
+    # return features.repeat([1, 1, 256])
+
+class mlp_conv(nn.Module):
+    def __init__(self, in_channels, layer_dims):
+        super(mlp_conv, self).__init__()
+        self.layer_dims = layer_dims
+        for i, out_channels in enumerate(self.layer_dims):
+            layer = torch.nn.Conv1d(in_channels, out_channels, kernel_size=1)
+            setattr(self, 'conv_' + str(i), layer)
+            in_channels = out_channels
+
+    def __call__(self, inputs):
+        outputs = inputs
+        dims = len(self.layer_dims)
+        for i in range(dims):
+            layer = getattr(self, 'conv_' + str(i))
+            if i == dims - 1:
+                outputs = layer(outputs)
+            else:
+                outputs = relu(layer(outputs))
+        return outputs
+
+class mlp(nn.Module):
+    def __init__(self, in_channels, layer_dims):
+        super(mlp, self).__init__()
+        self.layer_dims = layer_dims
+        for i, out_channels in enumerate(layer_dims):
+            layer = torch.nn.Linear(in_channels, out_channels)
+            setattr(self, 'fc_' + str(i), layer)
+            in_channels = out_channels
+
+    def __call__(self, inputs):
+        outputs = inputs
+        dims = len(self.layer_dims)
+        for i in range(dims):
+            layer = getattr(self, 'fc_' + str(i))
+            if i == dims - 1:
+                outputs = layer(outputs)
+            else:
+                outputs = relu(layer(outputs))
+        return outputs
+
+class DoubleDeConv(nn.Module):
+    def __init__(self, in_ch, out_ch):
+        super(DoubleDeConv, self).__init__()
+        self.conv = nn.Sequential(
+            nn.ConvTranspose2d(in_ch, out_ch, kernel_size=(3, 3), padding=1),
+            nn.BatchNorm2d(out_ch),
+            nn.ELU(inplace=True),
+            nn.ConvTranspose2d(out_ch, out_ch, kernel_size=(3, 3), padding=1),
+            nn.BatchNorm2d(out_ch),
+            nn.ELU(inplace=True)
+        )
+
+    def forward(self, input):
+        return self.conv(input)
+
+class DoubleConv(nn.Module):
+    def __init__(self, in_ch, out_ch):
+        super(DoubleConv, self).__init__()
+        self.conv = nn.Sequential(
+            nn.Conv2d(in_ch, out_ch, kernel_size=(3, 3), padding=1),
+            nn.BatchNorm2d(out_ch),
+            nn.ELU(inplace=True),
+            nn.Conv2d(out_ch, out_ch, kernel_size=(3, 3), padding=1),
+            nn.BatchNorm2d(out_ch),
+            # nn.ELU(inplace=True)
+        )
+
+    def forward(self, input):
+        return self.conv(input)
+
+
+if __name__ == "__main__":
+    x = torch.rand(3, 4, 2048)
+    conditions = torch.rand(3, 2, 1)
+
+    network = BetaVAE()
+    y_coarse, y_detail = network(x, conditions)
+    print(y_coarse.size(), y_detail.size())