|
a |
|
b/loss.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import numpy as np |
|
|
4 |
import torch.nn.functional as F |
|
|
5 |
from utils import visualize_PC_with_label |
|
|
6 |
import torch_geometric.transforms as T |
|
|
7 |
from sklearn.metrics import roc_auc_score |
|
|
8 |
import matplotlib.pyplot as plt |
|
|
9 |
from scipy.spatial import KDTree |
|
|
10 |
from utils import visualize_PC_with_label |
|
|
11 |
# from distance.chamfer_distance import ChamferDistanceFunction |
|
|
12 |
# from distance.emd_module import emdFunction |
|
|
13 |
|
|
|
14 |
def dtw_loss(ecg1, ecg2): # to do: plot the curve of x-y axis. |
|
|
15 |
""" |
|
|
16 |
计算两个ECG序列之间的Dynamic Time Warping(DTW)损失。 |
|
|
17 |
|
|
|
18 |
参数: |
|
|
19 |
- ecg1: 第一个ECG序列,形状为 (batch_size, seq_len1, num_features) |
|
|
20 |
- ecg2: 第二个ECG序列,形状为 (batch_size, seq_len2, num_features) |
|
|
21 |
|
|
|
22 |
返回: |
|
|
23 |
- dtw_loss: DTW损失,标量张量 |
|
|
24 |
""" |
|
|
25 |
batch_size, seq_len1, num_features = ecg1.size() |
|
|
26 |
_, seq_len2, _ = ecg2.size() |
|
|
27 |
|
|
|
28 |
# 计算两个ECG序列之间的距离矩阵 |
|
|
29 |
distance_matrix = torch.cdist(ecg1, ecg2) # 形状为 (batch_size, seq_len1, seq_len2) |
|
|
30 |
|
|
|
31 |
# 初始化动态规划表格 |
|
|
32 |
torch.autograd.set_detect_anomaly(True) |
|
|
33 |
dp = torch.zeros((batch_size, seq_len1, seq_len2)).to(ecg1.device) |
|
|
34 |
|
|
|
35 |
# 填充动态规划表格 |
|
|
36 |
dp[:, 0, 0] = distance_matrix[:, 0, 0] |
|
|
37 |
for i in range(1, seq_len1): |
|
|
38 |
dp[:, i, 0] = distance_matrix[:, i, 0] + dp[:, i-1, 0].clone() |
|
|
39 |
for j in range(1, seq_len2): |
|
|
40 |
dp[:, 0, j] = distance_matrix[:, 0, j] + dp[:, 0, j-1].clone() |
|
|
41 |
for i in range(1, seq_len1): |
|
|
42 |
for j in range(1, seq_len2): |
|
|
43 |
dp[:, i, j] = distance_matrix[:, i, j] + torch.min(torch.stack([ |
|
|
44 |
dp[:, i-1, j].clone(), |
|
|
45 |
dp[:, i, j-1].clone(), |
|
|
46 |
dp[:, i-1, j-1].clone() |
|
|
47 |
], dim=1), dim=1).values |
|
|
48 |
|
|
|
49 |
dtw_loss = torch.mean(dp[:, seq_len1-1, seq_len2-1] / (seq_len1 + seq_len2)) |
|
|
50 |
|
|
|
51 |
return dtw_loss |
|
|
52 |
|
|
|
53 |
def calculate_classify_loss(y_MI, gt_MI_label, mu, log_var): |
|
|
54 |
|
|
|
55 |
loss_func_CE = nn.CrossEntropyLoss() # weight=PC_weight |
|
|
56 |
loss_CE = loss_func_CE(y_MI, gt_MI_label) |
|
|
57 |
|
|
|
58 |
KL_loss = -0.5 * torch.sum(1 + log_var - torch.square(mu) - torch.exp(log_var)) |
|
|
59 |
|
|
|
60 |
return loss_CE, KL_loss |
|
|
61 |
|
|
|
62 |
def calculate_ECG_reconstruction_loss(y_signal, signal_input): |
|
|
63 |
|
|
|
64 |
y_signal = y_signal.squeeze(1) |
|
|
65 |
|
|
|
66 |
loss_signal = torch.mean(torch.square(y_signal-signal_input)) |
|
|
67 |
|
|
|
68 |
return loss_signal |
|
|
69 |
|
|
|
70 |
def calculate_reconstruction_loss(y_coarse, y_detail, coarse_gt, dense_gt, y_signal, signal_input): |
|
|
71 |
dense_gt = dense_gt.permute(0, 2, 1) |
|
|
72 |
y_signal = y_signal.squeeze(1) |
|
|
73 |
loss_coarse = calculate_chamfer_distance(y_coarse[:, :, 0:3], coarse_gt[:, :, 0:3]) + calculate_chamfer_distance(y_coarse[:, :, 3:], coarse_gt[:, :, 3:7]) |
|
|
74 |
loss_fine = calculate_chamfer_distance(y_detail[:, :, 0:3], dense_gt[:, :, 0:3]) + calculate_chamfer_distance(y_coarse[:, :, 3:], coarse_gt[:, :, 3:7]) |
|
|
75 |
# loss_coarse_emd = calculate_emd(y_coarse[:, :, 0:3], coarse_gt[:, :, 0:3]) + calculate_emd(y_coarse[:, :, 3:], coarse_gt[:, :, 3:]) |
|
|
76 |
|
|
|
77 |
# Per-class chamfer losses as reconstruction loss |
|
|
78 |
# loss_coarse = per_class_PCdist(y_coarse, coarse_gt, dist_type='chamfer') + per_class_PCdist(y_coarse, coarse_gt, dist_type = 'EDM') |
|
|
79 |
# loss_fine = per_class_PCdist(y_detail, dense_gt, dist_type='chamfer') |
|
|
80 |
|
|
|
81 |
loss_signal = torch.mean(torch.square(y_signal-signal_input)) |
|
|
82 |
loss_DTW = dtw_loss(y_signal, signal_input) # dynamic time warping |
|
|
83 |
|
|
|
84 |
# ECG_dist = torch.sqrt(torch.sum((y_signal - signal_input) ** 2)) |
|
|
85 |
# PC_dist = torch.sqrt(torch.sum((y_coarse[:, :, 3:7] - coarse_gt[:, :, 3:7]) ** 2)) + torch.sqrt(torch.sum((y_detail[:, :, 3:7] - dense_gt[:, :, 3:7]) ** 2)) |
|
|
86 |
|
|
|
87 |
return loss_coarse + 5*loss_fine, loss_signal + loss_DTW #0.5*(loss_coarse + loss_fine), loss_signal + loss_DTW # |
|
|
88 |
|
|
|
89 |
def evaluate_AHA_localization(predicted_center_id, predicted_covered_ids, gt_center_id, gt_covered_ids, center_distance): |
|
|
90 |
# Center ID Comparison |
|
|
91 |
center_id_match = predicted_center_id == gt_center_id |
|
|
92 |
center_id_score = 1 if center_id_match else 0 |
|
|
93 |
|
|
|
94 |
# Covered ID Comparison |
|
|
95 |
common_ids = set(predicted_covered_ids.tolist()) & set(gt_covered_ids.tolist()) |
|
|
96 |
intersection = len(common_ids) |
|
|
97 |
union = len(set(predicted_covered_ids.tolist()).union(set(gt_covered_ids.tolist()))) |
|
|
98 |
iou_score = intersection / union if union != 0 else 0 |
|
|
99 |
|
|
|
100 |
# Weighting |
|
|
101 |
center_id_weight = 0.5 |
|
|
102 |
center_distance_weight = 0.3 |
|
|
103 |
covered_id_weight = 0.2 |
|
|
104 |
|
|
|
105 |
# Overall Evaluation Metric |
|
|
106 |
evaluation_metric = (center_id_weight * center_id_score) + (covered_id_weight * iou_score) + (center_distance_weight*(1-center_distance)) |
|
|
107 |
|
|
|
108 |
return evaluation_metric |
|
|
109 |
|
|
|
110 |
def evaluate_pointcloud(predictions, target, partial_input, n_classes=3): |
|
|
111 |
# To address the issue of class imbalance and obtain a more comprehensive evaluation of model performance, |
|
|
112 |
# you may consider using other metrics such as precision, recall (or sensitivity), F1-score, and area under the |
|
|
113 |
# receiver operating characteristic (ROC) curve. These metrics provide a more nuanced evaluation of model performance, |
|
|
114 |
# taking into account both true positive and false positive/negative rates for each class separately. |
|
|
115 |
|
|
|
116 |
PC_xyz = partial_input[:, 0:3, :].permute(0, 2, 1).squeeze(0) |
|
|
117 |
AHA_id = partial_input[:, 7, :].squeeze(0) |
|
|
118 |
|
|
|
119 |
targets = F.one_hot(target, n_classes).permute(0, 2, 1) |
|
|
120 |
|
|
|
121 |
"""Function to evaluate point cloud predictions with multiple classes""" |
|
|
122 |
assert predictions.shape == targets.shape, "Input shapes must be the same" |
|
|
123 |
assert predictions.shape[0] == 1, "Batch size must be 1" |
|
|
124 |
|
|
|
125 |
# Convert predictions and targets to boolean values based on threshold |
|
|
126 |
# predictions = torch.ge(predictions, threshold).bool() |
|
|
127 |
predictions = one_hot_argmax(predictions).bool() |
|
|
128 |
targets = targets.bool().squeeze(0) |
|
|
129 |
|
|
|
130 |
MI_size_pre = torch.sum(predictions, dim=1).tolist() |
|
|
131 |
MI_size_gd = torch.sum(targets, dim=1).tolist() |
|
|
132 |
|
|
|
133 |
y_MI_center = torch.mean(PC_xyz[predictions[1]], dim=0) |
|
|
134 |
gt_MI_center = torch.mean(PC_xyz[targets[1]], dim=0) |
|
|
135 |
|
|
|
136 |
# calculate and compare the covered AHA IDs and the centered AHA ID of prediction and ground truth |
|
|
137 |
kdtree = KDTree(PC_xyz.cpu().detach().numpy()) |
|
|
138 |
distance_pre, index_pre = kdtree.query(y_MI_center.cpu().detach().numpy()) |
|
|
139 |
distance_gd, index_gd = kdtree.query(gt_MI_center.cpu().detach().numpy()) # to do: check whether its AHA=0 |
|
|
140 |
max_distance = torch.max(torch.sqrt(torch.sum((PC_xyz[AHA_id!=0.0][:, None] - PC_xyz[AHA_id!=0.0]) ** 2, dim=2))) |
|
|
141 |
if index_pre == 4096: |
|
|
142 |
center_distance = 1 |
|
|
143 |
AHA_center_pre = 0 |
|
|
144 |
print('no valid nearest neighbor was found') |
|
|
145 |
else: |
|
|
146 |
center_distance = (torch.sqrt(torch.sum((PC_xyz[index_pre] - PC_xyz[index_gd]) ** 2))/max_distance).cpu().detach().numpy() |
|
|
147 |
AHA_center_pre = AHA_id[index_pre] |
|
|
148 |
AHA_center_gd = AHA_id[index_gd] |
|
|
149 |
AHA_list_pre, AHA_list_gd = torch.unique(AHA_id[predictions[1]]), torch.unique(AHA_id[targets[1]]) |
|
|
150 |
AHA_loc_score = evaluate_AHA_localization(AHA_center_pre, AHA_list_pre, AHA_center_gd, AHA_list_gd, center_distance) |
|
|
151 |
|
|
|
152 |
# Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) for each class |
|
|
153 |
tp = torch.sum(predictions & targets, dim=1).tolist() |
|
|
154 |
fp = torch.sum(predictions & ~targets, dim=1).tolist() |
|
|
155 |
fn = torch.sum(~predictions & targets, dim=1).tolist() |
|
|
156 |
tn = torch.sum(~predictions & ~targets, dim=1).tolist() |
|
|
157 |
|
|
|
158 |
# Calculate Accuracy, Precision, Recall (Sensitivity), Specificity, and F1-score for each class |
|
|
159 |
accuracy = sum(tp) / (sum(tp) + sum(fp) + sum(fn) + sum(tn)) |
|
|
160 |
precision = [tp[i] / (tp[i] + fp[i]) if (tp[i] + fp[i]) > 0 else 0.0 for i in range(n_classes)] |
|
|
161 |
recall = [tp[i] / (tp[i] + fn[i]) if (tp[i] + fn[i]) > 0 else 0.0 for i in range(n_classes)] |
|
|
162 |
specificity = [tn[i] / (fp[i] + tn[i]) if (fp[i] + tn[i]) > 0 else 0.0 for i in range(n_classes)] |
|
|
163 |
f1_score = [2 * (precision[i] * recall[i]) / (precision[i] + recall[i]) if (precision[i] + recall[i]) > 0 else 0.0 for i in range(n_classes)] |
|
|
164 |
roc_auc = [roc_auc_score(targets[i, :].detach().cpu().numpy(), predictions[i, :].detach().cpu().numpy()) for i in range(n_classes)] |
|
|
165 |
|
|
|
166 |
visualize_ROC = False |
|
|
167 |
if visualize_ROC: |
|
|
168 |
# Create a figure and axes |
|
|
169 |
fig, ax = plt.subplots() |
|
|
170 |
|
|
|
171 |
# Plot ROC curve for each class |
|
|
172 |
for i in range(len(roc_auc)): |
|
|
173 |
ax.plot([0, 1], [0, 1], 'k--') # Plot diagonal line |
|
|
174 |
ax.plot(1 - specificity[i], recall[i], label='Class {} (AUC = {:.2f})'.format(i, roc_auc[i])) |
|
|
175 |
|
|
|
176 |
# Set labels and title |
|
|
177 |
ax.set_xlabel('False Positive Rate (1 - Specificity)') |
|
|
178 |
ax.set_ylabel('True Positive Rate (Sensitivity / Recall)') |
|
|
179 |
ax.set_title('Receiver Operating Characteristic (ROC) Curve') |
|
|
180 |
|
|
|
181 |
# Set legend |
|
|
182 |
ax.legend() |
|
|
183 |
# Show the plot |
|
|
184 |
plt.show() |
|
|
185 |
|
|
|
186 |
# precision, recall (or sensitivity), F1-score, roc_auc |
|
|
187 |
return precision, recall, f1_score, roc_auc, MI_size_pre, MI_size_gd, center_distance, AHA_loc_score |
|
|
188 |
|
|
|
189 |
def calculate_chamfer_distance_old(x, y): |
|
|
190 |
""" |
|
|
191 |
Computes the Chamfer distance between two point clouds. |
|
|
192 |
|
|
|
193 |
Args: |
|
|
194 |
x: Tensor of shape (n_batch, n_point, n_label). |
|
|
195 |
y: Tensor of shape (n_batch, n_point, n_label). |
|
|
196 |
|
|
|
197 |
Returns: |
|
|
198 |
chamfer_distance: Tensor of shape (1,) |
|
|
199 |
""" |
|
|
200 |
x_expand = x.unsqueeze(2) # Shape: (n_batch, n_point, 1, n_label) |
|
|
201 |
y_expand = y.unsqueeze(1) # Shape: (n_batch, 1, n_point, n_label) |
|
|
202 |
diff = x_expand - y_expand |
|
|
203 |
dist = torch.sum(diff**2, dim=-1) # Shape: (n_batch, n_point, n_point) |
|
|
204 |
dist_x2y = torch.min(dist, dim=2).values # Shape: (n_batch, n_point) |
|
|
205 |
dist_y2x = torch.min(dist, dim=1).values # Shape: (n_batch, n_point) |
|
|
206 |
chamfer_distance = torch.mean(dist_x2y, dim=1) + torch.mean(dist_y2x, dim=1) # Shape: (n_batch,) |
|
|
207 |
return torch.mean(chamfer_distance) |
|
|
208 |
|
|
|
209 |
def calculate_chamfer_distance(x, y): |
|
|
210 |
dist_x_y = torch.cdist(x, y) |
|
|
211 |
min_dist_x_y, _ = torch.min(dist_x_y, dim=1) |
|
|
212 |
min_dist_y_x, _ = torch.min(dist_x_y, dim=0) |
|
|
213 |
chamfer_distance = torch.mean(min_dist_x_y) + torch.mean(min_dist_y_x) |
|
|
214 |
|
|
|
215 |
return torch.mean(chamfer_distance) |
|
|
216 |
|
|
|
217 |
def per_class_PCdist(pcd1, pcd2, dist_type='EDM', n_class=3): |
|
|
218 |
|
|
|
219 |
# Extract points from prediction and ground truth for each class |
|
|
220 |
LV_endo_pcd1, LV_epi_pcd1, RV_endo_pcd1 = torch.split(pcd1, n_class, dim=2) |
|
|
221 |
LV_endo_pcd2, LV_epi_pcd2, RV_endo_pcd2 = torch.split(pcd2, n_class, dim=2) |
|
|
222 |
|
|
|
223 |
# Note that ChamferDistance has O(n log n) complexity, while EMD has O(n2), which is too expensive to compute during training |
|
|
224 |
if dist_type == 'EDM': |
|
|
225 |
PCdist = calculate_emd |
|
|
226 |
else: |
|
|
227 |
PCdist = calculate_chamfer_distance |
|
|
228 |
LV_endo_loss = PCdist(LV_endo_pcd1, LV_endo_pcd2) |
|
|
229 |
LV_epi_loss = PCdist(LV_epi_pcd1, LV_epi_pcd2) |
|
|
230 |
RV_endo_loss = PCdist(RV_endo_pcd1, RV_endo_pcd2) |
|
|
231 |
combined_loss = (LV_endo_loss + LV_epi_loss + RV_endo_loss) / n_class |
|
|
232 |
|
|
|
233 |
return combined_loss |
|
|
234 |
|
|
|
235 |
def calculate_emd(x1, x2, eps=1e-8, norm=1): |
|
|
236 |
""" |
|
|
237 |
Calculates the Earth Mover's Distance (EMD) between two batches of point clouds. |
|
|
238 |
|
|
|
239 |
Args: |
|
|
240 |
- x1: A tensor of shape (batch_size, num_points, num_dims) representing the first batch of point clouds. |
|
|
241 |
- x2: A tensor of shape (batch_size, num_points, num_dims) representing the second batch of point clouds. |
|
|
242 |
- eps: A small constant added to the distance matrix to prevent numerical instability. |
|
|
243 |
- norm: The order of the norm used to calculate the distance matrix (default is L1 norm). |
|
|
244 |
|
|
|
245 |
Returns: |
|
|
246 |
- A tensor of shape (batch_size,) representing the EMD between each pair of point clouds in the batches. |
|
|
247 |
""" |
|
|
248 |
batch_size, num_points, num_dims = x1.size() |
|
|
249 |
|
|
|
250 |
# Calculate distance matrix between points in each batch |
|
|
251 |
dist_mat = torch.cdist(x1, x2, p=norm) |
|
|
252 |
|
|
|
253 |
# Initialize flow matrix with zeros |
|
|
254 |
flow = torch.zeros(batch_size, num_points, num_points, requires_grad=True).to(x1.get_device()) |
|
|
255 |
|
|
|
256 |
# Compute EMD using PyTorch's Sinkhorn algorithm |
|
|
257 |
for i in range(batch_size): |
|
|
258 |
flow[i] = F.sinkhorn_knopp(dist_mat[i], eps=eps) |
|
|
259 |
|
|
|
260 |
# Calculate total EMD for each pair of point clouds in the batches |
|
|
261 |
emd = torch.sum(flow * dist_mat, dim=(1, 2)) |
|
|
262 |
|
|
|
263 |
return emd |
|
|
264 |
|
|
|
265 |
def calculate_inference_loss(y_MI, gt_MI_label, mu, log_var, partial_input): |
|
|
266 |
PC_xyz = partial_input[:, 0:3, :] |
|
|
267 |
PC_tv = torch.where((partial_input[:, 7, :] == 0.0) & (partial_input[:, 6, :] > 0), 1, 0) |
|
|
268 |
|
|
|
269 |
# x_input = partial_input[0].cpu().detach().numpy() |
|
|
270 |
# x_input_lab = PC_tv[0].cpu().detach().numpy().astype(int) |
|
|
271 |
# visualize_PC_with_label(x_input[0:3, :].transpose(), x_input_lab, filename='RNmap_pre.pdf') |
|
|
272 |
|
|
|
273 |
class_weights = torch.FloatTensor([1, 10, 10]).to(y_MI.get_device()) |
|
|
274 |
loss_func_CE = nn.CrossEntropyLoss() # weight=class_weights |
|
|
275 |
|
|
|
276 |
y_MI_label = torch.argmax(y_MI, dim=1) |
|
|
277 |
loss_compactness, loss_MI_size, loss_MI_RVpenalty = calculate_MI_distribution_loss(y_MI_label, gt_MI_label, PC_xyz.permute(0, 2, 1), PC_tv) |
|
|
278 |
loss_CE = loss_func_CE(y_MI, gt_MI_label) |
|
|
279 |
Dice = calculate_Dice(y_MI, gt_MI_label, num_classes=3) |
|
|
280 |
loss_Dice = torch.sum((1.0-Dice) * class_weights) |
|
|
281 |
|
|
|
282 |
KL_loss = -0.5 * torch.sum(1 + log_var - torch.square(mu) - torch.exp(log_var)) |
|
|
283 |
|
|
|
284 |
return loss_CE + 0.1*loss_Dice, loss_compactness, loss_MI_RVpenalty, loss_MI_size, KL_loss |
|
|
285 |
|
|
|
286 |
def calculate_MI_distribution_loss(y_MI_label, gt_MI_label, PC_xyz, PC_tv): |
|
|
287 |
""" |
|
|
288 |
计算点云数据的compactness |
|
|
289 |
|
|
|
290 |
Args: |
|
|
291 |
point_cloud: 点云数据,shape为(B, N, 3), only work when B=1 |
|
|
292 |
|
|
|
293 |
Returns: |
|
|
294 |
compactness: 点云数据的compactness |
|
|
295 |
""" |
|
|
296 |
y_MI_label_mask = torch.where((y_MI_label % 3) == 0, 0, 1).bool() |
|
|
297 |
gt_MI_label_mask = torch.where((gt_MI_label % 3) == 0, 0, 1).bool() |
|
|
298 |
|
|
|
299 |
compactness_sum = torch.tensor(0.0, requires_grad=True).to(y_MI_label.get_device()) |
|
|
300 |
MI_size_div_sum = torch.tensor(0.0, requires_grad=True).to(y_MI_label.get_device()) |
|
|
301 |
MI_RVpenalty_sum = torch.tensor(0.0, requires_grad=True).to(y_MI_label.get_device()) |
|
|
302 |
|
|
|
303 |
num_iter = 0 |
|
|
304 |
for i_batch in range(PC_xyz.shape[0]): |
|
|
305 |
y_PC_xyz_masked = PC_xyz[i_batch][y_MI_label_mask[i_batch]] |
|
|
306 |
gt_PC_xyz_masked = PC_xyz[i_batch][gt_MI_label_mask[i_batch]] |
|
|
307 |
|
|
|
308 |
if gt_PC_xyz_masked.shape[0]==0 or y_PC_xyz_masked.shape[0]==0: |
|
|
309 |
continue |
|
|
310 |
|
|
|
311 |
MI_size_div = abs(gt_PC_xyz_masked.size(0) - y_PC_xyz_masked.size(0))/gt_PC_xyz_masked.size(0) |
|
|
312 |
MI_size_div_sum = MI_size_div_sum.add(torch.tensor(MI_size_div, dtype=torch.float32).to(y_MI_label.get_device())) |
|
|
313 |
|
|
|
314 |
MI_RVpenalty = torch.sum(PC_tv[i_batch]*y_MI_label[i_batch])/y_PC_xyz_masked.shape[0] |
|
|
315 |
MI_RVpenalty_sum = MI_RVpenalty_sum.add(MI_RVpenalty) |
|
|
316 |
|
|
|
317 |
visual_check = False |
|
|
318 |
if visual_check: |
|
|
319 |
y_predict = y_MI_label_mask[i_batch].cpu().detach().numpy() |
|
|
320 |
x_input = PC_xyz[i_batch].cpu().detach().numpy() |
|
|
321 |
visualize_PC_with_label(x_input[y_predict], y_predict[y_predict], filename='RNmap_gd.jpg') |
|
|
322 |
visualize_PC_with_label(x_input, y_predict, filename='RNmap_pre.jpg') |
|
|
323 |
|
|
|
324 |
y_MI_center = torch.mean(y_PC_xyz_masked, dim=0).unsqueeze(0) |
|
|
325 |
gt_MI_center = torch.mean(gt_PC_xyz_masked, dim=0).unsqueeze(0) |
|
|
326 |
y_dist_sq = torch.sum((y_PC_xyz_masked - y_MI_center) ** 2, dim=1) |
|
|
327 |
gt_dist_sq = torch.sum((y_PC_xyz_masked - gt_MI_center) ** 2, dim=1) |
|
|
328 |
|
|
|
329 |
# max_distance = torch.max(torch.sqrt(torch.sum((PC_xyz[AHA_id>0][:, None] - PC_xyz[AHA_id>0]) ** 2, dim=2))) |
|
|
330 |
max_distance = torch.max(torch.sqrt(torch.sum((gt_PC_xyz_masked - gt_MI_center) ** 2, dim=1))) |
|
|
331 |
y_compactness = torch.mean(torch.sqrt(y_dist_sq))/max_distance |
|
|
332 |
gt_compactness = torch.mean(torch.sqrt(gt_dist_sq))/max_distance |
|
|
333 |
|
|
|
334 |
compactness_sum = compactness_sum.add(y_compactness + gt_compactness) |
|
|
335 |
num_iter += (num_iter + 1) |
|
|
336 |
if num_iter != 0: |
|
|
337 |
return compactness_sum/num_iter, MI_size_div_sum/num_iter, MI_RVpenalty_sum/num_iter |
|
|
338 |
else: |
|
|
339 |
return compactness_sum, MI_size_div_sum, MI_RVpenalty_sum |
|
|
340 |
|
|
|
341 |
def calculate_Dice(inputs, target, num_classes): |
|
|
342 |
|
|
|
343 |
target_onehot = F.one_hot(target, num_classes).permute(0, 2, 1) |
|
|
344 |
|
|
|
345 |
eps = 1e-6 |
|
|
346 |
intersection = torch.sum(inputs * target_onehot, dim=[0, 2]) |
|
|
347 |
cardinality = torch.sum(inputs + target_onehot, dim=[0, 2]) |
|
|
348 |
Dice = (2.0 * intersection + eps) / (cardinality + eps) |
|
|
349 |
|
|
|
350 |
return Dice |
|
|
351 |
|
|
|
352 |
def one_hot_argmax(input_tensor): |
|
|
353 |
""" |
|
|
354 |
This function takes a PyTorch tensor as input and returns a tuple of two tensors: |
|
|
355 |
- One-hot tensor: a binary tensor with the same shape as the input tensor, where the value 1 |
|
|
356 |
is placed in the position of the maximum element of the input tensor and 0 elsewhere. |
|
|
357 |
- Argmax tensor: a tensor with the same shape as the input tensor, where the value is the index |
|
|
358 |
of the maximum element of the input tensor. |
|
|
359 |
""" |
|
|
360 |
input_tensor = input_tensor.permute(0, 2, 1).squeeze(0) |
|
|
361 |
max_indices = torch.argmax(input_tensor, dim=1) |
|
|
362 |
one_hot_tensor = torch.zeros_like(input_tensor) |
|
|
363 |
one_hot_tensor.scatter_(1, max_indices.view(-1, 1), 1) |
|
|
364 |
|
|
|
365 |
return one_hot_tensor.permute(1, 0) |
|
|
366 |
|
|
|
367 |
if __name__ == '__main__': |
|
|
368 |
|
|
|
369 |
pcs1 = torch.rand(10, 1024, 4) |
|
|
370 |
pcs2 = torch.rand(10, 1024, 4) |
|
|
371 |
|
|
|
372 |
|
|
|
373 |
|