Diff of /utils.py [000000] .. [390c2f]

Switch to unified view

a b/utils.py
1
import numpy as np
2
import pandas as pd
3
from matplotlib import pyplot as plt
4
import plotly.express as px
5
from mpl_toolkits.mplot3d import Axes3D
6
import torch
7
import torch.nn as nn
8
EPS = 1e-4
9
10
class ProductOfExperts(nn.Module):
11
    """Return parameters for product of independent experts.
12
    See https://arxiv.org/pdf/1410.7827.pdf for equations.
13
14
    Args:
15
    mu (torch.Tensor): Mean of experts distribution. M x D for M experts
16
    logvar (torch.Tensor): Log of variance of experts distribution. M x D for M experts
17
    """
18
19
    def forward(self, mu, logvar):
20
        var = torch.exp(logvar) + EPS
21
        T = 1. / (var + EPS)
22
        pd_mu = torch.sum(mu * T, dim=0) / torch.sum(T, dim=0)
23
        pd_var = 1. / torch.sum(T, dim=0)
24
        pd_logvar = torch.log(pd_var + EPS)
25
26
        return pd_mu, pd_logvar
27
28
class alphaProductOfExperts(nn.Module):
29
    """Return parameters for weighted product of independent experts (mmJSD implementation).
30
    See https://arxiv.org/pdf/1410.7827.pdf for equations.
31
32
    Args:
33
    mu (torch.Tensor): Mean of experts distribution. M x D for M experts
34
    logvar (torch.Tensor): Log of variance of experts distribution. M x D for M experts
35
    """
36
37
    def forward(self, mu, logvar, weights=None):
38
        if weights is None:
39
            num_components = mu.shape[0]
40
            weights = (1/num_components) * torch.ones(mu.shape).to(mu.device)
41
    
42
        var = torch.exp(logvar) + EPS
43
        T = 1. / (var + EPS)
44
        weights = torch.broadcast_to(weights, mu.shape)
45
        pd_var = 1. / torch.sum(weights * T + EPS, dim=0)
46
        pd_mu = pd_var * torch.sum(weights * mu * T, dim=0)
47
        pd_logvar = torch.log(pd_var + EPS)
48
        
49
        return pd_mu, pd_logvar
50
    
51
class weightedProductOfExperts(nn.Module):
52
    """Return parameters for weighted product of independent experts.
53
    See https://arxiv.org/pdf/1410.7827.pdf for equations.
54
55
    Args:
56
    mu (torch.Tensor): Mean of experts distribution. M x D for M experts
57
    logvar (torch.Tensor): Log of variance of experts distribution. M x D for M experts
58
    """
59
60
    def forward(self, mu, logvar, weight):
61
62
        var = torch.exp(logvar) + EPS     
63
        weight = weight[:, None, :].repeat(1, mu.shape[1],1)
64
        T = 1.0 / (var + EPS)
65
        pd_var = 1. / torch.sum(weight * T + EPS, dim=0)
66
        pd_mu = pd_var * torch.sum(weight * mu * T, dim=0)
67
        pd_logvar = torch.log(pd_var + EPS)
68
        return pd_mu, pd_logvar
69
70
class MixtureOfExperts(nn.Module):
71
    """Return parameters for mixture of independent experts.
72
    Implementation from: https://github.com/thomassutter/MoPoE
73
74
    Args:
75
    mus (torch.Tensor): Mean of experts distribution. M x D for M experts
76
    logvars (torch.Tensor): Log of variance of experts distribution. M x D for M experts
77
    """
78
79
    def forward(self, mus, logvars):
80
81
        num_components = mus.shape[0]
82
        num_samples = mus.shape[1]
83
        weights = (1/num_components) * torch.ones(num_components).to(mus[0].device)
84
        idx_start = []
85
        idx_end = []
86
        for k in range(0, num_components):
87
            if k == 0:
88
                i_start = 0
89
            else:
90
                i_start = int(idx_end[k-1])
91
            if k == num_components-1:
92
                i_end = num_samples
93
            else:
94
                i_end = i_start + int(torch.floor(num_samples*weights[k]))
95
            idx_start.append(i_start)
96
            idx_end.append(i_end)
97
        idx_end[-1] = num_samples
98
99
        mu_sel = torch.cat([mus[k, idx_start[k]:idx_end[k], :] for k in range(num_components)])
100
        logvar_sel = torch.cat([logvars[k, idx_start[k]:idx_end[k], :] for k in range(num_components)])
101
102
        return mu_sel, logvar_sel
103
104
class MeanRepresentation(nn.Module):
105
    """Return mean of separate VAE representations.
106
    
107
    Args:
108
    mu (torch.Tensor): Mean of distributions. M x D for M views.
109
    logvar (torch.Tensor): Log of Variance of distributions. M x D for M views.
110
    """
111
112
    def forward(self, mu, logvar):
113
        mean_mu = torch.mean(mu, axis=0)
114
        mean_logvar = torch.mean(logvar, axis=0)
115
        
116
        return mean_mu, mean_logvar
117
118
119
def visualize_PC_with_twolabel_rotated(nodes_xyz_pre, labels_pre, labels_gd, filename='PC_label.pdf'):
120
    # Define custom colors for labels
121
    color_dict = {0: '#BCB6AE', 1: '#288596', 2: '#7D9083'}
122
123
    df = pd.DataFrame(nodes_xyz_pre, columns=['x', 'y', 'z'])
124
    colors_gd = [color_dict[label] for label in labels_gd]
125
    colors_pre = [color_dict[label] for label in labels_pre]
126
    
127
128
    fig, (ax1, ax2) = plt.subplots(1, 2, subplot_kw={'projection': '3d'})
129
    ax1.scatter(df['x'], df['y'], df['z'], c=colors_gd, s=1.5)  
130
    ax1.set_title('Ground truth')
131
    ax2.scatter(df['x'], df['y'], df['z'], c=colors_pre, s=1.5) 
132
    ax2.set_title('Prediction')
133
    ax1.set_axis_off() # Hide coordinate space 
134
    ax2.set_axis_off() # Hide coordinate space
135
136
    # 定义交互事件函数
137
    def on_rotate(event):
138
        # 获取当前旋转的角度
139
        elev = ax1.elev
140
        azim = ax1.azim
141
        
142
        # 设置两个子图的视角
143
        ax1.view_init(elev=elev, azim=azim)
144
        ax2.view_init(elev=elev, azim=azim)
145
        
146
        # 更新图形
147
        fig.canvas.draw()
148
149
    # 绑定交互事件
150
    fig.canvas.mpl_connect('motion_notify_event', on_rotate)
151
152
    plt.show()
153
154
def visualize_PC_with_twolabel(nodes_xyz_pre, labels_pre, labels_gd, filename='PC_label.pdf'):
155
    # Define custom colors for labels
156
    color_dict = {0: '#BCB6AE', 1: '#288596', 2: '#7D9083'}
157
158
    df = pd.DataFrame(nodes_xyz_pre, columns=['x', 'y', 'z'])
159
    colors_pre = [color_dict[label] for label in labels_pre]
160
    colors_gd = [color_dict[label] for label in labels_gd]
161
162
    fig = plt.figure(figsize=(6, 4))
163
    ax1 = fig.add_subplot(122, projection='3d')
164
    ax1.scatter(df['x'], df['y'], df['z'], c=colors_pre, s=1.5)  
165
    ax1.set_axis_off() # Hide coordinate space
166
    ax2 = fig.add_subplot(121, projection='3d')
167
    ax2.scatter(df['x'], df['y'], df['z'], c=colors_gd, s=1.5)    
168
    ax2.set_axis_off() # Hide coordinate space
169
    plt.subplots_adjust(wspace=0)
170
    plt.savefig(filename)
171
    # plt.show()
172
    plt.close(fig)
173
174
def visualize_two_PC(nodes_xyz_pre, nodes_xyz_gd, labels, filename='PC_recon.pdf'):
175
    color_dict = {0: '#BCB6AE', 1: '#BCB6AE', 2: '#BCB6AE'}
176
    colors = [color_dict[label] for label in labels]
177
178
    df_pre = pd.DataFrame(nodes_xyz_pre, columns=['x', 'y', 'z'])
179
    df_gd = pd.DataFrame(nodes_xyz_gd, columns=['x', 'y', 'z'])
180
181
    fig = plt.figure(figsize=(4, 6))
182
    ax1 = fig.add_subplot(212, projection='3d')
183
    ax1.scatter(df_pre['x'], df_pre['y'], df_pre['z'], c=colors, s=1.5)  
184
    ax1.set_axis_off() # Hide coordinate space
185
    ax2 = fig.add_subplot(211, projection='3d')
186
    ax2.scatter(df_gd['x'], df_gd['y'], df_gd['z'], c=colors, s=1.5)    
187
    ax2.set_axis_off() # Hide coordinate space
188
    plt.subplots_adjust(hspace=0)
189
    plt.savefig(filename)
190
    # plt.show()
191
    plt.close(fig)
192
193
def visualize_PC_with_label(nodes_xyz, labels=1, filename='PC_label.pdf'):
194
    # plot in 3d using plotly
195
    df = pd.DataFrame(nodes_xyz, columns=['x', 'y', 'z'])
196
    # define custom colors for each category
197
    # colors = {'0': '#BCB6AE', '1': '#288596', '3': '#7D9083'}
198
    # colors = {'0': 'grey', '1': 'blue', '3': 'red'}
199
    # df['color'] = label.astype(int)
200
    # fig = px.scatter_3d(df, x='x', y='y', z='z', color = 'color', color_discrete_sequence=[colors[k] for k in sorted(colors.keys())])
201
    # # fig = px.scatter_3d(df, x='x', y='y', z='z', color = clr_nodes, color_continuous_scale=px.colors.sequential.Viridis)
202
    # fig.update_traces(marker_size = 1.5)  # increase marker_size for bigger node size
203
    # fig.show()   
204
    # plotly.offline.plot(fig)
205
    # fig.write_image(filename) 
206
207
    # Define custom colors for labels
208
    color_dict = {0: '#BCB6AE', 1: '#288596', 2: '#7D9083'}
209
    # color_dict = {0: '#BCB6AE', 1: '#288596'}
210
    colors = [color_dict[label] for label in labels]
211
212
    fig = plt.figure()
213
    ax = fig.add_subplot(111, projection='3d')
214
    ax.scatter(df['x'], df['y'], df['z'], c=colors, s=1.5)  
215
    ax.set_axis_off() # Hide coordinate space
216
    plt.savefig(filename)
217
    plt.close(fig)
218
219
def save_coord_for_visualization(data, savename):
220
    with open('./log/' + savename+'_LVendo.csv', 'w') as f:
221
        f.write('"Points:0","Points:1","Points:2"\n')
222
        for i in range(0, len(data)):
223
            f.write(str(data[i, 0]) + ',' + str(data[i, 1]) + ',' + str(data[i, 2]) + '\n')
224
    with open('./log/' + savename+'_epi.csv', 'w') as f:
225
        f.write('"Points:0","Points:1","Points:2"\n')
226
        for i in range(0, len(data)):
227
            f.write(str(data[i, 3]) + ',' + str(data[i, 4]) + ',' + str(data[i, 5]) + '\n')
228
    with open('./log/' + savename+'_RVendo.csv', 'w') as f:
229
        f.write('"Points:0","Points:1","Points:2"\n')
230
        for i in range(0, len(data)):
231
            f.write(str(data[i, 6]) + ',' + str(data[i, 7]) + ',' + str(data[i, 8]) + '\n')
232
233
def lossplot_detailed(lossfile_train, lossfile_val, lossfile_mesh_train, lossfile_mesh_val, lossfile_KL_train, lossfile_KL_val, lossfile_compactness_train, lossfile_compactness_val, lossfile_PC_train, lossfile_PC_val, lossfile_ecg_train, lossfile_ecg_val, lossfile_RVp_train, lossfile_RVp_val, lossfile_size_train, lossfile_size_val):
234
    ax = plt.subplot(331)
235
    ax.set_title('total loss')
236
    lossplot(lossfile_train, lossfile_val)
237
238
    ax = plt.subplot(332)
239
    ax.set_title('MI Dice + CE loss')
240
    lossplot(lossfile_mesh_train, lossfile_mesh_val)
241
242
    ax = plt.subplot(333)
243
    ax.set_title('MI compactness loss')
244
    lossplot(lossfile_compactness_train, lossfile_compactness_val)
245
246
    ax = plt.subplot(334)
247
    ax.set_title('KL loss')
248
    lossplot(lossfile_KL_train, lossfile_KL_val)
249
250
    ax = plt.subplot(335)
251
    ax.set_title('PC recon loss')
252
    lossplot(lossfile_PC_train, lossfile_PC_val)
253
254
    ax = plt.subplot(336)
255
    ax.set_title('ECG recon loss')
256
    lossplot(lossfile_ecg_train, lossfile_ecg_val)
257
258
    ax = plt.subplot(337)
259
    ax.set_title('MI size loss')
260
    lossplot(lossfile_size_train, lossfile_size_val)
261
262
    ax = plt.subplot(338)
263
    ax.set_title('MI RVpenalty loss')
264
    lossplot(lossfile_RVp_train, lossfile_RVp_val)
265
266
    # set the spacing between subplots
267
    plt.subplots_adjust(left=0.1,
268
                    bottom=0.1, 
269
                    right=0.9, 
270
                    top=0.9, 
271
                    wspace=0.4, 
272
                    hspace=0.4)
273
274
    plt.savefig("img.png")
275
    # plt.show()
276
277
def lossplot_classify(lossfile_train, lossfile_val, lossfile_mesh_train, lossfile_mesh_val, lossfile_KL_train, lossfile_KL_val, lossfile_ecg_train, lossfile_ecg_val):
278
    ax = plt.subplot(221)
279
    ax.set_title('total loss')
280
    lossplot(lossfile_train, lossfile_val)
281
282
    ax = plt.subplot(222)
283
    ax.set_title('MI classfication loss')
284
    lossplot(lossfile_mesh_train, lossfile_mesh_val)
285
286
    ax = plt.subplot(223)
287
    ax.set_title('KL loss')
288
    lossplot(lossfile_KL_train, lossfile_KL_val)
289
290
291
    ax = plt.subplot(224)
292
    ax.set_title('ECG recon loss')
293
    lossplot(lossfile_ecg_train, lossfile_ecg_val)
294
295
296
    # set the spacing between subplots
297
    plt.subplots_adjust(left=0.1,
298
                    bottom=0.1, 
299
                    right=0.9, 
300
                    top=0.9, 
301
                    wspace=0.4, 
302
                    hspace=0.4)
303
304
    plt.savefig("img_classify.png")
305
    # plt.show()
306
307
def lossplot(lossfile1, lossfile2):
308
    loss = np.loadtxt(lossfile1)
309
    x = range(0, loss.size)
310
    y = loss
311
    plt.plot(x, y, '#FF7F61') # , label='train'
312
    plt.legend(frameon=False)
313
314
    loss = np.loadtxt(lossfile2)
315
    x = range(0, loss.size)
316
    y = loss
317
    plt.plot(x, y, '#2C4068') # , label='val'
318
    plt.legend(frameon=False)
319
    # plt.show()
320
    # plt.savefig("img.png")
321
322
def ECG_visual_two(prop_data, target_ecg):   
323
    prop_data[target_ecg[np.newaxis, ...] == 0.0], target_ecg[target_ecg == 0.0] = np.nan, np.nan
324
325
    leadNames = ['I', 'II', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
326
327
    fig, axs = plt.subplots(2, 8, constrained_layout=True, figsize=(40, 10))
328
    for i in range(8):
329
        leadName = leadNames[i]
330
        axs[0, i].plot(prop_data[0, i, :], color=[223/256,176/256,160/256], label='pred', linewidth=4)
331
        for j in range(1, prop_data.shape[0]):
332
            axs[0, i].plot(prop_data[j, i, :], color=[223/256,176/256,160/256], linewidth=4) 
333
        axs[0, i].plot(target_ecg[i, :], color=[154/256,181/256,174/256], label='true', linewidth=4)
334
        axs[0, i].set_title('Lead ' + leadName, fontsize=20)
335
        axs[0, i].set_axis_off() 
336
        axs[1, i].set_axis_off() 
337
    axs[0, i].legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=20)
338
    fig.savefig("ECG_visual.pdf")
339
    # plt.show()
340
    plt.close(fig)
341
342
if __name__ == '__main__':
343
    # input_data_dir = 'C:/Users/lilei/OneDrive - Nexus365/2021_Oxford/Oxford Research/BivenMesh_Script/dataset/gt/'
344
    # pc = input_data_dir + 'dense_RV_endo_output_labeled_ES_pc_6003744.ply'
345
    # pc_volume = calculate_pointcloudvolume(pc)
346
    # F_visual_CV()
347
348
    log_dir = 'E:/2022_ECG_inference/Cardiac_Personalisation/log'
349
    lossfile_train = log_dir + "/training_loss.txt"
350
    lossfile_val = log_dir + "/val_loss.txt"
351
    lossfile_geometry_train = log_dir + "/training_calculate_inference_loss.txt"
352
    lossfile_geometry_val = log_dir + "/val_calculate_inference_loss.txt"
353
    lossfile_compactness_train = log_dir + "/training_compactness_loss.txt"
354
    lossfile_compactness_val = log_dir + "/val_compactness_loss.txt"
355
    lossfile_KL_train = log_dir + "/training_KL_loss.txt"
356
    lossfile_KL_val = log_dir + "/val_KL_loss.txt"
357
    lossfile_PC_train = log_dir + "/training_PC_loss.txt"
358
    lossfile_PC_val = log_dir + "/val_PC_loss.txt"
359
    lossfile_ecg_train = log_dir + "/training_ecg_loss.txt"
360
    lossfile_ecg_val = log_dir + "/val_ecg_loss.txt"
361
    lossfile_RVp_train = log_dir + "/training_RVp_loss.txt"
362
    lossfile_RVp_val = log_dir + "/val_RVp_loss.txt"
363
    lossfile_size_train = log_dir + "/training_MIsize_loss.txt"
364
    lossfile_size_val = log_dir + "/val_MIsize_loss.txt"
365
366
    lossplot_detailed(lossfile_train, lossfile_val, lossfile_geometry_train, lossfile_geometry_val, lossfile_KL_train, lossfile_KL_val, lossfile_compactness_train, lossfile_compactness_val, lossfile_PC_train, lossfile_PC_val, lossfile_ecg_train, lossfile_ecg_val, lossfile_RVp_train, lossfile_RVp_val, lossfile_size_train, lossfile_size_val)