--- a
+++ b/Region/unet_aug_sag.py
@@ -0,0 +1,362 @@
+1#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Tue Oct 30 14:28:59 2018
+
+@author: Josefine
+"""
+
+## Import libraries
+import numpy as np 
+import tensorflow as tf
+import re
+import glob
+import keras
+from time import time
+from sklearn.utils import shuffle
+import nibabel as nib
+from skimage.transform import resize
+
+# Define parameters:
+lr          = 1e-5    # learning-rate (or starting LR if it is decreasing)
+nEpochs     = 50         # Number of epochs
+batch_size  = 1
+valid_size  = 1
+
+# Other network specific parameters
+n_classes = 2
+beta1 = 0.9
+beta2 = 0.999
+epsilon = 1e-8
+
+imgDim = 128
+######################################################################
+##                                                                  ##
+##                   Setting up the network                         ##
+##                                                                  ##
+######################################################################
+
+tf.reset_default_graph()
+
+#Define placeholder for input and output
+x = tf.placeholder(tf.float32,[None,imgDim,imgDim,1],name = 'x_train') #input (572+572+1 image)
+y = tf.placeholder(tf.float32,[None,imgDim,imgDim,n_classes],name='y_train') #Output (388x388x2 labels)
+drop_rate = tf.placeholder(tf.float32, shape=())
+
+######################################################################
+##                                                                  ##
+##                   Metrics and functions                          ##
+##                                                                  ##
+######################################################################
+
+def natural_sort(l): 
+    convert = lambda text: int(text) if text.isdigit() else text.lower() 
+    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
+    return sorted(l, key = alphanum_key)
+
+def dice_coef(y_true, y_pred): #making the loss function smooth
+    y_true_f = tf.contrib.layers.flatten(tf.argmax(y,axis=-1))
+    y_pred_f = tf.contrib.layers.flatten(tf.argmax(output,axis=-1))
+    intersection = tf.reduce_sum(y_true_f * y_pred_f)
+    return (2 * intersection) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f))
+
+######################################################################
+##                               Layers                             ##
+######################################################################
+def conv2d(inputs, filters, kernel, stride, pad, name):
+    """ Creates a 2D convolution with following specs:
+    Args:
+        inputs:         (Tensor)            Tensor which you want to apply convolution to 
+        filters:        (integer)           Number of filters in kernel
+        kernel_size:    (integer)           Size of kernel
+        Strides:        (integer)           Stride
+        pad:            ('VALID' or 'SAME') Type of padding
+        name:           (string)            Name of layer
+    """
+    with tf.name_scope(name):
+        conv = tf.layers.conv2d(inputs, filters, kernel_size = kernel, strides = [stride,stride], padding=pad,activation=tf.nn.relu,kernel_initializer=tf.contrib.layers.xavier_initializer())
+        return conv  
+
+def max_pool(inputs,n,stride,pad):
+    maxpool = tf.nn.max_pool(inputs, ksize=[1,n,n,1], strides=[1,stride,stride,1], padding=pad)
+    return maxpool
+
+def crop2d(inputs,dim):
+    crop = tf.image.resize_image_with_crop_or_pad(inputs,dim,dim)
+    return crop
+
+def concat(input1,input2,axis):
+    combined = tf.concat([input1,input2],axis)
+    return combined
+
+def dropout(input1,drop_rate):
+    input_shape = input1.get_shape().as_list()
+    noise_shape = tf.constant(value=[1, 1, 1, input_shape[3]])
+    drop = tf.nn.dropout(input1, keep_prob=drop_rate, noise_shape=noise_shape)
+    return drop
+
+def transpose(inputs,filters, kernel, stride, pad, name):
+    with tf.name_scope(name):
+        trans = tf.layers.conv2d_transpose(inputs,filters, kernel_size=[kernel,kernel],strides=[stride,stride],padding=pad,kernel_initializer=tf.contrib.layers.xavier_initializer())
+        return trans
+
+######################################################################
+##                             Data                                 ##
+######################################################################
+
+def create_data(filename_img,filename_label,direction):
+    images = []
+    for f in range(len(filename_img)):
+        a = nib.load(filename_img[f])
+        a = a.get_data()
+        # Normalize:
+        a2 = np.clip(a,-1000,1000)
+        a3 = np.interp(a2, (a2.min(), a2.max()), (-1, +1))
+        # Reshape:
+        img = np.zeros([512,512,512])+np.min(a3)
+        index1 = int(np.ceil((512-a.shape[2])/2))
+        index2 = int(512-np.floor((512-a.shape[2])/2))
+        img[:,:,index1:index2] = a3
+        im = resize(img,(imgDim,imgDim,imgDim),order=0)
+        if direction == 'sag':
+            for i in range(im.shape[0]):
+                images.append((im[i,:,:]))
+        if direction == 'cor':
+            for i in range(im.shape[1]):
+                images.append((im[:,i,:]))
+        if direction == 'axial':
+            for i in range(im.shape[2]):
+                images.append((im[:,:,i]))    
+    images = np.asarray(images)
+    images = images.reshape(-1, imgDim,imgDim,1)
+    
+    # Label creation
+    labels = []
+    for g in range(len(filename_label)):
+        b = nib.load(filename_label[g])
+        b = b.get_data()
+        img = np.zeros([b.shape[0],b.shape[0],b.shape[0]])
+        index1 = int(np.ceil((img.shape[2]-b.shape[2])/2))
+        index2 = int(img.shape[2]-np.floor((img.shape[2]-b.shape[2])/2))
+        img[:,:,index1:index2] = b
+        lab = resize(img,(imgDim,imgDim,imgDim),order=0)
+        lab[lab>1] = 1
+        if direction == 'sag':
+            for i in range(lab.shape[0]):
+                labels.append((lab[i,:,:]))
+        if direction == 'cor':
+            for i in range(lab.shape[1]):
+                labels.append((lab[:,i,:]))
+        if direction == 'axial':
+            for i in range(lab.shape[2]):
+                labels.append((lab[:,:,i]))            
+    labels = np.asarray(labels)
+    labels_onehot = np.stack((labels==0, labels==1), axis=3).astype('int32')
+    return images, labels_onehot
+
+###############################################################################
+##                            Setup of network                               ##
+###############################################################################
+
+# -------------------------- Contracting path ---------------------------------
+conv1a = conv2d(x,filters=64,kernel=3,stride=1,pad='same',name = 'conv1a')
+conv1a.get_shape()
+conv1b = conv2d(conv1a,filters=64,kernel=3,stride=1,pad='same',name = 'conv1b')
+conv1b.get_shape()
+#drop1 = dropout(conv1b, drop_rate) 
+#drop1.get_shape()
+pool1 = max_pool(conv1b,n=2,stride=2,pad='SAME')
+pool1.get_shape()
+
+conv2a = conv2d(pool1,filters=128,kernel=3,stride=1,pad='same',name = 'conv2a')
+conv2a.get_shape()
+conv2b = conv2d(conv2a,filters=128,kernel=3,stride=1,pad='same',name = 'conv2b')
+conv2b.get_shape()
+drop2 = dropout(conv2b, drop_rate) 
+drop2.get_shape()
+pool2 = max_pool(drop2,n=2,stride=2,pad='SAME')
+pool2.get_shape()
+
+conv3a = conv2d(pool2,filters=256,kernel=3,stride=1,pad='same',name = 'conv3a')
+conv3a.get_shape()
+conv3b = conv2d(conv3a,filters=256,kernel=3,stride=1,pad='same',name = 'conv3b')
+conv3b.get_shape()
+drop3 = dropout(conv3b, drop_rate) 
+drop3.get_shape()
+pool3 = max_pool(drop3,n=2,stride=2,pad='SAME')
+pool3.get_shape()
+
+conv4a = conv2d(pool3,filters=512,kernel=3,stride=1,pad='same',name = 'conv4a')
+conv4a.get_shape()
+conv4b = conv2d(conv4a,filters=512,kernel=3,stride=1,pad='same',name = 'conv4b')
+conv4b.get_shape()
+drop4 = dropout(conv4b, drop_rate) 
+drop4.get_shape()
+pool4 = max_pool(drop4,n=2,stride=2,pad='SAME')
+pool4.get_shape()
+
+conv5a = conv2d(pool4,filters=1024,kernel=3,stride=1,pad='same',name = 'conv5a')
+conv5a.get_shape()
+conv5b = conv2d(conv5a,filters=1024,kernel=3,stride=1,pad='same',name = 'conv5b')
+conv5b.get_shape()
+drop5 = dropout(conv5b, drop_rate) 
+drop5.get_shape()
+# ---------------------------- Expansive path ---------------------------------
+up6a = transpose(drop5,filters=512,kernel=2,stride=2,pad='same',name='up6a')
+up6a.get_shape()
+up6b = concat(up6a,conv4b,axis=3)
+up6b.get_shape()
+
+conv7a = conv2d(up6b,filters=512,kernel=3,stride=1,pad='same',name = 'conv7a')
+conv7a.get_shape()
+conv7b = conv2d(conv7a,filters=512,kernel=3,stride=1,pad='same',name = 'conv7b')
+conv7b.get_shape()
+drop7 = dropout(conv7b, drop_rate) 
+drop7.get_shape()
+up7a = transpose(drop7,filters=256,kernel=2,stride=2,pad='same',name='up7a')
+up7a.get_shape()
+up7b = concat(up7a,conv3b,axis=3)
+up7b.get_shape()
+
+conv8a = conv2d(up7b,filters=256,kernel=3,stride=1,pad='same',name = 'conv7a')
+conv8a.get_shape()
+conv8b = conv2d(conv8a,filters=256,kernel=3,stride=1,pad='same',name = 'conv7b')
+conv8b.get_shape()
+drop8 = dropout(conv8b, drop_rate) 
+drop8.get_shape()
+up8a = transpose(drop8,filters=128,kernel=2,stride=2,pad='same',name='up7a')
+up8a.get_shape()
+up8b = concat(up8a,conv2b,axis=3)
+up8b.get_shape()
+
+conv9a = conv2d(up8b,filters=128,kernel=3,stride=1,pad='same',name = 'conv7a')
+conv9a.get_shape()
+conv9b = conv2d(conv9a,filters=128,kernel=3,stride=1,pad='same',name = 'conv7b')
+conv9b.get_shape()
+#drop9 = dropout(conv9b, drop_rate) 
+#drop9.get_shape()
+up9a = transpose(conv9b,filters=64,kernel=2,stride=2,pad='same',name='up7a')
+up9a.get_shape()
+up9b = concat(up9a,conv1b,axis=3)
+up9b.get_shape()
+
+conv10a = conv2d(up9b,filters=64,kernel=3,stride=1,pad='same',name = 'conv7a')
+conv10a.get_shape()
+conv10b = conv2d(conv10a,filters=64,kernel=3,stride=1,pad='same',name = 'conv7b')
+conv10b.get_shape()
+
+output = tf.layers.conv2d(conv10b, 2, 1, (1,1),padding ='same',activation=tf.nn.softmax, kernel_initializer=tf.contrib.layers.xavier_initializer(), name = 'output')
+output.get_shape()
+
+######################################################################
+##                                                                  ##
+##                            Loading data                          ##
+##                                                                  ##
+######################################################################
+
+filelist_train = natural_sort(glob.glob('WHS/Augment_data/*_image.nii')) # list of file names
+filelist_train_label = natural_sort(glob.glob('WHS/Augment_data/*_label.nii')) # list of file names
+x_data, y_data = create_data(filelist_train,filelist_train_label,'sag')
+
+#filelist_val = natural_sort(glob.glob('WHS/validation/*_image.nii.gz')) # list of file names
+#filelist_val_label = natural_sort(glob.glob('WHS/validation/*_label.nii.gz')) # list of file names
+#x_val, y_val = create_data(filelist_val,filelist_val_label,'sag')
+
+######################################################################
+##                                                                  ##
+##                   Defining the training                          ##
+##                                                                  ##
+######################################################################
+
+# Training-steps (honestly I have no idea what it does...)
+global_step = tf.Variable(0,trainable=False)
+
+###############################################################################
+##                               Loss                                        ##
+###############################################################################
+# Compare the output of the network (output: tensor) with the ground truth (y: tensor/placeholder)
+# In this case we use sigmoid cross entropu losss with logits
+loss = tf.reduce_mean(keras.losses.binary_crossentropy(y_true = y, y_pred = output))
+
+# accuracy and dice 
+correct_prediction = tf.equal(tf.argmax(output, axis=-1), tf.argmax(y, axis=-1))
+accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
+dice = dice_coef(tf.argmax(y,axis=-1), tf.argmax(output,axis=-1))
+
+###############################################################################
+##                               Optimizer                                   ##
+###############################################################################
+opt = tf.train.AdamOptimizer(lr,beta1,beta2,epsilon)
+
+###############################################################################
+##                               Minimizer                                   ##
+###############################################################################
+train_adam = opt.minimize(loss, global_step)
+
+###############################################################################
+##                               Initializer                                 ##
+###############################################################################
+# Initializes all variables in the graph
+init = tf.global_variables_initializer()
+
+######################################################################
+##                                                                  ##
+##                   Start training                                 ## 
+##                                                                  ##
+######################################################################
+# Initialize saving of the network parameters:
+saver = tf.train.Saver()
+
+######################## Start training Session ###########################
+start_time = time()
+#valid_loss, valid_accuracy, valid_dice = [], [], []
+train_loss, train_accuracy, train_dice = [], [], []
+
+index_train = shuffle(range(x_data.shape[0]))
+#valid_size = int(np.floor(len(index1)*0.1))
+#index_train = index1[valid_size:]
+#index_valid = index1[:valid_size]
+with tf.Session() as sess:
+    t_start = time()
+    # Initialize
+    sess.run(init)    
+  
+    # Trainingsloop
+    for epoch in range(nEpochs):
+        t_epoch_start = time()
+        print('========Training Epoch: ', (epoch + 1))
+        iter_by_epoch = len(index_train)
+        index_train_shuffle = shuffle(index_train)
+        for i in range(iter_by_epoch):
+            t_iter_start = time()
+            x_batch = np.expand_dims(x_data[index_train_shuffle[i],:,:,:], axis=0)
+            y_batch = np.expand_dims(y_data[index_train_shuffle[i],:,:,:], axis=0)
+            _,_loss,_acc,_dice= sess.run([train_adam, loss, accuracy,dice], feed_dict = {x: x_batch, y: y_batch, drop_rate: 0.5})    
+            
+            train_loss.append(_loss)
+            train_accuracy.append(_acc)
+            train_dice.append(_dice)
+                 
+#            # Validation-step:
+#            if i==np.max(range(iter_by_epoch)):
+#                valid_range = x_val.shape[0]
+#                for m in range(valid_range):
+#                    x_batch_val = np.expand_dims(x_val[m,:,:,:], axis=0)
+#                    y_batch_val = np.expand_dims(y_val[m,:,:,:], axis=0)
+#                    _loss_valid,_acc_valid,_dice_valid, = sess.run([loss,accuracy,dice], feed_dict= {x: x_batch_val,y: y_batch_val, drop_rate: 1.0})
+#                    valid_loss.append(_loss_valid)
+#                    valid_accuracy.append(_acc_valid)
+#                    valid_dice.append(_dice_valid)
+
+        t_epoch_finish = time() 
+        print("Epoch:", (epoch + 1), '  avg_loss= ', "{:.9f}".format(np.mean(train_loss)), 'avg_acc= ', "{:.9f}".format(np.mean(train_accuracy)),'avg_dice= ', "{:.9f}".format(np.mean(train_dice)),' time_epoch=', str(t_epoch_finish-t_epoch_start))
+#        print("Validation:", (epoch + 1), '  avg_loss= ', "{:.9f}".format(np.mean(valid_loss)), '  avg_acc= ', "{:.9f}".format(np.mean(valid_accuracy)),'avg_dice= ', "{:.9f}".format(np.mean(valid_dice)))
+
+    t_end = time()
+#     Save the model in the end
+    saver.save(sess,"WHS/Results/region/model_sag/model.ckpt")
+    np.save('WHS/Results/train_hist/region/train_loss_sag',train_loss)
+    np.save('WHS/Results/train_hist/region/train_acc_sag',train_accuracy)
+#    np.save('WHS/Results/train_hist/region/valid_loss_sag',valid_loss)
+#    np.save('WHS/Results/train_hist/region/valid_acc_sag',valid_accuracy)
+    print('Training Done! Total time:' + str(t_end - t_start))
\ No newline at end of file