|
a |
|
b/utils.py |
|
|
1 |
import torch |
|
|
2 |
import numpy as np |
|
|
3 |
import cv2 |
|
|
4 |
import SimpleITK as sitk |
|
|
5 |
import matplotlib |
|
|
6 |
import matplotlib.pyplot as plt |
|
|
7 |
from scipy import ndimage |
|
|
8 |
import pdb |
|
|
9 |
import math |
|
|
10 |
import vtk |
|
|
11 |
from torch.autograd import Variable |
|
|
12 |
from skimage.morphology import binary_dilation, disk |
|
|
13 |
import imageio |
|
|
14 |
import os |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
def cdist(x, y): |
|
|
18 |
""" |
|
|
19 |
Compute distance between each pair of the two collections of inputs. |
|
|
20 |
:param x: Nxd Tensor |
|
|
21 |
:param y: Mxd Tensor |
|
|
22 |
:res: NxM matrix where dist[i,j] is the norm between x[i,:] and y[j,:], |
|
|
23 |
i.e. dist[i,j] = ||x[i,:]-y[j,:]|| |
|
|
24 |
""" |
|
|
25 |
differences = x.unsqueeze(1) - y.unsqueeze(0) |
|
|
26 |
distances = torch.sum((differences+1e-6)**2, -1).sqrt() |
|
|
27 |
return distances |
|
|
28 |
|
|
|
29 |
def generaliz_mean(tensor, dim, p=-9, keepdim=False): |
|
|
30 |
# """ |
|
|
31 |
# Computes the softmin along some axes. |
|
|
32 |
# Softmin is the same as -softmax(-x), i.e, |
|
|
33 |
# softmin(x) = -log(sum_i(exp(-x_i))) |
|
|
34 |
|
|
|
35 |
# The smoothness of the operator is controlled with k: |
|
|
36 |
# softmin(x) = -log(sum_i(exp(-k*x_i)))/k |
|
|
37 |
|
|
|
38 |
# :param input: Tensor of any dimension. |
|
|
39 |
# :param dim: (int or tuple of ints) The dimension or dimensions to reduce. |
|
|
40 |
# :param keepdim: (bool) Whether the output tensor has dim retained or not. |
|
|
41 |
# :param k: (float>0) How similar softmin is to min (the lower the more smooth). |
|
|
42 |
# """ |
|
|
43 |
# return -torch.log(torch.sum(torch.exp(-k*input), dim, keepdim))/k |
|
|
44 |
""" |
|
|
45 |
The generalized mean. It corresponds to the minimum when p = -inf. |
|
|
46 |
https://en.wikipedia.org/wiki/Generalized_mean |
|
|
47 |
:param tensor: Tensor of any dimension. |
|
|
48 |
:param dim: (int or tuple of ints) The dimension or dimensions to reduce. |
|
|
49 |
:param keepdim: (bool) Whether the output tensor has dim retained or not. |
|
|
50 |
:param p: (float<0). |
|
|
51 |
""" |
|
|
52 |
assert p < 0 |
|
|
53 |
res= torch.mean((tensor + 1e-6)**p, dim, keepdim=keepdim)**(1./p) |
|
|
54 |
return res |
|
|
55 |
|
|
|
56 |
|
|
|
57 |
|
|
|
58 |
def weightedHausdorff_batch(prob_loc, prob_vec, gt, height, width, temper, status): |
|
|
59 |
max_dist = math.sqrt(height ** 2 + width ** 2) |
|
|
60 |
|
|
|
61 |
# print (gt.shape) |
|
|
62 |
# print (gt.sum()) |
|
|
63 |
# print (prob_vec.sum()) |
|
|
64 |
batch_size = prob_loc.shape[0] |
|
|
65 |
# print (batch_size) |
|
|
66 |
|
|
|
67 |
|
|
|
68 |
term_1 = [] |
|
|
69 |
term_2 = [] |
|
|
70 |
|
|
|
71 |
for i in range(batch_size): |
|
|
72 |
prob_vec_sele = prob_vec[i, :, 0][prob_vec[i, :, 0] > torch.exp(torch.tensor((-1) * temper).cuda())] |
|
|
73 |
idx_sele_x = prob_loc[i, :, 0][prob_vec[i, :, 0] > torch.exp(torch.tensor((-1) * temper).cuda())] |
|
|
74 |
idx_sele_y = prob_loc[i, :, 1][prob_vec[i, :, 0] > torch.exp(torch.tensor((-1) * temper).cuda())] |
|
|
75 |
idx_sele = torch.stack((idx_sele_x, idx_sele_y), 1) |
|
|
76 |
|
|
|
77 |
|
|
|
78 |
# For case GT=0 |
|
|
79 |
if gt[i,:,:].sum() == 0: |
|
|
80 |
if prob_vec_sele.sum() < 1e-3: |
|
|
81 |
if status=='train': |
|
|
82 |
term_1.append(Variable(torch.tensor(0.0).cuda(), requires_grad=True)) |
|
|
83 |
term_2.append(Variable(torch.tensor(0.0).cuda(), requires_grad=True)) |
|
|
84 |
else: |
|
|
85 |
term_1.append(torch.tensor(0.0).cuda()) |
|
|
86 |
term_2.append(torch.tensor(0.0).cuda()) |
|
|
87 |
else: |
|
|
88 |
if status == 'train': |
|
|
89 |
term_1.append(Variable(torch.tensor(0.0).cuda(), requires_grad=True)) |
|
|
90 |
term_2.append(Variable((torch.tensor(max_dist)).cuda(), requires_grad=True)) |
|
|
91 |
else: |
|
|
92 |
term_1.append(torch.tensor(0.0).cuda()) |
|
|
93 |
term_2.append(torch.tensor(max_dist).cuda()) |
|
|
94 |
else: |
|
|
95 |
if prob_vec_sele.sum() < 1e-3: |
|
|
96 |
if status == 'train': |
|
|
97 |
term_1.append(Variable((torch.tensor(max_dist)).cuda(), requires_grad=True)) |
|
|
98 |
term_2.append(Variable(torch.tensor(0.0).cuda(), requires_grad=True)) |
|
|
99 |
else: |
|
|
100 |
term_1.append(torch.tensor(max_dist).cuda()) |
|
|
101 |
term_2.append(torch.tensor(0.0).cuda()) |
|
|
102 |
else: |
|
|
103 |
# find nonzero point in gt |
|
|
104 |
idx_gt = torch.nonzero(gt[i, :, :]) |
|
|
105 |
d_matrix = cdist(idx_sele, idx_gt) |
|
|
106 |
# print (d_matrix.shape) # N*M |
|
|
107 |
|
|
|
108 |
|
|
|
109 |
term_1.append( |
|
|
110 |
(1 / (prob_vec_sele.sum() + 1e-6)) * torch.sum(prob_vec_sele * torch.min(d_matrix, 1)[0])) |
|
|
111 |
p_replicated = prob_vec_sele.view(-1, 1).repeat(1, idx_gt.shape[0]) |
|
|
112 |
weighted_d_matrix = (1 - p_replicated) * max_dist + p_replicated * d_matrix |
|
|
113 |
minn = generaliz_mean(weighted_d_matrix, p=-7, dim=0, keepdim=False) |
|
|
114 |
term_2.append(torch.mean(minn)) |
|
|
115 |
|
|
|
116 |
|
|
|
117 |
# print (term_1) |
|
|
118 |
# print (term_2) |
|
|
119 |
term_1 = torch.stack(term_1) |
|
|
120 |
term_2 = torch.stack(term_2) |
|
|
121 |
|
|
|
122 |
res = term_1.mean()+term_2.mean() |
|
|
123 |
|
|
|
124 |
|
|
|
125 |
return res |
|
|
126 |
|
|
|
127 |
|
|
|
128 |
|
|
|
129 |
def huber_loss_3d(x): |
|
|
130 |
bsize, csize, depth, height, width = x.size() |
|
|
131 |
d_x = torch.index_select(x, 4, torch.arange(1, width).cuda()) - torch.index_select(x, 4, torch.arange(width-1).cuda()) |
|
|
132 |
d_y = torch.index_select(x, 3, torch.arange(1, height).cuda()) - torch.index_select(x, 3, torch.arange(height-1).cuda()) |
|
|
133 |
d_z = torch.index_select(x, 2, torch.arange(1, depth).cuda()) - torch.index_select(x, 2, torch.arange(depth-1).cuda()) |
|
|
134 |
err = torch.sum(torch.mul(d_x, d_x))/width + torch.sum(torch.mul(d_y, d_y))/height + torch.sum(torch.mul(d_z, d_z))/depth |
|
|
135 |
err /= bsize |
|
|
136 |
tv_err = torch.sqrt(0.01+err) |
|
|
137 |
return tv_err |
|
|
138 |
|
|
|
139 |
|
|
|
140 |
|
|
|
141 |
|
|
|
142 |
def projection(voxels, z_target, temper): |
|
|
143 |
# voxels are transformed from meshes based on affine information of different target plane |
|
|
144 |
# z_target is the z coordinate of the target plane, e.g., SAX is 12,17,22,27,32,37,42,47,52, 2CH is 0, 4CH is 0 |
|
|
145 |
v_idx = voxels[:,:,0:2] # [bs, numer_of verties, x/y coordinate] |
|
|
146 |
v_probability = torch.exp((-1) * temper * torch.square(voxels[:, :, 2:3] - z_target)) # [bs, numer_of verties, probability] |
|
|
147 |
|
|
|
148 |
|
|
|
149 |
return v_idx, v_probability |
|
|
150 |
|
|
|
151 |
|
|
|
152 |
|
|
|
153 |
def distance_metric(pts_A, pts_B, dx): |
|
|
154 |
# Measure the distance errors between the contours of two segmentations |
|
|
155 |
# The manual contours are drawn on 2D slices. |
|
|
156 |
# We calculate contour to contour distance for each slice. |
|
|
157 |
# pts_A is N*2, pts_B is M*2 |
|
|
158 |
if pts_A.shape[0] > 0 and pts_B.shape[0] > 0: |
|
|
159 |
# Distance matrix between point sets |
|
|
160 |
M = np.zeros((pts_A.shape[0], pts_B.shape[0])) |
|
|
161 |
for i in range(pts_A.shape[0]): |
|
|
162 |
for j in range(pts_B.shape[0]): |
|
|
163 |
M[i, j] = np.linalg.norm(pts_A[i, :] - pts_B[j, :]) |
|
|
164 |
|
|
|
165 |
# Mean distance and hausdorff distance |
|
|
166 |
md = 0.5 * (np.mean(np.min(M, axis=0)) + np.mean(np.min(M, axis=1))) * dx |
|
|
167 |
hd = np.max([np.max(np.min(M, axis=0)), np.max(np.min(M, axis=1))]) * dx |
|
|
168 |
else: |
|
|
169 |
md = None |
|
|
170 |
hd = None |
|
|
171 |
|
|
|
172 |
return md, hd |
|
|
173 |
|
|
|
174 |
|
|
|
175 |
def slice_2D(v_hat_es_cp, slice_num): |
|
|
176 |
idx_x = v_hat_es_cp[0, :, 0][torch.abs(v_hat_es_cp[0, :, 2] - slice_num) < 0.3] |
|
|
177 |
idx_y = v_hat_es_cp[0, :, 1][torch.abs(v_hat_es_cp[0, :, 2] - slice_num) < 0.3] |
|
|
178 |
idx_x_t = np.round(idx_x.detach().cpu().numpy()).astype(np.int16) |
|
|
179 |
idx_y_t = np.round(idx_y.detach().cpu().numpy()).astype(np.int16) |
|
|
180 |
idx = np.stack((idx_x_t, idx_y_t), 1) |
|
|
181 |
|
|
|
182 |
return idx |
|
|
183 |
|
|
|
184 |
|
|
|
185 |
def compute_sa_mcd_hd(v_sa_hat_es_cp, contour_sa_es, sliceall): |
|
|
186 |
mcd_sa_allslice = [] |
|
|
187 |
hd_sa_allslice = [] |
|
|
188 |
|
|
|
189 |
slice_number = [1,4,7] |
|
|
190 |
threeslice = [sliceall[slice_number[0]], sliceall[slice_number[1]], sliceall[slice_number[2]]] |
|
|
191 |
|
|
|
192 |
print (threeslice) |
|
|
193 |
for i in range(len(threeslice)): |
|
|
194 |
idx_sa = slice_2D(v_sa_hat_es_cp, threeslice[i]) |
|
|
195 |
idx_sa_gt = np.stack(np.nonzero(contour_sa_es[slice_number[i], :, :]), 1) |
|
|
196 |
|
|
|
197 |
mcd_sa, hd_sa = distance_metric(idx_sa, idx_sa_gt, 1.25) |
|
|
198 |
if (mcd_sa != None) and (hd_sa != None): |
|
|
199 |
mcd_sa_allslice.append(mcd_sa) |
|
|
200 |
hd_sa_allslice.append(hd_sa) |
|
|
201 |
|
|
|
202 |
|
|
|
203 |
mean_mcd_sa_allslices = np.mean(mcd_sa_allslice) if mcd_sa_allslice else None |
|
|
204 |
mean_hd_sa_allslices = np.mean(hd_sa_allslice) if hd_sa_allslice else None |
|
|
205 |
|
|
|
206 |
return mean_mcd_sa_allslices, mean_hd_sa_allslices |
|
|
207 |
|
|
|
208 |
|
|
|
209 |
|
|
|
210 |
def FBoundary(pred_contour, gt_contour, bound_th=2): |
|
|
211 |
bound_pix = bound_th if bound_th >= 1 else \ |
|
|
212 |
np.ceil(bound_th * np.linalg.norm(pred_contour.shape)) |
|
|
213 |
|
|
|
214 |
pred_dil = binary_dilation(pred_contour, disk(bound_pix)) |
|
|
215 |
gt_dil = binary_dilation(gt_contour, disk(bound_pix)) |
|
|
216 |
|
|
|
217 |
# Get the intersection |
|
|
218 |
gt_match = gt_contour * pred_dil |
|
|
219 |
pred_match = pred_contour * gt_dil |
|
|
220 |
|
|
|
221 |
# Area of the intersection |
|
|
222 |
n_pred = np.sum(pred_contour) |
|
|
223 |
n_gt = np.sum(gt_contour) |
|
|
224 |
|
|
|
225 |
# % Compute precision and recall |
|
|
226 |
if n_pred == 0 and n_gt > 0: |
|
|
227 |
precision = 1 |
|
|
228 |
recall = 0 |
|
|
229 |
elif n_pred > 0 and n_gt == 0: |
|
|
230 |
precision = 0 |
|
|
231 |
recall = 1 |
|
|
232 |
elif n_pred == 0 and n_gt == 0: |
|
|
233 |
precision = 1 |
|
|
234 |
recall = 1 |
|
|
235 |
else: |
|
|
236 |
precision = np.sum(pred_match) / float(n_pred) |
|
|
237 |
recall = np.sum(gt_match) / float(n_gt) |
|
|
238 |
|
|
|
239 |
# Compute F measure |
|
|
240 |
if precision + recall == 0: |
|
|
241 |
Fscore = None |
|
|
242 |
else: |
|
|
243 |
Fscore = 2 * precision * recall / (precision + recall) |
|
|
244 |
|
|
|
245 |
return Fscore |
|
|
246 |
|
|
|
247 |
def compute_sa_Fboundary(v_sa_hat_es_cp, contour_sa_es, sliceall, height, width): |
|
|
248 |
|
|
|
249 |
bfscore_all = [] |
|
|
250 |
for i in range(len(sliceall)): |
|
|
251 |
idx_sa = slice_2D(v_sa_hat_es_cp, sliceall[i]) |
|
|
252 |
sa_pred = np.zeros(shape=(height, width)) |
|
|
253 |
for j in range(idx_sa.shape[0]): |
|
|
254 |
sa_pred[idx_sa[j,0], idx_sa[j,1]] = 1 |
|
|
255 |
|
|
|
256 |
Fscore_1 = FBoundary(sa_pred, contour_sa_es[i,:,:], 1) |
|
|
257 |
Fscore_2 = FBoundary(sa_pred, contour_sa_es[i,:,:], 2) |
|
|
258 |
Fscore_3 = FBoundary(sa_pred, contour_sa_es[i,:,:], 3) |
|
|
259 |
Fscore_4 = FBoundary(sa_pred, contour_sa_es[i,:,:], 4) |
|
|
260 |
Fscore_5 = FBoundary(sa_pred, contour_sa_es[i,:,:], 5) |
|
|
261 |
|
|
|
262 |
|
|
|
263 |
if (Fscore_1 != None): |
|
|
264 |
Fscore = (Fscore_1+Fscore_2+Fscore_3+Fscore_4+Fscore_5)/5.0 |
|
|
265 |
bfscore_all.append(Fscore) |
|
|
266 |
|
|
|
267 |
mean_bfscore = np.mean(bfscore_all) if bfscore_all else None |
|
|
268 |
|
|
|
269 |
|
|
|
270 |
return mean_bfscore |
|
|
271 |
|
|
|
272 |
def compute_la_Fboundary(pred_contour, gt_contour): |
|
|
273 |
|
|
|
274 |
Fscore_1 = FBoundary(pred_contour, gt_contour, 1) |
|
|
275 |
Fscore_2 = FBoundary(pred_contour, gt_contour, 2) |
|
|
276 |
Fscore_3 = FBoundary(pred_contour, gt_contour, 3) |
|
|
277 |
Fscore_4 = FBoundary(pred_contour, gt_contour, 4) |
|
|
278 |
Fscore_5 = FBoundary(pred_contour, gt_contour, 5) |
|
|
279 |
|
|
|
280 |
|
|
|
281 |
if (Fscore_1 != None): |
|
|
282 |
Fscore = (Fscore_1+Fscore_2+Fscore_3+Fscore_4+Fscore_5)/5.0 |
|
|
283 |
else: |
|
|
284 |
Fscore = None |
|
|
285 |
|
|
|
286 |
|
|
|
287 |
return Fscore |
|
|
288 |
|
|
|
289 |
|
|
|
290 |
|
|
|
291 |
|
|
|
292 |
def projection_weightHD_loss_SA(v_sa_hat_ed_cp, temper, height, width, depth, gt_mesh2seg_sa, status): |
|
|
293 |
|
|
|
294 |
weightHD_loss = [] |
|
|
295 |
|
|
|
296 |
for i in range(depth-1): |
|
|
297 |
v_sa_idx_ed, w_sa_ed = projection(v_sa_hat_ed_cp, i, temper) |
|
|
298 |
slice_loss = weightedHausdorff_batch(v_sa_idx_ed, w_sa_ed, gt_mesh2seg_sa[:,:,:,i], height, width, temper, status) |
|
|
299 |
|
|
|
300 |
weightHD_loss.append(slice_loss) |
|
|
301 |
|
|
|
302 |
weightHD_loss = torch.stack(weightHD_loss) |
|
|
303 |
|
|
|
304 |
loss_aver = torch.mean(weightHD_loss) |
|
|
305 |
|
|
|
306 |
|
|
|
307 |
|
|
|
308 |
return loss_aver |
|
|
309 |
|
|
|
310 |
|
|
|
311 |
|
|
|
312 |
|
|
|
313 |
|
|
|
314 |
|
|
|
315 |
|
|
|
316 |
|
|
|
317 |
|
|
|
318 |
|
|
|
319 |
|
|
|
320 |
|
|
|
321 |
|
|
|
322 |
|
|
|
323 |
|
|
|
324 |
|
|
|
325 |
|
|
|
326 |
|
|
|
327 |
|
|
|
328 |
|
|
|
329 |
|
|
|
330 |
|
|
|
331 |
|
|
|
332 |
|
|
|
333 |
|
|
|
334 |
|
|
|
335 |
|