--- a +++ b/Region/predict_region_sag.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Wed Oct 31 09:05:57 2018 + +@author: Josefine +""" + +import tensorflow as tf +import numpy as np +import nibabel as nib +import glob +import re +from skimage.transform import resize + +imgDim = 128 + +############################################################################## +### Data functions ###### +############################################################################## +def natural_sort(l): + convert = lambda text: int(text) if text.isdigit() else text.lower() + alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] + return sorted(l, key = alphanum_key) + +def create_image(filename_img,direction): + images = [] + a = nib.load(filename_img) + a = a.get_data() + # Normalize: + a2 = np.clip(a,-1000,1000) + a3 = np.interp(a2, (a2.min(), a2.max()), (-1, +1)) + # Reshape: + img = np.zeros([512,512,512])+np.min(a3) + index1 = int(np.ceil((512-a.shape[2])/2)) + index2 = int(512-np.floor((512-a.shape[2])/2)) + img[:,:,index1:index2] = a3 + im = resize(img,(imgDim,imgDim,imgDim),order=0) + if direction == 'sag': + for i in range(im.shape[0]): + images.append((im[i,:,:])) + if direction == 'cor': + for i in range(im.shape[1]): + images.append((im[:,i,:])) + if direction == 'axial': + for i in range(im.shape[2]): + images.append((im[:,:,i])) + images = np.asarray(images) + images = images.reshape(-1, imgDim,imgDim,1) + return images + +# Load test data +filelist_test = natural_sort(glob.glob('WHS/ct_train_test/ct_test/*_image.nii.gz')) # list of file names + +# Load train data for segmentation network +filelist_train = natural_sort(glob.glob('WHS/Augment_data/*_image.nii')) # list of file names + +############################################################################## +### Reload network and predict ###### +############################################################################## +print("====================== LOAD SAGITTAL NETWORK: ===========================") + +# Doing predictions with the model +tf.reset_default_graph() + +new_saver = tf.train.import_meta_graph('WHS/Results/region/model_sag/model.ckpt.meta') + +with tf.Session() as sess: + new_saver.restore(sess, tf.train.latest_checkpoint('WHS/Results/region/model_sag/')) + graph = tf.get_default_graph() + x = graph.get_tensor_by_name("x_train:0") + op_to_restore = graph.get_tensor_by_name("output/Softmax:0") #ME +# +# for i in range(len(filelist_test)): +# print('Processing test image', (i+1),'out of',(np.max(range(len(filelist_test)))+1)) +# # Find renderings corresponding to the given name +# prob_maps = [] +# x_test = create_image(filelist_test[i],'sag') +# for k in range(x_test.shape[0]): +# x_test_image = np.expand_dims(x_test[k,:,:,:], axis=0) +# y_output = sess.run(tf.nn.softmax(op_to_restore), feed_dict={x: x_test_image,'Placeholder:0':1.0}) +# prob_maps.append(y_output[0,:,:,:]) +# np.savez('WHS/Results/Predictions/region/test_prob_maps_sag_{}'.format(i),prob_maps=prob_maps) +# print("================ DONE WITH TEST PREDICTIONS! ==================") + + for i in range(30,len(filelist_train)): + print('Processing test image', (i+1),'out of',(np.max(range(len(filelist_train)))+1)) + # Find renderings corresponding to the given name + prob_maps = [] + x_test = create_image(filelist_train[i],'sag') + for k in range(x_test.shape[0]): + x_test_image = np.expand_dims(x_test[k,:,:,:], axis=0) + y_output = sess.run(tf.nn.softmax(op_to_restore), feed_dict={x: x_test_image,'Placeholder:0':1.0}) + prob_maps.append(y_output[0,:,:,:]) + np.savez('WHS/Results/Predictions/region/train_prob_maps_sag_{}'.format(i),prob_maps=prob_maps) + print("================ DONE WITH TRAIN PREDICTIONS! ==================") + +print("================ DONE WITH AXIAL PREDICTIONS! ==================") \ No newline at end of file