|
a |
|
b/utils/metrics.py |
|
|
1 |
import torch |
|
|
2 |
import numpy as np |
|
|
3 |
from hausdorff import hausdorff_distance |
|
|
4 |
from medpy.metric.binary import hd, dc |
|
|
5 |
|
|
|
6 |
def dice(pred, target): |
|
|
7 |
pred = pred.contiguous() |
|
|
8 |
target = target.contiguous() |
|
|
9 |
smooth = 0.00001 |
|
|
10 |
|
|
|
11 |
# intersection = (pred * target).sum(dim=2).sum(dim=2) |
|
|
12 |
pred_flat = pred.view(1, -1) |
|
|
13 |
target_flat = target.view(1, -1) |
|
|
14 |
|
|
|
15 |
intersection = (pred_flat * target_flat).sum().item() |
|
|
16 |
|
|
|
17 |
# loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth))) |
|
|
18 |
dice = (2 * intersection + smooth) / (pred_flat.sum().item() + target_flat.sum().item() + smooth) |
|
|
19 |
return dice |
|
|
20 |
|
|
|
21 |
def dice3D(img_gt, img_pred, voxel_size): |
|
|
22 |
""" |
|
|
23 |
Function to compute the metrics between two segmentation maps given as input. |
|
|
24 |
|
|
|
25 |
Parameters |
|
|
26 |
---------- |
|
|
27 |
img_gt: np.array |
|
|
28 |
Array of the ground truth segmentation map. |
|
|
29 |
|
|
|
30 |
img_pred: np.array |
|
|
31 |
Array of the predicted segmentation map. |
|
|
32 |
|
|
|
33 |
voxel_size: list, tuple or np.array |
|
|
34 |
The size of a voxel of the images used to compute the volumes. |
|
|
35 |
|
|
|
36 |
Return |
|
|
37 |
------ |
|
|
38 |
A list of metrics in this order, [Dice LV, Volume LV, Err LV(ml), |
|
|
39 |
Dice RV, Volume RV, Err RV(ml), Dice MYO, Volume MYO, Err MYO(ml)] |
|
|
40 |
""" |
|
|
41 |
|
|
|
42 |
if img_gt.ndim != img_pred.ndim: |
|
|
43 |
raise ValueError("The arrays 'img_gt' and 'img_pred' should have the " |
|
|
44 |
"same dimension, {} against {}".format(img_gt.ndim, |
|
|
45 |
img_pred.ndim)) |
|
|
46 |
|
|
|
47 |
res = [] |
|
|
48 |
# Loop on each classes of the input images |
|
|
49 |
for c in [3, 1, 2]: |
|
|
50 |
# Copy the gt image to not alterate the input |
|
|
51 |
gt_c_i = np.copy(img_gt) |
|
|
52 |
gt_c_i[gt_c_i != c] = 0 |
|
|
53 |
|
|
|
54 |
# Copy the pred image to not alterate the input |
|
|
55 |
pred_c_i = np.copy(img_pred) |
|
|
56 |
pred_c_i[pred_c_i != c] = 0 |
|
|
57 |
|
|
|
58 |
# Clip the value to compute the volumes |
|
|
59 |
gt_c_i = np.clip(gt_c_i, 0, 1) |
|
|
60 |
pred_c_i = np.clip(pred_c_i, 0, 1) |
|
|
61 |
|
|
|
62 |
# Compute the Dice |
|
|
63 |
dice = dc(gt_c_i, pred_c_i) |
|
|
64 |
|
|
|
65 |
# Compute volume |
|
|
66 |
# volpred = pred_c_i.sum() * np.prod(voxel_size) / 1000. |
|
|
67 |
# volgt = gt_c_i.sum() * np.prod(voxel_size) / 1000. |
|
|
68 |
|
|
|
69 |
# res += [dice, volpred, volpred-volgt] |
|
|
70 |
res += [dice] |
|
|
71 |
|
|
|
72 |
return res |
|
|
73 |
|
|
|
74 |
def hd_3D(img_pred, img_gt, labels=[3, 1, 2]): |
|
|
75 |
res = [] |
|
|
76 |
for c in labels: |
|
|
77 |
gt_c_i = np.copy(img_gt) |
|
|
78 |
gt_c_i[gt_c_i != c] = 0 |
|
|
79 |
|
|
|
80 |
pred_c_i = np.copy(img_pred) |
|
|
81 |
pred_c_i[pred_c_i != c] = 0 |
|
|
82 |
|
|
|
83 |
gt_c_i = np.clip(gt_c_i, 0, 1) |
|
|
84 |
pred_c_i = np.clip(pred_c_i, 0, 1) |
|
|
85 |
|
|
|
86 |
if np.sum(pred_c_i) == 0 or np.sum(gt_c_i) == 0: |
|
|
87 |
hausdorff = 0 |
|
|
88 |
else: |
|
|
89 |
hausdorff = hd(pred_c_i, gt_c_i) |
|
|
90 |
|
|
|
91 |
res += [hausdorff] |
|
|
92 |
|
|
|
93 |
return res |
|
|
94 |
|
|
|
95 |
def cal_hausdorff_distance(pred,target): |
|
|
96 |
|
|
|
97 |
pred = np.array(pred.contiguous()) |
|
|
98 |
target = np.array(target.contiguous()) |
|
|
99 |
result = hausdorff_distance(pred,target,distance="euclidean") |
|
|
100 |
|
|
|
101 |
return result |
|
|
102 |
|
|
|
103 |
def make_one_hot(input, num_classes): |
|
|
104 |
"""Convert class index tensor to one hot encoding tensor. |
|
|
105 |
Args: |
|
|
106 |
input: A tensor of shape [N, 1, *] |
|
|
107 |
num_classes: An int of number of class |
|
|
108 |
Returns: |
|
|
109 |
A tensor of shape [N, num_classes, *] |
|
|
110 |
""" |
|
|
111 |
shape = np.array(input.shape) |
|
|
112 |
shape[1] = num_classes |
|
|
113 |
shape = tuple(shape) |
|
|
114 |
result = torch.zeros(shape).scatter_(1, input.cpu().long(), 1) |
|
|
115 |
# result = result.scatter_(1, input.cpu(), 1) |
|
|
116 |
|
|
|
117 |
return result |
|
|
118 |
|
|
|
119 |
def match_pred_gt(pred, gt): |
|
|
120 |
""" pred: (1, C, H, W) |
|
|
121 |
gt: (1, C, H, W) |
|
|
122 |
""" |
|
|
123 |
gt_labels = torch.unique(gt, sorted=True)[1:] |
|
|
124 |
pred_labels = torch.unique(pred, sorted=True)[1:] |
|
|
125 |
|
|
|
126 |
if len(gt_labels) != 0 and len(pred_labels) != 0: |
|
|
127 |
dice_Matrix = torch.zeros((len(pred_labels), len(gt_labels))) |
|
|
128 |
for i, pl in enumerate(pred_labels): |
|
|
129 |
pred_i = torch.tensor(pred==pl, dtype=torch.float) |
|
|
130 |
for j, gl in enumerate(gt_labels): |
|
|
131 |
dice_Matrix[i, j] = dice(make_one_hot(pred_i, 2)[0], make_one_hot(gt==gl, 2)[0]) |
|
|
132 |
|
|
|
133 |
# max_axis0 = np.max(dice_Matrix, axis=0) |
|
|
134 |
max_arg0 = np.argmax(dice_Matrix, axis=0) |
|
|
135 |
else: |
|
|
136 |
return torch.zeros_like(pred) |
|
|
137 |
|
|
|
138 |
pred_match = torch.zeros_like(pred) |
|
|
139 |
for i, arg in enumerate(max_arg0): |
|
|
140 |
pred_match[pred==pred_labels[arg]] = i + 1 |
|
|
141 |
return pred_match |
|
|
142 |
|
|
|
143 |
if __name__ == "__main__": |
|
|
144 |
npy_path = "/home/fcheng/Cardia/source_code/logs/logs_df_50000/eval_pp_test/200.npy" |
|
|
145 |
pred_df, gt_df = np.load(npy_p) |