--- a
+++ b/utils.py
@@ -0,0 +1,335 @@
+import torch
+import numpy as np
+import cv2
+import SimpleITK as sitk
+import matplotlib
+import matplotlib.pyplot as plt
+from scipy import ndimage
+import pdb
+import math
+import vtk
+from torch.autograd import Variable
+from skimage.morphology import binary_dilation, disk
+import imageio
+import os
+
+
+def cdist(x, y):
+    """
+    Compute distance between each pair of the two collections of inputs.
+    :param x: Nxd Tensor
+    :param y: Mxd Tensor
+    :res: NxM matrix where dist[i,j] is the norm between x[i,:] and y[j,:],
+          i.e. dist[i,j] = ||x[i,:]-y[j,:]||
+    """
+    differences = x.unsqueeze(1) - y.unsqueeze(0)
+    distances = torch.sum((differences+1e-6)**2, -1).sqrt()
+    return distances
+
+def generaliz_mean(tensor, dim, p=-9, keepdim=False):
+    # """
+    # Computes the softmin along some axes.
+    # Softmin is the same as -softmax(-x), i.e,
+    # softmin(x) = -log(sum_i(exp(-x_i)))
+
+    # The smoothness of the operator is controlled with k:
+    # softmin(x) = -log(sum_i(exp(-k*x_i)))/k
+
+    # :param input: Tensor of any dimension.
+    # :param dim: (int or tuple of ints) The dimension or dimensions to reduce.
+    # :param keepdim: (bool) Whether the output tensor has dim retained or not.
+    # :param k: (float>0) How similar softmin is to min (the lower the more smooth).
+    # """
+    # return -torch.log(torch.sum(torch.exp(-k*input), dim, keepdim))/k
+    """
+    The generalized mean. It corresponds to the minimum when p = -inf.
+    https://en.wikipedia.org/wiki/Generalized_mean
+    :param tensor: Tensor of any dimension.
+    :param dim: (int or tuple of ints) The dimension or dimensions to reduce.
+    :param keepdim: (bool) Whether the output tensor has dim retained or not.
+    :param p: (float<0).
+    """
+    assert p < 0
+    res= torch.mean((tensor + 1e-6)**p, dim, keepdim=keepdim)**(1./p)
+    return res
+
+
+
+def weightedHausdorff_batch(prob_loc, prob_vec, gt, height, width, temper, status):
+    max_dist = math.sqrt(height ** 2 + width ** 2)
+
+    # print (gt.shape)
+    # print (gt.sum())
+    # print (prob_vec.sum())
+    batch_size = prob_loc.shape[0]
+    # print (batch_size)
+
+
+    term_1 = []
+    term_2 = []
+
+    for i in range(batch_size):
+        prob_vec_sele = prob_vec[i, :, 0][prob_vec[i, :, 0] > torch.exp(torch.tensor((-1) * temper).cuda())]
+        idx_sele_x = prob_loc[i, :, 0][prob_vec[i, :, 0] > torch.exp(torch.tensor((-1) * temper).cuda())]
+        idx_sele_y = prob_loc[i, :, 1][prob_vec[i, :, 0] > torch.exp(torch.tensor((-1) * temper).cuda())]
+        idx_sele = torch.stack((idx_sele_x, idx_sele_y), 1)
+
+
+        # For case GT=0
+        if gt[i,:,:].sum() == 0:
+            if prob_vec_sele.sum() < 1e-3:
+                if status=='train':
+                    term_1.append(Variable(torch.tensor(0.0).cuda(), requires_grad=True))
+                    term_2.append(Variable(torch.tensor(0.0).cuda(), requires_grad=True))
+                else:
+                    term_1.append(torch.tensor(0.0).cuda())
+                    term_2.append(torch.tensor(0.0).cuda())
+            else:
+                if status == 'train':
+                    term_1.append(Variable(torch.tensor(0.0).cuda(), requires_grad=True))
+                    term_2.append(Variable((torch.tensor(max_dist)).cuda(), requires_grad=True))
+                else:
+                    term_1.append(torch.tensor(0.0).cuda())
+                    term_2.append(torch.tensor(max_dist).cuda())
+        else:
+            if prob_vec_sele.sum() < 1e-3:
+                if status == 'train':
+                    term_1.append(Variable((torch.tensor(max_dist)).cuda(), requires_grad=True))
+                    term_2.append(Variable(torch.tensor(0.0).cuda(), requires_grad=True))
+                else:
+                    term_1.append(torch.tensor(max_dist).cuda())
+                    term_2.append(torch.tensor(0.0).cuda())
+            else:
+                # find nonzero point in gt
+                idx_gt = torch.nonzero(gt[i, :, :])
+                d_matrix = cdist(idx_sele, idx_gt)
+                # print (d_matrix.shape) # N*M
+
+
+                term_1.append(
+                    (1 / (prob_vec_sele.sum() + 1e-6)) * torch.sum(prob_vec_sele * torch.min(d_matrix, 1)[0]))
+                p_replicated = prob_vec_sele.view(-1, 1).repeat(1, idx_gt.shape[0])
+                weighted_d_matrix = (1 - p_replicated) * max_dist + p_replicated * d_matrix
+                minn = generaliz_mean(weighted_d_matrix, p=-7, dim=0, keepdim=False)
+                term_2.append(torch.mean(minn))
+
+
+    # print (term_1)
+    # print (term_2)
+    term_1 = torch.stack(term_1)
+    term_2 = torch.stack(term_2)
+
+    res = term_1.mean()+term_2.mean()
+
+
+    return res
+
+
+
+def huber_loss_3d(x):
+    bsize, csize, depth, height, width = x.size()
+    d_x = torch.index_select(x, 4, torch.arange(1, width).cuda()) - torch.index_select(x, 4, torch.arange(width-1).cuda())
+    d_y = torch.index_select(x, 3, torch.arange(1, height).cuda()) - torch.index_select(x, 3, torch.arange(height-1).cuda())
+    d_z = torch.index_select(x, 2, torch.arange(1, depth).cuda()) - torch.index_select(x, 2, torch.arange(depth-1).cuda())
+    err = torch.sum(torch.mul(d_x, d_x))/width + torch.sum(torch.mul(d_y, d_y))/height + torch.sum(torch.mul(d_z, d_z))/depth
+    err /= bsize
+    tv_err = torch.sqrt(0.01+err)
+    return tv_err
+
+
+
+
+def projection(voxels, z_target, temper):
+    # voxels are transformed from meshes based on affine information of different target plane
+    # z_target is the z coordinate of the target plane, e.g., SAX is 12,17,22,27,32,37,42,47,52, 2CH is 0, 4CH is 0
+    v_idx = voxels[:,:,0:2]  # [bs, numer_of verties, x/y coordinate]
+    v_probability = torch.exp((-1) * temper * torch.square(voxels[:, :, 2:3] - z_target)) # [bs, numer_of verties, probability]
+
+
+    return v_idx, v_probability
+
+
+
+def distance_metric(pts_A, pts_B, dx):
+    # Measure the distance errors between the contours of two segmentations
+    # The manual contours are drawn on 2D slices.
+    # We calculate contour to contour distance for each slice.
+    # pts_A is N*2, pts_B is M*2
+    if pts_A.shape[0] > 0 and pts_B.shape[0] > 0:
+        # Distance matrix between point sets
+        M = np.zeros((pts_A.shape[0], pts_B.shape[0]))
+        for i in range(pts_A.shape[0]):
+            for j in range(pts_B.shape[0]):
+                M[i, j] = np.linalg.norm(pts_A[i, :] - pts_B[j, :])
+
+        # Mean distance and hausdorff distance
+        md = 0.5 * (np.mean(np.min(M, axis=0)) + np.mean(np.min(M, axis=1))) * dx
+        hd = np.max([np.max(np.min(M, axis=0)), np.max(np.min(M, axis=1))]) * dx
+    else:
+        md = None
+        hd = None
+
+    return md, hd
+
+
+def slice_2D(v_hat_es_cp, slice_num):
+    idx_x = v_hat_es_cp[0, :, 0][torch.abs(v_hat_es_cp[0, :, 2] - slice_num) < 0.3]
+    idx_y = v_hat_es_cp[0, :, 1][torch.abs(v_hat_es_cp[0, :, 2] - slice_num) < 0.3]
+    idx_x_t = np.round(idx_x.detach().cpu().numpy()).astype(np.int16)
+    idx_y_t = np.round(idx_y.detach().cpu().numpy()).astype(np.int16)
+    idx = np.stack((idx_x_t, idx_y_t), 1)
+
+    return idx
+
+
+def compute_sa_mcd_hd(v_sa_hat_es_cp, contour_sa_es, sliceall):
+    mcd_sa_allslice = []
+    hd_sa_allslice = []
+
+    slice_number = [1,4,7]
+    threeslice = [sliceall[slice_number[0]], sliceall[slice_number[1]], sliceall[slice_number[2]]]
+
+    print (threeslice)
+    for i in range(len(threeslice)):
+        idx_sa = slice_2D(v_sa_hat_es_cp, threeslice[i])
+        idx_sa_gt = np.stack(np.nonzero(contour_sa_es[slice_number[i], :, :]), 1)
+
+        mcd_sa, hd_sa = distance_metric(idx_sa, idx_sa_gt, 1.25)
+        if (mcd_sa != None) and (hd_sa != None):
+            mcd_sa_allslice.append(mcd_sa)
+            hd_sa_allslice.append(hd_sa)
+
+
+    mean_mcd_sa_allslices = np.mean(mcd_sa_allslice) if mcd_sa_allslice else None
+    mean_hd_sa_allslices = np.mean(hd_sa_allslice) if hd_sa_allslice else None
+
+    return mean_mcd_sa_allslices, mean_hd_sa_allslices
+
+
+
+def FBoundary(pred_contour, gt_contour, bound_th=2):
+    bound_pix = bound_th if bound_th >= 1 else \
+        np.ceil(bound_th * np.linalg.norm(pred_contour.shape))
+
+    pred_dil = binary_dilation(pred_contour, disk(bound_pix))
+    gt_dil = binary_dilation(gt_contour, disk(bound_pix))
+
+    # Get the intersection
+    gt_match = gt_contour * pred_dil
+    pred_match = pred_contour * gt_dil
+
+    # Area of the intersection
+    n_pred = np.sum(pred_contour)
+    n_gt = np.sum(gt_contour)
+
+    # % Compute precision and recall
+    if n_pred == 0 and n_gt > 0:
+        precision = 1
+        recall = 0
+    elif n_pred > 0 and n_gt == 0:
+        precision = 0
+        recall = 1
+    elif n_pred == 0 and n_gt == 0:
+        precision = 1
+        recall = 1
+    else:
+        precision = np.sum(pred_match) / float(n_pred)
+        recall = np.sum(gt_match) / float(n_gt)
+
+    # Compute F measure
+    if precision + recall == 0:
+        Fscore = None
+    else:
+        Fscore = 2 * precision * recall / (precision + recall)
+
+    return Fscore
+
+def compute_sa_Fboundary(v_sa_hat_es_cp, contour_sa_es, sliceall, height, width):
+
+    bfscore_all = []
+    for i in range(len(sliceall)):
+        idx_sa = slice_2D(v_sa_hat_es_cp, sliceall[i])
+        sa_pred = np.zeros(shape=(height, width))
+        for j in range(idx_sa.shape[0]):
+            sa_pred[idx_sa[j,0], idx_sa[j,1]] = 1
+
+        Fscore_1 = FBoundary(sa_pred, contour_sa_es[i,:,:], 1)
+        Fscore_2 = FBoundary(sa_pred, contour_sa_es[i,:,:], 2)
+        Fscore_3 = FBoundary(sa_pred, contour_sa_es[i,:,:], 3)
+        Fscore_4 = FBoundary(sa_pred, contour_sa_es[i,:,:], 4)
+        Fscore_5 = FBoundary(sa_pred, contour_sa_es[i,:,:], 5)
+
+
+        if (Fscore_1 != None):
+            Fscore = (Fscore_1+Fscore_2+Fscore_3+Fscore_4+Fscore_5)/5.0
+            bfscore_all.append(Fscore)
+
+    mean_bfscore = np.mean(bfscore_all) if bfscore_all else None
+
+
+    return mean_bfscore
+
+def compute_la_Fboundary(pred_contour, gt_contour):
+
+    Fscore_1 = FBoundary(pred_contour, gt_contour, 1)
+    Fscore_2 = FBoundary(pred_contour, gt_contour, 2)
+    Fscore_3 = FBoundary(pred_contour, gt_contour, 3)
+    Fscore_4 = FBoundary(pred_contour, gt_contour, 4)
+    Fscore_5 = FBoundary(pred_contour, gt_contour, 5)
+
+
+    if (Fscore_1 != None):
+        Fscore = (Fscore_1+Fscore_2+Fscore_3+Fscore_4+Fscore_5)/5.0
+    else:
+        Fscore = None
+
+
+    return Fscore
+
+
+
+
+def projection_weightHD_loss_SA(v_sa_hat_ed_cp, temper, height, width, depth, gt_mesh2seg_sa, status):
+
+    weightHD_loss = []
+
+    for i in range(depth-1):
+        v_sa_idx_ed, w_sa_ed = projection(v_sa_hat_ed_cp, i, temper)
+        slice_loss = weightedHausdorff_batch(v_sa_idx_ed, w_sa_ed, gt_mesh2seg_sa[:,:,:,i], height, width, temper, status)
+
+        weightHD_loss.append(slice_loss)
+
+    weightHD_loss = torch.stack(weightHD_loss)
+
+    loss_aver = torch.mean(weightHD_loss)
+
+
+
+    return loss_aver
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+