Switch to unified view

a b/Segmentation/predict_seg.py
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
"""
4
Created on Wed Nov 14 21:47:22 2018
5
6
@author: Josefine
7
"""
8
9
import tensorflow as tf
10
import numpy as np
11
import glob
12
import re
13
from skimage.transform import resize
14
15
imgDim = 256
16
labelDim = 256
17
18
##############################################################################
19
###                              Data functions                         ######
20
##############################################################################
21
def natural_sort(l): 
22
    convert = lambda text: int(text) if text.isdigit() else text.lower() 
23
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
24
    return sorted(l, key = alphanum_key)
25
26
def create_data(filename_img,direction):
27
    images = []
28
    file = np.load(filename_img)
29
    a = file['images']
30
    # Normalize:
31
    #a2 = np.clip(a,-1000,1000)
32
    #a3 = np.interp(a2, (a2.min(), a2.max()), (-1, +1))
33
    im = resize(a,(labelDim,labelDim,labelDim),order=0)
34
    if direction == 'axial':
35
        for i in range(im.shape[0]):
36
            images.append((im[i,:,:]))
37
    if direction == 'sag':
38
        for i in range(im.shape[1]):
39
            images.append((im[:,i,:]))
40
    if direction == 'cor':
41
        for i in range(im.shape[2]):
42
            images.append((im[:,:,i]))    
43
    images = np.asarray(images)
44
    images = images.reshape(-1, imgDim,imgDim,1)
45
    return images
46
47
# Load test data
48
filelist_test = natural_sort(glob.glob('WHS/Data/test_segments_*.npz')) # list of file names
49
50
#############################################################################
51
##                  Reload network and predict                         ######
52
#############################################################################
53
#
54
## =============================================================================
55
print("====================== LOAD AXIAL NETWORK: ===========================")
56
# Doing predictions with the model 
57
tf.reset_default_graph()      
58
59
new_saver = tf.train.import_meta_graph('WHS/Results/segmentation/model_axial/model.ckpt.meta')
60
 
61
prediction = np.zeros([1,256,256,9])
62
with tf.Session() as sess:
63
    new_saver.restore(sess, tf.train.latest_checkpoint('WHS/Results/segmentation/model_axial/'))
64
    graph = tf.get_default_graph()       
65
    x = graph.get_tensor_by_name("x_train:0")
66
    op_to_restore = graph.get_tensor_by_name("output/Softmax:0")
67
    keep_rate = graph.get_tensor_by_name("Placeholder:0")
68
    context = graph.get_tensor_by_name("concat_5:0")
69
    x_contextual = graph.get_tensor_by_name("x_train_context:0")
70
    for i in range(30,len(filelist_test)):
71
        print('Processing test image', (i+1),'out of',(np.max(range(len(filelist_test)))+1))
72
        # Find renderings corresponding to the given name
73
        prob_maps = []
74
        x_test = create_data(filelist_test[i],'axial')
75
        for k in range(x_test.shape[0]):
76
            x_test_image = np.expand_dims(x_test[k,:,:,:], axis=0)
77
            y_output,out_context = sess.run([tf.nn.softmax(op_to_restore),context], feed_dict={x: x_test_image, x_contextual: prediction,keep_rate: 1.0})
78
            prediction[0,:,:,:] = out_context
79
            prob_maps.append(y_output[0,:,:,:])
80
        np.savez('WHS/Results/Predictions/segment/train_prob_maps_axial_{}'.format(i),prob_maps=prob_maps)                            
81
print("================ DONE WITH AXIAL PREDICTIONS! ==================")  
82
#
83
# =============================================================================
84
#print("====================== LOAD SAGITTAL NETWORK: ===========================")
85
## Doing predictions with the model 
86
#tf.reset_default_graph()      
87
#
88
#new_saver = tf.train.import_meta_graph('WHS/Results/segmentation/model_sag/model.ckpt.meta')
89
#prediction = np.zeros([1,256,256,9])
90
#with tf.Session() as sess:
91
#    new_saver.restore(sess, tf.train.latest_checkpoint('WHS/Results/segmentation/model_sag/'))
92
#    graph = tf.get_default_graph()       
93
#    x = graph.get_tensor_by_name("x_train:0")
94
#    keep_rate = graph.get_tensor_by_name("Placeholder:0")
95
#    op_to_restore = graph.get_tensor_by_name("output/Softmax:0")
96
#    context = graph.get_tensor_by_name("concat_5:0")
97
#    x_contextual = graph.get_tensor_by_name("x_train_context:0")
98
#    for i in range(30,len(filelist_test)):
99
#        print('Processing test image', (i+1),'out of',(np.max(range(len(filelist_test)))+1))
100
#        # Find renderings corresponding to the given name
101
#        prob_maps = []
102
#        x_test = create_data(filelist_test[i],'sag')
103
#        for k in range(x_test.shape[0]):
104
#            x_test_image = np.expand_dims(x_test[k,:,:,:], axis=0)
105
#            y_output,out_context = sess.run([tf.nn.softmax(op_to_restore),context], feed_dict={x: x_test_image, x_contextual: prediction,keep_rate: 1.0})
106
#            prediction[0,:,:,:] = out_context
107
#            prob_maps.append(y_output[0,:,:,:])
108
#        np.savez('WHS/Results/Predictions/segment/train_prob_maps_sag_{}'.format(i),prob_maps=prob_maps)                            
109
#print("================ DONE WITH SAGITTAL PREDICTIONS! ==================")  
110
##
111
#print("====================== LOAD CORONAL NETWORK: ===========================")
112
## Doing predictions with the model 
113
#tf.reset_default_graph()      
114
#
115
#new_saver = tf.train.import_meta_graph('WHS/Results/segmentation/model_cor/model.ckpt.meta')
116
#prediction = np.zeros([1,256,256,9])
117
#with tf.Session() as sess:
118
#    new_saver.restore(sess, tf.train.latest_checkpoint('WHS/Results/segmentation/model_cor/'))
119
#    graph = tf.get_default_graph()       
120
#    x = graph.get_tensor_by_name("x_train:0")
121
#    keep_rate = graph.get_tensor_by_name("Placeholder:0")
122
#    op_to_restore = graph.get_tensor_by_name("output/Softmax:0")
123
#    context = graph.get_tensor_by_name("concat_5:0")
124
#    x_contextual = graph.get_tensor_by_name("x_train_context:0")
125
#    for i in range(30,len(filelist_test)):
126
#        print('Processing test image', (i+1),'out of',(np.max(range(len(filelist_test)))+1))
127
#        # Find renderings corresponding to the given name
128
#        prob_maps = []
129
#        x_test = create_data(filelist_test[i],'cor')
130
#        for k in range(x_test.shape[0]):
131
#            x_test_image = np.expand_dims(x_test[k,:,:,:], axis=0)
132
#            y_output,out_context = sess.run([tf.nn.softmax(op_to_restore),context], feed_dict={x: x_test_image, x_contextual: prediction,keep_rate: 1.0})
133
#            prediction[0,:,:,:] = out_context
134
#            prob_maps.append(y_output[0,:,:,:])
135
#        np.savez('WHS/Results/Predictions/segment/train_prob_maps_cor_{}'.format(i),prob_maps=prob_maps)                            
136
#print("================ DONE WITH CORONAL PREDICTONS! ==================")  
137
#