--- a +++ b/utils.py @@ -0,0 +1,335 @@ +import torch +import numpy as np +import cv2 +import SimpleITK as sitk +import matplotlib +import matplotlib.pyplot as plt +from scipy import ndimage +import pdb +import math +import vtk +from torch.autograd import Variable +from skimage.morphology import binary_dilation, disk +import imageio +import os + + +def cdist(x, y): + """ + Compute distance between each pair of the two collections of inputs. + :param x: Nxd Tensor + :param y: Mxd Tensor + :res: NxM matrix where dist[i,j] is the norm between x[i,:] and y[j,:], + i.e. dist[i,j] = ||x[i,:]-y[j,:]|| + """ + differences = x.unsqueeze(1) - y.unsqueeze(0) + distances = torch.sum((differences+1e-6)**2, -1).sqrt() + return distances + +def generaliz_mean(tensor, dim, p=-9, keepdim=False): + # """ + # Computes the softmin along some axes. + # Softmin is the same as -softmax(-x), i.e, + # softmin(x) = -log(sum_i(exp(-x_i))) + + # The smoothness of the operator is controlled with k: + # softmin(x) = -log(sum_i(exp(-k*x_i)))/k + + # :param input: Tensor of any dimension. + # :param dim: (int or tuple of ints) The dimension or dimensions to reduce. + # :param keepdim: (bool) Whether the output tensor has dim retained or not. + # :param k: (float>0) How similar softmin is to min (the lower the more smooth). + # """ + # return -torch.log(torch.sum(torch.exp(-k*input), dim, keepdim))/k + """ + The generalized mean. It corresponds to the minimum when p = -inf. + https://en.wikipedia.org/wiki/Generalized_mean + :param tensor: Tensor of any dimension. + :param dim: (int or tuple of ints) The dimension or dimensions to reduce. + :param keepdim: (bool) Whether the output tensor has dim retained or not. + :param p: (float<0). + """ + assert p < 0 + res= torch.mean((tensor + 1e-6)**p, dim, keepdim=keepdim)**(1./p) + return res + + + +def weightedHausdorff_batch(prob_loc, prob_vec, gt, height, width, temper, status): + max_dist = math.sqrt(height ** 2 + width ** 2) + + # print (gt.shape) + # print (gt.sum()) + # print (prob_vec.sum()) + batch_size = prob_loc.shape[0] + # print (batch_size) + + + term_1 = [] + term_2 = [] + + for i in range(batch_size): + prob_vec_sele = prob_vec[i, :, 0][prob_vec[i, :, 0] > torch.exp(torch.tensor((-1) * temper).cuda())] + idx_sele_x = prob_loc[i, :, 0][prob_vec[i, :, 0] > torch.exp(torch.tensor((-1) * temper).cuda())] + idx_sele_y = prob_loc[i, :, 1][prob_vec[i, :, 0] > torch.exp(torch.tensor((-1) * temper).cuda())] + idx_sele = torch.stack((idx_sele_x, idx_sele_y), 1) + + + # For case GT=0 + if gt[i,:,:].sum() == 0: + if prob_vec_sele.sum() < 1e-3: + if status=='train': + term_1.append(Variable(torch.tensor(0.0).cuda(), requires_grad=True)) + term_2.append(Variable(torch.tensor(0.0).cuda(), requires_grad=True)) + else: + term_1.append(torch.tensor(0.0).cuda()) + term_2.append(torch.tensor(0.0).cuda()) + else: + if status == 'train': + term_1.append(Variable(torch.tensor(0.0).cuda(), requires_grad=True)) + term_2.append(Variable((torch.tensor(max_dist)).cuda(), requires_grad=True)) + else: + term_1.append(torch.tensor(0.0).cuda()) + term_2.append(torch.tensor(max_dist).cuda()) + else: + if prob_vec_sele.sum() < 1e-3: + if status == 'train': + term_1.append(Variable((torch.tensor(max_dist)).cuda(), requires_grad=True)) + term_2.append(Variable(torch.tensor(0.0).cuda(), requires_grad=True)) + else: + term_1.append(torch.tensor(max_dist).cuda()) + term_2.append(torch.tensor(0.0).cuda()) + else: + # find nonzero point in gt + idx_gt = torch.nonzero(gt[i, :, :]) + d_matrix = cdist(idx_sele, idx_gt) + # print (d_matrix.shape) # N*M + + + term_1.append( + (1 / (prob_vec_sele.sum() + 1e-6)) * torch.sum(prob_vec_sele * torch.min(d_matrix, 1)[0])) + p_replicated = prob_vec_sele.view(-1, 1).repeat(1, idx_gt.shape[0]) + weighted_d_matrix = (1 - p_replicated) * max_dist + p_replicated * d_matrix + minn = generaliz_mean(weighted_d_matrix, p=-7, dim=0, keepdim=False) + term_2.append(torch.mean(minn)) + + + # print (term_1) + # print (term_2) + term_1 = torch.stack(term_1) + term_2 = torch.stack(term_2) + + res = term_1.mean()+term_2.mean() + + + return res + + + +def huber_loss_3d(x): + bsize, csize, depth, height, width = x.size() + d_x = torch.index_select(x, 4, torch.arange(1, width).cuda()) - torch.index_select(x, 4, torch.arange(width-1).cuda()) + d_y = torch.index_select(x, 3, torch.arange(1, height).cuda()) - torch.index_select(x, 3, torch.arange(height-1).cuda()) + d_z = torch.index_select(x, 2, torch.arange(1, depth).cuda()) - torch.index_select(x, 2, torch.arange(depth-1).cuda()) + err = torch.sum(torch.mul(d_x, d_x))/width + torch.sum(torch.mul(d_y, d_y))/height + torch.sum(torch.mul(d_z, d_z))/depth + err /= bsize + tv_err = torch.sqrt(0.01+err) + return tv_err + + + + +def projection(voxels, z_target, temper): + # voxels are transformed from meshes based on affine information of different target plane + # z_target is the z coordinate of the target plane, e.g., SAX is 12,17,22,27,32,37,42,47,52, 2CH is 0, 4CH is 0 + v_idx = voxels[:,:,0:2] # [bs, numer_of verties, x/y coordinate] + v_probability = torch.exp((-1) * temper * torch.square(voxels[:, :, 2:3] - z_target)) # [bs, numer_of verties, probability] + + + return v_idx, v_probability + + + +def distance_metric(pts_A, pts_B, dx): + # Measure the distance errors between the contours of two segmentations + # The manual contours are drawn on 2D slices. + # We calculate contour to contour distance for each slice. + # pts_A is N*2, pts_B is M*2 + if pts_A.shape[0] > 0 and pts_B.shape[0] > 0: + # Distance matrix between point sets + M = np.zeros((pts_A.shape[0], pts_B.shape[0])) + for i in range(pts_A.shape[0]): + for j in range(pts_B.shape[0]): + M[i, j] = np.linalg.norm(pts_A[i, :] - pts_B[j, :]) + + # Mean distance and hausdorff distance + md = 0.5 * (np.mean(np.min(M, axis=0)) + np.mean(np.min(M, axis=1))) * dx + hd = np.max([np.max(np.min(M, axis=0)), np.max(np.min(M, axis=1))]) * dx + else: + md = None + hd = None + + return md, hd + + +def slice_2D(v_hat_es_cp, slice_num): + idx_x = v_hat_es_cp[0, :, 0][torch.abs(v_hat_es_cp[0, :, 2] - slice_num) < 0.3] + idx_y = v_hat_es_cp[0, :, 1][torch.abs(v_hat_es_cp[0, :, 2] - slice_num) < 0.3] + idx_x_t = np.round(idx_x.detach().cpu().numpy()).astype(np.int16) + idx_y_t = np.round(idx_y.detach().cpu().numpy()).astype(np.int16) + idx = np.stack((idx_x_t, idx_y_t), 1) + + return idx + + +def compute_sa_mcd_hd(v_sa_hat_es_cp, contour_sa_es, sliceall): + mcd_sa_allslice = [] + hd_sa_allslice = [] + + slice_number = [1,4,7] + threeslice = [sliceall[slice_number[0]], sliceall[slice_number[1]], sliceall[slice_number[2]]] + + print (threeslice) + for i in range(len(threeslice)): + idx_sa = slice_2D(v_sa_hat_es_cp, threeslice[i]) + idx_sa_gt = np.stack(np.nonzero(contour_sa_es[slice_number[i], :, :]), 1) + + mcd_sa, hd_sa = distance_metric(idx_sa, idx_sa_gt, 1.25) + if (mcd_sa != None) and (hd_sa != None): + mcd_sa_allslice.append(mcd_sa) + hd_sa_allslice.append(hd_sa) + + + mean_mcd_sa_allslices = np.mean(mcd_sa_allslice) if mcd_sa_allslice else None + mean_hd_sa_allslices = np.mean(hd_sa_allslice) if hd_sa_allslice else None + + return mean_mcd_sa_allslices, mean_hd_sa_allslices + + + +def FBoundary(pred_contour, gt_contour, bound_th=2): + bound_pix = bound_th if bound_th >= 1 else \ + np.ceil(bound_th * np.linalg.norm(pred_contour.shape)) + + pred_dil = binary_dilation(pred_contour, disk(bound_pix)) + gt_dil = binary_dilation(gt_contour, disk(bound_pix)) + + # Get the intersection + gt_match = gt_contour * pred_dil + pred_match = pred_contour * gt_dil + + # Area of the intersection + n_pred = np.sum(pred_contour) + n_gt = np.sum(gt_contour) + + # % Compute precision and recall + if n_pred == 0 and n_gt > 0: + precision = 1 + recall = 0 + elif n_pred > 0 and n_gt == 0: + precision = 0 + recall = 1 + elif n_pred == 0 and n_gt == 0: + precision = 1 + recall = 1 + else: + precision = np.sum(pred_match) / float(n_pred) + recall = np.sum(gt_match) / float(n_gt) + + # Compute F measure + if precision + recall == 0: + Fscore = None + else: + Fscore = 2 * precision * recall / (precision + recall) + + return Fscore + +def compute_sa_Fboundary(v_sa_hat_es_cp, contour_sa_es, sliceall, height, width): + + bfscore_all = [] + for i in range(len(sliceall)): + idx_sa = slice_2D(v_sa_hat_es_cp, sliceall[i]) + sa_pred = np.zeros(shape=(height, width)) + for j in range(idx_sa.shape[0]): + sa_pred[idx_sa[j,0], idx_sa[j,1]] = 1 + + Fscore_1 = FBoundary(sa_pred, contour_sa_es[i,:,:], 1) + Fscore_2 = FBoundary(sa_pred, contour_sa_es[i,:,:], 2) + Fscore_3 = FBoundary(sa_pred, contour_sa_es[i,:,:], 3) + Fscore_4 = FBoundary(sa_pred, contour_sa_es[i,:,:], 4) + Fscore_5 = FBoundary(sa_pred, contour_sa_es[i,:,:], 5) + + + if (Fscore_1 != None): + Fscore = (Fscore_1+Fscore_2+Fscore_3+Fscore_4+Fscore_5)/5.0 + bfscore_all.append(Fscore) + + mean_bfscore = np.mean(bfscore_all) if bfscore_all else None + + + return mean_bfscore + +def compute_la_Fboundary(pred_contour, gt_contour): + + Fscore_1 = FBoundary(pred_contour, gt_contour, 1) + Fscore_2 = FBoundary(pred_contour, gt_contour, 2) + Fscore_3 = FBoundary(pred_contour, gt_contour, 3) + Fscore_4 = FBoundary(pred_contour, gt_contour, 4) + Fscore_5 = FBoundary(pred_contour, gt_contour, 5) + + + if (Fscore_1 != None): + Fscore = (Fscore_1+Fscore_2+Fscore_3+Fscore_4+Fscore_5)/5.0 + else: + Fscore = None + + + return Fscore + + + + +def projection_weightHD_loss_SA(v_sa_hat_ed_cp, temper, height, width, depth, gt_mesh2seg_sa, status): + + weightHD_loss = [] + + for i in range(depth-1): + v_sa_idx_ed, w_sa_ed = projection(v_sa_hat_ed_cp, i, temper) + slice_loss = weightedHausdorff_batch(v_sa_idx_ed, w_sa_ed, gt_mesh2seg_sa[:,:,:,i], height, width, temper, status) + + weightHD_loss.append(slice_loss) + + weightHD_loss = torch.stack(weightHD_loss) + + loss_aver = torch.mean(weightHD_loss) + + + + return loss_aver + + + + + + + + + + + + + + + + + + + + + + + + + + +