--- a
+++ b/Segmentation/predict_seg.py
@@ -0,0 +1,137 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Wed Nov 14 21:47:22 2018
+
+@author: Josefine
+"""
+
+import tensorflow as tf
+import numpy as np
+import glob
+import re
+from skimage.transform import resize
+
+imgDim = 256
+labelDim = 256
+
+##############################################################################
+###                              Data functions                         ######
+##############################################################################
+def natural_sort(l): 
+    convert = lambda text: int(text) if text.isdigit() else text.lower() 
+    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
+    return sorted(l, key = alphanum_key)
+
+def create_data(filename_img,direction):
+    images = []
+    file = np.load(filename_img)
+    a = file['images']
+    # Normalize:
+    #a2 = np.clip(a,-1000,1000)
+    #a3 = np.interp(a2, (a2.min(), a2.max()), (-1, +1))
+    im = resize(a,(labelDim,labelDim,labelDim),order=0)
+    if direction == 'axial':
+        for i in range(im.shape[0]):
+            images.append((im[i,:,:]))
+    if direction == 'sag':
+        for i in range(im.shape[1]):
+            images.append((im[:,i,:]))
+    if direction == 'cor':
+        for i in range(im.shape[2]):
+            images.append((im[:,:,i]))    
+    images = np.asarray(images)
+    images = images.reshape(-1, imgDim,imgDim,1)
+    return images
+
+# Load test data
+filelist_test = natural_sort(glob.glob('WHS/Data/test_segments_*.npz')) # list of file names
+
+#############################################################################
+##                  Reload network and predict                         ######
+#############################################################################
+#
+## =============================================================================
+print("====================== LOAD AXIAL NETWORK: ===========================")
+# Doing predictions with the model 
+tf.reset_default_graph()      
+
+new_saver = tf.train.import_meta_graph('WHS/Results/segmentation/model_axial/model.ckpt.meta')
+ 
+prediction = np.zeros([1,256,256,9])
+with tf.Session() as sess:
+    new_saver.restore(sess, tf.train.latest_checkpoint('WHS/Results/segmentation/model_axial/'))
+    graph = tf.get_default_graph()       
+    x = graph.get_tensor_by_name("x_train:0")
+    op_to_restore = graph.get_tensor_by_name("output/Softmax:0")
+    keep_rate = graph.get_tensor_by_name("Placeholder:0")
+    context = graph.get_tensor_by_name("concat_5:0")
+    x_contextual = graph.get_tensor_by_name("x_train_context:0")
+    for i in range(30,len(filelist_test)):
+        print('Processing test image', (i+1),'out of',(np.max(range(len(filelist_test)))+1))
+        # Find renderings corresponding to the given name
+        prob_maps = []
+        x_test = create_data(filelist_test[i],'axial')
+        for k in range(x_test.shape[0]):
+            x_test_image = np.expand_dims(x_test[k,:,:,:], axis=0)
+            y_output,out_context = sess.run([tf.nn.softmax(op_to_restore),context], feed_dict={x: x_test_image, x_contextual: prediction,keep_rate: 1.0})
+            prediction[0,:,:,:] = out_context
+            prob_maps.append(y_output[0,:,:,:])
+        np.savez('WHS/Results/Predictions/segment/train_prob_maps_axial_{}'.format(i),prob_maps=prob_maps)                            
+print("================ DONE WITH AXIAL PREDICTIONS! ==================")  
+#
+# =============================================================================
+#print("====================== LOAD SAGITTAL NETWORK: ===========================")
+## Doing predictions with the model 
+#tf.reset_default_graph()      
+#
+#new_saver = tf.train.import_meta_graph('WHS/Results/segmentation/model_sag/model.ckpt.meta')
+#prediction = np.zeros([1,256,256,9])
+#with tf.Session() as sess:
+#    new_saver.restore(sess, tf.train.latest_checkpoint('WHS/Results/segmentation/model_sag/'))
+#    graph = tf.get_default_graph()       
+#    x = graph.get_tensor_by_name("x_train:0")
+#    keep_rate = graph.get_tensor_by_name("Placeholder:0")
+#    op_to_restore = graph.get_tensor_by_name("output/Softmax:0")
+#    context = graph.get_tensor_by_name("concat_5:0")
+#    x_contextual = graph.get_tensor_by_name("x_train_context:0")
+#    for i in range(30,len(filelist_test)):
+#        print('Processing test image', (i+1),'out of',(np.max(range(len(filelist_test)))+1))
+#        # Find renderings corresponding to the given name
+#        prob_maps = []
+#        x_test = create_data(filelist_test[i],'sag')
+#        for k in range(x_test.shape[0]):
+#            x_test_image = np.expand_dims(x_test[k,:,:,:], axis=0)
+#            y_output,out_context = sess.run([tf.nn.softmax(op_to_restore),context], feed_dict={x: x_test_image, x_contextual: prediction,keep_rate: 1.0})
+#            prediction[0,:,:,:] = out_context
+#            prob_maps.append(y_output[0,:,:,:])
+#        np.savez('WHS/Results/Predictions/segment/train_prob_maps_sag_{}'.format(i),prob_maps=prob_maps)                            
+#print("================ DONE WITH SAGITTAL PREDICTIONS! ==================")  
+##
+#print("====================== LOAD CORONAL NETWORK: ===========================")
+## Doing predictions with the model 
+#tf.reset_default_graph()      
+#
+#new_saver = tf.train.import_meta_graph('WHS/Results/segmentation/model_cor/model.ckpt.meta')
+#prediction = np.zeros([1,256,256,9])
+#with tf.Session() as sess:
+#    new_saver.restore(sess, tf.train.latest_checkpoint('WHS/Results/segmentation/model_cor/'))
+#    graph = tf.get_default_graph()       
+#    x = graph.get_tensor_by_name("x_train:0")
+#    keep_rate = graph.get_tensor_by_name("Placeholder:0")
+#    op_to_restore = graph.get_tensor_by_name("output/Softmax:0")
+#    context = graph.get_tensor_by_name("concat_5:0")
+#    x_contextual = graph.get_tensor_by_name("x_train_context:0")
+#    for i in range(30,len(filelist_test)):
+#        print('Processing test image', (i+1),'out of',(np.max(range(len(filelist_test)))+1))
+#        # Find renderings corresponding to the given name
+#        prob_maps = []
+#        x_test = create_data(filelist_test[i],'cor')
+#        for k in range(x_test.shape[0]):
+#            x_test_image = np.expand_dims(x_test[k,:,:,:], axis=0)
+#            y_output,out_context = sess.run([tf.nn.softmax(op_to_restore),context], feed_dict={x: x_test_image, x_contextual: prediction,keep_rate: 1.0})
+#            prediction[0,:,:,:] = out_context
+#            prob_maps.append(y_output[0,:,:,:])
+#        np.savez('WHS/Results/Predictions/segment/train_prob_maps_cor_{}'.format(i),prob_maps=prob_maps)                            
+#print("================ DONE WITH CORONAL PREDICTONS! ==================")  
+#