|
a |
|
b/MI-DESS_IWTSE/Eval_OAI_MI.py |
|
|
1 |
# ============================================================================== |
|
|
2 |
# Copyright (C) 2023 Haresh Rengaraj Rajamohan, Tianyu Wang, Kevin Leung, |
|
|
3 |
# Gregory Chang, Kyunghyun Cho, Richard Kijowski & Cem M. Deniz |
|
|
4 |
# |
|
|
5 |
# This file is part of OAI-MRI-TKR |
|
|
6 |
# |
|
|
7 |
# This program is free software: you can redistribute it and/or modify |
|
|
8 |
# it under the terms of the GNU Affero General Public License as published |
|
|
9 |
# by the Free Software Foundation, either version 3 of the License, or |
|
|
10 |
# (at your option) any later version. |
|
|
11 |
|
|
|
12 |
# This program is distributed in the hope that it will be useful, |
|
|
13 |
# but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
|
14 |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
|
15 |
# GNU Affero General Public License for more details. |
|
|
16 |
|
|
|
17 |
# You should have received a copy of the GNU Affero General Public License |
|
|
18 |
# along with this program. If not, see <https://www.gnu.org/licenses/>. |
|
|
19 |
# ============================================================================== |
|
|
20 |
import numpy as np |
|
|
21 |
import pandas as pd |
|
|
22 |
import h5py |
|
|
23 |
import nibabel as nib |
|
|
24 |
import keras |
|
|
25 |
from mpl_toolkits.mplot3d import Axes3D |
|
|
26 |
import numpy as np |
|
|
27 |
import matplotlib.pyplot as plt |
|
|
28 |
import matplotlib.cm |
|
|
29 |
import matplotlib.colorbar |
|
|
30 |
import matplotlib.colors |
|
|
31 |
import pandas as pd |
|
|
32 |
import numpy as np |
|
|
33 |
import os |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
from sklearn import metrics |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
from keras.models import load_model |
|
|
40 |
|
|
|
41 |
from Augmentation import RandomCrop, CenterCrop, RandomFlip |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
|
|
|
45 |
from sklearn.metrics import roc_auc_score,auc,roc_curve,average_precision_score |
|
|
46 |
|
|
|
47 |
import tensorflow as tf |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
from DataGenerator import DataGenerator as DG |
|
|
51 |
|
|
|
52 |
class DataGenerator(keras.utils.Sequence): |
|
|
53 |
'Generates data for Keras' |
|
|
54 |
def __init__(self, directory1,directory2,file_folder1,file_folder2, batch_size=6, dim1=(384,384,36), dim2= (352,352,144), n_channels=1, n_classes=10, shuffle=True,normalize = True, randomCrop = True, randomFlip = True,flipProbability = -1): |
|
|
55 |
'Initialization' |
|
|
56 |
self.dim1 = dim1 |
|
|
57 |
self.dim2 = dim2 |
|
|
58 |
self.dim3 = (384,384,144) |
|
|
59 |
self.batch_size = batch_size |
|
|
60 |
self.dataset = pd.read_csv(directory1) |
|
|
61 |
self.IWdataset = pd.read_csv(directory1) |
|
|
62 |
self.DESSdataset = pd.read_csv(directory2) |
|
|
63 |
#self.list_IDs = list_IDs |
|
|
64 |
self.list_IDs = pd.read_csv(directory1)['ID'] |
|
|
65 |
self.n_channels = n_channels |
|
|
66 |
self.n_classes = n_classes |
|
|
67 |
self.shuffle = shuffle |
|
|
68 |
self.on_epoch_end() |
|
|
69 |
self.file_folder1 = file_folder1+"00m/" |
|
|
70 |
self.file_folder2 = file_folder2+"00m/" |
|
|
71 |
self.normalize = normalize |
|
|
72 |
self.randomCrop = randomCrop |
|
|
73 |
self.randomFlip = randomFlip |
|
|
74 |
self.flipProbability = flipProbability |
|
|
75 |
|
|
|
76 |
def __len__(self): |
|
|
77 |
'Denotes the number of batches per epoch' |
|
|
78 |
return int(np.floor(len(self.list_IDs) / self.batch_size)) |
|
|
79 |
|
|
|
80 |
def __getitem__(self, index): |
|
|
81 |
'Generate one batch of data' |
|
|
82 |
# Generate indexes of the batch |
|
|
83 |
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size] |
|
|
84 |
# Find list of IDs |
|
|
85 |
list_IDs_temp = [self.list_IDs[k] for k in indexes] |
|
|
86 |
# Generate data |
|
|
87 |
X, y = self.__data_generation(indexes) |
|
|
88 |
return X, y |
|
|
89 |
|
|
|
90 |
def on_epoch_end(self): |
|
|
91 |
'Updates indexes after each epoch' |
|
|
92 |
self.indexes = np.arange(len(self.list_IDs)) |
|
|
93 |
if self.shuffle == True: |
|
|
94 |
np.random.shuffle(self.indexes) |
|
|
95 |
|
|
|
96 |
def __data_generation(self, indexes): |
|
|
97 |
'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels) |
|
|
98 |
# Initialization |
|
|
99 |
X1 = np.empty((self.batch_size, *self.dim1, self.n_channels)) |
|
|
100 |
X2 = np.empty((self.batch_size, *self.dim3, self.n_channels)) |
|
|
101 |
#X2 = np.empty((self.batch_size, 6)) |
|
|
102 |
y = np.empty((self.batch_size), dtype=int) |
|
|
103 |
for i in range(len(indexes)): |
|
|
104 |
# Store sample |
|
|
105 |
#print(i,ID) |
|
|
106 |
filename1 = self.IWdataset.iloc[indexes[i]]['h5Name'] |
|
|
107 |
filename2 = self.DESSdataset.iloc[indexes[i]]['h5Name'] |
|
|
108 |
pre_image1 = h5py.File(self.file_folder1 + filename1, "r")['data/'].value.astype('float64') |
|
|
109 |
pre_image2 = h5py.File(self.file_folder2 + filename2, "r")['data/'].value.astype('float64') |
|
|
110 |
#pre_image = padding_image(data = image,shape = [448,448,48]) |
|
|
111 |
#pre_image = np.zeros(image.shape) |
|
|
112 |
#pre_image = image |
|
|
113 |
if pre_image1.shape[2] < 36: |
|
|
114 |
pre_image1 = padding_image(data = pre_image1) |
|
|
115 |
if pre_image2.shape[2] < 144: |
|
|
116 |
pre_image2 = padding_image2(data = pre_image2) |
|
|
117 |
# normalize |
|
|
118 |
if self.normalize: |
|
|
119 |
pre_image1 = normalize_MRIs(pre_image1) |
|
|
120 |
pre_image2 = normalize_MRIs(pre_image2) |
|
|
121 |
# Augmentation |
|
|
122 |
if self.randomFlip: |
|
|
123 |
pre_image1 = RandomFlip(image=pre_image1,p=0.5).horizontal_flip(p=self.flipProbability) |
|
|
124 |
pre_image2 = RandomFlip(image=pre_image2,p=0.5).horizontal_flip(p=self.flipProbability) |
|
|
125 |
if self.randomCrop: |
|
|
126 |
pre_image1 = RandomCrop(pre_image1).crop_along_hieght_width_depth(self.dim1) |
|
|
127 |
pre_image2 = RandomCrop(pre_image2).crop_along_hieght_width_depth(self.dim2) |
|
|
128 |
else: |
|
|
129 |
pre_image1 = CenterCrop(image=pre_image1).crop(size = self.dim1) |
|
|
130 |
pre_image2 = CenterCrop(image=pre_image2).crop(size = self.dim2) |
|
|
131 |
|
|
|
132 |
tempx = np.zeros([1,384,384,36,1]) |
|
|
133 |
tempx[0,:,:,:,0] = pre_image1 |
|
|
134 |
X1[i] = tempx |
|
|
135 |
tempx = np.zeros([1,384,384,144,1]) |
|
|
136 |
tempx[0,16:368,16:368,:,0] = pre_image2 |
|
|
137 |
X2[i] = tempx |
|
|
138 |
#X1[i,:,:,:,0] = pre_image1 |
|
|
139 |
#X2[i,:,:,:,0] = pre_image2 |
|
|
140 |
#X2[i] = self.dataset[self.dataset.FileName == ID].iloc[:,-6:] |
|
|
141 |
# Store class |
|
|
142 |
#print(self.dataset[self.dataset.FileName == ID].Label) |
|
|
143 |
y[i] = self.IWdataset.iloc[indexes[i]].Label |
|
|
144 |
|
|
|
145 |
return [X1,X2], y |
|
|
146 |
|
|
|
147 |
def getXvalue(self,index): |
|
|
148 |
return self.__getitem__(index) |
|
|
149 |
|
|
|
150 |
def padding_image(data): |
|
|
151 |
l,w,h = data.shape |
|
|
152 |
images = np.zeros((l,w,36)) |
|
|
153 |
zstart = int(np.ceil((36-data.shape[2])/2)) |
|
|
154 |
images[:,:,zstart:zstart + h] = data |
|
|
155 |
return images |
|
|
156 |
|
|
|
157 |
def padding_image2(data): |
|
|
158 |
l,w,h = data.shape |
|
|
159 |
images = np.zeros((l,w,144)) |
|
|
160 |
zstart = int(np.ceil((144-data.shape[2])/2)) |
|
|
161 |
images[:,:,zstart:zstart + h] = data |
|
|
162 |
return images |
|
|
163 |
def normalize_MRIs(image): |
|
|
164 |
mean = np.mean(image) |
|
|
165 |
std = np.std(image) |
|
|
166 |
image -= mean |
|
|
167 |
#image -= 95.09 |
|
|
168 |
image /= std |
|
|
169 |
#image /= 86.38 |
|
|
170 |
return image |
|
|
171 |
|
|
|
172 |
tf.app.flags.DEFINE_string('model_path','/gpfs/data/denizlab/Users/hrr288/Radiology_test/TCBmodelv1_400_add_final_arch/Dnetv1/add_ch32/', 'Folder with the models') |
|
|
173 |
tf.app.flags.DEFINE_string('val_csv_path', '/gpfs/data/denizlab/Users/hrr288/Tianyu_dat/TestSets/', 'Folder with the fold splits') |
|
|
174 |
|
|
|
175 |
tf.app.flags.DEFINE_string('test_csv_path1', '/gpfs/data/denizlab/Users/hrr288/data/OAI_SAG_TSE_test.csv', 'Folder with IW TSE test csv') |
|
|
176 |
tf.app.flags.DEFINE_string('test_csv_path2', '/gpfs/data/denizlab/Users/hrr288/data/OAI_SAG_DESS_test.csv', 'Folder with DESS test csv') |
|
|
177 |
|
|
|
178 |
tf.app.flags.DEFINE_string('result_path', './', 'Folder to save output csv with preds') |
|
|
179 |
tf.app.flags.DEFINE_bool('vote', False, 'Choice to generate binary predictions for each model to compute final sensitivity/specificity') |
|
|
180 |
tf.app.flags.DEFINE_string('file_folder1','/gpfs/data/denizlab/Datasets/OAI/SAG_IW_TSE/', 'Path to IW TSE HDF5 radiographs of test set') |
|
|
181 |
tf.app.flags.DEFINE_string('file_folder2','/gpfs/data/denizlab/Datasets/OAI/SAG_3D_DESS/', 'Path to DESS HDF5 radiographs of test set') |
|
|
182 |
tf.app.flags.DEFINE_string('IWdataset_csv','/gpfs/data/denizlab/Datasets/OAI/SAG_IW_TSE/HDF5_00_cohort_2_prime.csv', 'Path to HDF5_00_cohort_2_prime.csv') |
|
|
183 |
|
|
|
184 |
tf.app.flags.DEFINE_string('DESSdataset_csv','/gpfs/data/denizlab/Datasets/OAI/SAG_3D_DESS/HDF5_00_SAG_3D_DESScohort_2_prime.csv', 'Path to HDF5_00_SAG_3D_DESScohort_2_prime.csv') |
|
|
185 |
|
|
|
186 |
|
|
|
187 |
|
|
|
188 |
FLAGS = tf.app.flags.FLAGS |
|
|
189 |
def main(argv=None): |
|
|
190 |
|
|
|
191 |
|
|
|
192 |
|
|
|
193 |
|
|
|
194 |
val_params = {'dim1': (384,384,36), |
|
|
195 |
'dim2': (352,352,144), |
|
|
196 |
'batch_size': 1, |
|
|
197 |
'n_classes': 1, |
|
|
198 |
'n_channels': 1, |
|
|
199 |
'shuffle': False, |
|
|
200 |
'normalize' : True, |
|
|
201 |
'randomCrop' : False, |
|
|
202 |
'randomFlip' : False, |
|
|
203 |
'flipProbability' : -1, |
|
|
204 |
} |
|
|
205 |
|
|
|
206 |
|
|
|
207 |
|
|
|
208 |
validation_generator = DataGenerator(directory1 = FLAGS.test_csv_path1,directory2 = FLAGS.test_csv_path2,file_folder1=FLAGS.file_folder1,file_folder2=FLAGS.file_folder2, **val_params) |
|
|
209 |
df = pd.read_csv(FLAGS.test_csv_path2,index_col=0) |
|
|
210 |
|
|
|
211 |
base_path = FLAGS.model_path |
|
|
212 |
|
|
|
213 |
models= {'fold_1':[],'fold_2':[],'fold_3':[],'fold_4':[],'fold_5':[],'fold_6':[],'fold_7':[]} |
|
|
214 |
for fold in np.arange(1,8): |
|
|
215 |
tmp_mod_list = [] |
|
|
216 |
for cv in np.arange(1,7): |
|
|
217 |
dir_1 = 'Fold_'+str(fold)+'/CV_'+str(cv)+'/' |
|
|
218 |
files_avai = os.listdir(base_path+dir_1) |
|
|
219 |
cands = [] |
|
|
220 |
cands_score = [] |
|
|
221 |
for fs in files_avai: |
|
|
222 |
if 'weights' not in fs: |
|
|
223 |
continue |
|
|
224 |
else: |
|
|
225 |
|
|
|
226 |
cands_score.append(float(fs.split('-')[2])) |
|
|
227 |
cands.append(dir_1+fs) |
|
|
228 |
ind_c = int(np.argmin(cands_score)) |
|
|
229 |
|
|
|
230 |
tmp_mod_list.append(cands[ind_c]) |
|
|
231 |
models['fold_'+str(fold)]=tmp_mod_list |
|
|
232 |
AUCS = [] |
|
|
233 |
preds = [] |
|
|
234 |
dfs = [] |
|
|
235 |
pred_arr = np.zeros(df.shape[0]) |
|
|
236 |
for i in np.arange(1,8): |
|
|
237 |
|
|
|
238 |
for j in np.arange(1,7): |
|
|
239 |
model = load_model(base_path+'/'+models['fold_'+str(i)][j-1]) |
|
|
240 |
if FLAGS.vote: |
|
|
241 |
test_df = pd.read_csv(FLAGS.val_csv_path+'Fold_'+str(i)+'/CV_'+str(j)+'_val.csv') |
|
|
242 |
test_generator = DG(directory = FLAGS.val_csv_path+'Fold_'+str(i)+'/CV_'+str(j)+'_val.csv',file_folder1=FLAGS.file_folder1,file_folder2=FLAGS.file_folder2,IWdataset_csv=FLAGS.IWdataset_csv,DESSdataset_csv=FLAGS.DESSdataset_csv, **val_params) |
|
|
243 |
|
|
|
244 |
test_pred = model.predict_generator(test_generator) |
|
|
245 |
test_df["Pred"] = test_pred |
|
|
246 |
fpr, tpr, thresholds = metrics.roc_curve(test_df["Label"], test_df["Pred"]) |
|
|
247 |
opt_ind = np.argmax(tpr-fpr) |
|
|
248 |
opt_thresh = thresholds[int(opt_ind)] |
|
|
249 |
s = model.predict_generator(validation_generator) |
|
|
250 |
|
|
|
251 |
pred_arr += (np.squeeze(s)>=opt_thresh) |
|
|
252 |
else: |
|
|
253 |
s = model.predict_generator(validation_generator) |
|
|
254 |
|
|
|
255 |
pred_arr += np.squeeze(s) |
|
|
256 |
|
|
|
257 |
|
|
|
258 |
|
|
|
259 |
#AUCS.append(roc_auc_score(df['Label'],pred_arr)) |
|
|
260 |
|
|
|
261 |
#preds.extend(list(pred_arr)) |
|
|
262 |
|
|
|
263 |
|
|
|
264 |
pred_arr = pred_arr/42 |
|
|
265 |
|
|
|
266 |
|
|
|
267 |
|
|
|
268 |
|
|
|
269 |
# In[ ]: |
|
|
270 |
|
|
|
271 |
|
|
|
272 |
df["Preds"] = pred_arr |
|
|
273 |
if FLAGS.vote: |
|
|
274 |
df.to_csv(FLAGS.result_path+"OAI_results_vote.csv") |
|
|
275 |
else: |
|
|
276 |
df.to_csv(FLAGS.result_path+"OAI_results.csv") |
|
|
277 |
if __name__ == "__main__": |
|
|
278 |
tf.app.run() |
|
|
279 |
|