Diff of /inference.py [000000] .. [1b6491]

Switch to unified view

a b/inference.py
1
# ==============================================================================
2
# Copyright (C) 2020 Vladimir Juras, Ravinder Regatte and Cem M. Deniz
3
#
4
# This file is part of 2019_IWOAI_Challenge
5
#
6
# This program is free software: you can redistribute it and/or modify
7
# it under the terms of the GNU Affero General Public License as published
8
# by the Free Software Foundation, either version 3 of the License, or
9
# (at your option) any later version.
10
11
# This program is distributed in the hope that it will be useful,
12
# but WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14
# GNU Affero General Public License for more details.
15
16
# You should have received a copy of the GNU Affero General Public License
17
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
18
# ==============================================================================
19
import tensorflow as tf
20
import tf_utilities as tfut
21
import tf_layers as tflay
22
import models
23
import nibabel as nib
24
import numpy as np
25
import re
26
import time
27
import os
28
from functools import partial
29
from pathlib import Path
30
import h5py
31
from scipy import ndimage
32
from scipy.spatial import distance
33
34
import glob
35
from sklearn.model_selection import StratifiedKFold
36
from sys import platform
37
from sklearn.preprocessing import label_binarize
38
39
tf.app.flags.DEFINE_string('model_path', "./InferenceModel", 'Name of output folder.')
40
tf.app.flags.DEFINE_string('data_folder', './data', 'Data Folder')
41
tf.app.flags.DEFINE_integer('cv', -1, 'which fold to run')
42
tf.app.flags.DEFINE_integer('feature', 16, 'which fold to run')
43
tf.app.flags.DEFINE_string('model', '4atrous248', 'Model name.')
44
tf.app.flags.DEFINE_integer('reso', 384, 'Image size.')
45
tf.app.flags.DEFINE_integer('slices', 160, 'Number Of Slices')
46
tf.app.flags.DEFINE_integer('seed', 1234, 'Graph-level random seed.')
47
tf.app.flags.DEFINE_boolean('resnet', False, 'Whether to use resnet shortcut.')
48
tf.app.flags.DEFINE_integer('noImages', -1, 'how many images to infer')
49
FLAGS = tf.app.flags.FLAGS
50
51
num_classes = 7
52
num_CV =1
53
num_channels = 1
54
55
56
def main(argv=None):
57
58
    print('OUT:: ',FLAGS.feature,FLAGS.seed, FLAGS.resnet,FLAGS.model,FLAGS.reso)
59
60
    tf.set_random_seed(FLAGS.seed)
61
    np.random.seed(FLAGS.seed)
62
            
63
    batch_x = tf.placeholder(tf.float32, shape=(None, FLAGS.reso, FLAGS.reso, FLAGS.slices, 1))
64
    batch_y = tf.placeholder(tf.float32, shape=(None, FLAGS.reso, FLAGS.reso, FLAGS.slices, num_classes))
65
66
    keep_prob = tf.placeholder(tf.float32, shape=[])
67
    class_weights = tf.placeholder(tf.float32, shape=(num_classes))
68
69
    # choose the model
70
    inference_raw = {'4unet': partial(models.inference_unet4),# the original architecture and use 4 layers 
71
                      # replace the convolution operations between down-convolution and up-convolution layers 
72
                      # by atrous convolution
73
                     '4atrous248': partial(models.inference_atrous4, n_class=num_classes, dilation_rates=[2,4,8])}[FLAGS.model]
74
75
    inference = partial(inference_raw, resnet=FLAGS.resnet)
76
77
    score = inference(batch_x, features_root=FLAGS.feature, keep_prob=keep_prob, n_class=num_classes)
78
    logits = tf.nn.softmax(score)
79
80
    # load dataset from folder
81
    dataFolder = FLAGS.data_folder + '/test'
82
    pathNifti = Path(dataFolder)
83
84
    X = []  # create an empty list
85
    for fileList in list(pathNifti.glob('**/*.im')):
86
        X.append(fileList)
87
    X = sorted(X)
88
89
    if FLAGS.noImages ==-1:
90
        noOfFiles = len(X)
91
    else:
92
        noOfFiles = FLAGS.noImages
93
    list_X = list( X[i] for i in range(noOfFiles) )
94
    n_samples = len(list_X)
95
    X_test, train_info = tfut.loadData_list_h5_image(list_X,num_channels)
96
97
98
    X_test = tfut.zeroMeanUnitVariance(X_test)
99
    X_test = X_test[...,np.newaxis]
100
    cv=1
101
102
    sample_size = X_test.shape[0]
103
    output_path = FLAGS.model_path
104
    
105
    # find the model to read
106
    fdr = Path(output_path)
107
    cpktFile = sorted(fdr.glob(('**/*.meta')))
108
    read_file = str((cpktFile[-1]))
109
110
    with tf.Session() as sess:
111
        sess.run(tf.global_variables_initializer())
112
        saver = tf.train.Saver(max_to_keep=0)
113
114
        saver.restore(sess, read_file[:-5])
115
        print('Model restored from file: %s' % read_file[:-5])
116
        
117
        start = time.clock()
118
        y_out = np.zeros((FLAGS.reso, FLAGS.reso, FLAGS.slices, num_classes))
119
        for xi in range(sample_size):
120
            prob = sess.run(logits,
121
                                feed_dict={batch_x: X_test[xi:xi+1],
122
                                            keep_prob:1})
123
            y_out=prob
124
125
            winOut = np.zeros(y_out.shape)
126
            winOut[y_out[...,1]>0.5,...,1] =1
127
            winOut[y_out[...,2]>0.5,...,2] =2
128
            winOut[y_out[...,3]>0.5,...,3] =3
129
            winOut[y_out[...,4]>0.5,...,4] =4
130
            winOut[y_out[...,5]>0.5,...,5] =5
131
            winOut[y_out[...,6]>0.1,...,6] =6
132
133
            # place to keep only largest connected volume    
134
            if 1:
135
                for iii in range(1,7):
136
                    inn = winOut[...,iii]
137
                    all_labels, num_features = ndimage.label(inn)
138
                    print('Label #:',iii, num_features,'Number of Connected Volumes')
139
                    if num_features > 1:
140
                        volume = ndimage.sum(inn, all_labels, index=range(num_features+1))
141
                        print("Volume:", volume)
142
                        cem = all_labels == np.argmax(volume)
143
                        winOut[...,iii] = winOut[...,iii] * cem
144
145
            winOut =  np.sum(winOut,axis=4)
146
            winOut.astype(int)
147
            seg = np.pad(np.squeeze(winOut),((1,1),(1,1),(1,1)),'edge')
148
149
            #For the classes dimensions, the order for 4 classes are as the following: 
150
            #0 = femoral cartilage, 1 = tibial cartilage, 2 = patellar cartilage, and 3 = meniscus.
151
            #save as numpy array
152
            saveNumpy = np.zeros((384,384,160,4))
153
            saveNumpy[seg==1,...,0] = 1
154
            saveNumpy[seg==2,...,1] = 1
155
            saveNumpy[seg==3,...,1] = 1
156
            saveNumpy[seg==4,...,2] = 1
157
            saveNumpy[seg==5,...,3] = 1
158
            saveNumpy[seg==6,...,3] = 1
159
            savename= str(X[xi])
160
            fdr = Path('./InferenceResults/%s.npy' % (savename[-15:-3]))
161
162
            np.save(fdr, saveNumpy.astype(int), allow_pickle = False)
163
            
164
165
if __name__ == '__main__':
166
    tf.app.run()