|
a |
|
b/Segmentation/predict_seg.py |
|
|
1 |
#!/usr/bin/env python3 |
|
|
2 |
# -*- coding: utf-8 -*- |
|
|
3 |
""" |
|
|
4 |
Created on Wed Nov 14 21:47:22 2018 |
|
|
5 |
|
|
|
6 |
@author: Josefine |
|
|
7 |
""" |
|
|
8 |
|
|
|
9 |
import tensorflow as tf |
|
|
10 |
import numpy as np |
|
|
11 |
import glob |
|
|
12 |
import re |
|
|
13 |
from skimage.transform import resize |
|
|
14 |
|
|
|
15 |
imgDim = 256 |
|
|
16 |
labelDim = 256 |
|
|
17 |
|
|
|
18 |
############################################################################## |
|
|
19 |
### Data functions ###### |
|
|
20 |
############################################################################## |
|
|
21 |
def natural_sort(l): |
|
|
22 |
convert = lambda text: int(text) if text.isdigit() else text.lower() |
|
|
23 |
alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] |
|
|
24 |
return sorted(l, key = alphanum_key) |
|
|
25 |
|
|
|
26 |
def create_data(filename_img,direction): |
|
|
27 |
images = [] |
|
|
28 |
file = np.load(filename_img) |
|
|
29 |
a = file['images'] |
|
|
30 |
# Normalize: |
|
|
31 |
#a2 = np.clip(a,-1000,1000) |
|
|
32 |
#a3 = np.interp(a2, (a2.min(), a2.max()), (-1, +1)) |
|
|
33 |
im = resize(a,(labelDim,labelDim,labelDim),order=0) |
|
|
34 |
if direction == 'axial': |
|
|
35 |
for i in range(im.shape[0]): |
|
|
36 |
images.append((im[i,:,:])) |
|
|
37 |
if direction == 'sag': |
|
|
38 |
for i in range(im.shape[1]): |
|
|
39 |
images.append((im[:,i,:])) |
|
|
40 |
if direction == 'cor': |
|
|
41 |
for i in range(im.shape[2]): |
|
|
42 |
images.append((im[:,:,i])) |
|
|
43 |
images = np.asarray(images) |
|
|
44 |
images = images.reshape(-1, imgDim,imgDim,1) |
|
|
45 |
return images |
|
|
46 |
|
|
|
47 |
# Load test data |
|
|
48 |
filelist_test = natural_sort(glob.glob('WHS/Data/test_segments_*.npz')) # list of file names |
|
|
49 |
|
|
|
50 |
############################################################################# |
|
|
51 |
## Reload network and predict ###### |
|
|
52 |
############################################################################# |
|
|
53 |
# |
|
|
54 |
## ============================================================================= |
|
|
55 |
print("====================== LOAD AXIAL NETWORK: ===========================") |
|
|
56 |
# Doing predictions with the model |
|
|
57 |
tf.reset_default_graph() |
|
|
58 |
|
|
|
59 |
new_saver = tf.train.import_meta_graph('WHS/Results/segmentation/model_axial/model.ckpt.meta') |
|
|
60 |
|
|
|
61 |
prediction = np.zeros([1,256,256,9]) |
|
|
62 |
with tf.Session() as sess: |
|
|
63 |
new_saver.restore(sess, tf.train.latest_checkpoint('WHS/Results/segmentation/model_axial/')) |
|
|
64 |
graph = tf.get_default_graph() |
|
|
65 |
x = graph.get_tensor_by_name("x_train:0") |
|
|
66 |
op_to_restore = graph.get_tensor_by_name("output/Softmax:0") |
|
|
67 |
keep_rate = graph.get_tensor_by_name("Placeholder:0") |
|
|
68 |
context = graph.get_tensor_by_name("concat_5:0") |
|
|
69 |
x_contextual = graph.get_tensor_by_name("x_train_context:0") |
|
|
70 |
for i in range(30,len(filelist_test)): |
|
|
71 |
print('Processing test image', (i+1),'out of',(np.max(range(len(filelist_test)))+1)) |
|
|
72 |
# Find renderings corresponding to the given name |
|
|
73 |
prob_maps = [] |
|
|
74 |
x_test = create_data(filelist_test[i],'axial') |
|
|
75 |
for k in range(x_test.shape[0]): |
|
|
76 |
x_test_image = np.expand_dims(x_test[k,:,:,:], axis=0) |
|
|
77 |
y_output,out_context = sess.run([tf.nn.softmax(op_to_restore),context], feed_dict={x: x_test_image, x_contextual: prediction,keep_rate: 1.0}) |
|
|
78 |
prediction[0,:,:,:] = out_context |
|
|
79 |
prob_maps.append(y_output[0,:,:,:]) |
|
|
80 |
np.savez('WHS/Results/Predictions/segment/train_prob_maps_axial_{}'.format(i),prob_maps=prob_maps) |
|
|
81 |
print("================ DONE WITH AXIAL PREDICTIONS! ==================") |
|
|
82 |
# |
|
|
83 |
# ============================================================================= |
|
|
84 |
#print("====================== LOAD SAGITTAL NETWORK: ===========================") |
|
|
85 |
## Doing predictions with the model |
|
|
86 |
#tf.reset_default_graph() |
|
|
87 |
# |
|
|
88 |
#new_saver = tf.train.import_meta_graph('WHS/Results/segmentation/model_sag/model.ckpt.meta') |
|
|
89 |
#prediction = np.zeros([1,256,256,9]) |
|
|
90 |
#with tf.Session() as sess: |
|
|
91 |
# new_saver.restore(sess, tf.train.latest_checkpoint('WHS/Results/segmentation/model_sag/')) |
|
|
92 |
# graph = tf.get_default_graph() |
|
|
93 |
# x = graph.get_tensor_by_name("x_train:0") |
|
|
94 |
# keep_rate = graph.get_tensor_by_name("Placeholder:0") |
|
|
95 |
# op_to_restore = graph.get_tensor_by_name("output/Softmax:0") |
|
|
96 |
# context = graph.get_tensor_by_name("concat_5:0") |
|
|
97 |
# x_contextual = graph.get_tensor_by_name("x_train_context:0") |
|
|
98 |
# for i in range(30,len(filelist_test)): |
|
|
99 |
# print('Processing test image', (i+1),'out of',(np.max(range(len(filelist_test)))+1)) |
|
|
100 |
# # Find renderings corresponding to the given name |
|
|
101 |
# prob_maps = [] |
|
|
102 |
# x_test = create_data(filelist_test[i],'sag') |
|
|
103 |
# for k in range(x_test.shape[0]): |
|
|
104 |
# x_test_image = np.expand_dims(x_test[k,:,:,:], axis=0) |
|
|
105 |
# y_output,out_context = sess.run([tf.nn.softmax(op_to_restore),context], feed_dict={x: x_test_image, x_contextual: prediction,keep_rate: 1.0}) |
|
|
106 |
# prediction[0,:,:,:] = out_context |
|
|
107 |
# prob_maps.append(y_output[0,:,:,:]) |
|
|
108 |
# np.savez('WHS/Results/Predictions/segment/train_prob_maps_sag_{}'.format(i),prob_maps=prob_maps) |
|
|
109 |
#print("================ DONE WITH SAGITTAL PREDICTIONS! ==================") |
|
|
110 |
## |
|
|
111 |
#print("====================== LOAD CORONAL NETWORK: ===========================") |
|
|
112 |
## Doing predictions with the model |
|
|
113 |
#tf.reset_default_graph() |
|
|
114 |
# |
|
|
115 |
#new_saver = tf.train.import_meta_graph('WHS/Results/segmentation/model_cor/model.ckpt.meta') |
|
|
116 |
#prediction = np.zeros([1,256,256,9]) |
|
|
117 |
#with tf.Session() as sess: |
|
|
118 |
# new_saver.restore(sess, tf.train.latest_checkpoint('WHS/Results/segmentation/model_cor/')) |
|
|
119 |
# graph = tf.get_default_graph() |
|
|
120 |
# x = graph.get_tensor_by_name("x_train:0") |
|
|
121 |
# keep_rate = graph.get_tensor_by_name("Placeholder:0") |
|
|
122 |
# op_to_restore = graph.get_tensor_by_name("output/Softmax:0") |
|
|
123 |
# context = graph.get_tensor_by_name("concat_5:0") |
|
|
124 |
# x_contextual = graph.get_tensor_by_name("x_train_context:0") |
|
|
125 |
# for i in range(30,len(filelist_test)): |
|
|
126 |
# print('Processing test image', (i+1),'out of',(np.max(range(len(filelist_test)))+1)) |
|
|
127 |
# # Find renderings corresponding to the given name |
|
|
128 |
# prob_maps = [] |
|
|
129 |
# x_test = create_data(filelist_test[i],'cor') |
|
|
130 |
# for k in range(x_test.shape[0]): |
|
|
131 |
# x_test_image = np.expand_dims(x_test[k,:,:,:], axis=0) |
|
|
132 |
# y_output,out_context = sess.run([tf.nn.softmax(op_to_restore),context], feed_dict={x: x_test_image, x_contextual: prediction,keep_rate: 1.0}) |
|
|
133 |
# prediction[0,:,:,:] = out_context |
|
|
134 |
# prob_maps.append(y_output[0,:,:,:]) |
|
|
135 |
# np.savez('WHS/Results/Predictions/segment/train_prob_maps_cor_{}'.format(i),prob_maps=prob_maps) |
|
|
136 |
#print("================ DONE WITH CORONAL PREDICTONS! ==================") |
|
|
137 |
# |