--- a +++ b/Segmentation/predict_seg.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Wed Nov 14 21:47:22 2018 + +@author: Josefine +""" + +import tensorflow as tf +import numpy as np +import glob +import re +from skimage.transform import resize + +imgDim = 256 +labelDim = 256 + +############################################################################## +### Data functions ###### +############################################################################## +def natural_sort(l): + convert = lambda text: int(text) if text.isdigit() else text.lower() + alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] + return sorted(l, key = alphanum_key) + +def create_data(filename_img,direction): + images = [] + file = np.load(filename_img) + a = file['images'] + # Normalize: + #a2 = np.clip(a,-1000,1000) + #a3 = np.interp(a2, (a2.min(), a2.max()), (-1, +1)) + im = resize(a,(labelDim,labelDim,labelDim),order=0) + if direction == 'axial': + for i in range(im.shape[0]): + images.append((im[i,:,:])) + if direction == 'sag': + for i in range(im.shape[1]): + images.append((im[:,i,:])) + if direction == 'cor': + for i in range(im.shape[2]): + images.append((im[:,:,i])) + images = np.asarray(images) + images = images.reshape(-1, imgDim,imgDim,1) + return images + +# Load test data +filelist_test = natural_sort(glob.glob('WHS/Data/test_segments_*.npz')) # list of file names + +############################################################################# +## Reload network and predict ###### +############################################################################# +# +## ============================================================================= +print("====================== LOAD AXIAL NETWORK: ===========================") +# Doing predictions with the model +tf.reset_default_graph() + +new_saver = tf.train.import_meta_graph('WHS/Results/segmentation/model_axial/model.ckpt.meta') + +prediction = np.zeros([1,256,256,9]) +with tf.Session() as sess: + new_saver.restore(sess, tf.train.latest_checkpoint('WHS/Results/segmentation/model_axial/')) + graph = tf.get_default_graph() + x = graph.get_tensor_by_name("x_train:0") + op_to_restore = graph.get_tensor_by_name("output/Softmax:0") + keep_rate = graph.get_tensor_by_name("Placeholder:0") + context = graph.get_tensor_by_name("concat_5:0") + x_contextual = graph.get_tensor_by_name("x_train_context:0") + for i in range(30,len(filelist_test)): + print('Processing test image', (i+1),'out of',(np.max(range(len(filelist_test)))+1)) + # Find renderings corresponding to the given name + prob_maps = [] + x_test = create_data(filelist_test[i],'axial') + for k in range(x_test.shape[0]): + x_test_image = np.expand_dims(x_test[k,:,:,:], axis=0) + y_output,out_context = sess.run([tf.nn.softmax(op_to_restore),context], feed_dict={x: x_test_image, x_contextual: prediction,keep_rate: 1.0}) + prediction[0,:,:,:] = out_context + prob_maps.append(y_output[0,:,:,:]) + np.savez('WHS/Results/Predictions/segment/train_prob_maps_axial_{}'.format(i),prob_maps=prob_maps) +print("================ DONE WITH AXIAL PREDICTIONS! ==================") +# +# ============================================================================= +#print("====================== LOAD SAGITTAL NETWORK: ===========================") +## Doing predictions with the model +#tf.reset_default_graph() +# +#new_saver = tf.train.import_meta_graph('WHS/Results/segmentation/model_sag/model.ckpt.meta') +#prediction = np.zeros([1,256,256,9]) +#with tf.Session() as sess: +# new_saver.restore(sess, tf.train.latest_checkpoint('WHS/Results/segmentation/model_sag/')) +# graph = tf.get_default_graph() +# x = graph.get_tensor_by_name("x_train:0") +# keep_rate = graph.get_tensor_by_name("Placeholder:0") +# op_to_restore = graph.get_tensor_by_name("output/Softmax:0") +# context = graph.get_tensor_by_name("concat_5:0") +# x_contextual = graph.get_tensor_by_name("x_train_context:0") +# for i in range(30,len(filelist_test)): +# print('Processing test image', (i+1),'out of',(np.max(range(len(filelist_test)))+1)) +# # Find renderings corresponding to the given name +# prob_maps = [] +# x_test = create_data(filelist_test[i],'sag') +# for k in range(x_test.shape[0]): +# x_test_image = np.expand_dims(x_test[k,:,:,:], axis=0) +# y_output,out_context = sess.run([tf.nn.softmax(op_to_restore),context], feed_dict={x: x_test_image, x_contextual: prediction,keep_rate: 1.0}) +# prediction[0,:,:,:] = out_context +# prob_maps.append(y_output[0,:,:,:]) +# np.savez('WHS/Results/Predictions/segment/train_prob_maps_sag_{}'.format(i),prob_maps=prob_maps) +#print("================ DONE WITH SAGITTAL PREDICTIONS! ==================") +## +#print("====================== LOAD CORONAL NETWORK: ===========================") +## Doing predictions with the model +#tf.reset_default_graph() +# +#new_saver = tf.train.import_meta_graph('WHS/Results/segmentation/model_cor/model.ckpt.meta') +#prediction = np.zeros([1,256,256,9]) +#with tf.Session() as sess: +# new_saver.restore(sess, tf.train.latest_checkpoint('WHS/Results/segmentation/model_cor/')) +# graph = tf.get_default_graph() +# x = graph.get_tensor_by_name("x_train:0") +# keep_rate = graph.get_tensor_by_name("Placeholder:0") +# op_to_restore = graph.get_tensor_by_name("output/Softmax:0") +# context = graph.get_tensor_by_name("concat_5:0") +# x_contextual = graph.get_tensor_by_name("x_train_context:0") +# for i in range(30,len(filelist_test)): +# print('Processing test image', (i+1),'out of',(np.max(range(len(filelist_test)))+1)) +# # Find renderings corresponding to the given name +# prob_maps = [] +# x_test = create_data(filelist_test[i],'cor') +# for k in range(x_test.shape[0]): +# x_test_image = np.expand_dims(x_test[k,:,:,:], axis=0) +# y_output,out_context = sess.run([tf.nn.softmax(op_to_restore),context], feed_dict={x: x_test_image, x_contextual: prediction,keep_rate: 1.0}) +# prediction[0,:,:,:] = out_context +# prob_maps.append(y_output[0,:,:,:]) +# np.savez('WHS/Results/Predictions/segment/train_prob_maps_cor_{}'.format(i),prob_maps=prob_maps) +#print("================ DONE WITH CORONAL PREDICTONS! ==================") +#