--- a
+++ b/Segmentation/model_keras.py
@@ -0,0 +1,1427 @@
+"""
+A set of functions defining segmentation models and code for evaluating networks.
+"""
+import os
+import numpy as np
+import tensorflow as tf
+import keras
+import nibabel as nib
+import nrrd
+import itertools
+import json
+from skimage.transform import resize
+from skimage import measure
+from scipy.ndimage.morphology import binary_closing
+import matplotlib.pyplot as plt
+from shutil import copy
+from keras.wrappers.scikit_learn import KerasClassifier
+from sklearn.metrics import precision_score, recall_score, precision_recall_curve, \
+    f1_score, roc_curve, roc_auc_score, jaccard_similarity_score
+
+
+def cnn_model_3D_3lyr_relu_dice(pretrained_weights=None, input_size=(25, 25, 25, 4), lr=1e-5):
+    padding = 'VALID'
+    """ CNN model function"""
+    # Constants
+    chan_in = 4 #features.shape[4]
+    in_shape = input_size #(25, 25, 25, 4)
+    f1 = 24
+    f2 = 48
+    f3 = 64
+    f4 = 64
+    f5 = 48
+    f6 = 24
+    kernel_size = (3, 3, 3)
+    pool_size = (2, 2, 2)
+    strides = (1, 1, 1)
+
+
+    with tf.device('/GPU:0'):
+        # Input layer
+        input_layer = keras.layers.Input(shape=in_shape)
+        drop1 = keras.layers.Dropout(0.1)(input_layer)
+
+        # Convolutional layers
+        conv1 = keras.layers.Conv3D(filters=f1, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Conv1')(drop1)
+        pool1 = keras.layers.MaxPool3D(pool_size=pool_size, strides=pool_size, data_format='channels_last', name='Pool1')(conv1)
+
+
+        conv2 = keras.layers.Conv3D(filters=f2, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Conv2')(pool1)
+        pool2 = keras.layers.MaxPool3D(pool_size=pool_size, strides=pool_size, data_format='channels_last', name='Pool2')(conv2)
+
+        conv3 = keras.layers.Conv3D(filters=f3, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Conv3')(pool2)
+
+        # Deconvolutional layers
+        dconv1 = keras.layers.Deconvolution3D(filters=f4, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Dconv1')(conv3)
+        unpool1 = keras.layers.UpSampling3D(size=pool_size)(dconv1)
+
+        dconv2 = keras.layers.Conv3DTranspose(filters=f5, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Dconv2')(unpool1)
+        unpool2 = keras.layers.UpSampling3D(size=pool_size)(dconv2)
+
+        dconv3 = keras.layers.Conv3DTranspose(filters=f6, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Dconv3')(unpool2)
+
+        output = keras.layers.Conv3D(filters=1, kernel_size=kernel_size, padding=padding, activation='sigmoid')(dconv3)
+
+        model = keras.models.Model(inputs=input_layer, outputs=output)
+
+        opt = keras.optimizers.Adam(lr=lr)
+        model.compile(optimizer=opt, loss=dice_loss, metrics=[tf.keras.metrics.binary_accuracy, dice_metric])
+
+    model.summary()
+
+    if pretrained_weights:
+        model.load_weights(pretrained_weights)
+
+    return model, opt
+
+def cnn_model_3D_3lyr_do_relu_dice(pretrained_weights=None, input_size=(25, 25, 25, 4), lr=1e-5):
+    padding = 'VALID'
+    """ CNN model function"""
+    # Constants
+    chan_in = 4 #features.shape[4]
+    in_shape = input_size #(25, 25, 25, 4)
+    f1 = 24
+    f2 = 48
+    f3 = 64
+    f4 = 64
+    f5 = 48
+    f6 = 24
+    kernel_size = (3, 3, 3)
+    pool_size = (2, 2, 2)
+    strides = (1, 1, 1)
+
+
+    with tf.device('/GPU:0'):
+        # Input layer
+        input_layer = keras.layers.Input(shape=in_shape)
+        drop0 = keras.layers.SpatialDropout3D(0.1)(input_layer)
+
+        # Convolutional layers
+        conv1 = keras.layers.Conv3D(filters=f1, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Conv1')(drop0)
+        drop1 = keras.layers.SpatialDropout3D(0.1)(conv1)
+        pool1 = keras.layers.MaxPool3D(pool_size=pool_size, strides=pool_size, data_format='channels_last', name='Pool1')(drop1)
+
+
+        conv2 = keras.layers.Conv3D(filters=f2, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Conv2')(pool1)
+        drop2 = keras.layers.SpatialDropout3D(0.1)(conv2)
+        pool2 = keras.layers.MaxPool3D(pool_size=pool_size, strides=pool_size, data_format='channels_last', name='Pool2')(drop2)
+
+        conv3 = keras.layers.Conv3D(filters=f3, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Conv3')(pool2)
+        drop3 = keras.layers.SpatialDropout3D(0.1)(conv3)
+
+        # Deconvolutional layers
+        dconv1 = keras.layers.Deconvolution3D(filters=f4, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Dconv1')(drop3)
+        drop3 = keras.layers.SpatialDropout3D(0.1)(dconv1)
+        unpool1 = keras.layers.UpSampling3D(size=pool_size)(drop3)
+
+        dconv2 = keras.layers.Conv3DTranspose(filters=f5, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Dconv2')(unpool1)
+        drop4 = keras.layers.SpatialDropout3D(0.1)(dconv2)
+        unpool2 = keras.layers.UpSampling3D(size=pool_size)(drop4)
+
+        dconv3 = keras.layers.Conv3DTranspose(filters=f6, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Dconv3')(unpool2)
+
+        output = keras.layers.Conv3D(filters=1, kernel_size=kernel_size, padding=padding, activation='sigmoid')(dconv3)
+
+        model = keras.models.Model(inputs=input_layer, outputs=output)
+
+        opt = keras.optimizers.Adam(lr=lr)
+        model.compile(optimizer=opt, loss=dice_loss, metrics=['accuracy'])
+
+    model.summary()
+
+    if pretrained_weights:
+        model.load_weights(pretrained_weights)
+
+    return model, opt
+
+def cnn_model_3D_3lyr_do_relu_dice_skip(pretrained_weights=None, input_size=(25, 25, 25, 4), lr=1e-5):
+    padding = 'VALID'
+    """ CNN model function"""
+    # Constants
+    chan_in = 4 #features.shape[4]
+    in_shape = input_size #(25, 25, 25, 4)
+    f1 = 24
+    f2 = 48
+    f3 = 64
+    f4 = 64
+    f5 = 48
+    f6 = 24
+    kernel_size = (3, 3, 3)
+    pool_size = (2, 2, 2)
+    strides = (1, 1, 1)
+
+
+    with tf.device('/GPU:0'):
+        # Input layer
+        input_layer = keras.layers.Input(shape=in_shape)
+        drop0 = keras.layers.SpatialDropout3D(0.1)(input_layer)
+
+        # Convolutional layers
+        conv1 = keras.layers.Conv3D(filters=f1, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Conv1')(drop0)
+        drop1 = keras.layers.SpatialDropout3D(0.1)(conv1)
+        pool1 = keras.layers.MaxPool3D(pool_size=pool_size, strides=pool_size, data_format='channels_last', name='Pool1')(drop1)
+
+
+        conv2 = keras.layers.Conv3D(filters=f2, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Conv2')(pool1)
+        drop2 = keras.layers.SpatialDropout3D(0.1)(conv2)
+        pool2 = keras.layers.MaxPool3D(pool_size=pool_size, strides=pool_size, data_format='channels_last', name='Pool2')(drop2)
+
+        conv3 = keras.layers.Conv3D(filters=f3, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Conv3')(pool2)
+        drop3 = keras.layers.SpatialDropout3D(0.1)(conv3)
+
+        # Deconvolutional layers
+        dconv1 = keras.layers.Deconvolution3D(filters=f4, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Dconv1')(drop3)
+        drop3 = keras.layers.SpatialDropout3D(0.1)(dconv1)
+        unpool1 = keras.layers.UpSampling3D(size=pool_size)(drop3)
+
+        cat1 = keras.layers.concatenate([drop2, unpool1])
+        dconv2 = keras.layers.Conv3DTranspose(filters=f5,
+                                              kernel_size=kernel_size,
+                                              strides=strides,
+                                              padding=padding,
+                                              activation='relu',
+                                              name='Dconv2')(cat1)
+        drop4 = keras.layers.SpatialDropout3D(0.1)(dconv2)
+        unpool2 = keras.layers.UpSampling3D(size=pool_size)(drop4)
+
+        cat2 = keras.layers.concatenate([drop1, unpool2])
+        dconv3 = keras.layers.Conv3DTranspose(filters=f6,
+                                              kernel_size=kernel_size,
+                                              strides=strides,
+                                              padding=padding,
+                                              activation='relu',
+                                              name='Dconv3')(cat2)
+
+        output = keras.layers.Conv3D(filters=1, kernel_size=kernel_size, padding=padding, activation='sigmoid')(dconv3)
+
+        model = keras.models.Model(inputs=input_layer, outputs=output)
+
+        opt = keras.optimizers.Adam(lr=lr)
+        model.compile(optimizer=opt, loss=dice_loss, metrics=[keras.metrics.binary_accuracy, dice_metric])
+
+    model.summary()
+
+    if pretrained_weights:
+        model.load_weights(pretrained_weights)
+
+    return model, opt
+
+def cnn_model_3D_3lyr_do_relu_xentropy_skip(pretrained_weights=None, input_size=(25, 25, 25, 4), lr=1e-5):
+    padding = 'VALID'
+    """ CNN model function"""
+    # Constants
+    chan_in = 4 #features.shape[4]
+    in_shape = input_size #(25, 25, 25, 4)
+    f1 = 24
+    f2 = 48
+    f3 = 64
+    f4 = 64
+    f5 = 48
+    f6 = 24
+    kernel_size = (3, 3, 3)
+    pool_size = (2, 2, 2)
+    strides = (1, 1, 1)
+
+
+    with tf.device('/GPU:0'):
+        # Input layer
+        input_layer = keras.layers.Input(shape=in_shape)
+        drop0 = keras.layers.SpatialDropout3D(0.1)(input_layer)
+
+        # Convolutional layers
+        conv1 = keras.layers.Conv3D(filters=f1, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Conv1')(drop0)
+        drop1 = keras.layers.SpatialDropout3D(0.1)(conv1)
+        pool1 = keras.layers.MaxPool3D(pool_size=pool_size, strides=pool_size, data_format='channels_last', name='Pool1')(drop1)
+
+
+        conv2 = keras.layers.Conv3D(filters=f2, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Conv2')(pool1)
+        drop2 = keras.layers.SpatialDropout3D(0.1)(conv2)
+        pool2 = keras.layers.MaxPool3D(pool_size=pool_size, strides=pool_size, data_format='channels_last', name='Pool2')(drop2)
+
+        conv3 = keras.layers.Conv3D(filters=f3, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Conv3')(pool2)
+        drop3 = keras.layers.SpatialDropout3D(0.1)(conv3)
+
+        # Deconvolutional layers
+        dconv1 = keras.layers.Deconvolution3D(filters=f4, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Dconv1')(drop3)
+        drop3 = keras.layers.SpatialDropout3D(0.1)(dconv1)
+        unpool1 = keras.layers.UpSampling3D(size=pool_size)(drop3)
+
+        cat1 = keras.layers.concatenate([drop2, unpool1])
+        dconv2 = keras.layers.Conv3DTranspose(filters=f5,
+                                              kernel_size=kernel_size,
+                                              strides=strides,
+                                              padding=padding,
+                                              activation='relu',
+                                              name='Dconv2')(cat1)
+        drop4 = keras.layers.SpatialDropout3D(0.1)(dconv2)
+        unpool2 = keras.layers.UpSampling3D(size=pool_size)(drop4)
+
+        cat2 = keras.layers.concatenate([drop1, unpool2])
+        dconv3 = keras.layers.Conv3DTranspose(filters=f6,
+                                              kernel_size=kernel_size,
+                                              strides=strides,
+                                              padding=padding,
+                                              activation='relu',
+                                              name='Dconv3')(cat2)
+
+        output = keras.layers.Conv3D(filters=1, kernel_size=kernel_size, padding=padding, activation='sigmoid')(dconv3)
+
+        model = keras.models.Model(inputs=input_layer, outputs=output)
+
+        opt = keras.optimizers.Adam(lr=lr)
+        model.compile(optimizer=opt, loss='binary_crossentropy', metrics=[keras.metrics.binary_accuracy, dice_metric])
+
+    model.summary()
+
+    if pretrained_weights:
+        model.load_weights(pretrained_weights)
+
+    return model, opt
+
+def cnn_model_3D_3lyr_do_relu_xentropy(pretrained_weights=None, input_size=(25, 25, 25, 4), lr=1e-5):
+    padding = 'VALID'
+    """ CNN model function"""
+    # Constants
+    chan_in = 4 #features.shape[4]
+    in_shape = input_size #(25, 25, 25, 4)
+    f1 = 24
+    f2 = 48
+    f3 = 64
+    f4 = 64
+    f5 = 48
+    f6 = 24
+    kernel_size = (3, 3, 3)
+    pool_size = (2, 2, 2)
+    strides = (1, 1, 1)
+
+
+
+    with tf.device('/GPU:0'):
+        # Input layer
+        input_layer = keras.layers.Input(shape=in_shape)
+        drop0 = keras.layers.Dropout(0.1)(input_layer)
+
+        # Convolutional layers
+        conv1 = keras.layers.Conv3D(filters=f1, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Conv1')(drop0)
+        drop1 = keras.layers.Dropout(0.1)(conv1)
+        pool1 = keras.layers.MaxPool3D(pool_size=pool_size, strides=pool_size, data_format='channels_last', name='Pool1')(drop1)
+
+
+        conv2 = keras.layers.Conv3D(filters=f2, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Conv2')(pool1)
+        drop2 = keras.layers.Dropout(0.1)(conv2)
+        pool2 = keras.layers.MaxPool3D(pool_size=pool_size, strides=pool_size, data_format='channels_last', name='Pool2')(drop2)
+
+        conv3 = keras.layers.Conv3D(filters=f3, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Conv3')(pool2)
+        drop3 = keras.layers.Dropout(0.1)(conv3)
+
+        # Deconvolutional layers
+        dconv1 = keras.layers.Deconvolution3D(filters=f4, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Dconv1')(drop3)
+        drop3 = keras.layers.Dropout(0.1)(dconv1)
+        unpool1 = keras.layers.UpSampling3D(size=pool_size)(drop3)
+
+        dconv2 = keras.layers.Conv3DTranspose(filters=f5, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Dconv2')(unpool1)
+        drop4 = keras.layers.Dropout(0.1)(dconv2)
+        unpool2 = keras.layers.UpSampling3D(size=pool_size)(drop4)
+
+        dconv3 = keras.layers.Conv3DTranspose(filters=f6, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu', name='Dconv3')(unpool2)
+
+        output = keras.layers.Conv3D(filters=1, kernel_size=kernel_size, padding=padding, activation='sigmoid')(dconv3)
+
+        model = keras.models.Model(inputs=input_layer, outputs=output)
+
+        opt = keras.optimizers.Adam(lr=lr)
+        model.compile(optimizer=opt, loss='binary_crossentropy', metrics=[keras.metrics.binary_accuracy, dice_metric])
+
+    model.summary()
+
+    if pretrained_weights:
+        model.load_weights(pretrained_weights)
+
+    return model, opt
+
+
+def multi_gpu_model(model, opt, loss, gpus):
+    parallel_model = keras.utils.multi_gpu_model(model, gpus=2)
+    parallel_model.compile(loss=loss, optimizer=opt, metrics=['accuracy'])
+    return parallel_model
+
+
+def load_data_3D(files):
+    images = np.array([nib.load(file).get_data().astype('float32').squeeze() for file in files])
+
+    images = np.swapaxes(images, 0, 3)
+
+    # Normalize images
+    sz = images.shape
+    for z in range(sz[3]):
+        mn = images[:, :, :, z].mean()
+        std = images[:, :, :, z].std()
+        images[:, :, :, z] -= mn
+        images[:, :, :, z] /= std
+
+    return images
+
+
+def load_label_3D(files, sz):
+    # Load image
+    if '.nii' in files:
+        # If the image is Nifti format
+        label = nib.load(files).get_data().astype('float32').squeeze()
+
+        # Resize image
+        label /= label.max()
+        label = resize(label, (sz[1], sz[2]))
+
+        # Make binary
+        label[label < 0.5] = 0
+        label[label > 0.5] = 1
+
+        # Close the mask
+        label = binary_closing(label, structure=np.ones(shape=(3, 3, 3)))
+
+    else:
+        # If the image is nrrd format
+        _nrrd = nrrd.read(files)
+        tmpy = _nrrd[0][0, :, :, :]
+        info = _nrrd[1]
+
+         # Swap axis to match images
+        # tmpy = np.rollaxis(tmpy, axis=2)
+        tmpy = np.swapaxes(tmpy, 0, 2)
+
+        # Pad label to match images
+        offsets = info['space origin'].astype(int)
+        offsets_inds = [2, 1, 0]
+
+        # Z direction
+        offset = offsets[offsets_inds[0]]
+        pad = sz[0] - tmpy.shape[0] - offset
+        tmpy = np.concatenate((np.zeros((offset, tmpy.shape[1], tmpy.shape[2])), tmpy), axis=0)
+        tmpy = np.concatenate((tmpy, np.zeros((pad, tmpy.shape[1], tmpy.shape[2]))), axis=0)
+
+        # X direction
+        offset = offsets[offsets_inds[1]]
+        pad = sz[1] - tmpy.shape[1] - offset
+        tmpy = np.concatenate((np.zeros((tmpy.shape[0], offset, tmpy.shape[2])), tmpy), axis=1)
+        tmpy = np.concatenate((tmpy, np.zeros((tmpy.shape[0], pad, tmpy.shape[2]))), axis=1)
+
+        # Y direction
+        offset = offsets[offsets_inds[2]]
+        pad = sz[2] - tmpy.shape[2] - offset
+        tmpy = np.concatenate((np.zeros((tmpy.shape[0], tmpy.shape[1], offset)), tmpy), axis=2)
+        tmpy = np.concatenate((tmpy, np.zeros((tmpy.shape[0], tmpy.shape[1], pad))), axis=2)
+
+        # Assign to label
+        label = np.swapaxes(tmpy, 0, 2).astype(np.bool)
+
+    # Reorganize data to put slices first
+    label = np.swapaxes(label[np.newaxis, :, :, :], 0, 3)
+
+    return label
+
+
+def make_train_3D(X, Y, block_size, oversamp=1, lab_trun=4):
+    # print('Making volume patches')
+    sz = X.shape  # z, x, y, channels, vols
+
+    if len(sz) == 4:
+        sz = sz + (1,)
+        X = X[:, :, :, :, np.newaxis]
+        Y = Y[:, :, :, :, np.newaxis]
+
+    # Number of volumes in each dimension
+    num_vols = [int((sz[i] * oversamp) // (block_size[i] - lab_trun) + 1) for i in range(len(block_size))]
+
+    # Get starting indices of each block (zind[0}:zind[0] + block_size[0])
+    zind = np.linspace(0, sz[0] - block_size[0], num_vols[0], dtype=int)
+    xind = np.linspace(0, sz[1] - block_size[1], num_vols[1], dtype=int)
+    yind = np.linspace(0, sz[2] - block_size[2], num_vols[2], dtype=int)
+
+    # Preallocate volumes - samples, block, block, block, channels
+    in_patches = np.zeros(shape=(np.product(num_vols) * sz[4], block_size[0], block_size[1], block_size[2], sz[3]))
+    lab_patches = np.zeros(shape=(np.product(num_vols) * sz[4], block_size[0] - lab_trun, block_size[1] - lab_trun, block_size[2] - lab_trun, 1))
+
+    # Make volumes
+    n = 0
+    for vol in range(sz[4]):
+        for z in itertools.product(zind, xind, yind):
+            # print(z)
+            in_patches[n, :, :, :, :] = X[z[0]:z[0] + block_size[0], z[1]:z[1] + block_size[1], z[2]:z[2] + block_size[2], :, vol]
+
+            lab_patches[n, :, :, :, :] = Y[z[0] + lab_trun//2:z[0] + block_size[0] - lab_trun//2, z[1] + lab_trun//2:z[1] + block_size[1] - lab_trun//2, z[2] + lab_trun//2:z[2] + block_size[2] - lab_trun//2, :, vol]
+
+            n += 1
+
+    return in_patches, lab_patches
+
+
+def recon_test_3D(X, Y, orig_size, block_size, oversamp=1, lab_trun=4):
+    print('Reconstructing volume from patches')
+
+    # Number of volumes in each dimension
+    num_vols = [int((orig_size[i] * oversamp) // (block_size[i] - lab_trun) + 1) for i in range(len(block_size))]
+
+    # Get input dimensions
+    szy = Y.shape
+    sz = X.shape
+
+    # Add volume dimension if it does not exist
+    total_vols = int(X.shape[0] / np.prod(num_vols))
+    if len(orig_size) == 4:
+        orig_size = orig_size + (total_vols,)
+
+    if total_vols > 1:
+
+        split = int(szy[0] / total_vols)
+        y_vols = np.zeros(shape=(split, szy[1], szy[2], szy[3], szy[4], total_vols) )
+        x_vols = np.zeros(shape=(split, sz[1], sz[2], sz[3], sz[4], total_vols))
+
+        for i in range(total_vols):
+            inds = list(range(i*split, (i+1)*split))
+            y_vols[:,:,:,:,:,i] = Y[inds,:,:,:,:]
+            x_vols[:,:,:,:,:,i] = X[inds,:,:,:,:]
+
+        # Update volumes
+        Y = y_vols
+        X = x_vols
+
+        # Clear temp variables
+        y_vols, x_vols = None, None
+
+    # Add volume axis if it does not exist
+    szy = Y.shape
+    sz = X.shape
+    if len(sz) < 6:
+        X = X[:, :, :, :, :, np.newaxis]
+        Y = Y[:, :, :, :, :, np.newaxis]
+
+    # Get starting indices of each block (zind[0}:zind[0] + block_size[0])
+    zind = np.linspace(0, orig_size[0] - block_size[0], num_vols[0], dtype=int)
+    xind = np.linspace(0, orig_size[1] - block_size[1], num_vols[1], dtype=int)
+    yind = np.linspace(0, orig_size[2] - block_size[2], num_vols[2], dtype=int)
+
+    # Preallocate arrays - z, x, y, channels, vols
+    in_recon = np.zeros(shape=(orig_size))
+    lab_recon = np.zeros(shape=(orig_size[0], orig_size[1], orig_size[2], 1, orig_size[4]))
+    inds_in = np.zeros_like(in_recon, dtype=np.int8)
+    inds_lab = np.zeros_like(lab_recon, dtype=np.int8)
+
+    # Reconstruct images
+    for vol in range(orig_size[4]):
+        n = 0
+        for z in itertools.product(zind, xind, yind):
+            # print(z)
+            # Update images
+            in_recon[z[0]:z[0] + block_size[0], z[1]:z[1] + block_size[1], z[2]:z[2] + block_size[2], :, vol] += X[n, :, :, :, :, vol]
+            lab_recon[z[0] + lab_trun//2:z[0] + block_size[0] - lab_trun//2, z[1] + lab_trun//2:z[1] + block_size[1] - lab_trun//2, z[2] + lab_trun//2:z[2] + block_size[2] - lab_trun//2, :, vol] += Y[n, :, :, :, :, vol]
+
+            # Keep track of duplicate values
+            inds_in[z[0]:z[0] + block_size[0], z[1]:z[1] + block_size[1], z[2]:z[2] + block_size[2], :, vol] += 1
+            inds_lab[z[0] + lab_trun//2:z[0] + block_size[0] - lab_trun//2, z[1] + lab_trun//2:z[1] + block_size[1] - lab_trun//2, z[2] + lab_trun//2:z[2] + block_size[2] - lab_trun//2, :, vol] += 1
+
+            n += 1
+
+    in_recon /= inds_in
+    lab_recon[inds_lab > 0] /= inds_lab[inds_lab > 0]
+
+    return in_recon, lab_recon
+
+
+def training_eval(model, features, spath, orig_size, oversamp, lab_trun):
+    # Eval training data
+    train_test = features['X']
+    train_test_lab = features['Y']
+
+    # Get block size
+    block_size = train_test.shape
+    block_size = block_size[1:4]
+
+    train_ev = model.evaluate(x=train_test, y=train_test_lab)
+    print('Training evaluation\n\t')
+    print(train_ev)
+    train_pred = model.predict(train_test)
+
+    # Compute ideal threshold
+    from sklearn.metrics import precision_recall_curve
+    precisions, recalls, thresholds = precision_recall_curve(y_true=train_test_lab.reshape(-1),
+                                                             probas_pred=train_pred.reshape(-1), pos_label=None)
+    sub = np.abs(recalls[:-1] - precisions[:-1])
+    # val, idx = min((val, idx) for (idx, val) in enumerate(sub))
+    # threshold = thresholds[idx]
+    inds = sub < 0.025
+    ind = np.argmax(inds)
+    threshold = thresholds[ind]
+    # threshold = threshold[0]
+
+
+    # Use threshold
+    train_pred = train_pred > threshold
+
+    # Convert back to images
+    _, Y = recon_test_3D(X=train_test, Y=train_test_lab, orig_size=orig_size, block_size=block_size, oversamp=oversamp,
+                         lab_trun=lab_trun)
+    X, train_pred = recon_test_3D(X=train_test, Y=train_pred, orig_size=orig_size, block_size=block_size,
+                                  oversamp=oversamp, lab_trun=lab_trun)
+
+    # Plot images
+    Y_mask = np.ma.masked_where(Y.astype(bool) == 0, Y.astype(bool))
+    train_pred_mask = np.ma.masked_where(train_pred.astype(bool) == 0, train_pred.astype(bool))
+    # train_pred_mask[train_pred_mask<0.9] = 0
+
+    slices = range(20, 40, 4)
+    resh = (Y.shape[1] * len(slices), Y.shape[2])
+    fig, ax = plt.subplots(nrows=3, ncols=1)
+
+    ax[0].imshow(X[slices, :, :, 2, 0].reshape(resh).T, cmap='gray')
+    ax[1].imshow(X[slices, :, :, 2, 0].reshape(resh).T, cmap='gray')
+    ax[1].imshow(train_pred_mask[slices, :, :, 0, 0].reshape(resh).T, cmap='summer', alpha=0.5)
+    ax[2].imshow(X[slices, :, :, 2, 0].reshape(resh).T, cmap='gray')
+    ax[2].imshow(Y_mask[slices, :, :, 0, 0].reshape(resh).T, cmap='summer', alpha=0.5)
+
+    ax[0].set_title('Input slice (T2)')
+    ax[1].set_title('Output slice')
+    ax[2].set_title('Label')
+
+    ax[2].set_xticklabels(slices)
+
+    for a in ax.ravel():
+        a.axis('off')
+
+    fig.savefig(os.path.join(spath, 'Train_set.pdf'), dpi=300, format='pdf', frameon=False)
+
+
+def test_set_3D(model, X, Y, spath, orig_size, block_size, oversamp, lab_trun, batch_size, threshold=None, vols=None, continuous=False):
+
+    sz = X.shape
+    szy = Y.shape
+    Y_pred = np.zeros_like(Y)
+
+    # Test the segemntation
+    if len(sz) > 5:
+        vols = sz[5]
+        for vol in range(vols):
+            temp = X[:, :, :, :, :, vol]
+            Y_pred[:, :, :, :, :, vol] = model.predict(x=temp, batch_size=batch_size)
+    else:
+
+        Y_pred = model.predict(x=X, batch_size=batch_size)
+
+    # Evaluate metrics
+    if threshold:
+        print('\tUsing threshold from training set:\t%0.3f' %threshold)
+
+    else:
+        threshold = 0.9
+        print('\tUsing fixed threshold:\t%0.3f' %threshold)
+
+    Y_pred_thresh = Y_pred > threshold
+
+    # convert volume patches to images
+    _, Y             = recon_test_3D(X=X, Y=Y, orig_size=orig_size, block_size=block_size, oversamp=oversamp, lab_trun=lab_trun)
+    _, Y_pred        = recon_test_3D(X=X, Y=Y_pred, orig_size=orig_size, block_size=block_size, oversamp=oversamp, lab_trun=lab_trun)
+    X, Y_pred_thresh = recon_test_3D(X=X, Y=Y_pred_thresh, orig_size=orig_size, block_size=block_size, oversamp=oversamp, lab_trun=lab_trun)
+
+    # Account for partial volume effects
+    Y_pred_thresh = Y_pred_thresh > 0.1
+
+    _ = model_evaluation(spath=spath, Y=Y, Y_pred=Y_pred, epoch='End', images=True, threshold=threshold, continous=continuous)
+
+    # Plot images
+    y_plot = Y[:, :, :, :, 0]
+    y_plot_pred = Y_pred_thresh[:, :, :, :, 0]
+    Y_mask = np.ma.masked_where(y_plot.astype(bool) == 0, y_plot.astype(bool))
+    test_out_mask = np.ma.masked_where(y_plot_pred.astype(bool) == 0, y_plot_pred.astype(bool))
+
+    slices = range(20, 40, 4)
+    resh = (y_plot.shape[1] * len(slices), y_plot.shape[2])
+    fig, ax = plt.subplots(nrows=3, ncols=1)
+
+    ax[0].imshow(X[slices, :, :, -1, 0].reshape(resh).T, cmap='gray')
+    ax[1].imshow(X[slices, :, :, -1, 0].reshape(resh).T, cmap='gray')
+    ax[1].imshow(test_out_mask[slices, :, :, 0].reshape(resh).T, cmap='summer', alpha=0.5)
+    ax[2].imshow(X[slices, :, :, -1, 0].reshape(resh).T, cmap='gray')
+    ax[2].imshow(Y_mask[slices, :, :, 0].reshape(resh).T, cmap='summer', alpha=0.5)
+
+    ax[0].set_title('Input slice (T2)')
+    ax[1].set_title('Output slice')
+    ax[2].set_title('Label')
+
+    ax[2].set_xticklabels(slices)
+
+    for a in ax.ravel():
+        a.axis('off')
+
+    backgrnd = np.concatenate((X[slices, :, :, -1, 0].reshape(resh).T,
+                               X[slices, :, :, -1, 0].reshape(resh).T,
+                               X[slices, :, :, -1, 0].reshape(resh).T))
+
+    fig.savefig(os.path.join(spath, 'Test_set.png'), dpi=300, format='png', frameon=False)
+
+    # Rotate image axis for viewing
+    X = X.swapaxes(0, 2).swapaxes(0, 1)
+    Y = Y.swapaxes(0, 2).swapaxes(0, 1).astype(np.int16)
+    Y_pred = Y_pred.swapaxes(0, 2).swapaxes(0, 1)
+    Y_pred_thresh = Y_pred_thresh.swapaxes(0, 2).swapaxes(0, 1).astype(np.int16)
+
+    # Save volumes
+    nib.save(nib.Nifti1Image(X, np.eye(4)), os.path.join(spath, 'X.nii'))
+    nib.save(nib.Nifti1Image(Y, np.eye(4)), os.path.join(spath, 'Y_lab.nii.gz'))
+    nib.save(nib.Nifti1Image(Y_pred_thresh, np.eye(4)), os.path.join(spath, 'Y_pred.nii.gz'))
+    nib.save(nib.Nifti1Image(Y_pred, np.eye(4)), os.path.join(spath, 'Y_pred_prob.nii.gz'))
+
+
+def remove_extra_labels(y_pred):
+    """
+    Finds the largest continous region in the tumor label. All other regions are discarded.
+    Args:
+        y_pred (3D numpy array): thresholded network ouput
+
+    Returns:
+        (3D numpy array): returned segmentation
+    """
+
+    # Get sizes
+    sz = y_pred.shape
+    vols = sz[-1]
+
+    # Initialize output array
+    y_out = np.zeros_like(y_pred)
+
+    for z in range(vols):
+
+        # Convert predictions to integer mask
+        y_mask = y_pred[:,:,:,:,z].astype(np.uint8).squeeze()
+
+        # Find continous regions
+        labels = measure.label(y_mask, connectivity=3)
+
+        # Find the number of counts for each region
+        vals, counts = np.unique(labels, return_counts=True)
+
+        # Remove background
+        bg = labels[0, 0, 0]
+        counts = counts[vals != bg]
+        vals = vals[vals != bg]
+
+        # Get the largest labels
+        ind = np.argmax(counts)
+        tumor_val = vals[ind]
+
+        # Return clean predictions
+        y_mask = (labels == tumor_val).astype(np.float)
+
+        y_out[:,:,:,0,z] = y_mask
+
+    return y_out
+
+
+def test_set_3D_kfold(model, X, Y, spath, spath_iter, orig_size, block_size, oversamp, lab_trun, batch_size, threshold=None):
+
+    sz = X.shape
+    Y_pred = np.zeros_like(Y)
+
+    # Test the segemntation
+    if len(sz) > 5:
+        vols = sz[5]
+        for vol in range(vols):
+            temp = X[:, :, :, :, :, vol]
+            Y_pred[:, :, :, :, :, vol] = model.predict(x=temp, batch_size=batch_size)
+    else:
+
+        Y_pred = model.predict(x=X, batch_size=batch_size)
+
+
+
+    # Evaluate metrics
+    if threshold:
+        print('\tUsing threshold from training set:\t%0.3f' %threshold)
+
+    else:
+        threshold = 0.9
+        print('\tUsing fixed threshold:\t%0.3f' %threshold)
+
+
+    # _ = model_evaluation(spath=spath_iter, Y=Y, Y_pred=Y_pred, epoch='End', images=True, threshold=threshold)
+    _ = model_evaluation_kfold(spath=spath, Y=Y, Y_pred=Y_pred, epoch='End', images=False, threshold=threshold)
+
+
+    Y_pred_thresh = Y_pred > threshold
+
+    # convert volume patches to images
+    _, Y = recon_test_3D(X=X, Y=Y, orig_size=orig_size, block_size=block_size, oversamp=oversamp, lab_trun=lab_trun)
+    _, Y_pred_thresh = recon_test_3D(X=X, Y=Y_pred_thresh, orig_size=orig_size, block_size=block_size, oversamp=oversamp,
+                                       lab_trun=lab_trun)
+    X, Y_pred = recon_test_3D(X=X, Y=Y_pred, orig_size=orig_size, block_size=block_size, oversamp=oversamp,
+                                       lab_trun=lab_trun)
+
+    # Convert to boolean
+    Y = Y.astype('bool')
+    Y_pred_thresh = Y_pred_thresh.astype('bool')
+
+    # Plot images
+    Y_mask = np.ma.masked_where(Y == 0, Y)
+    test_out_mask = np.ma.masked_where(Y_pred_thresh == 0, Y_pred_thresh)
+
+    slices = range(20, 40, 4)
+    resh = (Y.shape[1] * len(slices), Y.shape[2])
+    fig, ax = plt.subplots(nrows=3, ncols=1)
+
+    ax[0].imshow(X[slices, :, :, 2, 0].reshape(resh).T, cmap='gray')
+    ax[1].imshow(X[slices, :, :, 2, 0].reshape(resh).T, cmap='gray')
+    ax[1].imshow(test_out_mask[slices, :, :, 0, 0].reshape(resh).T, cmap='summer', alpha=0.5)
+    ax[2].imshow(X[slices, :, :, 2, 0].reshape(resh).T, cmap='gray')
+    ax[2].imshow(Y_mask[slices, :, :, 0, 0].reshape(resh).T, cmap='summer', alpha=0.5)
+
+    ax[0].set_title('Input slice (T2)')
+    ax[1].set_title('Output slice')
+    ax[2].set_title('Label')
+
+    ax[2].set_xticklabels(slices)
+
+    for a in ax.ravel():
+        a.axis('off')
+
+    f1 = f1_score(y_true=Y.reshape(-1), y_pred=Y_pred_thresh.reshape(-1))
+    fid = open(os.path.join(spath, 'k-fold-metrics'), 'a')
+    fid.write('Dice:\t%0.4f\n' %f1)
+    fid.write('')
+
+    backgrnd = np.concatenate((X[slices, :, :, 2, 0].reshape(resh).T, X[slices, :, :, 2, 0].reshape(resh).T, X[slices, :, :, 2, 0].reshape(resh).T))
+
+    fig.savefig(os.path.join(spath, 'Test_set.png'), dpi=300, format='png', frameon=False)
+
+    # Save volumes
+    nib.save(nib.Nifti1Image(X, np.eye(4)), os.path.join(spath_iter, 'X.nii'))
+    nib.save(nib.Nifti1Image(Y.astype('float32'), np.eye(4)), os.path.join(spath_iter, 'Y_lab.nii'))
+    nib.save(nib.Nifti1Image(Y_pred_thresh.astype('float32'), np.eye(4)), os.path.join(spath_iter, 'Y_pred.nii'))
+    nib.save(nib.Nifti1Image(Y_pred.astype('float32'), np.eye(4)), os.path.join(spath_iter, 'Y_pred_prob.nii'))
+
+
+def test_set_3D_T2(model, X, Y, spath, orig_size, block_size, oversamp, lab_trun, batch_size, threshold=None):
+    print('Evaluating the test set')
+
+    sz = X.shape
+    test_out = np.zeros_like(Y)
+
+    # Test the segemntation
+    print('\tMaking model predictions')
+    if len(sz) > 5:
+        vols = sz[5]
+        for vol in range(vols):
+            temp = X[:, :, :, :, :, vol]
+            test_out[:, :, :, :, :, vol] = model.predict(x=temp, batch_size=batch_size)
+    else:
+
+        test_out = model.predict(x=X, batch_size=batch_size)
+
+
+    if threshold:
+        print('\tUsing threshold from training set:\t%0.3f' %threshold)
+
+    else:
+        threshold = 0.9
+        print('\tUsing fixed threshold:\t%0.3f' %threshold)
+
+    # Evaluate metrics
+    # test_ev = model.evaluate(x=X, y=Y)
+    # print('Test evaluation\n\t')
+    # print(test_ev)
+    _ = model_evaluation(spath=spath, Y=Y, Y_pred=test_out, epoch='End', images=True, threshold=threshold)
+
+    test_out_thresh = test_out > threshold
+
+    # convert volume patches to images
+    _, Y = recon_test_3D(X=X, Y=Y, orig_size=orig_size, block_size=block_size, oversamp=oversamp, lab_trun=lab_trun)
+    X, test_out_thresh = recon_test_3D(X=X, Y=test_out_thresh, orig_size=orig_size, block_size=block_size,
+                                       oversamp=oversamp,
+                                       lab_trun=lab_trun)
+
+    # Plot images
+    Y_mask = np.ma.masked_where(Y.astype(bool) == 0, Y.astype(bool))
+    test_out_mask = np.ma.masked_where(test_out_thresh.astype(bool) == 0, test_out_thresh.astype(bool))
+
+    slices = range(20, 40, 4)
+    resh = (Y.shape[1] * len(slices), Y.shape[2])
+    fig, ax = plt.subplots(nrows=3, ncols=1)
+
+    ax[0].imshow(X[slices, :, :, 0, 0].reshape(resh).T, cmap='gray')
+    ax[1].imshow(X[slices, :, :, 0, 0].reshape(resh).T, cmap='gray')
+    ax[1].imshow(test_out_mask[slices, :, :, 0, 0].reshape(resh).T, cmap='summer', alpha=0.5)
+    ax[2].imshow(X[slices, :, :, 0, 0].reshape(resh).T, cmap='gray')
+    ax[2].imshow(Y_mask[slices, :, :, 0, 0].reshape(resh).T, cmap='summer', alpha=0.5)
+
+    ax[0].set_title('Input slice (T2)')
+    ax[1].set_title('Output slice')
+    ax[2].set_title('Label')
+
+    ax[2].set_xticklabels(slices)
+
+    for a in ax.ravel():
+        a.axis('off')
+
+    fig.savefig(os.path.join(spath, 'Test_set.png'), dpi=300, format='png', frameon=False)
+
+
+def load_test_data(block_size, oversamp, lab_trun):
+    # Load test data
+    # Data path
+    dpath = 'E:/MR Data/SarcomaSegmentations/Mouse VI/Bias Corrections/'
+    files = ['Mouse VI T1FLASH New Bias Correction.nii', 'Mouse VI T1FLASH w Con New Bias Correction.nii',
+             'Mouse VI T2TurboRARE New Bias Correction.nii']
+    files = [os.path.join(dpath, file) for file in files]
+
+    label_file = 'E:/MR Data/SarcomaSegmentations/Mouse VI/50002 VI ROI Black and White Volume.nii'
+
+    # Load data
+    X_test = load_data_3D(files)
+    sz = X_test.shape
+    Y_test = load_label_3D(label_file, sz)
+
+    # Convert to volumes
+    X_test, Y_test = make_train_3D(X_test, Y_test, block_size=block_size, oversamp=oversamp, lab_trun=lab_trun)
+
+    return X_test, Y_test, sz
+
+
+def model_evaluation(spath, Y, Y_pred, epoch, images, threshold=None, continous=False):
+    # Model evaluation
+
+    print('\tEvaluating predictions\n')
+
+    # Make sure epoch is a string
+    if not isinstance(epoch, str):
+        epoch = str(epoch)
+
+    metric_file = open(os.path.join(spath, 'metrics2.txt'), 'w')
+
+    # Precision, recall curve
+    print('\t\tCalculating precision and recall')
+    precisions, recalls, thresholds = precision_recall_curve(y_true=Y.reshape(-1), probas_pred=Y_pred.reshape(-1), pos_label=None)
+
+    if images:
+        fig = plt.figure(11)
+        ax1 = fig.add_subplot(111)
+        ax1.plot(thresholds, precisions[:-1], 'b--', label='Precision')
+        ax1.plot(thresholds, recalls[:-1], 'g-', label='Recall')
+        plt.xlabel('Threshold')
+        plt.legend()
+        plt.grid()
+        plt.ylim([0, 1])
+
+        fig.savefig(os.path.join(spath, 'Precision_recall_%s.svg' %epoch), dpi=300, format='svg', frameon=False)
+
+        # Compute ROC curve
+        fpr, tpr, rcthresh = roc_curve(y_true=Y.reshape(-1), y_score=Y_pred.reshape(-1))
+        roc_score = roc_auc_score(y_true=Y.reshape(-1), y_score=Y_pred.reshape(-1))
+        print('\t\tROC score:\t%0.4f' %roc_score)
+
+        fig = plt.figure(21)
+        plt.plot(fpr, tpr, linewidth=2)
+        plt.plot([0, 1], [0, 1], 'k--')
+        plt.axis([0, 1, 0, 1])
+        plt.xlabel('False positive rate')
+        plt.ylabel('True positive rate')
+
+        fig.savefig(os.path.join(spath, 'ROC_curve_%s.svg' %epoch), dpi=300, format='svg', frameon=False)
+
+        plt.close('all')
+
+    # Remove non-continuous regions - keep only the largest region
+    if continous:
+        Y_pred_inds = Y_pred > 0.15
+        Y_pred_inds = remove_extra_labels(Y_pred_inds)
+        Y_pred[~Y_pred_inds] = 0
+
+    # Compute ideal threshold
+    print('\t\tCalculating ideal threshold')
+    if not threshold:
+        sub = recalls[:-1] - precisions[:-1]
+        # val, idx = min((val, idx) for (idx, val) in enumerate(sub))
+        # threshold = thresholds[idx]
+        if np.sum(sub < 0) == 0:
+            ind = np.argmin(sub) - 1
+        else:
+            inds = sub < 0
+            ind = np.argmax(inds)
+        threshold = thresholds[ind]
+
+    # Create and save a dictionary for metrics
+    # df = {'false_positive': fpr.tolist(),
+    #       'true_postive': tpr.tolist(),
+    #       'roc_thresholds': rcthresh.tolist(),
+    #       'precision': precisions.tolist(),
+    #       'recall': recalls.tolist(),
+    #       'thresholds': thresholds.tolist()}
+    # metric_data = os.path.join(spath, 'metric_data.json')
+    # log = json.dumps(df)
+    #
+    # with open(metric_data, 'w') as f:
+    #     f.write(log)
+
+    # Compute precision and recall
+    print('\t\tCalculating ROC and DICE scores')
+    test_out_thresh = (Y_pred > threshold)
+    precision = precision_score(y_true=Y.reshape(-1), y_pred=test_out_thresh.reshape(-1))
+    recall = recall_score(y_true=Y.reshape(-1), y_pred=test_out_thresh.reshape(-1))
+    f1 = f1_score(y_true=Y.reshape(-1), y_pred=test_out_thresh.reshape(-1))
+    voe = jaccard_similarity_score(y_true=Y.reshape(-1), y_pred=test_out_thresh.reshape(-1))
+    roc = roc_auc_score(y_true=Y.reshape(-1), y_score=Y_pred.reshape(-1))
+    print('Ideal threshold:\t%0.4f' % threshold)
+    print('Precision score:\t%0.4f' %precision)
+    print('Recall score:\t\t%0.4f' %recall)
+    print('DICE score:\t\t\t%0.4f' % f1)
+    print('VOE score:\t\t\t%0.4f' % voe)
+
+    metric_file.write('Epoch %s\n' %epoch)
+    metric_file.write('\tBest threshold:\t\t%0.4f\n' %threshold)
+    metric_file.write('\tROC score:\t\t%0.4f\n' % roc_score)
+    metric_file.write('\tPrecision score:\t%0.4f\n' %precision)
+    metric_file.write('\tRecall score:\t\t%0.4f\n' %recall)
+    metric_file.write('\tDICE score:\t\t%0.4f\n' % f1)
+    metric_file.write('\tVOE score:\t\t%0.4f\n\n' % voe)
+
+    # Compute DICE similarity coefficient
+    # tp = sum((Y.reshape(-1) == 1) * (test_out_thresh.reshape(-1) == 1))
+    # fp = sum((Y.reshape(-1) == 0) * (test_out_thresh.reshape(-1) == 1))
+    # fn = sum((Y.reshape(-1) == 1) * (test_out_thresh.reshape(-1) == 0))
+    # tn = sum((Y.reshape(-1) == 0) * (test_out_thresh.reshape(-1) == 0))
+
+    # DSC = 2 * tp / (2*tp + fp + fn) # Dice
+    # VDR = abs(fp - fn) / (tp + fn)  # volume difference rate
+
+
+    # print('Volume difference score:\t%0.4f\n' %VDR)
+
+
+    # metric_file.write('Volume difference score:\t%0.4f\n' %VDR)
+
+    metric_file.close()
+
+    return threshold
+
+def model_evaluation_kfold(spath, Y, Y_pred, epoch, images, threshold=None):
+    # Model evaluation
+
+    print('\tEvaluating predictions\n')
+
+    # Make sure epoch is a string
+    if not isinstance(epoch, str):
+        epoch = str(epoch)
+
+    metric_file = open(os.path.join(spath, 'metrics.txt'), 'a')
+
+    # Precision, recall curve
+    print('\t\tCalculating precision and recall')
+    precisions, recalls, thresholds = precision_recall_curve(y_true=Y.reshape(-1), probas_pred=Y_pred.reshape(-1), pos_label=None)
+
+    if images:
+        fig = plt.figure(11)
+        ax1 = fig.add_subplot(111)
+        ax1.plot(thresholds, precisions[:-1], 'b--', label='Precision')
+        ax1.plot(thresholds, recalls[:-1], 'g-', label='Recall')
+        plt.xlabel('Threshold')
+        plt.legend()
+        plt.grid()
+        plt.ylim([0, 1])
+
+        fig.savefig(os.path.join(spath, 'Precision_recall_%s.png' %epoch), dpi=300, format='png', frameon=False)
+
+        # Compute ROC curve
+        fpr, tpr, _ = roc_curve(y_true=Y.reshape(-1), y_score=Y_pred.reshape(-1))
+        roc_score = roc_auc_score(y_true=Y.reshape(-1), y_score=Y_pred.reshape(-1))
+        print('\t\tROC score:\t%0.4f' %roc_score)
+
+        fig = plt.figure(21)
+        plt.plot(fpr, tpr, linewidth=2)
+        plt.plot([0, 1], [0, 1], 'k--')
+        plt.axis([0, 1, 0, 1])
+        plt.xlabel('False positive rate')
+        plt.ylabel('True positive rate')
+
+        fig.savefig(os.path.join(spath, 'ROC_curve_%s.png' %epoch), dpi=300, format='png', frameon=False)
+
+        plt.close('all')
+
+    # Compute ideal threshold
+    print('\t\tCalculating ideal threshold')
+    if not threshold:
+        sub = recalls[:-1] - precisions[:-1]
+        # val, idx = min((val, idx) for (idx, val) in enumerate(sub))
+        # threshold = thresholds[idx]
+        inds = sub < 0
+        ind = np.argmax(inds)
+        threshold = thresholds[ind]
+
+    # threshold = threshold[0]
+
+    # Compute precision and recall
+    print('\t\tCalculating ROC and DICE scores')
+    test_out_thresh = (Y_pred > threshold)
+    precision = precision_score(y_true=Y.reshape(-1), y_pred=test_out_thresh.reshape(-1))
+    recall = recall_score(y_true=Y.reshape(-1), y_pred=test_out_thresh.reshape(-1))
+    f1 = f1_score(y_true=Y.reshape(-1), y_pred=test_out_thresh.reshape(-1))
+    roc = roc_auc_score(y_true=Y.reshape(-1), y_score=Y_pred.reshape(-1))
+    print('Ideal threshold:\t%0.4f' % threshold)
+    print('Precision score:\t%0.4f' %precision)
+    print('Recall score:\t\t%0.4f' %recall)
+    print('DICE score:\t\t\t%0.4f' % f1)
+
+    metric_file.write('Epoch %s\n' %epoch)
+    metric_file.write('\tBest threshold:\t\t%0.4f\n' %threshold)
+    if images:
+        metric_file.write('\tROC score:\t\t%0.4f\n' % roc_score)
+    metric_file.write('\tPrecision score:\t%0.4f\n' %precision)
+    metric_file.write('\tRecall score:\t\t%0.4f\n' %recall)
+    metric_file.write('\tDICE score:\t\t%0.4f\n\n' % f1)
+
+    # Compute DICE similarity coefficient
+    # tp = sum((Y.reshape(-1) == 1) * (test_out_thresh.reshape(-1) == 1))
+    # fp = sum((Y.reshape(-1) == 0) * (test_out_thresh.reshape(-1) == 1))
+    # fn = sum((Y.reshape(-1) == 1) * (test_out_thresh.reshape(-1) == 0))
+    # tn = sum((Y.reshape(-1) == 0) * (test_out_thresh.reshape(-1) == 0))
+
+    # DSC = 2 * tp / (2*tp + fp + fn) # Dice
+    # VDR = abs(fp - fn) / (tp + fn)  # volume difference rate
+
+
+    # print('Volume difference score:\t%0.4f\n' %VDR)
+
+
+    # metric_file.write('Volume difference score:\t%0.4f\n' %VDR)
+
+    metric_file.close()
+
+    return threshold
+
+def training_threshold(model, spath, X, Y):
+    print('Predicting best threshold based on training data')
+
+    # Prediction based on training data
+    print('\tRunning training data through model')
+    Y_pred = model.predict(X, batch_size=10)
+
+    # Calculating precision/recall curves
+    print('\tCalculating precision/recall curves')
+    precisions, recalls, thresholds = precision_recall_curve(
+        y_true=Y.reshape(-1), probas_pred=Y_pred.reshape(-1), pos_label=None)
+
+    if True:
+        print('\t\tSaving precision/recall plot')
+        fig = plt.figure(11)
+        ax1 = fig.add_subplot(111)
+        ax1.plot(thresholds, precisions[:-1], 'b--', label='Precision')
+        ax1.plot(thresholds, recalls[:-1], 'g-', label='Recall')
+        plt.xlabel('Threshold')
+        plt.legend()
+        plt.grid()
+        plt.ylim([0, 1])
+
+        fig.savefig(os.path.join(spath, 'Precision_recall_training.png'),
+                    dpi=300, format='png', frameon=False)
+
+        # Compute ROC curve
+        print('\tCalculating ROC curve')
+        fpr, tpr, _ = roc_curve(y_true=Y.reshape(-1), y_score=Y_pred.reshape(-1))
+
+        fig = plt.figure(21)
+        plt.plot(fpr, tpr, linewidth=2)
+        plt.plot([0, 1], [0, 1], 'k--')
+        plt.axis([0, 1, 0, 1])
+        plt.xlabel('False positive rate')
+        plt.ylabel('True positive rate')
+
+        fig.savefig(os.path.join(spath, 'ROC_curve_training.png'), dpi=300,
+                    format='png', frameon=False)
+
+        plt.close('all')
+
+    # Determining optimal threshold (precision == recall)
+    print('\tDetermining optimal threshold (precision == recall)')
+    sub = recalls[:-1] - precisions[:-1]
+    inds = sub < 0
+    if np.sum(inds) ==0:  # if recall and precision do not cross
+        inds = np.argmin(recalls[:-1] - precisions[:-1])
+
+    ind = np.argmax(inds)
+    threshold = thresholds[ind]
+
+    fid = open(os.path.join(spath, 'metrics.txt'), 'a')
+    fid.write('Training threshold:\t%0.3f\n\n' %threshold)
+    fid.close()
+
+    return threshold
+
+
+def dice_loss(y_true, y_pred):
+    threshold = 0.5
+    smooth = 1e-5
+    # # Compute thresholds
+    # thresholds = tf.range(start=0.4, limit=0.7, delta=0.05)
+    # thresholds =  np.arange(0.4, 0.8, step=0.05).astype('float32')
+    # tp = tf.metrics.true_positives_at_thresholds(tf.cast(y_true, tf.float32), tf.cast(y_pred, tf.float32), thresholds)
+    # fn = tf.metrics.false_negatives_at_thresholds(tf.cast(y_true, tf.float32), tf.cast(y_pred, tf.float32), thresholds)
+    # thresholds = tf.convert_to_tensor_or_sparse_tensor(thresholds.reshape(-1))
+    # threshold = thresholds[ tf.cast( tf.argmax( tf.multiply(tp, fn) ), dtype=tf.int32 ) ]
+
+    mask = y_pred > threshold
+    mask = tf.cast(mask, dtype=tf.float32)
+    y_pred = tf.multiply(y_pred, mask)
+    mask = y_true > threshold
+    mask = tf.cast(mask, dtype=tf.float32)
+    y_true = tf.multiply(y_true, mask)
+    # y_pred = tf.cast(y_pred > threshold, dtype=tf.float32)
+    # y_true = tf.cast(y_true > threshold, dtype=tf.float32)
+
+    inse = tf.reduce_sum(tf.multiply(y_pred, y_true))
+    l = tf.reduce_sum(y_pred)
+    r = tf.reduce_sum(y_true)
+
+    # new haodong
+    hard_dice = (2. * inse + smooth) / (l + r + smooth)
+
+    hard_dice = 1 - tf.reduce_mean(hard_dice)
+
+    return hard_dice
+
+def dice_metric(y_true, y_pred):
+
+    threshold = 0.5
+
+    mask = y_pred > threshold
+    mask = tf.cast(mask, dtype=tf.float32)
+    y_pred = tf.multiply(y_pred, mask)
+    mask = y_true > threshold
+    mask = tf.cast(mask, dtype=tf.float32)
+    y_true = tf.multiply(y_true, mask)
+    # y_pred = tf.cast(y_pred > threshold, dtype=tf.float32)
+    # y_true = tf.cast(y_true > threshold, dtype=tf.float32)
+
+    inse = tf.reduce_sum(tf.multiply(y_pred, y_true))
+    l = tf.reduce_sum(y_pred)
+    r = tf.reduce_sum(y_true)
+
+    # new haodong
+    hard_dice = (2. * inse) / (l + r)
+
+    hard_dice = tf.reduce_mean(hard_dice)
+
+    return hard_dice
+
+
+def save_code(spath, calling_script):
+    # Get current file
+    current_file = os.path.realpath(__file__)
+
+    # Copy current file to save directory
+    copy(current_file, spath)
+
+    # Copy calling file to save directory
+    copy(calling_script, spath)
+
+
+class image_callback_T2(keras.callbacks.Callback):
+    def __init__(self, X, Y, spath, orig_size, block_size=(25, 25, 25), oversamp=1, lab_trun=4, im_freq=10, batch_size=10):
+        self.X = X
+        self.Y = Y
+        self.spath = spath
+        self.orig_size = orig_size
+        self.block_size = block_size
+        self.oversamp = oversamp
+        self.lab_trun = lab_trun
+        self.im_freq = im_freq
+        self.batch_size = batch_size
+
+    def on_epoch_begin(self, epoch, logs=None):
+
+        if (epoch % self.im_freq == 0):
+            print('Starting evaluation')
+
+            # Test the segemntation
+            sz = self.X.shape
+            test_out = np.zeros_like(self.Y)
+            print('\tMaking predictions')
+            if len(sz) > 5:
+                vols = sz[5]
+                for vol in range(vols):
+                    temp = self.X[:, :, :, :, :, vol]
+                    test_out[:, :, :, :, :, vol] = self.model.predict(x=temp, batch_size=self.batch_size)
+            else:
+
+                test_out = self.model.predict(x=self.X, batch_size=self.batch_size)
+
+            # Evaluate metrics
+            threshold = model_evaluation(spath=self.spath, Y=self.Y, Y_pred=test_out, epoch=epoch, images=True)
+            test_out_thresh = test_out > threshold
+            fid = open(os.path.join(self.spath, 'validation_thresholds.txt'), 'w')
+            fid.write(str(threshold))
+            fid.close()
+
+            # convert volume patches to images
+            _, Y = recon_test_3D(X=self.X, Y=self.Y, orig_size=self.orig_size, block_size=self.block_size, oversamp=self.oversamp, lab_trun=self.lab_trun)
+            X, test_out_thresh = recon_test_3D(X=self.X, Y=test_out_thresh, orig_size=self.orig_size, block_size=self.block_size, oversamp=self.oversamp, lab_trun=self.lab_trun)
+
+            # Plot images
+
+            Y_mask = np.ma.masked_where(Y.astype(bool) == 0, Y.astype(bool))
+            test_out_mask = np.ma.masked_where(test_out_thresh.astype(bool) == 0, test_out_thresh.astype(bool))
+
+            slices = range(20, 40, 4)
+            resh = (Y.shape[1] * len(slices), Y.shape[2])
+            fig, ax = plt.subplots(nrows=3, ncols=1)
+
+            ax[0].imshow(X[slices, :, :, 0, 0].reshape(resh).T, cmap='gray')
+            ax[1].imshow(X[slices, :, :, 0, 0].reshape(resh).T, cmap='gray')
+            ax[1].imshow(test_out_mask[slices, :, :, 0, 0].reshape(resh).T, cmap='summer', alpha=0.5)
+            ax[2].imshow(X[slices, :, :, 0, 0].reshape(resh).T, cmap='gray')
+            ax[2].imshow(Y_mask[slices, :, :, 0, 0].reshape(resh).T, cmap='summer', alpha=0.5)
+
+            ax[0].set_title('Input slice (T2). Epoch %d' %epoch)
+            ax[1].set_title('Output slice')
+            ax[2].set_title('Label')
+
+            ax[2].set_xticklabels(slices)
+
+            for a in ax.ravel():
+                a.axis('off')
+
+            fig.savefig(os.path.join(self.spath, 'Val_set_%d.png' %epoch), dpi=300, format='png', frameon=False)
+
+class image_callback_val(keras.callbacks.Callback):
+    def __init__(self, X, Y, spath, orig_size, block_size=(25, 25, 25), oversamp=1, lab_trun=4, im_freq=10, batch_size=10):
+        self.X = X
+        self.Y = Y
+        self.spath = spath
+        self.orig_size = orig_size
+        self.block_size = block_size
+        self.oversamp = oversamp
+        self.lab_trun = lab_trun
+        self.im_freq = im_freq
+        self.batch_size = batch_size
+
+    def on_epoch_begin(self, epoch, logs=None):
+
+        if (epoch % self.im_freq == 0) and (epoch != 0):
+            print('Starting evaluation')
+
+            # Test the segemntation
+            sz = self.X.shape
+            test_out = np.zeros_like(self.Y)
+            print('\tMaking predictions')
+            if len(sz) > 5:
+                vols = sz[5]
+                for vol in range(vols):
+                    temp = self.X[:, :, :, :, :, vol]
+                    test_out[:, :, :, :, :, vol] = self.model.predict(x=temp, batch_size=self.batch_size)
+            else:
+
+                test_out = self.model.predict(x=self.X, batch_size=self.batch_size)
+
+            # Evaluate metrics
+            threshold = model_evaluation(spath=self.spath, Y=self.Y, Y_pred=test_out, epoch=epoch, images=True)
+            test_out_thresh = test_out > threshold
+            fid = open(os.path.join(self.spath, 'validation_thresholds.txt'), 'w')
+            fid.write(str(threshold))
+            fid.close()
+
+
+            # convert volume patches to images
+            _, Y = recon_test_3D(X=self.X, Y=self.Y, orig_size=self.orig_size, block_size=self.block_size, oversamp=self.oversamp, lab_trun=self.lab_trun)
+            X, test_out_thresh = recon_test_3D(X=self.X, Y=test_out_thresh, orig_size=self.orig_size, block_size=self.block_size, oversamp=self.oversamp, lab_trun=self.lab_trun)
+
+            # Plot images
+
+            Y_mask = np.ma.masked_where(Y.astype(bool) == 0, Y.astype(bool))
+            test_out_mask = np.ma.masked_where(test_out_thresh.astype(bool) == 0, test_out_thresh.astype(bool))
+
+            slices = range(20, 40, 4)
+            resh = (Y.shape[1] * len(slices), Y.shape[2])
+            fig, ax = plt.subplots(nrows=3, ncols=1)
+
+            ax[0].imshow(X[slices, :, :, -1, 0].reshape(resh).T, cmap='gray')
+            ax[1].imshow(X[slices, :, :, -1, 0].reshape(resh).T, cmap='gray')
+            ax[1].imshow(test_out_mask[slices, :, :, 0, 0].reshape(resh).T, cmap='summer', alpha=0.5)
+            ax[2].imshow(X[slices, :, :, -1, 0].reshape(resh).T, cmap='gray')
+            ax[2].imshow(Y_mask[slices, :, :, 0, 0].reshape(resh).T, cmap='summer', alpha=0.5)
+
+            ax[0].set_title('Input slice (T2). Epoch %d' %epoch)
+            ax[1].set_title('Output slice')
+            ax[2].set_title('Label')
+
+            ax[2].set_xticklabels(slices)
+
+            for a in ax.ravel():
+                a.axis('off')
+
+            fig.savefig(os.path.join(self.spath, 'Val_set_%d.png' %epoch), dpi=300, format='png', frameon=False)
+
+class image_callback_1px(keras.callbacks.Callback):
+    def __init__(self, gen, spath, orig_size, block_size=(25, 25, 25), oversamp=1, lab_trun=4, im_freq=10, batch_size=10):
+        self.gen = gen
+        self.spath = spath
+        self.orig_size = orig_size
+        self.block_size = block_size
+        self.oversamp = oversamp
+        self.lab_trun = lab_trun
+        self.im_freq = im_freq
+        self.batch_size = batch_size
+
+    def on_epoch_begin(self, epoch, logs=None):
+
+        if (epoch % self.im_freq == 0):# and not (epoch == 0):
+            print('Starting evaluation')
+
+
+            # Test the segemntation
+            sz = self.X.shape
+            test_out = np.zeros_like(self.Y)
+            print('\tMaking predictions')
+            if len(sz) > 5:
+                vols = sz[5]
+                for vol in range(vols):
+                    temp = self.X[:, :, :, :, :, vol]
+                    test_out[:, :, :, :, :, vol] = self.model.predict(x=temp, batch_size=self.batch_size)
+            else:
+
+                test_out = self.model.predict(x=self.X, batch_size=self.batch_size)
+
+            # test_out =
+
+            # Evaluate metrics
+            threshold = model_evaluation(spath=self.spath, Y=self.Y, Y_pred=test_out, epoch=epoch, images=True)
+            test_out_thresh = test_out > threshold
+            fid = open(os.path.join(self.spath, 'validation_thresholds.txt'), 'w')
+            fid.write(str(threshold))
+            fid.close()
+
+
+            # convert volume patches to images
+            # _, Y = recon_test_3D(X=self.X[:int(sz[0]/2 + 1), :, :, :, :], Y=self.Y[:int(sz[0]/2 + 1), :, :, :, :], orig_size=self.orig_size, block_size=self.block_size, oversamp=self.oversamp, lab_trun=self.lab_trun)
+            # X, test_out_thresh = recon_test_3D(X=self.X[:int(sz[0]/2 + 1), :, :, :, :], Y=test_out_thresh[:int(sz[0]/2 + 1), :, :, :, :], orig_size=self.orig_size, block_size=self.block_size, oversamp=self.oversamp, lab_trun=self.lab_trun)
+            _, Y = recon_test_3D(X=self.X, Y=self.Y, orig_size=self.orig_size, block_size=self.block_size, oversamp=self.oversamp, lab_trun=self.lab_trun)
+            X, test_out_thresh = recon_test_3D(X=self.X, Y=test_out_thresh, orig_size=self.orig_size, block_size=self.block_size, oversamp=self.oversamp, lab_trun=self.lab_trun)
+
+            # Plot images
+
+            Y_mask = np.ma.masked_where(Y.astype(bool) == 0, Y.astype(bool))
+            test_out_mask = np.ma.masked_where(test_out_thresh.astype(bool) == 0, test_out_thresh.astype(bool))
+
+            slices = range(20, 40, 4)
+            resh = (Y.shape[1] * len(slices), Y.shape[2])
+            fig, ax = plt.subplots(nrows=3, ncols=1)
+
+            ax[0].imshow(X[slices, :, :, 2, 0].reshape(resh).T, cmap='gray')
+            ax[1].imshow(X[slices, :, :, 2, 0].reshape(resh).T, cmap='gray')
+            ax[1].imshow(test_out_mask[slices, :, :, 0, 0].reshape(resh).T, cmap='summer', alpha=0.5)
+            ax[2].imshow(X[slices, :, :, 2, 0].reshape(resh).T, cmap='gray')
+            ax[2].imshow(Y_mask[slices, :, :, 0, 0].reshape(resh).T, cmap='summer', alpha=0.5)
+
+            ax[0].set_title('Input slice (T2). Epoch %d' %epoch)
+            ax[1].set_title('Output slice')
+            ax[2].set_title('Label')
+
+            ax[2].set_xticklabels(slices)
+
+            for a in ax.ravel():
+                a.axis('off')
+
+            fig.savefig(os.path.join(self.spath, 'Val_set_%d.png' %epoch), dpi=300, format='png', frameon=False)