--- a
+++ b/utils.py
@@ -0,0 +1,474 @@
+""" Utility functions. """
+import numpy as np
+import os
+import random
+import tensorflow as tf
+
+from tensorflow.contrib.layers.python import layers as tf_layers
+from tensorflow.python.platform import flags
+import SimpleITK as sitk
+from scipy import ndimage
+import itertools
+from tensorflow.contrib import slim
+from scipy.ndimage import _ni_support
+from scipy.ndimage.morphology import distance_transform_edt, binary_erosion,\
+    generate_binary_structure
+FLAGS = flags.FLAGS
+
+## Image reader
+def get_images(paths, labels, nb_samples=None, shuffle=True):
+    if nb_samples is not None:
+        sampler = lambda x: random.sample(x, nb_samples)
+    else:
+        sampler = lambda x: x
+    images = [(i, os.path.join(path, image)) \
+        for i, path in zip(labels, paths) \
+        for image in sampler(os.listdir(path))]
+    if shuffle:
+        random.shuffle(images)
+    return images
+ 
+## Loss functions
+def mse(pred, label):
+    pred = tf.reshape(pred, [-1])
+    label = tf.reshape(label, [-1])
+    return tf.reduce_mean(tf.square(pred-label))
+
+def xent(pred, label):
+    return tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred, labels=label)
+
+def kd(data1, label1, data2, label2, bool_indicator, n_class=7, temperature=2.0):
+
+    kd_loss = 0.0
+    eps = 1e-16
+
+    prob1s = []
+    prob2s = []
+
+    for cls in range(n_class):
+        mask1 = tf.tile(tf.expand_dims(label1[:, cls], -1), [1, n_class])
+        logits_sum1 = tf.reduce_sum(tf.multiply(data1, mask1), axis=0)
+        num1 = tf.reduce_sum(label1[:, cls])
+        activations1 = logits_sum1 * 1.0 / (num1 + eps) # add eps for prevent un-sampled class resulting in NAN
+        prob1 = tf.nn.softmax(activations1 / temperature)
+        prob1 = tf.clip_by_value(prob1, clip_value_min=1e-8, clip_value_max=1.0)  # for preventing prob=0 resulting in NAN
+
+        mask2 = tf.tile(tf.expand_dims(label2[:, cls], -1), [1, n_class])
+        logits_sum2 = tf.reduce_sum(tf.multiply(data2, mask2), axis=0)
+        num2 = tf.reduce_sum(label2[:, cls])
+        activations2 = logits_sum2 * 1.0 / (num2 + eps)
+        prob2 = tf.nn.softmax(activations2 / temperature)
+        prob2 = tf.clip_by_value(prob2, clip_value_min=1e-8, clip_value_max=1.0)
+
+        KL_div = (tf.reduce_sum(prob1 * tf.log(prob1 / prob2)) + tf.reduce_sum(prob2 * tf.log(prob2 / prob1))) / 2.0
+        kd_loss += KL_div * bool_indicator[cls]
+
+        prob1s.append(prob1)
+        prob2s.append(prob2)
+
+    kd_loss = kd_loss / n_class
+
+    return kd_loss, prob1s, prob2s
+
+def JS(data1, label1, data2, label2, bool_indicator, n_class=7, temperature=2.0):
+
+    kd_loss = 0.0
+    eps = 1e-16
+
+    prob1s = []
+    prob2s = []
+
+    for cls in range(n_class):
+        mask1 = tf.tile(tf.expand_dims(label1[:, cls], -1), [1, n_class])
+        logits_sum1 = tf.reduce_sum(tf.multiply(data1, mask1), axis=0)
+        num1 = tf.reduce_sum(label1[:, cls])
+        activations1 = logits_sum1 * 1.0 / (num1 + eps) # add eps for prevent un-sampled class resulting in NAN
+        prob1 = tf.nn.softmax(activations1 / temperature)
+        prob1 = tf.clip_by_value(prob1, clip_value_min=1e-8, clip_value_max=1.0)  # for preventing prob=0 resulting in NAN
+
+        mask2 = tf.tile(tf.expand_dims(label2[:, cls], -1), [1, n_class])
+        logits_sum2 = tf.reduce_sum(tf.multiply(data2, mask2), axis=0)
+        num2 = tf.reduce_sum(label2[:, cls])
+        activations2 = logits_sum2 * 1.0 / (num2 + eps)
+        prob2 = tf.nn.softmax(activations2 / temperature)
+        prob2 = tf.clip_by_value(prob2, clip_value_min=1e-8, clip_value_max=1.0)
+
+        mean_prob = (prob1 + prob2) / 2
+
+        JS_div = (tf.reduce_sum(prob1 * tf.log(prob1 / mean_prob)) + tf.reduce_sum(prob2 * tf.log(prob2 / mean_prob))) / 2.0
+        kd_loss += JS_div * bool_indicator[cls]
+
+        prob1s.append(prob1)
+        prob2s.append(prob2)
+
+    kd_loss = kd_loss / n_class
+
+    return kd_loss, prob1s, prob2s
+
+def contrastive(feature1, label1, feature2, label2, bool_indicator=None, margin=50):
+
+    l1 = tf.argmax(label1, axis=1)
+    l2 = tf.argmax(label2, axis=1)
+    pair = tf.to_float(tf.equal(l1,l2))
+
+    delta = tf.reduce_sum(tf.square(feature1-feature2), 1) + 1e-10
+    match_loss = delta
+
+    delta_sqrt = tf.sqrt(delta + 1e-10)
+    mismatch_loss = tf.square(tf.nn.relu(margin - delta_sqrt))
+
+    if bool_indicator is None:
+        loss = tf.reduce_mean(0.5 * (pair * match_loss + (1-pair) * mismatch_loss))
+    else:
+        loss = 0.5 * tf.reduce_sum(match_loss*pair)/tf.reduce_sum(pair)
+
+    debug_dist_positive = tf.reduce_sum(delta_sqrt * pair)/tf.reduce_sum(pair)
+    debug_dist_negative = tf.reduce_sum(delta_sqrt * (1-pair))/tf.reduce_sum(1-pair)
+
+    return loss, pair, delta, debug_dist_positive, debug_dist_negative
+
+def compute_distance(feature1, label1, feature2, label2):
+    l1 = tf.argmax(label1, axis=1)
+    l2 = tf.argmax(label2, axis=1)
+    pair = tf.to_float(tf.equal(l1,l2))
+
+    delta = tf.reduce_sum(tf.square(feature1-feature2), 1)
+    delta_sqrt = tf.sqrt(delta + 1e-16)
+
+    dist_positive_pair = tf.reduce_sum(delta_sqrt * pair)/tf.reduce_sum(pair)
+    dist_negative_pair = tf.reduce_sum(delta_sqrt * (1-pair))/tf.reduce_sum(1-pair)
+
+    return dist_positive_pair, dist_negative_pair
+
+def _get_segmentation_cost(softmaxpred, seg_gt, n_class=2):
+    """
+    calculate the loss for segmentation prediction
+    :param seg_logits: probability segmentation from the segmentation network
+    :param seg_gt: ground truth segmentaiton mask
+    :return: segmentation loss, according to the cost_kwards setting, cross-entropy weighted loss and dice loss
+    """
+    dice = 0
+
+    for i in xrange(n_class):
+        #inse = tf.reduce_sum(softmaxpred[:, :, :, i]*seg_gt[:, :, :, i])
+        inse = tf.reduce_sum(softmaxpred[:, :, :, i]*seg_gt[:, :, :, i])
+        l = tf.reduce_sum(softmaxpred[:, :, :, i])
+        r = tf.reduce_sum(seg_gt[:, :, :, i])
+        dice += 2.0 * inse/(l+r+1e-7) # here 1e-7 is relaxation eps
+    dice_loss = 1 - 1.0 * dice / n_class
+
+    # ce_weighted = 0
+    # for i in xrange(n_class):
+    #     gti = seg_gt[:,:,:,i]
+    #     predi = softmaxpred[:,:,:,i]
+    #     ce_weighted += -1.0 * gti * tf.log(tf.clip_by_value(predi, 0.005, 1))
+    # ce_weighted_loss = tf.reduce_mean(ce_weighted)
+
+    # total_loss =  dice_loss 
+
+
+    return dice_loss#, dice_loss, ce_weighted_loss
+
+def _get_compactness_cost(y_pred, y_true): 
+
+    """
+    y_pred: BxHxWxC
+    """
+    """
+    lenth term
+    """
+
+    # y_pred = tf.one_hot(y_pred, depth=2)
+    # print (y_true.shape)
+    # print (y_pred.shape)
+    y_pred = y_pred[..., 1]
+    y_true = y_pred[..., 1]
+
+    x = y_pred[:,1:,:] - y_pred[:,:-1,:] # horizontal and vertical directions 
+    y = y_pred[:,:,1:] - y_pred[:,:,:-1]
+
+    delta_x = x[:,:,1:]**2
+    delta_y = y[:,1:,:]**2
+
+    delta_u = tf.abs(delta_x + delta_y) 
+
+    epsilon = 0.00000001 # where is a parameter to avoid square root is zero in practice.
+    w = 0.01
+    length = w * tf.reduce_sum(tf.sqrt(delta_u + epsilon), [1, 2])
+
+    area = tf.reduce_sum(y_pred, [1,2])
+
+    compactness_loss = tf.reduce_sum(length ** 2 / (area * 4 * 3.1415926))
+
+    return compactness_loss, tf.reduce_sum(length), tf.reduce_sum(area), delta_u
+
+# def _get_sample_masf(y_true):
+#     """
+#     y_pred: BxHxWx2
+#     """
+#     positive_mask = np.expand_dims(y_true[..., 1], axis=3)
+#     metrix_label_group = np.expand_dims(np.array([1, 0, 1, 1, 0]), axis = 1)
+#     # print (positive_mask.shape)
+#     coutour_group = np.zeros(positive_mask.shape)
+
+#     for i in range(positive_mask.shape[0]):
+#         slice_i = positive_mask[i]
+        
+#         if metrix_label_group[i] == 1:
+#             sample = (slice_i == 1)
+#         elif metrix_label_group[i] == 0:
+#             sample = (slice_i == 0)
+
+#         coutour_group[i] = sample
+
+#     return coutour_group, metrix_label_group
+
+def _get_coutour_sample(y_true):
+    """
+    y_true: BxHxWx2
+    """
+    positive_mask = np.expand_dims(y_true[..., 1], axis=3)
+    metrix_label_group = np.expand_dims(np.array([1, 0, 1, 1, 0]), axis = 1)
+    coutour_group = np.zeros(positive_mask.shape)
+
+    for i in range(positive_mask.shape[0]):
+        slice_i = positive_mask[i]
+
+        if metrix_label_group[i] == 1:
+            # generate coutour mask
+            erosion = ndimage.binary_erosion(slice_i[..., 0], iterations=1).astype(slice_i.dtype)
+            sample = np.expand_dims(slice_i[..., 0] - erosion, axis = 2)
+
+        elif metrix_label_group[i] == 0:
+            # generate background mask
+            dilation = ndimage.binary_dilation(slice_i, iterations=5).astype(slice_i.dtype)
+            sample = dilation - slice_i 
+
+        coutour_group[i] = sample
+    return coutour_group, metrix_label_group
+
+# def _get_negative(y_true):
+def _get_boundary_cost(y_pred, y_true): 
+
+    """
+    y_pred: BxHxWxC
+    """
+    """
+    lenth term
+    """
+
+    # y_pred = tf.one_hot(y_pred, depth=2)
+    # print (y_true.shape)
+    # print (y_pred.shape)
+    y_pred = y_pred[..., 1]
+    y_true = y_pred[..., 1]
+
+    x = y_pred[:,1:,:] - y_pred[:,:-1,:] # horizontal and vertical directions 
+    y = y_pred[:,:,1:] - y_pred[:,:,:-1]
+
+    delta_x = x[:,:,1:]**2
+    delta_y = y[:,1:,:]**2
+
+    delta_u = tf.abs(delta_x + delta_y) 
+
+    epsilon = 0.00000001 # where is a parameter to avoid square root is zero in practice.
+    w = 0.01
+    length = w * tf.reduce_sum(tf.sqrt(delta_u + epsilon), [1, 2]) # equ.(11) in the paper
+
+    area = tf.reduce_sum(y_pred, [1,2])
+
+    compactness_loss = tf.reduce_sum(length ** 2 / (area * 4 * 3.1415926))
+
+    return compactness_loss, tf.reduce_sum(length), tf.reduce_sum(area)
+
+def check_folder(log_dir):
+    if not os.path.exists(log_dir):
+        print ("Allocating '{:}'".format(log_dir))
+        os.makedirs(log_dir)
+    return log_dir
+
+def _eval_dice(gt_y, pred_y, detail=False):
+
+    class_map = {  # a map used for mapping label value to its name, used for output
+        "0": "bg",
+        "1": "CZ",
+        "2": "prostate"
+    }
+
+    dice = []
+
+    for cls in xrange(1,2):
+
+        gt = np.zeros(gt_y.shape)
+        pred = np.zeros(pred_y.shape)
+
+        gt[gt_y == cls] = 1
+        pred[pred_y == cls] = 1
+
+        dice_this = 2*np.sum(gt*pred)/(np.sum(gt)+np.sum(pred))
+        dice.append(dice_this)
+
+        if detail is True:
+            #print ("class {}, dice is {:2f}".format(class_map[str(cls)], dice_this))
+            logging.info("class {}, dice is {:2f}".format(class_map[str(cls)], dice_this))
+    return dice
+
+def __surface_distances(result, reference, voxelspacing=None, connectivity=1):
+    """
+    The distances between the surface voxel of binary objects in result and their
+    nearest partner surface voxel of a binary object in reference.
+    """
+    result = np.atleast_1d(result.astype(np.bool))
+    reference = np.atleast_1d(reference.astype(np.bool))
+    if voxelspacing is not None:
+        voxelspacing = _ni_support._normalize_sequence(voxelspacing, result.ndim)
+        voxelspacing = np.asarray(voxelspacing, dtype=np.float64)
+        if not voxelspacing.flags.contiguous:
+            voxelspacing = voxelspacing.copy()
+            
+    # binary structure
+    footprint = generate_binary_structure(result.ndim, connectivity)
+    
+    # test for emptiness
+    if 0 == np.count_nonzero(result): 
+        raise RuntimeError('The first supplied array does not contain any binary object.')
+    if 0 == np.count_nonzero(reference): 
+        raise RuntimeError('The second supplied array does not contain any binary object.')    
+            
+    # extract only 1-pixel border line of objects
+    result_border = result ^ binary_erosion(result, structure=footprint, iterations=1)
+    reference_border = reference ^ binary_erosion(reference, structure=footprint, iterations=1)
+    
+    # compute average surface distance        
+    # Note: scipys distance transform is calculated only inside the borders of the
+    #       foreground objects, therefore the input has to be reversed
+    dt = distance_transform_edt(~reference_border, sampling=voxelspacing)
+    sds = dt[result_border]
+    
+    return sds
+
+def asd(result, reference, voxelspacing=None, connectivity=1):
+  
+    sds = __surface_distances(result, reference, voxelspacing, connectivity)
+    asd = sds.mean()
+    return asd
+
+def calculate_hausdorff(lP,lT,spacing):
+
+    return asd(lP, lT, spacing)
+
+def _eval_haus(pred, gt, spacing, detail=False):
+    '''
+    :param pred: whole brain prediction
+    :param gt: whole
+    :param detail:
+    :return: a list, indicating Dice of each class for one case
+    '''
+    haus = []
+
+    for cls in range(1,2):
+        pred_i = np.zeros(pred.shape)
+        pred_i[pred == cls] = 1
+        gt_i = np.zeros(gt.shape)
+        gt_i[gt == cls] = 1
+
+        # hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
+        # hausdorff_distance_filter.Execute(gt_i, pred_i)
+
+        haus_cls = calculate_hausdorff(gt_i, (pred_i), spacing)
+
+        haus.append(haus_cls)
+
+        if detail is True:
+            logging.info("class {}, haus is {:4f}".format(class_map[str(cls)], haus_cls))
+    # logging.info("4 class average haus is {:4f}".format(np.mean(haus)))
+
+    return haus
+
+def _connectivity_region_analysis(mask):
+    s = [[0,1,0],
+         [1,1,1],
+         [0,1,0]]
+    label_im, nb_labels = ndimage.label(mask)#, structure=s)
+
+    sizes = ndimage.sum(mask, label_im, range(nb_labels + 1))
+
+    # plt.imshow(label_im)        
+    label_im[label_im != np.argmax(sizes)] = 0
+    label_im[label_im == np.argmax(sizes)] = 1
+
+    return label_im
+
+def _crop_object_region(mask, prediction):
+
+    limX, limY, limZ = np.where(mask>0)
+    min_z = np.min(limZ)
+    max_z = np.max(limZ)
+
+    prediction[..., :np.min(limZ)] = 0
+    prediction[..., np.max(limZ)+1:] = 0
+
+    return prediction
+
+def parse_fn(data_path):
+    '''
+    :param image_path: path to a folder of a patient
+    :return: normalized entire image with its corresponding label
+    In an image, the air region is 0, so we only calculate the mean and std within the brain area
+    For any image-level normalization, do it here
+    '''
+    path = data_path.split(",")
+    image_path = path[0]
+    label_path = path[1]
+    #itk_image = zoom2shape(image_path, [512,512])#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz'))
+    #itk_mask = zoom2shape(label_path, [512,512], label=True)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz'))
+    itk_image = sitk.ReadImage(image_path)#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz'))
+    itk_mask = sitk.ReadImage(label_path)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz'))
+    # itk_image = sitk.ReadImage(os.path.join(image_path, 'T2_FLAIR_unbiased_brain_rigid_to_mni.nii.gz'))
+
+    image = sitk.GetArrayFromImage(itk_image)
+    mask = sitk.GetArrayFromImage(itk_mask)
+    #image[image >= 1000] = 1000
+    binary_mask = np.ones(mask.shape)
+    mean = np.sum(image * binary_mask) / np.sum(binary_mask)
+    std = np.sqrt(np.sum(np.square(image - mean) * binary_mask) / np.sum(binary_mask))
+    image = (image - mean) / std  # normalize per image, using statistics within the brain, but apply to whole image
+
+    mask[mask==2] = 1
+
+    return image.transpose([1,2,0]), mask.transpose([1,2,0]) # transpose the orientation of the
+
+
+def parse_fn_haus(data_path):
+    '''
+    :param image_path: path to a folder of a patient
+    :return: normalized entire image with its corresponding label
+    In an image, the air region is 0, so we only calculate the mean and std within the brain area
+    For any image-level normalization, do it here
+    '''
+    path = data_path.split(",")
+    image_path = path[0]
+    label_path = path[1]
+    #itk_image = zoom2shape(image_path, [512,512])#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz'))
+    #itk_mask = zoom2shape(label_path, [512,512], label=True)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz'))
+    itk_image = sitk.ReadImage(image_path)#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz'))
+    itk_mask = sitk.ReadImage(label_path)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz'))
+    # itk_image = sitk.ReadImage(os.path.join(image_path, 'T2_FLAIR_unbiased_brain_rigid_to_mni.nii.gz'))
+    spacing = itk_mask.GetSpacing()
+
+    image = sitk.GetArrayFromImage(itk_image)
+    mask = sitk.GetArrayFromImage(itk_mask)
+    #image[image >= 1000] = 1000
+    binary_mask = np.ones(mask.shape)
+    mean = np.sum(image * binary_mask) / np.sum(binary_mask)
+    std = np.sqrt(np.sum(np.square(image - mean) * binary_mask) / np.sum(binary_mask))
+    image = (image - mean) / std  # normalize per image, using statistics within the brain, but apply to whole image
+
+    mask[mask==2] = 1
+
+    return image.transpose([1,2,0]), mask.transpose([1,2,0]), spacing
+
+def show_all_variables():
+    model_vars = tf.trainable_variables()
+    slim.model_analyzer.analyze_vars(model_vars, print_info=True)
+