Diff of /loss.py [000000] .. [390c2f]

Switch to unified view

a b/loss.py
1
import torch
2
import torch.nn as nn
3
import numpy as np
4
import torch.nn.functional as F
5
from utils import visualize_PC_with_label
6
import torch_geometric.transforms as T
7
from sklearn.metrics import roc_auc_score
8
import matplotlib.pyplot as plt
9
from scipy.spatial import KDTree
10
from utils import visualize_PC_with_label
11
# from distance.chamfer_distance import ChamferDistanceFunction
12
# from distance.emd_module import emdFunction
13
14
def dtw_loss(ecg1, ecg2): # to do: plot the curve of x-y axis.
15
    """
16
    计算两个ECG序列之间的Dynamic Time Warping(DTW)损失。
17
18
    参数:
19
    - ecg1: 第一个ECG序列,形状为 (batch_size, seq_len1, num_features)
20
    - ecg2: 第二个ECG序列,形状为 (batch_size, seq_len2, num_features)
21
22
    返回:
23
    - dtw_loss: DTW损失,标量张量
24
    """
25
    batch_size, seq_len1, num_features = ecg1.size()
26
    _, seq_len2, _ = ecg2.size()
27
28
    # 计算两个ECG序列之间的距离矩阵
29
    distance_matrix = torch.cdist(ecg1, ecg2)  # 形状为 (batch_size, seq_len1, seq_len2)
30
31
    # 初始化动态规划表格
32
    torch.autograd.set_detect_anomaly(True)
33
    dp = torch.zeros((batch_size, seq_len1, seq_len2)).to(ecg1.device)
34
35
    # 填充动态规划表格
36
    dp[:, 0, 0] = distance_matrix[:, 0, 0]
37
    for i in range(1, seq_len1):
38
        dp[:, i, 0] = distance_matrix[:, i, 0] + dp[:, i-1, 0].clone()
39
    for j in range(1, seq_len2):
40
        dp[:, 0, j] = distance_matrix[:, 0, j] + dp[:, 0, j-1].clone()
41
    for i in range(1, seq_len1):
42
        for j in range(1, seq_len2):
43
            dp[:, i, j] = distance_matrix[:, i, j] + torch.min(torch.stack([
44
                dp[:, i-1, j].clone(),
45
                dp[:, i, j-1].clone(),
46
                dp[:, i-1, j-1].clone()
47
            ], dim=1), dim=1).values
48
49
    dtw_loss = torch.mean(dp[:, seq_len1-1, seq_len2-1] / (seq_len1 + seq_len2))
50
51
    return dtw_loss
52
53
def calculate_classify_loss(y_MI, gt_MI_label, mu, log_var):
54
55
    loss_func_CE = nn.CrossEntropyLoss() # weight=PC_weight
56
    loss_CE = loss_func_CE(y_MI, gt_MI_label)
57
58
    KL_loss = -0.5 * torch.sum(1 + log_var - torch.square(mu) - torch.exp(log_var))
59
60
    return loss_CE, KL_loss
61
62
def calculate_ECG_reconstruction_loss(y_signal, signal_input):
63
64
    y_signal = y_signal.squeeze(1)
65
66
    loss_signal = torch.mean(torch.square(y_signal-signal_input))
67
    
68
    return loss_signal
69
70
def calculate_reconstruction_loss(y_coarse, y_detail, coarse_gt, dense_gt, y_signal, signal_input):
71
    dense_gt = dense_gt.permute(0, 2, 1)
72
    y_signal = y_signal.squeeze(1)
73
    loss_coarse = calculate_chamfer_distance(y_coarse[:, :, 0:3], coarse_gt[:, :, 0:3]) + calculate_chamfer_distance(y_coarse[:, :, 3:], coarse_gt[:, :, 3:7])
74
    loss_fine = calculate_chamfer_distance(y_detail[:, :, 0:3], dense_gt[:, :, 0:3]) + calculate_chamfer_distance(y_coarse[:, :, 3:], coarse_gt[:, :, 3:7])
75
    # loss_coarse_emd = calculate_emd(y_coarse[:, :, 0:3], coarse_gt[:, :, 0:3]) + calculate_emd(y_coarse[:, :, 3:], coarse_gt[:, :, 3:])
76
77
    # Per-class chamfer losses as reconstruction loss
78
    # loss_coarse = per_class_PCdist(y_coarse, coarse_gt, dist_type='chamfer') + per_class_PCdist(y_coarse, coarse_gt, dist_type = 'EDM')
79
    # loss_fine = per_class_PCdist(y_detail, dense_gt, dist_type='chamfer')
80
81
    loss_signal = torch.mean(torch.square(y_signal-signal_input)) 
82
    loss_DTW = dtw_loss(y_signal, signal_input) # dynamic time warping
83
84
    # ECG_dist = torch.sqrt(torch.sum((y_signal - signal_input) ** 2)) 
85
    # PC_dist = torch.sqrt(torch.sum((y_coarse[:, :, 3:7] - coarse_gt[:, :, 3:7]) ** 2)) + torch.sqrt(torch.sum((y_detail[:, :, 3:7] - dense_gt[:, :, 3:7]) ** 2))
86
87
    return loss_coarse + 5*loss_fine, loss_signal + loss_DTW #0.5*(loss_coarse + loss_fine), loss_signal + loss_DTW # 
88
89
def evaluate_AHA_localization(predicted_center_id, predicted_covered_ids, gt_center_id, gt_covered_ids, center_distance):
90
    # Center ID Comparison
91
    center_id_match = predicted_center_id == gt_center_id
92
    center_id_score = 1 if center_id_match else 0
93
94
    # Covered ID Comparison
95
    common_ids = set(predicted_covered_ids.tolist()) & set(gt_covered_ids.tolist())
96
    intersection = len(common_ids)
97
    union = len(set(predicted_covered_ids.tolist()).union(set(gt_covered_ids.tolist())))
98
    iou_score = intersection / union if union != 0 else 0
99
100
    # Weighting
101
    center_id_weight = 0.5
102
    center_distance_weight = 0.3
103
    covered_id_weight = 0.2
104
105
    # Overall Evaluation Metric
106
    evaluation_metric = (center_id_weight * center_id_score) + (covered_id_weight * iou_score) + (center_distance_weight*(1-center_distance))
107
108
    return evaluation_metric
109
110
def evaluate_pointcloud(predictions, target, partial_input, n_classes=3):
111
    # To address the issue of class imbalance and obtain a more comprehensive evaluation of model performance, 
112
    # you may consider using other metrics such as precision, recall (or sensitivity), F1-score, and area under the
113
    # receiver operating characteristic (ROC) curve. These metrics provide a more nuanced evaluation of model performance, 
114
    # taking into account both true positive and false positive/negative rates for each class separately.
115
116
    PC_xyz = partial_input[:, 0:3, :].permute(0, 2, 1).squeeze(0)
117
    AHA_id = partial_input[:, 7, :].squeeze(0)
118
    
119
    targets = F.one_hot(target, n_classes).permute(0, 2, 1)
120
121
    """Function to evaluate point cloud predictions with multiple classes"""
122
    assert predictions.shape == targets.shape, "Input shapes must be the same"
123
    assert predictions.shape[0] == 1, "Batch size must be 1"
124
125
    # Convert predictions and targets to boolean values based on threshold
126
    # predictions = torch.ge(predictions, threshold).bool()
127
    predictions = one_hot_argmax(predictions).bool()
128
    targets = targets.bool().squeeze(0)
129
130
    MI_size_pre = torch.sum(predictions, dim=1).tolist()
131
    MI_size_gd = torch.sum(targets, dim=1).tolist()
132
133
    y_MI_center = torch.mean(PC_xyz[predictions[1]], dim=0)
134
    gt_MI_center = torch.mean(PC_xyz[targets[1]], dim=0)
135
136
    # calculate and compare the covered AHA IDs and the centered AHA ID of prediction and ground truth
137
    kdtree = KDTree(PC_xyz.cpu().detach().numpy())
138
    distance_pre, index_pre = kdtree.query(y_MI_center.cpu().detach().numpy())
139
    distance_gd, index_gd = kdtree.query(gt_MI_center.cpu().detach().numpy()) # to do: check whether its AHA=0
140
    max_distance = torch.max(torch.sqrt(torch.sum((PC_xyz[AHA_id!=0.0][:, None] - PC_xyz[AHA_id!=0.0]) ** 2, dim=2)))
141
    if index_pre == 4096:
142
        center_distance = 1
143
        AHA_center_pre = 0
144
        print('no valid nearest neighbor was found')
145
    else:
146
        center_distance = (torch.sqrt(torch.sum((PC_xyz[index_pre] - PC_xyz[index_gd]) ** 2))/max_distance).cpu().detach().numpy()
147
        AHA_center_pre = AHA_id[index_pre] 
148
    AHA_center_gd = AHA_id[index_gd]
149
    AHA_list_pre, AHA_list_gd = torch.unique(AHA_id[predictions[1]]), torch.unique(AHA_id[targets[1]])
150
    AHA_loc_score = evaluate_AHA_localization(AHA_center_pre, AHA_list_pre, AHA_center_gd, AHA_list_gd, center_distance)
151
152
    # Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) for each class
153
    tp = torch.sum(predictions & targets, dim=1).tolist()
154
    fp = torch.sum(predictions & ~targets, dim=1).tolist()
155
    fn = torch.sum(~predictions & targets, dim=1).tolist()
156
    tn = torch.sum(~predictions & ~targets, dim=1).tolist()
157
158
    # Calculate Accuracy, Precision, Recall (Sensitivity), Specificity, and F1-score for each class
159
    accuracy = sum(tp) / (sum(tp) + sum(fp) + sum(fn) + sum(tn))
160
    precision = [tp[i] / (tp[i] + fp[i]) if (tp[i] + fp[i]) > 0 else 0.0 for i in range(n_classes)]
161
    recall = [tp[i] / (tp[i] + fn[i]) if (tp[i] + fn[i]) > 0 else 0.0 for i in range(n_classes)]
162
    specificity = [tn[i] / (fp[i] + tn[i]) if (fp[i] + tn[i]) > 0 else 0.0 for i in range(n_classes)]
163
    f1_score = [2 * (precision[i] * recall[i]) / (precision[i] + recall[i]) if (precision[i] + recall[i]) > 0 else 0.0 for i in range(n_classes)]
164
    roc_auc = [roc_auc_score(targets[i, :].detach().cpu().numpy(), predictions[i, :].detach().cpu().numpy()) for i in range(n_classes)]
165
166
    visualize_ROC = False
167
    if visualize_ROC:
168
        # Create a figure and axes
169
        fig, ax = plt.subplots()
170
171
        # Plot ROC curve for each class
172
        for i in range(len(roc_auc)):
173
            ax.plot([0, 1], [0, 1], 'k--')  # Plot diagonal line
174
            ax.plot(1 - specificity[i], recall[i], label='Class {} (AUC = {:.2f})'.format(i, roc_auc[i]))
175
176
        # Set labels and title
177
        ax.set_xlabel('False Positive Rate (1 - Specificity)')
178
        ax.set_ylabel('True Positive Rate (Sensitivity / Recall)')
179
        ax.set_title('Receiver Operating Characteristic (ROC) Curve')
180
181
        # Set legend
182
        ax.legend()
183
        # Show the plot
184
        plt.show()
185
186
    # precision, recall (or sensitivity), F1-score, roc_auc
187
    return precision, recall, f1_score, roc_auc, MI_size_pre, MI_size_gd, center_distance, AHA_loc_score
188
189
def calculate_chamfer_distance_old(x, y):
190
    """
191
    Computes the Chamfer distance between two point clouds.
192
193
    Args:
194
        x: Tensor of shape (n_batch, n_point, n_label).
195
        y: Tensor of shape (n_batch, n_point, n_label).
196
197
    Returns:
198
        chamfer_distance: Tensor of shape (1,)
199
    """
200
    x_expand = x.unsqueeze(2)  # Shape: (n_batch, n_point, 1, n_label)
201
    y_expand = y.unsqueeze(1)  # Shape: (n_batch, 1, n_point, n_label)
202
    diff = x_expand - y_expand
203
    dist = torch.sum(diff**2, dim=-1)  # Shape: (n_batch, n_point, n_point)
204
    dist_x2y = torch.min(dist, dim=2).values  # Shape: (n_batch, n_point)
205
    dist_y2x = torch.min(dist, dim=1).values  # Shape: (n_batch, n_point)
206
    chamfer_distance = torch.mean(dist_x2y, dim=1) + torch.mean(dist_y2x, dim=1)  # Shape: (n_batch,)
207
    return torch.mean(chamfer_distance)
208
209
def calculate_chamfer_distance(x, y):
210
    dist_x_y = torch.cdist(x, y)
211
    min_dist_x_y, _ = torch.min(dist_x_y, dim=1)
212
    min_dist_y_x, _ = torch.min(dist_x_y, dim=0)
213
    chamfer_distance = torch.mean(min_dist_x_y) + torch.mean(min_dist_y_x)
214
215
    return torch.mean(chamfer_distance)
216
217
def per_class_PCdist(pcd1, pcd2, dist_type='EDM', n_class=3):
218
219
    # Extract points from prediction and ground truth for each class
220
    LV_endo_pcd1, LV_epi_pcd1, RV_endo_pcd1 = torch.split(pcd1, n_class, dim=2)
221
    LV_endo_pcd2, LV_epi_pcd2, RV_endo_pcd2 = torch.split(pcd2, n_class, dim=2)
222
223
    # Note that ChamferDistance has O(n log n) complexity, while EMD has O(n2), which is too expensive to compute during training
224
    if dist_type == 'EDM':
225
        PCdist = calculate_emd
226
    else:
227
        PCdist = calculate_chamfer_distance
228
    LV_endo_loss = PCdist(LV_endo_pcd1, LV_endo_pcd2)
229
    LV_epi_loss = PCdist(LV_epi_pcd1, LV_epi_pcd2)
230
    RV_endo_loss = PCdist(RV_endo_pcd1, RV_endo_pcd2)
231
    combined_loss = (LV_endo_loss + LV_epi_loss + RV_endo_loss) / n_class
232
233
    return combined_loss
234
235
def calculate_emd(x1, x2, eps=1e-8, norm=1):
236
    """
237
    Calculates the Earth Mover's Distance (EMD) between two batches of point clouds.
238
239
    Args:
240
    - x1: A tensor of shape (batch_size, num_points, num_dims) representing the first batch of point clouds.
241
    - x2: A tensor of shape (batch_size, num_points, num_dims) representing the second batch of point clouds.
242
    - eps: A small constant added to the distance matrix to prevent numerical instability.
243
    - norm: The order of the norm used to calculate the distance matrix (default is L1 norm).
244
245
    Returns:
246
    - A tensor of shape (batch_size,) representing the EMD between each pair of point clouds in the batches.
247
    """
248
    batch_size, num_points, num_dims = x1.size()
249
250
    # Calculate distance matrix between points in each batch
251
    dist_mat = torch.cdist(x1, x2, p=norm)
252
253
    # Initialize flow matrix with zeros
254
    flow = torch.zeros(batch_size, num_points, num_points, requires_grad=True).to(x1.get_device())
255
256
    # Compute EMD using PyTorch's Sinkhorn algorithm
257
    for i in range(batch_size):
258
        flow[i] = F.sinkhorn_knopp(dist_mat[i], eps=eps)
259
260
    # Calculate total EMD for each pair of point clouds in the batches
261
    emd = torch.sum(flow * dist_mat, dim=(1, 2))
262
263
    return emd
264
265
def calculate_inference_loss(y_MI, gt_MI_label, mu, log_var, partial_input):
266
    PC_xyz = partial_input[:, 0:3, :]
267
    PC_tv =  torch.where((partial_input[:, 7, :] == 0.0) & (partial_input[:, 6, :] > 0), 1, 0)
268
269
    # x_input = partial_input[0].cpu().detach().numpy()
270
    # x_input_lab = PC_tv[0].cpu().detach().numpy().astype(int)
271
    # visualize_PC_with_label(x_input[0:3, :].transpose(), x_input_lab, filename='RNmap_pre.pdf')
272
273
    class_weights = torch.FloatTensor([1, 10, 10]).to(y_MI.get_device())
274
    loss_func_CE = nn.CrossEntropyLoss() # weight=class_weights
275
276
    y_MI_label = torch.argmax(y_MI, dim=1)
277
    loss_compactness, loss_MI_size, loss_MI_RVpenalty = calculate_MI_distribution_loss(y_MI_label, gt_MI_label, PC_xyz.permute(0, 2, 1), PC_tv)
278
    loss_CE = loss_func_CE(y_MI, gt_MI_label)
279
    Dice = calculate_Dice(y_MI, gt_MI_label, num_classes=3)
280
    loss_Dice = torch.sum((1.0-Dice) * class_weights)
281
282
    KL_loss = -0.5 * torch.sum(1 + log_var - torch.square(mu) - torch.exp(log_var))
283
284
    return loss_CE + 0.1*loss_Dice, loss_compactness, loss_MI_RVpenalty, loss_MI_size, KL_loss
285
286
def calculate_MI_distribution_loss(y_MI_label, gt_MI_label, PC_xyz, PC_tv):
287
    """
288
    计算点云数据的compactness
289
    
290
    Args:
291
        point_cloud: 点云数据,shape为(B, N, 3), only work when B=1
292
    
293
    Returns:
294
        compactness: 点云数据的compactness
295
    """
296
    y_MI_label_mask = torch.where((y_MI_label % 3) == 0, 0, 1).bool()
297
    gt_MI_label_mask = torch.where((gt_MI_label % 3) == 0, 0, 1).bool()
298
    
299
    compactness_sum = torch.tensor(0.0, requires_grad=True).to(y_MI_label.get_device())
300
    MI_size_div_sum = torch.tensor(0.0, requires_grad=True).to(y_MI_label.get_device())
301
    MI_RVpenalty_sum = torch.tensor(0.0, requires_grad=True).to(y_MI_label.get_device())
302
303
    num_iter = 0
304
    for i_batch in range(PC_xyz.shape[0]):
305
        y_PC_xyz_masked = PC_xyz[i_batch][y_MI_label_mask[i_batch]]
306
        gt_PC_xyz_masked = PC_xyz[i_batch][gt_MI_label_mask[i_batch]]
307
        
308
        if gt_PC_xyz_masked.shape[0]==0 or y_PC_xyz_masked.shape[0]==0:
309
            continue
310
        
311
        MI_size_div = abs(gt_PC_xyz_masked.size(0) - y_PC_xyz_masked.size(0))/gt_PC_xyz_masked.size(0)
312
        MI_size_div_sum = MI_size_div_sum.add(torch.tensor(MI_size_div, dtype=torch.float32).to(y_MI_label.get_device()))
313
314
        MI_RVpenalty = torch.sum(PC_tv[i_batch]*y_MI_label[i_batch])/y_PC_xyz_masked.shape[0]
315
        MI_RVpenalty_sum = MI_RVpenalty_sum.add(MI_RVpenalty)
316
317
        visual_check = False
318
        if visual_check:
319
            y_predict = y_MI_label_mask[i_batch].cpu().detach().numpy()
320
            x_input = PC_xyz[i_batch].cpu().detach().numpy()
321
            visualize_PC_with_label(x_input[y_predict], y_predict[y_predict], filename='RNmap_gd.jpg')
322
            visualize_PC_with_label(x_input, y_predict, filename='RNmap_pre.jpg')
323
324
        y_MI_center = torch.mean(y_PC_xyz_masked, dim=0).unsqueeze(0)
325
        gt_MI_center = torch.mean(gt_PC_xyz_masked, dim=0).unsqueeze(0)
326
        y_dist_sq = torch.sum((y_PC_xyz_masked - y_MI_center) ** 2, dim=1)
327
        gt_dist_sq = torch.sum((y_PC_xyz_masked - gt_MI_center) ** 2, dim=1)
328
329
        # max_distance = torch.max(torch.sqrt(torch.sum((PC_xyz[AHA_id>0][:, None] - PC_xyz[AHA_id>0]) ** 2, dim=2)))
330
        max_distance = torch.max(torch.sqrt(torch.sum((gt_PC_xyz_masked - gt_MI_center) ** 2, dim=1))) 
331
        y_compactness = torch.mean(torch.sqrt(y_dist_sq))/max_distance
332
        gt_compactness = torch.mean(torch.sqrt(gt_dist_sq))/max_distance 
333
334
        compactness_sum = compactness_sum.add(y_compactness + gt_compactness)
335
        num_iter += (num_iter + 1)
336
    if num_iter != 0:
337
        return compactness_sum/num_iter, MI_size_div_sum/num_iter, MI_RVpenalty_sum/num_iter
338
    else:
339
        return compactness_sum, MI_size_div_sum, MI_RVpenalty_sum
340
341
def calculate_Dice(inputs, target, num_classes):
342
    
343
    target_onehot = F.one_hot(target, num_classes).permute(0, 2, 1)
344
    
345
    eps = 1e-6
346
    intersection = torch.sum(inputs * target_onehot, dim=[0, 2])
347
    cardinality = torch.sum(inputs + target_onehot, dim=[0, 2])
348
    Dice = (2.0 * intersection + eps) / (cardinality + eps)
349
    
350
    return Dice
351
352
def one_hot_argmax(input_tensor):
353
    """
354
    This function takes a PyTorch tensor as input and returns a tuple of two tensors:
355
    - One-hot tensor: a binary tensor with the same shape as the input tensor, where the value 1
356
      is placed in the position of the maximum element of the input tensor and 0 elsewhere.
357
    - Argmax tensor: a tensor with the same shape as the input tensor, where the value is the index
358
      of the maximum element of the input tensor.
359
    """
360
    input_tensor = input_tensor.permute(0, 2, 1).squeeze(0)
361
    max_indices = torch.argmax(input_tensor, dim=1)
362
    one_hot_tensor = torch.zeros_like(input_tensor)
363
    one_hot_tensor.scatter_(1, max_indices.view(-1, 1), 1)
364
    
365
    return one_hot_tensor.permute(1, 0)
366
367
if __name__ == '__main__':
368
369
    pcs1 = torch.rand(10, 1024, 4)
370
    pcs2 = torch.rand(10, 1024, 4)
371
372
373