Diff of /utils.py [000000] .. [390c2f]

Switch to side-by-side view

--- a
+++ b/utils.py
@@ -0,0 +1,366 @@
+import numpy as np
+import pandas as pd
+from matplotlib import pyplot as plt
+import plotly.express as px
+from mpl_toolkits.mplot3d import Axes3D
+import torch
+import torch.nn as nn
+EPS = 1e-4
+
+class ProductOfExperts(nn.Module):
+    """Return parameters for product of independent experts.
+    See https://arxiv.org/pdf/1410.7827.pdf for equations.
+
+    Args:
+    mu (torch.Tensor): Mean of experts distribution. M x D for M experts
+    logvar (torch.Tensor): Log of variance of experts distribution. M x D for M experts
+    """
+
+    def forward(self, mu, logvar):
+        var = torch.exp(logvar) + EPS
+        T = 1. / (var + EPS)
+        pd_mu = torch.sum(mu * T, dim=0) / torch.sum(T, dim=0)
+        pd_var = 1. / torch.sum(T, dim=0)
+        pd_logvar = torch.log(pd_var + EPS)
+
+        return pd_mu, pd_logvar
+
+class alphaProductOfExperts(nn.Module):
+    """Return parameters for weighted product of independent experts (mmJSD implementation).
+    See https://arxiv.org/pdf/1410.7827.pdf for equations.
+
+    Args:
+    mu (torch.Tensor): Mean of experts distribution. M x D for M experts
+    logvar (torch.Tensor): Log of variance of experts distribution. M x D for M experts
+    """
+
+    def forward(self, mu, logvar, weights=None):
+        if weights is None:
+            num_components = mu.shape[0]
+            weights = (1/num_components) * torch.ones(mu.shape).to(mu.device)
+    
+        var = torch.exp(logvar) + EPS
+        T = 1. / (var + EPS)
+        weights = torch.broadcast_to(weights, mu.shape)
+        pd_var = 1. / torch.sum(weights * T + EPS, dim=0)
+        pd_mu = pd_var * torch.sum(weights * mu * T, dim=0)
+        pd_logvar = torch.log(pd_var + EPS)
+        
+        return pd_mu, pd_logvar
+    
+class weightedProductOfExperts(nn.Module):
+    """Return parameters for weighted product of independent experts.
+    See https://arxiv.org/pdf/1410.7827.pdf for equations.
+
+    Args:
+    mu (torch.Tensor): Mean of experts distribution. M x D for M experts
+    logvar (torch.Tensor): Log of variance of experts distribution. M x D for M experts
+    """
+
+    def forward(self, mu, logvar, weight):
+
+        var = torch.exp(logvar) + EPS     
+        weight = weight[:, None, :].repeat(1, mu.shape[1],1)
+        T = 1.0 / (var + EPS)
+        pd_var = 1. / torch.sum(weight * T + EPS, dim=0)
+        pd_mu = pd_var * torch.sum(weight * mu * T, dim=0)
+        pd_logvar = torch.log(pd_var + EPS)
+        return pd_mu, pd_logvar
+
+class MixtureOfExperts(nn.Module):
+    """Return parameters for mixture of independent experts.
+    Implementation from: https://github.com/thomassutter/MoPoE
+
+    Args:
+    mus (torch.Tensor): Mean of experts distribution. M x D for M experts
+    logvars (torch.Tensor): Log of variance of experts distribution. M x D for M experts
+    """
+
+    def forward(self, mus, logvars):
+
+        num_components = mus.shape[0]
+        num_samples = mus.shape[1]
+        weights = (1/num_components) * torch.ones(num_components).to(mus[0].device)
+        idx_start = []
+        idx_end = []
+        for k in range(0, num_components):
+            if k == 0:
+                i_start = 0
+            else:
+                i_start = int(idx_end[k-1])
+            if k == num_components-1:
+                i_end = num_samples
+            else:
+                i_end = i_start + int(torch.floor(num_samples*weights[k]))
+            idx_start.append(i_start)
+            idx_end.append(i_end)
+        idx_end[-1] = num_samples
+
+        mu_sel = torch.cat([mus[k, idx_start[k]:idx_end[k], :] for k in range(num_components)])
+        logvar_sel = torch.cat([logvars[k, idx_start[k]:idx_end[k], :] for k in range(num_components)])
+
+        return mu_sel, logvar_sel
+
+class MeanRepresentation(nn.Module):
+    """Return mean of separate VAE representations.
+    
+    Args:
+    mu (torch.Tensor): Mean of distributions. M x D for M views.
+    logvar (torch.Tensor): Log of Variance of distributions. M x D for M views.
+    """
+
+    def forward(self, mu, logvar):
+        mean_mu = torch.mean(mu, axis=0)
+        mean_logvar = torch.mean(logvar, axis=0)
+        
+        return mean_mu, mean_logvar
+
+
+def visualize_PC_with_twolabel_rotated(nodes_xyz_pre, labels_pre, labels_gd, filename='PC_label.pdf'):
+    # Define custom colors for labels
+    color_dict = {0: '#BCB6AE', 1: '#288596', 2: '#7D9083'}
+
+    df = pd.DataFrame(nodes_xyz_pre, columns=['x', 'y', 'z'])
+    colors_gd = [color_dict[label] for label in labels_gd]
+    colors_pre = [color_dict[label] for label in labels_pre]
+    
+
+    fig, (ax1, ax2) = plt.subplots(1, 2, subplot_kw={'projection': '3d'})
+    ax1.scatter(df['x'], df['y'], df['z'], c=colors_gd, s=1.5)  
+    ax1.set_title('Ground truth')
+    ax2.scatter(df['x'], df['y'], df['z'], c=colors_pre, s=1.5) 
+    ax2.set_title('Prediction')
+    ax1.set_axis_off() # Hide coordinate space 
+    ax2.set_axis_off() # Hide coordinate space
+
+    # 定义交互事件函数
+    def on_rotate(event):
+        # 获取当前旋转的角度
+        elev = ax1.elev
+        azim = ax1.azim
+        
+        # 设置两个子图的视角
+        ax1.view_init(elev=elev, azim=azim)
+        ax2.view_init(elev=elev, azim=azim)
+        
+        # 更新图形
+        fig.canvas.draw()
+
+    # 绑定交互事件
+    fig.canvas.mpl_connect('motion_notify_event', on_rotate)
+
+    plt.show()
+
+def visualize_PC_with_twolabel(nodes_xyz_pre, labels_pre, labels_gd, filename='PC_label.pdf'):
+    # Define custom colors for labels
+    color_dict = {0: '#BCB6AE', 1: '#288596', 2: '#7D9083'}
+
+    df = pd.DataFrame(nodes_xyz_pre, columns=['x', 'y', 'z'])
+    colors_pre = [color_dict[label] for label in labels_pre]
+    colors_gd = [color_dict[label] for label in labels_gd]
+
+    fig = plt.figure(figsize=(6, 4))
+    ax1 = fig.add_subplot(122, projection='3d')
+    ax1.scatter(df['x'], df['y'], df['z'], c=colors_pre, s=1.5)  
+    ax1.set_axis_off() # Hide coordinate space
+    ax2 = fig.add_subplot(121, projection='3d')
+    ax2.scatter(df['x'], df['y'], df['z'], c=colors_gd, s=1.5)    
+    ax2.set_axis_off() # Hide coordinate space
+    plt.subplots_adjust(wspace=0)
+    plt.savefig(filename)
+    # plt.show()
+    plt.close(fig)
+
+def visualize_two_PC(nodes_xyz_pre, nodes_xyz_gd, labels, filename='PC_recon.pdf'):
+    color_dict = {0: '#BCB6AE', 1: '#BCB6AE', 2: '#BCB6AE'}
+    colors = [color_dict[label] for label in labels]
+
+    df_pre = pd.DataFrame(nodes_xyz_pre, columns=['x', 'y', 'z'])
+    df_gd = pd.DataFrame(nodes_xyz_gd, columns=['x', 'y', 'z'])
+
+    fig = plt.figure(figsize=(4, 6))
+    ax1 = fig.add_subplot(212, projection='3d')
+    ax1.scatter(df_pre['x'], df_pre['y'], df_pre['z'], c=colors, s=1.5)  
+    ax1.set_axis_off() # Hide coordinate space
+    ax2 = fig.add_subplot(211, projection='3d')
+    ax2.scatter(df_gd['x'], df_gd['y'], df_gd['z'], c=colors, s=1.5)    
+    ax2.set_axis_off() # Hide coordinate space
+    plt.subplots_adjust(hspace=0)
+    plt.savefig(filename)
+    # plt.show()
+    plt.close(fig)
+
+def visualize_PC_with_label(nodes_xyz, labels=1, filename='PC_label.pdf'):
+    # plot in 3d using plotly
+    df = pd.DataFrame(nodes_xyz, columns=['x', 'y', 'z'])
+    # define custom colors for each category
+    # colors = {'0': '#BCB6AE', '1': '#288596', '3': '#7D9083'}
+    # colors = {'0': 'grey', '1': 'blue', '3': 'red'}
+    # df['color'] = label.astype(int)
+    # fig = px.scatter_3d(df, x='x', y='y', z='z', color = 'color', color_discrete_sequence=[colors[k] for k in sorted(colors.keys())])
+    # # fig = px.scatter_3d(df, x='x', y='y', z='z', color = clr_nodes, color_continuous_scale=px.colors.sequential.Viridis)
+    # fig.update_traces(marker_size = 1.5)  # increase marker_size for bigger node size
+    # fig.show()   
+    # plotly.offline.plot(fig)
+    # fig.write_image(filename) 
+
+    # Define custom colors for labels
+    color_dict = {0: '#BCB6AE', 1: '#288596', 2: '#7D9083'}
+    # color_dict = {0: '#BCB6AE', 1: '#288596'}
+    colors = [color_dict[label] for label in labels]
+
+    fig = plt.figure()
+    ax = fig.add_subplot(111, projection='3d')
+    ax.scatter(df['x'], df['y'], df['z'], c=colors, s=1.5)  
+    ax.set_axis_off() # Hide coordinate space
+    plt.savefig(filename)
+    plt.close(fig)
+
+def save_coord_for_visualization(data, savename):
+    with open('./log/' + savename+'_LVendo.csv', 'w') as f:
+        f.write('"Points:0","Points:1","Points:2"\n')
+        for i in range(0, len(data)):
+            f.write(str(data[i, 0]) + ',' + str(data[i, 1]) + ',' + str(data[i, 2]) + '\n')
+    with open('./log/' + savename+'_epi.csv', 'w') as f:
+        f.write('"Points:0","Points:1","Points:2"\n')
+        for i in range(0, len(data)):
+            f.write(str(data[i, 3]) + ',' + str(data[i, 4]) + ',' + str(data[i, 5]) + '\n')
+    with open('./log/' + savename+'_RVendo.csv', 'w') as f:
+        f.write('"Points:0","Points:1","Points:2"\n')
+        for i in range(0, len(data)):
+            f.write(str(data[i, 6]) + ',' + str(data[i, 7]) + ',' + str(data[i, 8]) + '\n')
+
+def lossplot_detailed(lossfile_train, lossfile_val, lossfile_mesh_train, lossfile_mesh_val, lossfile_KL_train, lossfile_KL_val, lossfile_compactness_train, lossfile_compactness_val, lossfile_PC_train, lossfile_PC_val, lossfile_ecg_train, lossfile_ecg_val, lossfile_RVp_train, lossfile_RVp_val, lossfile_size_train, lossfile_size_val):
+    ax = plt.subplot(331)
+    ax.set_title('total loss')
+    lossplot(lossfile_train, lossfile_val)
+
+    ax = plt.subplot(332)
+    ax.set_title('MI Dice + CE loss')
+    lossplot(lossfile_mesh_train, lossfile_mesh_val)
+
+    ax = plt.subplot(333)
+    ax.set_title('MI compactness loss')
+    lossplot(lossfile_compactness_train, lossfile_compactness_val)
+
+    ax = plt.subplot(334)
+    ax.set_title('KL loss')
+    lossplot(lossfile_KL_train, lossfile_KL_val)
+
+    ax = plt.subplot(335)
+    ax.set_title('PC recon loss')
+    lossplot(lossfile_PC_train, lossfile_PC_val)
+
+    ax = plt.subplot(336)
+    ax.set_title('ECG recon loss')
+    lossplot(lossfile_ecg_train, lossfile_ecg_val)
+
+    ax = plt.subplot(337)
+    ax.set_title('MI size loss')
+    lossplot(lossfile_size_train, lossfile_size_val)
+
+    ax = plt.subplot(338)
+    ax.set_title('MI RVpenalty loss')
+    lossplot(lossfile_RVp_train, lossfile_RVp_val)
+
+    # set the spacing between subplots
+    plt.subplots_adjust(left=0.1,
+                    bottom=0.1, 
+                    right=0.9, 
+                    top=0.9, 
+                    wspace=0.4, 
+                    hspace=0.4)
+
+    plt.savefig("img.png")
+    # plt.show()
+
+def lossplot_classify(lossfile_train, lossfile_val, lossfile_mesh_train, lossfile_mesh_val, lossfile_KL_train, lossfile_KL_val, lossfile_ecg_train, lossfile_ecg_val):
+    ax = plt.subplot(221)
+    ax.set_title('total loss')
+    lossplot(lossfile_train, lossfile_val)
+
+    ax = plt.subplot(222)
+    ax.set_title('MI classfication loss')
+    lossplot(lossfile_mesh_train, lossfile_mesh_val)
+
+    ax = plt.subplot(223)
+    ax.set_title('KL loss')
+    lossplot(lossfile_KL_train, lossfile_KL_val)
+
+
+    ax = plt.subplot(224)
+    ax.set_title('ECG recon loss')
+    lossplot(lossfile_ecg_train, lossfile_ecg_val)
+
+
+    # set the spacing between subplots
+    plt.subplots_adjust(left=0.1,
+                    bottom=0.1, 
+                    right=0.9, 
+                    top=0.9, 
+                    wspace=0.4, 
+                    hspace=0.4)
+
+    plt.savefig("img_classify.png")
+    # plt.show()
+
+def lossplot(lossfile1, lossfile2):
+    loss = np.loadtxt(lossfile1)
+    x = range(0, loss.size)
+    y = loss
+    plt.plot(x, y, '#FF7F61') # , label='train'
+    plt.legend(frameon=False)
+
+    loss = np.loadtxt(lossfile2)
+    x = range(0, loss.size)
+    y = loss
+    plt.plot(x, y, '#2C4068') # , label='val'
+    plt.legend(frameon=False)
+    # plt.show()
+    # plt.savefig("img.png")
+
+def ECG_visual_two(prop_data, target_ecg):   
+    prop_data[target_ecg[np.newaxis, ...] == 0.0], target_ecg[target_ecg == 0.0] = np.nan, np.nan
+
+    leadNames = ['I', 'II', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
+
+    fig, axs = plt.subplots(2, 8, constrained_layout=True, figsize=(40, 10))
+    for i in range(8):
+        leadName = leadNames[i]
+        axs[0, i].plot(prop_data[0, i, :], color=[223/256,176/256,160/256], label='pred', linewidth=4)
+        for j in range(1, prop_data.shape[0]):
+            axs[0, i].plot(prop_data[j, i, :], color=[223/256,176/256,160/256], linewidth=4) 
+        axs[0, i].plot(target_ecg[i, :], color=[154/256,181/256,174/256], label='true', linewidth=4)
+        axs[0, i].set_title('Lead ' + leadName, fontsize=20)
+        axs[0, i].set_axis_off() 
+        axs[1, i].set_axis_off() 
+    axs[0, i].legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=20)
+    fig.savefig("ECG_visual.pdf")
+    # plt.show()
+    plt.close(fig)
+
+if __name__ == '__main__':
+    # input_data_dir = 'C:/Users/lilei/OneDrive - Nexus365/2021_Oxford/Oxford Research/BivenMesh_Script/dataset/gt/'
+    # pc = input_data_dir + 'dense_RV_endo_output_labeled_ES_pc_6003744.ply'
+    # pc_volume = calculate_pointcloudvolume(pc)
+    # F_visual_CV()
+
+    log_dir = 'E:/2022_ECG_inference/Cardiac_Personalisation/log'
+    lossfile_train = log_dir + "/training_loss.txt"
+    lossfile_val = log_dir + "/val_loss.txt"
+    lossfile_geometry_train = log_dir + "/training_calculate_inference_loss.txt"
+    lossfile_geometry_val = log_dir + "/val_calculate_inference_loss.txt"
+    lossfile_compactness_train = log_dir + "/training_compactness_loss.txt"
+    lossfile_compactness_val = log_dir + "/val_compactness_loss.txt"
+    lossfile_KL_train = log_dir + "/training_KL_loss.txt"
+    lossfile_KL_val = log_dir + "/val_KL_loss.txt"
+    lossfile_PC_train = log_dir + "/training_PC_loss.txt"
+    lossfile_PC_val = log_dir + "/val_PC_loss.txt"
+    lossfile_ecg_train = log_dir + "/training_ecg_loss.txt"
+    lossfile_ecg_val = log_dir + "/val_ecg_loss.txt"
+    lossfile_RVp_train = log_dir + "/training_RVp_loss.txt"
+    lossfile_RVp_val = log_dir + "/val_RVp_loss.txt"
+    lossfile_size_train = log_dir + "/training_MIsize_loss.txt"
+    lossfile_size_val = log_dir + "/val_MIsize_loss.txt"
+
+    lossplot_detailed(lossfile_train, lossfile_val, lossfile_geometry_train, lossfile_geometry_val, lossfile_KL_train, lossfile_KL_val, lossfile_compactness_train, lossfile_compactness_val, lossfile_PC_train, lossfile_PC_val, lossfile_ecg_train, lossfile_ecg_val, lossfile_RVp_train, lossfile_RVp_val, lossfile_size_train, lossfile_size_val)