Diff of /utils.py [000000] .. [721c7a]

Switch to unified view

a b/utils.py
1
import torch
2
import numpy as np
3
import cv2
4
import SimpleITK as sitk
5
import matplotlib
6
import matplotlib.pyplot as plt
7
from scipy import ndimage
8
import pdb
9
import math
10
import vtk
11
from torch.autograd import Variable
12
from skimage.morphology import binary_dilation, disk
13
import imageio
14
import os
15
16
17
def cdist(x, y):
18
    """
19
    Compute distance between each pair of the two collections of inputs.
20
    :param x: Nxd Tensor
21
    :param y: Mxd Tensor
22
    :res: NxM matrix where dist[i,j] is the norm between x[i,:] and y[j,:],
23
          i.e. dist[i,j] = ||x[i,:]-y[j,:]||
24
    """
25
    differences = x.unsqueeze(1) - y.unsqueeze(0)
26
    distances = torch.sum((differences+1e-6)**2, -1).sqrt()
27
    return distances
28
29
def generaliz_mean(tensor, dim, p=-9, keepdim=False):
30
    # """
31
    # Computes the softmin along some axes.
32
    # Softmin is the same as -softmax(-x), i.e,
33
    # softmin(x) = -log(sum_i(exp(-x_i)))
34
35
    # The smoothness of the operator is controlled with k:
36
    # softmin(x) = -log(sum_i(exp(-k*x_i)))/k
37
38
    # :param input: Tensor of any dimension.
39
    # :param dim: (int or tuple of ints) The dimension or dimensions to reduce.
40
    # :param keepdim: (bool) Whether the output tensor has dim retained or not.
41
    # :param k: (float>0) How similar softmin is to min (the lower the more smooth).
42
    # """
43
    # return -torch.log(torch.sum(torch.exp(-k*input), dim, keepdim))/k
44
    """
45
    The generalized mean. It corresponds to the minimum when p = -inf.
46
    https://en.wikipedia.org/wiki/Generalized_mean
47
    :param tensor: Tensor of any dimension.
48
    :param dim: (int or tuple of ints) The dimension or dimensions to reduce.
49
    :param keepdim: (bool) Whether the output tensor has dim retained or not.
50
    :param p: (float<0).
51
    """
52
    assert p < 0
53
    res= torch.mean((tensor + 1e-6)**p, dim, keepdim=keepdim)**(1./p)
54
    return res
55
56
57
58
def weightedHausdorff_batch(prob_loc, prob_vec, gt, height, width, temper, status):
59
    max_dist = math.sqrt(height ** 2 + width ** 2)
60
61
    # print (gt.shape)
62
    # print (gt.sum())
63
    # print (prob_vec.sum())
64
    batch_size = prob_loc.shape[0]
65
    # print (batch_size)
66
67
68
    term_1 = []
69
    term_2 = []
70
71
    for i in range(batch_size):
72
        prob_vec_sele = prob_vec[i, :, 0][prob_vec[i, :, 0] > torch.exp(torch.tensor((-1) * temper).cuda())]
73
        idx_sele_x = prob_loc[i, :, 0][prob_vec[i, :, 0] > torch.exp(torch.tensor((-1) * temper).cuda())]
74
        idx_sele_y = prob_loc[i, :, 1][prob_vec[i, :, 0] > torch.exp(torch.tensor((-1) * temper).cuda())]
75
        idx_sele = torch.stack((idx_sele_x, idx_sele_y), 1)
76
77
78
        # For case GT=0
79
        if gt[i,:,:].sum() == 0:
80
            if prob_vec_sele.sum() < 1e-3:
81
                if status=='train':
82
                    term_1.append(Variable(torch.tensor(0.0).cuda(), requires_grad=True))
83
                    term_2.append(Variable(torch.tensor(0.0).cuda(), requires_grad=True))
84
                else:
85
                    term_1.append(torch.tensor(0.0).cuda())
86
                    term_2.append(torch.tensor(0.0).cuda())
87
            else:
88
                if status == 'train':
89
                    term_1.append(Variable(torch.tensor(0.0).cuda(), requires_grad=True))
90
                    term_2.append(Variable((torch.tensor(max_dist)).cuda(), requires_grad=True))
91
                else:
92
                    term_1.append(torch.tensor(0.0).cuda())
93
                    term_2.append(torch.tensor(max_dist).cuda())
94
        else:
95
            if prob_vec_sele.sum() < 1e-3:
96
                if status == 'train':
97
                    term_1.append(Variable((torch.tensor(max_dist)).cuda(), requires_grad=True))
98
                    term_2.append(Variable(torch.tensor(0.0).cuda(), requires_grad=True))
99
                else:
100
                    term_1.append(torch.tensor(max_dist).cuda())
101
                    term_2.append(torch.tensor(0.0).cuda())
102
            else:
103
                # find nonzero point in gt
104
                idx_gt = torch.nonzero(gt[i, :, :])
105
                d_matrix = cdist(idx_sele, idx_gt)
106
                # print (d_matrix.shape) # N*M
107
108
109
                term_1.append(
110
                    (1 / (prob_vec_sele.sum() + 1e-6)) * torch.sum(prob_vec_sele * torch.min(d_matrix, 1)[0]))
111
                p_replicated = prob_vec_sele.view(-1, 1).repeat(1, idx_gt.shape[0])
112
                weighted_d_matrix = (1 - p_replicated) * max_dist + p_replicated * d_matrix
113
                minn = generaliz_mean(weighted_d_matrix, p=-7, dim=0, keepdim=False)
114
                term_2.append(torch.mean(minn))
115
116
117
    # print (term_1)
118
    # print (term_2)
119
    term_1 = torch.stack(term_1)
120
    term_2 = torch.stack(term_2)
121
122
    res = term_1.mean()+term_2.mean()
123
124
125
    return res
126
127
128
129
def huber_loss_3d(x):
130
    bsize, csize, depth, height, width = x.size()
131
    d_x = torch.index_select(x, 4, torch.arange(1, width).cuda()) - torch.index_select(x, 4, torch.arange(width-1).cuda())
132
    d_y = torch.index_select(x, 3, torch.arange(1, height).cuda()) - torch.index_select(x, 3, torch.arange(height-1).cuda())
133
    d_z = torch.index_select(x, 2, torch.arange(1, depth).cuda()) - torch.index_select(x, 2, torch.arange(depth-1).cuda())
134
    err = torch.sum(torch.mul(d_x, d_x))/width + torch.sum(torch.mul(d_y, d_y))/height + torch.sum(torch.mul(d_z, d_z))/depth
135
    err /= bsize
136
    tv_err = torch.sqrt(0.01+err)
137
    return tv_err
138
139
140
141
142
def projection(voxels, z_target, temper):
143
    # voxels are transformed from meshes based on affine information of different target plane
144
    # z_target is the z coordinate of the target plane, e.g., SAX is 12,17,22,27,32,37,42,47,52, 2CH is 0, 4CH is 0
145
    v_idx = voxels[:,:,0:2]  # [bs, numer_of verties, x/y coordinate]
146
    v_probability = torch.exp((-1) * temper * torch.square(voxels[:, :, 2:3] - z_target)) # [bs, numer_of verties, probability]
147
148
149
    return v_idx, v_probability
150
151
152
153
def distance_metric(pts_A, pts_B, dx):
154
    # Measure the distance errors between the contours of two segmentations
155
    # The manual contours are drawn on 2D slices.
156
    # We calculate contour to contour distance for each slice.
157
    # pts_A is N*2, pts_B is M*2
158
    if pts_A.shape[0] > 0 and pts_B.shape[0] > 0:
159
        # Distance matrix between point sets
160
        M = np.zeros((pts_A.shape[0], pts_B.shape[0]))
161
        for i in range(pts_A.shape[0]):
162
            for j in range(pts_B.shape[0]):
163
                M[i, j] = np.linalg.norm(pts_A[i, :] - pts_B[j, :])
164
165
        # Mean distance and hausdorff distance
166
        md = 0.5 * (np.mean(np.min(M, axis=0)) + np.mean(np.min(M, axis=1))) * dx
167
        hd = np.max([np.max(np.min(M, axis=0)), np.max(np.min(M, axis=1))]) * dx
168
    else:
169
        md = None
170
        hd = None
171
172
    return md, hd
173
174
175
def slice_2D(v_hat_es_cp, slice_num):
176
    idx_x = v_hat_es_cp[0, :, 0][torch.abs(v_hat_es_cp[0, :, 2] - slice_num) < 0.3]
177
    idx_y = v_hat_es_cp[0, :, 1][torch.abs(v_hat_es_cp[0, :, 2] - slice_num) < 0.3]
178
    idx_x_t = np.round(idx_x.detach().cpu().numpy()).astype(np.int16)
179
    idx_y_t = np.round(idx_y.detach().cpu().numpy()).astype(np.int16)
180
    idx = np.stack((idx_x_t, idx_y_t), 1)
181
182
    return idx
183
184
185
def compute_sa_mcd_hd(v_sa_hat_es_cp, contour_sa_es, sliceall):
186
    mcd_sa_allslice = []
187
    hd_sa_allslice = []
188
189
    slice_number = [1,4,7]
190
    threeslice = [sliceall[slice_number[0]], sliceall[slice_number[1]], sliceall[slice_number[2]]]
191
192
    print (threeslice)
193
    for i in range(len(threeslice)):
194
        idx_sa = slice_2D(v_sa_hat_es_cp, threeslice[i])
195
        idx_sa_gt = np.stack(np.nonzero(contour_sa_es[slice_number[i], :, :]), 1)
196
197
        mcd_sa, hd_sa = distance_metric(idx_sa, idx_sa_gt, 1.25)
198
        if (mcd_sa != None) and (hd_sa != None):
199
            mcd_sa_allslice.append(mcd_sa)
200
            hd_sa_allslice.append(hd_sa)
201
202
203
    mean_mcd_sa_allslices = np.mean(mcd_sa_allslice) if mcd_sa_allslice else None
204
    mean_hd_sa_allslices = np.mean(hd_sa_allslice) if hd_sa_allslice else None
205
206
    return mean_mcd_sa_allslices, mean_hd_sa_allslices
207
208
209
210
def FBoundary(pred_contour, gt_contour, bound_th=2):
211
    bound_pix = bound_th if bound_th >= 1 else \
212
        np.ceil(bound_th * np.linalg.norm(pred_contour.shape))
213
214
    pred_dil = binary_dilation(pred_contour, disk(bound_pix))
215
    gt_dil = binary_dilation(gt_contour, disk(bound_pix))
216
217
    # Get the intersection
218
    gt_match = gt_contour * pred_dil
219
    pred_match = pred_contour * gt_dil
220
221
    # Area of the intersection
222
    n_pred = np.sum(pred_contour)
223
    n_gt = np.sum(gt_contour)
224
225
    # % Compute precision and recall
226
    if n_pred == 0 and n_gt > 0:
227
        precision = 1
228
        recall = 0
229
    elif n_pred > 0 and n_gt == 0:
230
        precision = 0
231
        recall = 1
232
    elif n_pred == 0 and n_gt == 0:
233
        precision = 1
234
        recall = 1
235
    else:
236
        precision = np.sum(pred_match) / float(n_pred)
237
        recall = np.sum(gt_match) / float(n_gt)
238
239
    # Compute F measure
240
    if precision + recall == 0:
241
        Fscore = None
242
    else:
243
        Fscore = 2 * precision * recall / (precision + recall)
244
245
    return Fscore
246
247
def compute_sa_Fboundary(v_sa_hat_es_cp, contour_sa_es, sliceall, height, width):
248
249
    bfscore_all = []
250
    for i in range(len(sliceall)):
251
        idx_sa = slice_2D(v_sa_hat_es_cp, sliceall[i])
252
        sa_pred = np.zeros(shape=(height, width))
253
        for j in range(idx_sa.shape[0]):
254
            sa_pred[idx_sa[j,0], idx_sa[j,1]] = 1
255
256
        Fscore_1 = FBoundary(sa_pred, contour_sa_es[i,:,:], 1)
257
        Fscore_2 = FBoundary(sa_pred, contour_sa_es[i,:,:], 2)
258
        Fscore_3 = FBoundary(sa_pred, contour_sa_es[i,:,:], 3)
259
        Fscore_4 = FBoundary(sa_pred, contour_sa_es[i,:,:], 4)
260
        Fscore_5 = FBoundary(sa_pred, contour_sa_es[i,:,:], 5)
261
262
263
        if (Fscore_1 != None):
264
            Fscore = (Fscore_1+Fscore_2+Fscore_3+Fscore_4+Fscore_5)/5.0
265
            bfscore_all.append(Fscore)
266
267
    mean_bfscore = np.mean(bfscore_all) if bfscore_all else None
268
269
270
    return mean_bfscore
271
272
def compute_la_Fboundary(pred_contour, gt_contour):
273
274
    Fscore_1 = FBoundary(pred_contour, gt_contour, 1)
275
    Fscore_2 = FBoundary(pred_contour, gt_contour, 2)
276
    Fscore_3 = FBoundary(pred_contour, gt_contour, 3)
277
    Fscore_4 = FBoundary(pred_contour, gt_contour, 4)
278
    Fscore_5 = FBoundary(pred_contour, gt_contour, 5)
279
280
281
    if (Fscore_1 != None):
282
        Fscore = (Fscore_1+Fscore_2+Fscore_3+Fscore_4+Fscore_5)/5.0
283
    else:
284
        Fscore = None
285
286
287
    return Fscore
288
289
290
291
292
def projection_weightHD_loss_SA(v_sa_hat_ed_cp, temper, height, width, depth, gt_mesh2seg_sa, status):
293
294
    weightHD_loss = []
295
296
    for i in range(depth-1):
297
        v_sa_idx_ed, w_sa_ed = projection(v_sa_hat_ed_cp, i, temper)
298
        slice_loss = weightedHausdorff_batch(v_sa_idx_ed, w_sa_ed, gt_mesh2seg_sa[:,:,:,i], height, width, temper, status)
299
300
        weightHD_loss.append(slice_loss)
301
302
    weightHD_loss = torch.stack(weightHD_loss)
303
304
    loss_aver = torch.mean(weightHD_loss)
305
306
307
308
    return loss_aver
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335