|
a |
|
b/utils.py |
|
|
1 |
import numpy as np |
|
|
2 |
import pandas as pd |
|
|
3 |
from matplotlib import pyplot as plt |
|
|
4 |
import plotly.express as px |
|
|
5 |
from mpl_toolkits.mplot3d import Axes3D |
|
|
6 |
import torch |
|
|
7 |
import torch.nn as nn |
|
|
8 |
EPS = 1e-4 |
|
|
9 |
|
|
|
10 |
class ProductOfExperts(nn.Module): |
|
|
11 |
"""Return parameters for product of independent experts. |
|
|
12 |
See https://arxiv.org/pdf/1410.7827.pdf for equations. |
|
|
13 |
|
|
|
14 |
Args: |
|
|
15 |
mu (torch.Tensor): Mean of experts distribution. M x D for M experts |
|
|
16 |
logvar (torch.Tensor): Log of variance of experts distribution. M x D for M experts |
|
|
17 |
""" |
|
|
18 |
|
|
|
19 |
def forward(self, mu, logvar): |
|
|
20 |
var = torch.exp(logvar) + EPS |
|
|
21 |
T = 1. / (var + EPS) |
|
|
22 |
pd_mu = torch.sum(mu * T, dim=0) / torch.sum(T, dim=0) |
|
|
23 |
pd_var = 1. / torch.sum(T, dim=0) |
|
|
24 |
pd_logvar = torch.log(pd_var + EPS) |
|
|
25 |
|
|
|
26 |
return pd_mu, pd_logvar |
|
|
27 |
|
|
|
28 |
class alphaProductOfExperts(nn.Module): |
|
|
29 |
"""Return parameters for weighted product of independent experts (mmJSD implementation). |
|
|
30 |
See https://arxiv.org/pdf/1410.7827.pdf for equations. |
|
|
31 |
|
|
|
32 |
Args: |
|
|
33 |
mu (torch.Tensor): Mean of experts distribution. M x D for M experts |
|
|
34 |
logvar (torch.Tensor): Log of variance of experts distribution. M x D for M experts |
|
|
35 |
""" |
|
|
36 |
|
|
|
37 |
def forward(self, mu, logvar, weights=None): |
|
|
38 |
if weights is None: |
|
|
39 |
num_components = mu.shape[0] |
|
|
40 |
weights = (1/num_components) * torch.ones(mu.shape).to(mu.device) |
|
|
41 |
|
|
|
42 |
var = torch.exp(logvar) + EPS |
|
|
43 |
T = 1. / (var + EPS) |
|
|
44 |
weights = torch.broadcast_to(weights, mu.shape) |
|
|
45 |
pd_var = 1. / torch.sum(weights * T + EPS, dim=0) |
|
|
46 |
pd_mu = pd_var * torch.sum(weights * mu * T, dim=0) |
|
|
47 |
pd_logvar = torch.log(pd_var + EPS) |
|
|
48 |
|
|
|
49 |
return pd_mu, pd_logvar |
|
|
50 |
|
|
|
51 |
class weightedProductOfExperts(nn.Module): |
|
|
52 |
"""Return parameters for weighted product of independent experts. |
|
|
53 |
See https://arxiv.org/pdf/1410.7827.pdf for equations. |
|
|
54 |
|
|
|
55 |
Args: |
|
|
56 |
mu (torch.Tensor): Mean of experts distribution. M x D for M experts |
|
|
57 |
logvar (torch.Tensor): Log of variance of experts distribution. M x D for M experts |
|
|
58 |
""" |
|
|
59 |
|
|
|
60 |
def forward(self, mu, logvar, weight): |
|
|
61 |
|
|
|
62 |
var = torch.exp(logvar) + EPS |
|
|
63 |
weight = weight[:, None, :].repeat(1, mu.shape[1],1) |
|
|
64 |
T = 1.0 / (var + EPS) |
|
|
65 |
pd_var = 1. / torch.sum(weight * T + EPS, dim=0) |
|
|
66 |
pd_mu = pd_var * torch.sum(weight * mu * T, dim=0) |
|
|
67 |
pd_logvar = torch.log(pd_var + EPS) |
|
|
68 |
return pd_mu, pd_logvar |
|
|
69 |
|
|
|
70 |
class MixtureOfExperts(nn.Module): |
|
|
71 |
"""Return parameters for mixture of independent experts. |
|
|
72 |
Implementation from: https://github.com/thomassutter/MoPoE |
|
|
73 |
|
|
|
74 |
Args: |
|
|
75 |
mus (torch.Tensor): Mean of experts distribution. M x D for M experts |
|
|
76 |
logvars (torch.Tensor): Log of variance of experts distribution. M x D for M experts |
|
|
77 |
""" |
|
|
78 |
|
|
|
79 |
def forward(self, mus, logvars): |
|
|
80 |
|
|
|
81 |
num_components = mus.shape[0] |
|
|
82 |
num_samples = mus.shape[1] |
|
|
83 |
weights = (1/num_components) * torch.ones(num_components).to(mus[0].device) |
|
|
84 |
idx_start = [] |
|
|
85 |
idx_end = [] |
|
|
86 |
for k in range(0, num_components): |
|
|
87 |
if k == 0: |
|
|
88 |
i_start = 0 |
|
|
89 |
else: |
|
|
90 |
i_start = int(idx_end[k-1]) |
|
|
91 |
if k == num_components-1: |
|
|
92 |
i_end = num_samples |
|
|
93 |
else: |
|
|
94 |
i_end = i_start + int(torch.floor(num_samples*weights[k])) |
|
|
95 |
idx_start.append(i_start) |
|
|
96 |
idx_end.append(i_end) |
|
|
97 |
idx_end[-1] = num_samples |
|
|
98 |
|
|
|
99 |
mu_sel = torch.cat([mus[k, idx_start[k]:idx_end[k], :] for k in range(num_components)]) |
|
|
100 |
logvar_sel = torch.cat([logvars[k, idx_start[k]:idx_end[k], :] for k in range(num_components)]) |
|
|
101 |
|
|
|
102 |
return mu_sel, logvar_sel |
|
|
103 |
|
|
|
104 |
class MeanRepresentation(nn.Module): |
|
|
105 |
"""Return mean of separate VAE representations. |
|
|
106 |
|
|
|
107 |
Args: |
|
|
108 |
mu (torch.Tensor): Mean of distributions. M x D for M views. |
|
|
109 |
logvar (torch.Tensor): Log of Variance of distributions. M x D for M views. |
|
|
110 |
""" |
|
|
111 |
|
|
|
112 |
def forward(self, mu, logvar): |
|
|
113 |
mean_mu = torch.mean(mu, axis=0) |
|
|
114 |
mean_logvar = torch.mean(logvar, axis=0) |
|
|
115 |
|
|
|
116 |
return mean_mu, mean_logvar |
|
|
117 |
|
|
|
118 |
|
|
|
119 |
def visualize_PC_with_twolabel_rotated(nodes_xyz_pre, labels_pre, labels_gd, filename='PC_label.pdf'): |
|
|
120 |
# Define custom colors for labels |
|
|
121 |
color_dict = {0: '#BCB6AE', 1: '#288596', 2: '#7D9083'} |
|
|
122 |
|
|
|
123 |
df = pd.DataFrame(nodes_xyz_pre, columns=['x', 'y', 'z']) |
|
|
124 |
colors_gd = [color_dict[label] for label in labels_gd] |
|
|
125 |
colors_pre = [color_dict[label] for label in labels_pre] |
|
|
126 |
|
|
|
127 |
|
|
|
128 |
fig, (ax1, ax2) = plt.subplots(1, 2, subplot_kw={'projection': '3d'}) |
|
|
129 |
ax1.scatter(df['x'], df['y'], df['z'], c=colors_gd, s=1.5) |
|
|
130 |
ax1.set_title('Ground truth') |
|
|
131 |
ax2.scatter(df['x'], df['y'], df['z'], c=colors_pre, s=1.5) |
|
|
132 |
ax2.set_title('Prediction') |
|
|
133 |
ax1.set_axis_off() # Hide coordinate space |
|
|
134 |
ax2.set_axis_off() # Hide coordinate space |
|
|
135 |
|
|
|
136 |
# 定义交互事件函数 |
|
|
137 |
def on_rotate(event): |
|
|
138 |
# 获取当前旋转的角度 |
|
|
139 |
elev = ax1.elev |
|
|
140 |
azim = ax1.azim |
|
|
141 |
|
|
|
142 |
# 设置两个子图的视角 |
|
|
143 |
ax1.view_init(elev=elev, azim=azim) |
|
|
144 |
ax2.view_init(elev=elev, azim=azim) |
|
|
145 |
|
|
|
146 |
# 更新图形 |
|
|
147 |
fig.canvas.draw() |
|
|
148 |
|
|
|
149 |
# 绑定交互事件 |
|
|
150 |
fig.canvas.mpl_connect('motion_notify_event', on_rotate) |
|
|
151 |
|
|
|
152 |
plt.show() |
|
|
153 |
|
|
|
154 |
def visualize_PC_with_twolabel(nodes_xyz_pre, labels_pre, labels_gd, filename='PC_label.pdf'): |
|
|
155 |
# Define custom colors for labels |
|
|
156 |
color_dict = {0: '#BCB6AE', 1: '#288596', 2: '#7D9083'} |
|
|
157 |
|
|
|
158 |
df = pd.DataFrame(nodes_xyz_pre, columns=['x', 'y', 'z']) |
|
|
159 |
colors_pre = [color_dict[label] for label in labels_pre] |
|
|
160 |
colors_gd = [color_dict[label] for label in labels_gd] |
|
|
161 |
|
|
|
162 |
fig = plt.figure(figsize=(6, 4)) |
|
|
163 |
ax1 = fig.add_subplot(122, projection='3d') |
|
|
164 |
ax1.scatter(df['x'], df['y'], df['z'], c=colors_pre, s=1.5) |
|
|
165 |
ax1.set_axis_off() # Hide coordinate space |
|
|
166 |
ax2 = fig.add_subplot(121, projection='3d') |
|
|
167 |
ax2.scatter(df['x'], df['y'], df['z'], c=colors_gd, s=1.5) |
|
|
168 |
ax2.set_axis_off() # Hide coordinate space |
|
|
169 |
plt.subplots_adjust(wspace=0) |
|
|
170 |
plt.savefig(filename) |
|
|
171 |
# plt.show() |
|
|
172 |
plt.close(fig) |
|
|
173 |
|
|
|
174 |
def visualize_two_PC(nodes_xyz_pre, nodes_xyz_gd, labels, filename='PC_recon.pdf'): |
|
|
175 |
color_dict = {0: '#BCB6AE', 1: '#BCB6AE', 2: '#BCB6AE'} |
|
|
176 |
colors = [color_dict[label] for label in labels] |
|
|
177 |
|
|
|
178 |
df_pre = pd.DataFrame(nodes_xyz_pre, columns=['x', 'y', 'z']) |
|
|
179 |
df_gd = pd.DataFrame(nodes_xyz_gd, columns=['x', 'y', 'z']) |
|
|
180 |
|
|
|
181 |
fig = plt.figure(figsize=(4, 6)) |
|
|
182 |
ax1 = fig.add_subplot(212, projection='3d') |
|
|
183 |
ax1.scatter(df_pre['x'], df_pre['y'], df_pre['z'], c=colors, s=1.5) |
|
|
184 |
ax1.set_axis_off() # Hide coordinate space |
|
|
185 |
ax2 = fig.add_subplot(211, projection='3d') |
|
|
186 |
ax2.scatter(df_gd['x'], df_gd['y'], df_gd['z'], c=colors, s=1.5) |
|
|
187 |
ax2.set_axis_off() # Hide coordinate space |
|
|
188 |
plt.subplots_adjust(hspace=0) |
|
|
189 |
plt.savefig(filename) |
|
|
190 |
# plt.show() |
|
|
191 |
plt.close(fig) |
|
|
192 |
|
|
|
193 |
def visualize_PC_with_label(nodes_xyz, labels=1, filename='PC_label.pdf'): |
|
|
194 |
# plot in 3d using plotly |
|
|
195 |
df = pd.DataFrame(nodes_xyz, columns=['x', 'y', 'z']) |
|
|
196 |
# define custom colors for each category |
|
|
197 |
# colors = {'0': '#BCB6AE', '1': '#288596', '3': '#7D9083'} |
|
|
198 |
# colors = {'0': 'grey', '1': 'blue', '3': 'red'} |
|
|
199 |
# df['color'] = label.astype(int) |
|
|
200 |
# fig = px.scatter_3d(df, x='x', y='y', z='z', color = 'color', color_discrete_sequence=[colors[k] for k in sorted(colors.keys())]) |
|
|
201 |
# # fig = px.scatter_3d(df, x='x', y='y', z='z', color = clr_nodes, color_continuous_scale=px.colors.sequential.Viridis) |
|
|
202 |
# fig.update_traces(marker_size = 1.5) # increase marker_size for bigger node size |
|
|
203 |
# fig.show() |
|
|
204 |
# plotly.offline.plot(fig) |
|
|
205 |
# fig.write_image(filename) |
|
|
206 |
|
|
|
207 |
# Define custom colors for labels |
|
|
208 |
color_dict = {0: '#BCB6AE', 1: '#288596', 2: '#7D9083'} |
|
|
209 |
# color_dict = {0: '#BCB6AE', 1: '#288596'} |
|
|
210 |
colors = [color_dict[label] for label in labels] |
|
|
211 |
|
|
|
212 |
fig = plt.figure() |
|
|
213 |
ax = fig.add_subplot(111, projection='3d') |
|
|
214 |
ax.scatter(df['x'], df['y'], df['z'], c=colors, s=1.5) |
|
|
215 |
ax.set_axis_off() # Hide coordinate space |
|
|
216 |
plt.savefig(filename) |
|
|
217 |
plt.close(fig) |
|
|
218 |
|
|
|
219 |
def save_coord_for_visualization(data, savename): |
|
|
220 |
with open('./log/' + savename+'_LVendo.csv', 'w') as f: |
|
|
221 |
f.write('"Points:0","Points:1","Points:2"\n') |
|
|
222 |
for i in range(0, len(data)): |
|
|
223 |
f.write(str(data[i, 0]) + ',' + str(data[i, 1]) + ',' + str(data[i, 2]) + '\n') |
|
|
224 |
with open('./log/' + savename+'_epi.csv', 'w') as f: |
|
|
225 |
f.write('"Points:0","Points:1","Points:2"\n') |
|
|
226 |
for i in range(0, len(data)): |
|
|
227 |
f.write(str(data[i, 3]) + ',' + str(data[i, 4]) + ',' + str(data[i, 5]) + '\n') |
|
|
228 |
with open('./log/' + savename+'_RVendo.csv', 'w') as f: |
|
|
229 |
f.write('"Points:0","Points:1","Points:2"\n') |
|
|
230 |
for i in range(0, len(data)): |
|
|
231 |
f.write(str(data[i, 6]) + ',' + str(data[i, 7]) + ',' + str(data[i, 8]) + '\n') |
|
|
232 |
|
|
|
233 |
def lossplot_detailed(lossfile_train, lossfile_val, lossfile_mesh_train, lossfile_mesh_val, lossfile_KL_train, lossfile_KL_val, lossfile_compactness_train, lossfile_compactness_val, lossfile_PC_train, lossfile_PC_val, lossfile_ecg_train, lossfile_ecg_val, lossfile_RVp_train, lossfile_RVp_val, lossfile_size_train, lossfile_size_val): |
|
|
234 |
ax = plt.subplot(331) |
|
|
235 |
ax.set_title('total loss') |
|
|
236 |
lossplot(lossfile_train, lossfile_val) |
|
|
237 |
|
|
|
238 |
ax = plt.subplot(332) |
|
|
239 |
ax.set_title('MI Dice + CE loss') |
|
|
240 |
lossplot(lossfile_mesh_train, lossfile_mesh_val) |
|
|
241 |
|
|
|
242 |
ax = plt.subplot(333) |
|
|
243 |
ax.set_title('MI compactness loss') |
|
|
244 |
lossplot(lossfile_compactness_train, lossfile_compactness_val) |
|
|
245 |
|
|
|
246 |
ax = plt.subplot(334) |
|
|
247 |
ax.set_title('KL loss') |
|
|
248 |
lossplot(lossfile_KL_train, lossfile_KL_val) |
|
|
249 |
|
|
|
250 |
ax = plt.subplot(335) |
|
|
251 |
ax.set_title('PC recon loss') |
|
|
252 |
lossplot(lossfile_PC_train, lossfile_PC_val) |
|
|
253 |
|
|
|
254 |
ax = plt.subplot(336) |
|
|
255 |
ax.set_title('ECG recon loss') |
|
|
256 |
lossplot(lossfile_ecg_train, lossfile_ecg_val) |
|
|
257 |
|
|
|
258 |
ax = plt.subplot(337) |
|
|
259 |
ax.set_title('MI size loss') |
|
|
260 |
lossplot(lossfile_size_train, lossfile_size_val) |
|
|
261 |
|
|
|
262 |
ax = plt.subplot(338) |
|
|
263 |
ax.set_title('MI RVpenalty loss') |
|
|
264 |
lossplot(lossfile_RVp_train, lossfile_RVp_val) |
|
|
265 |
|
|
|
266 |
# set the spacing between subplots |
|
|
267 |
plt.subplots_adjust(left=0.1, |
|
|
268 |
bottom=0.1, |
|
|
269 |
right=0.9, |
|
|
270 |
top=0.9, |
|
|
271 |
wspace=0.4, |
|
|
272 |
hspace=0.4) |
|
|
273 |
|
|
|
274 |
plt.savefig("img.png") |
|
|
275 |
# plt.show() |
|
|
276 |
|
|
|
277 |
def lossplot_classify(lossfile_train, lossfile_val, lossfile_mesh_train, lossfile_mesh_val, lossfile_KL_train, lossfile_KL_val, lossfile_ecg_train, lossfile_ecg_val): |
|
|
278 |
ax = plt.subplot(221) |
|
|
279 |
ax.set_title('total loss') |
|
|
280 |
lossplot(lossfile_train, lossfile_val) |
|
|
281 |
|
|
|
282 |
ax = plt.subplot(222) |
|
|
283 |
ax.set_title('MI classfication loss') |
|
|
284 |
lossplot(lossfile_mesh_train, lossfile_mesh_val) |
|
|
285 |
|
|
|
286 |
ax = plt.subplot(223) |
|
|
287 |
ax.set_title('KL loss') |
|
|
288 |
lossplot(lossfile_KL_train, lossfile_KL_val) |
|
|
289 |
|
|
|
290 |
|
|
|
291 |
ax = plt.subplot(224) |
|
|
292 |
ax.set_title('ECG recon loss') |
|
|
293 |
lossplot(lossfile_ecg_train, lossfile_ecg_val) |
|
|
294 |
|
|
|
295 |
|
|
|
296 |
# set the spacing between subplots |
|
|
297 |
plt.subplots_adjust(left=0.1, |
|
|
298 |
bottom=0.1, |
|
|
299 |
right=0.9, |
|
|
300 |
top=0.9, |
|
|
301 |
wspace=0.4, |
|
|
302 |
hspace=0.4) |
|
|
303 |
|
|
|
304 |
plt.savefig("img_classify.png") |
|
|
305 |
# plt.show() |
|
|
306 |
|
|
|
307 |
def lossplot(lossfile1, lossfile2): |
|
|
308 |
loss = np.loadtxt(lossfile1) |
|
|
309 |
x = range(0, loss.size) |
|
|
310 |
y = loss |
|
|
311 |
plt.plot(x, y, '#FF7F61') # , label='train' |
|
|
312 |
plt.legend(frameon=False) |
|
|
313 |
|
|
|
314 |
loss = np.loadtxt(lossfile2) |
|
|
315 |
x = range(0, loss.size) |
|
|
316 |
y = loss |
|
|
317 |
plt.plot(x, y, '#2C4068') # , label='val' |
|
|
318 |
plt.legend(frameon=False) |
|
|
319 |
# plt.show() |
|
|
320 |
# plt.savefig("img.png") |
|
|
321 |
|
|
|
322 |
def ECG_visual_two(prop_data, target_ecg): |
|
|
323 |
prop_data[target_ecg[np.newaxis, ...] == 0.0], target_ecg[target_ecg == 0.0] = np.nan, np.nan |
|
|
324 |
|
|
|
325 |
leadNames = ['I', 'II', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] |
|
|
326 |
|
|
|
327 |
fig, axs = plt.subplots(2, 8, constrained_layout=True, figsize=(40, 10)) |
|
|
328 |
for i in range(8): |
|
|
329 |
leadName = leadNames[i] |
|
|
330 |
axs[0, i].plot(prop_data[0, i, :], color=[223/256,176/256,160/256], label='pred', linewidth=4) |
|
|
331 |
for j in range(1, prop_data.shape[0]): |
|
|
332 |
axs[0, i].plot(prop_data[j, i, :], color=[223/256,176/256,160/256], linewidth=4) |
|
|
333 |
axs[0, i].plot(target_ecg[i, :], color=[154/256,181/256,174/256], label='true', linewidth=4) |
|
|
334 |
axs[0, i].set_title('Lead ' + leadName, fontsize=20) |
|
|
335 |
axs[0, i].set_axis_off() |
|
|
336 |
axs[1, i].set_axis_off() |
|
|
337 |
axs[0, i].legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=20) |
|
|
338 |
fig.savefig("ECG_visual.pdf") |
|
|
339 |
# plt.show() |
|
|
340 |
plt.close(fig) |
|
|
341 |
|
|
|
342 |
if __name__ == '__main__': |
|
|
343 |
# input_data_dir = 'C:/Users/lilei/OneDrive - Nexus365/2021_Oxford/Oxford Research/BivenMesh_Script/dataset/gt/' |
|
|
344 |
# pc = input_data_dir + 'dense_RV_endo_output_labeled_ES_pc_6003744.ply' |
|
|
345 |
# pc_volume = calculate_pointcloudvolume(pc) |
|
|
346 |
# F_visual_CV() |
|
|
347 |
|
|
|
348 |
log_dir = 'E:/2022_ECG_inference/Cardiac_Personalisation/log' |
|
|
349 |
lossfile_train = log_dir + "/training_loss.txt" |
|
|
350 |
lossfile_val = log_dir + "/val_loss.txt" |
|
|
351 |
lossfile_geometry_train = log_dir + "/training_calculate_inference_loss.txt" |
|
|
352 |
lossfile_geometry_val = log_dir + "/val_calculate_inference_loss.txt" |
|
|
353 |
lossfile_compactness_train = log_dir + "/training_compactness_loss.txt" |
|
|
354 |
lossfile_compactness_val = log_dir + "/val_compactness_loss.txt" |
|
|
355 |
lossfile_KL_train = log_dir + "/training_KL_loss.txt" |
|
|
356 |
lossfile_KL_val = log_dir + "/val_KL_loss.txt" |
|
|
357 |
lossfile_PC_train = log_dir + "/training_PC_loss.txt" |
|
|
358 |
lossfile_PC_val = log_dir + "/val_PC_loss.txt" |
|
|
359 |
lossfile_ecg_train = log_dir + "/training_ecg_loss.txt" |
|
|
360 |
lossfile_ecg_val = log_dir + "/val_ecg_loss.txt" |
|
|
361 |
lossfile_RVp_train = log_dir + "/training_RVp_loss.txt" |
|
|
362 |
lossfile_RVp_val = log_dir + "/val_RVp_loss.txt" |
|
|
363 |
lossfile_size_train = log_dir + "/training_MIsize_loss.txt" |
|
|
364 |
lossfile_size_val = log_dir + "/val_MIsize_loss.txt" |
|
|
365 |
|
|
|
366 |
lossplot_detailed(lossfile_train, lossfile_val, lossfile_geometry_train, lossfile_geometry_val, lossfile_KL_train, lossfile_KL_val, lossfile_compactness_train, lossfile_compactness_val, lossfile_PC_train, lossfile_PC_val, lossfile_ecg_train, lossfile_ecg_val, lossfile_RVp_train, lossfile_RVp_val, lossfile_size_train, lossfile_size_val) |