a b/MI-DESS_IWTSE/Eval_OAI_MI.py
1
# ==============================================================================
2
# Copyright (C) 2023 Haresh Rengaraj Rajamohan, Tianyu Wang, Kevin Leung, 
3
# Gregory Chang, Kyunghyun Cho, Richard Kijowski & Cem M. Deniz 
4
#
5
# This file is part of OAI-MRI-TKR
6
#
7
# This program is free software: you can redistribute it and/or modify
8
# it under the terms of the GNU Affero General Public License as published
9
# by the Free Software Foundation, either version 3 of the License, or
10
# (at your option) any later version.
11
12
# This program is distributed in the hope that it will be useful,
13
# but WITHOUT ANY WARRANTY; without even the implied warranty of
14
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15
# GNU Affero General Public License for more details.
16
17
# You should have received a copy of the GNU Affero General Public License
18
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
19
# ==============================================================================
20
import numpy as np
21
import pandas as pd
22
import h5py
23
import nibabel as nib
24
import keras
25
from mpl_toolkits.mplot3d import Axes3D
26
import numpy as np
27
import matplotlib.pyplot as plt
28
import matplotlib.cm
29
import matplotlib.colorbar
30
import matplotlib.colors
31
import pandas as pd
32
import numpy as np
33
import os
34
35
36
from sklearn import metrics
37
38
39
from keras.models import load_model
40
41
from Augmentation import RandomCrop, CenterCrop, RandomFlip
42
43
44
45
from sklearn.metrics import roc_auc_score,auc,roc_curve,average_precision_score
46
47
import tensorflow as tf
48
49
50
from DataGenerator import DataGenerator as DG
51
52
class DataGenerator(keras.utils.Sequence):
53
    'Generates data for Keras'
54
    def __init__(self, directory1,directory2,file_folder1,file_folder2, batch_size=6, dim1=(384,384,36), dim2= (352,352,144), n_channels=1, n_classes=10, shuffle=True,normalize = True, randomCrop = True, randomFlip = True,flipProbability = -1):
55
        'Initialization'
56
        self.dim1 = dim1
57
        self.dim2 = dim2
58
        self.dim3 = (384,384,144)
59
        self.batch_size = batch_size
60
        self.dataset = pd.read_csv(directory1)
61
        self.IWdataset = pd.read_csv(directory1)
62
        self.DESSdataset = pd.read_csv(directory2)
63
        #self.list_IDs = list_IDs
64
        self.list_IDs = pd.read_csv(directory1)['ID']
65
        self.n_channels = n_channels
66
        self.n_classes = n_classes
67
        self.shuffle = shuffle
68
        self.on_epoch_end()
69
        self.file_folder1 = file_folder1+"00m/"
70
        self.file_folder2 = file_folder2+"00m/"
71
        self.normalize = normalize
72
        self.randomCrop = randomCrop
73
        self.randomFlip = randomFlip
74
        self.flipProbability = flipProbability
75
76
    def __len__(self):
77
        'Denotes the number of batches per epoch'
78
        return int(np.floor(len(self.list_IDs) / self.batch_size))
79
80
    def __getitem__(self, index):
81
        'Generate one batch of data'
82
        # Generate indexes of the batch
83
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
84
        # Find list of IDs
85
        list_IDs_temp = [self.list_IDs[k] for k in indexes]
86
        # Generate data
87
        X, y = self.__data_generation(indexes)
88
        return X, y
89
90
    def on_epoch_end(self):
91
        'Updates indexes after each epoch'
92
        self.indexes = np.arange(len(self.list_IDs))
93
        if self.shuffle == True:
94
            np.random.shuffle(self.indexes)
95
    
96
    def __data_generation(self, indexes):
97
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
98
        # Initialization
99
        X1 = np.empty((self.batch_size, *self.dim1, self.n_channels))
100
        X2 = np.empty((self.batch_size, *self.dim3, self.n_channels))
101
        #X2 = np.empty((self.batch_size, 6))
102
        y = np.empty((self.batch_size), dtype=int)
103
        for i in range(len(indexes)):
104
            # Store sample
105
            #print(i,ID)
106
            filename1 = self.IWdataset.iloc[indexes[i]]['h5Name']
107
            filename2 = self.DESSdataset.iloc[indexes[i]]['h5Name'] 
108
            pre_image1 = h5py.File(self.file_folder1 + filename1, "r")['data/'].value.astype('float64')
109
            pre_image2 = h5py.File(self.file_folder2 + filename2, "r")['data/'].value.astype('float64')
110
            #pre_image = padding_image(data = image,shape = [448,448,48])
111
            #pre_image = np.zeros(image.shape)
112
            #pre_image = image
113
            if pre_image1.shape[2] < 36:
114
                pre_image1 = padding_image(data = pre_image1)
115
            if pre_image2.shape[2] < 144:
116
                pre_image2 = padding_image2(data = pre_image2)
117
            # normalize
118
            if self.normalize:
119
                pre_image1 = normalize_MRIs(pre_image1)
120
                pre_image2 = normalize_MRIs(pre_image2)
121
            # Augmentation
122
            if self.randomFlip:
123
                pre_image1 = RandomFlip(image=pre_image1,p=0.5).horizontal_flip(p=self.flipProbability)
124
                pre_image2 = RandomFlip(image=pre_image2,p=0.5).horizontal_flip(p=self.flipProbability)
125
            if self.randomCrop:
126
                pre_image1 = RandomCrop(pre_image1).crop_along_hieght_width_depth(self.dim1)
127
                pre_image2 = RandomCrop(pre_image2).crop_along_hieght_width_depth(self.dim2)
128
            else:
129
                pre_image1 = CenterCrop(image=pre_image1).crop(size = self.dim1)
130
                pre_image2 = CenterCrop(image=pre_image2).crop(size = self.dim2)
131
132
            tempx = np.zeros([1,384,384,36,1])
133
            tempx[0,:,:,:,0] = pre_image1
134
            X1[i] = tempx
135
            tempx = np.zeros([1,384,384,144,1])
136
            tempx[0,16:368,16:368,:,0] = pre_image2
137
            X2[i] = tempx
138
            #X1[i,:,:,:,0] = pre_image1
139
            #X2[i,:,:,:,0] = pre_image2
140
            #X2[i] = self.dataset[self.dataset.FileName == ID].iloc[:,-6:]
141
            # Store class
142
            #print(self.dataset[self.dataset.FileName == ID].Label)
143
            y[i] = self.IWdataset.iloc[indexes[i]].Label
144
145
        return [X1,X2], y
146
147
    def getXvalue(self,index):
148
        return self.__getitem__(index)
149
150
def padding_image(data):
151
    l,w,h = data.shape
152
    images = np.zeros((l,w,36))
153
    zstart = int(np.ceil((36-data.shape[2])/2))
154
    images[:,:,zstart:zstart + h] = data
155
    return images 
156
157
def padding_image2(data):
158
    l,w,h = data.shape
159
    images = np.zeros((l,w,144))
160
    zstart = int(np.ceil((144-data.shape[2])/2))
161
    images[:,:,zstart:zstart + h] = data
162
    return images 
163
def normalize_MRIs(image):
164
    mean = np.mean(image)
165
    std = np.std(image)
166
    image -= mean
167
    #image -= 95.09
168
    image /= std
169
    #image /= 86.38
170
    return image
171
172
tf.app.flags.DEFINE_string('model_path','/gpfs/data/denizlab/Users/hrr288/Radiology_test/TCBmodelv1_400_add_final_arch/Dnetv1/add_ch32/', 'Folder with the models')
173
tf.app.flags.DEFINE_string('val_csv_path', '/gpfs/data/denizlab/Users/hrr288/Tianyu_dat/TestSets/', 'Folder with the fold splits')
174
175
tf.app.flags.DEFINE_string('test_csv_path1', '/gpfs/data/denizlab/Users/hrr288/data/OAI_SAG_TSE_test.csv', 'Folder with IW TSE test csv')
176
tf.app.flags.DEFINE_string('test_csv_path2', '/gpfs/data/denizlab/Users/hrr288/data/OAI_SAG_DESS_test.csv', 'Folder with DESS test csv')
177
178
tf.app.flags.DEFINE_string('result_path', './', 'Folder to save output csv with preds')
179
tf.app.flags.DEFINE_bool('vote', False, 'Choice to generate binary predictions for each model to compute final sensitivity/specificity')
180
tf.app.flags.DEFINE_string('file_folder1','/gpfs/data/denizlab/Datasets/OAI/SAG_IW_TSE/', 'Path to IW TSE HDF5 radiographs of test set')
181
tf.app.flags.DEFINE_string('file_folder2','/gpfs/data/denizlab/Datasets/OAI/SAG_3D_DESS/', 'Path to DESS  HDF5 radiographs of test set')
182
tf.app.flags.DEFINE_string('IWdataset_csv','/gpfs/data/denizlab/Datasets/OAI/SAG_IW_TSE/HDF5_00_cohort_2_prime.csv', 'Path to HDF5_00_cohort_2_prime.csv')
183
184
tf.app.flags.DEFINE_string('DESSdataset_csv','/gpfs/data/denizlab/Datasets/OAI/SAG_3D_DESS/HDF5_00_SAG_3D_DESScohort_2_prime.csv', 'Path to HDF5_00_SAG_3D_DESScohort_2_prime.csv')
185
186
187
188
FLAGS = tf.app.flags.FLAGS
189
def main(argv=None):
190
191
192
193
194
    val_params = {'dim1': (384,384,36),
195
                  'dim2': (352,352,144),
196
                  'batch_size': 1,
197
                  'n_classes': 1,
198
                  'n_channels': 1,
199
                  'shuffle': False,
200
                  'normalize' : True,
201
                  'randomCrop' : False,
202
                  'randomFlip' : False,
203
                  'flipProbability' : -1,
204
                     }
205
206
207
208
    validation_generator = DataGenerator(directory1 = FLAGS.test_csv_path1,directory2 = FLAGS.test_csv_path2,file_folder1=FLAGS.file_folder1,file_folder2=FLAGS.file_folder2,  **val_params)
209
    df = pd.read_csv(FLAGS.test_csv_path2,index_col=0)
210
211
    base_path = FLAGS.model_path
212
213
    models= {'fold_1':[],'fold_2':[],'fold_3':[],'fold_4':[],'fold_5':[],'fold_6':[],'fold_7':[]}
214
    for fold in np.arange(1,8):
215
        tmp_mod_list = []
216
        for cv in np.arange(1,7):
217
            dir_1 = 'Fold_'+str(fold)+'/CV_'+str(cv)+'/'
218
            files_avai =  os.listdir(base_path+dir_1)
219
            cands = []
220
            cands_score = []
221
            for fs in files_avai:
222
                if 'weights' not in fs:
223
                    continue
224
                else:
225
                    
226
                    cands_score.append(float(fs.split('-')[2]))
227
                    cands.append(dir_1+fs)
228
            ind_c = int(np.argmin(cands_score))
229
            
230
            tmp_mod_list.append(cands[ind_c])
231
        models['fold_'+str(fold)]=tmp_mod_list
232
    AUCS = []
233
    preds = []
234
    dfs = []
235
    pred_arr = np.zeros(df.shape[0])
236
    for i in np.arange(1,8):
237
        
238
        for j in np.arange(1,7):
239
            model = load_model(base_path+'/'+models['fold_'+str(i)][j-1])
240
            if FLAGS.vote:
241
                test_df = pd.read_csv(FLAGS.val_csv_path+'Fold_'+str(i)+'/CV_'+str(j)+'_val.csv')
242
                test_generator = DG(directory = FLAGS.val_csv_path+'Fold_'+str(i)+'/CV_'+str(j)+'_val.csv',file_folder1=FLAGS.file_folder1,file_folder2=FLAGS.file_folder2,IWdataset_csv=FLAGS.IWdataset_csv,DESSdataset_csv=FLAGS.DESSdataset_csv,  **val_params)
243
244
                test_pred = model.predict_generator(test_generator)
245
                test_df["Pred"] = test_pred
246
                fpr, tpr, thresholds = metrics.roc_curve(test_df["Label"], test_df["Pred"])
247
                opt_ind = np.argmax(tpr-fpr)
248
                opt_thresh = thresholds[int(opt_ind)]
249
                s = model.predict_generator(validation_generator)
250
                
251
                pred_arr += (np.squeeze(s)>=opt_thresh)
252
            else:
253
                s = model.predict_generator(validation_generator)
254
                
255
                pred_arr += np.squeeze(s)
256
257
258
        
259
        #AUCS.append(roc_auc_score(df['Label'],pred_arr))
260
        
261
        #preds.extend(list(pred_arr))
262
        
263
            
264
    pred_arr = pred_arr/42
265
266
267
268
269
    # In[ ]:
270
271
272
    df["Preds"] = pred_arr
273
    if  FLAGS.vote:
274
        df.to_csv(FLAGS.result_path+"OAI_results_vote.csv")
275
    else:
276
        df.to_csv(FLAGS.result_path+"OAI_results.csv")
277
if __name__ == "__main__":
278
    tf.app.run()
279