Switch to side-by-side view

--- a
+++ b/MI-DESS_IWTSE/Eval_OAI_MI.py
@@ -0,0 +1,279 @@
+# ==============================================================================
+# Copyright (C) 2023 Haresh Rengaraj Rajamohan, Tianyu Wang, Kevin Leung, 
+# Gregory Chang, Kyunghyun Cho, Richard Kijowski & Cem M. Deniz 
+#
+# This file is part of OAI-MRI-TKR
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+# ==============================================================================
+import numpy as np
+import pandas as pd
+import h5py
+import nibabel as nib
+import keras
+from mpl_toolkits.mplot3d import Axes3D
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib.cm
+import matplotlib.colorbar
+import matplotlib.colors
+import pandas as pd
+import numpy as np
+import os
+
+
+from sklearn import metrics
+
+
+from keras.models import load_model
+
+from Augmentation import RandomCrop, CenterCrop, RandomFlip
+
+
+
+from sklearn.metrics import roc_auc_score,auc,roc_curve,average_precision_score
+
+import tensorflow as tf
+
+
+from DataGenerator import DataGenerator as DG
+
+class DataGenerator(keras.utils.Sequence):
+    'Generates data for Keras'
+    def __init__(self, directory1,directory2,file_folder1,file_folder2, batch_size=6, dim1=(384,384,36), dim2= (352,352,144), n_channels=1, n_classes=10, shuffle=True,normalize = True, randomCrop = True, randomFlip = True,flipProbability = -1):
+        'Initialization'
+        self.dim1 = dim1
+        self.dim2 = dim2
+        self.dim3 = (384,384,144)
+        self.batch_size = batch_size
+        self.dataset = pd.read_csv(directory1)
+        self.IWdataset = pd.read_csv(directory1)
+        self.DESSdataset = pd.read_csv(directory2)
+        #self.list_IDs = list_IDs
+        self.list_IDs = pd.read_csv(directory1)['ID']
+        self.n_channels = n_channels
+        self.n_classes = n_classes
+        self.shuffle = shuffle
+        self.on_epoch_end()
+        self.file_folder1 = file_folder1+"00m/"
+        self.file_folder2 = file_folder2+"00m/"
+        self.normalize = normalize
+        self.randomCrop = randomCrop
+        self.randomFlip = randomFlip
+        self.flipProbability = flipProbability
+
+    def __len__(self):
+        'Denotes the number of batches per epoch'
+        return int(np.floor(len(self.list_IDs) / self.batch_size))
+
+    def __getitem__(self, index):
+        'Generate one batch of data'
+        # Generate indexes of the batch
+        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
+        # Find list of IDs
+        list_IDs_temp = [self.list_IDs[k] for k in indexes]
+        # Generate data
+        X, y = self.__data_generation(indexes)
+        return X, y
+
+    def on_epoch_end(self):
+        'Updates indexes after each epoch'
+        self.indexes = np.arange(len(self.list_IDs))
+        if self.shuffle == True:
+            np.random.shuffle(self.indexes)
+    
+    def __data_generation(self, indexes):
+        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
+        # Initialization
+        X1 = np.empty((self.batch_size, *self.dim1, self.n_channels))
+        X2 = np.empty((self.batch_size, *self.dim3, self.n_channels))
+        #X2 = np.empty((self.batch_size, 6))
+        y = np.empty((self.batch_size), dtype=int)
+        for i in range(len(indexes)):
+            # Store sample
+            #print(i,ID)
+            filename1 = self.IWdataset.iloc[indexes[i]]['h5Name']
+            filename2 = self.DESSdataset.iloc[indexes[i]]['h5Name'] 
+            pre_image1 = h5py.File(self.file_folder1 + filename1, "r")['data/'].value.astype('float64')
+            pre_image2 = h5py.File(self.file_folder2 + filename2, "r")['data/'].value.astype('float64')
+            #pre_image = padding_image(data = image,shape = [448,448,48])
+            #pre_image = np.zeros(image.shape)
+            #pre_image = image
+            if pre_image1.shape[2] < 36:
+                pre_image1 = padding_image(data = pre_image1)
+            if pre_image2.shape[2] < 144:
+                pre_image2 = padding_image2(data = pre_image2)
+            # normalize
+            if self.normalize:
+                pre_image1 = normalize_MRIs(pre_image1)
+                pre_image2 = normalize_MRIs(pre_image2)
+            # Augmentation
+            if self.randomFlip:
+                pre_image1 = RandomFlip(image=pre_image1,p=0.5).horizontal_flip(p=self.flipProbability)
+                pre_image2 = RandomFlip(image=pre_image2,p=0.5).horizontal_flip(p=self.flipProbability)
+            if self.randomCrop:
+                pre_image1 = RandomCrop(pre_image1).crop_along_hieght_width_depth(self.dim1)
+                pre_image2 = RandomCrop(pre_image2).crop_along_hieght_width_depth(self.dim2)
+            else:
+                pre_image1 = CenterCrop(image=pre_image1).crop(size = self.dim1)
+                pre_image2 = CenterCrop(image=pre_image2).crop(size = self.dim2)
+
+            tempx = np.zeros([1,384,384,36,1])
+            tempx[0,:,:,:,0] = pre_image1
+            X1[i] = tempx
+            tempx = np.zeros([1,384,384,144,1])
+            tempx[0,16:368,16:368,:,0] = pre_image2
+            X2[i] = tempx
+            #X1[i,:,:,:,0] = pre_image1
+            #X2[i,:,:,:,0] = pre_image2
+            #X2[i] = self.dataset[self.dataset.FileName == ID].iloc[:,-6:]
+            # Store class
+            #print(self.dataset[self.dataset.FileName == ID].Label)
+            y[i] = self.IWdataset.iloc[indexes[i]].Label
+
+        return [X1,X2], y
+
+    def getXvalue(self,index):
+        return self.__getitem__(index)
+
+def padding_image(data):
+    l,w,h = data.shape
+    images = np.zeros((l,w,36))
+    zstart = int(np.ceil((36-data.shape[2])/2))
+    images[:,:,zstart:zstart + h] = data
+    return images 
+
+def padding_image2(data):
+    l,w,h = data.shape
+    images = np.zeros((l,w,144))
+    zstart = int(np.ceil((144-data.shape[2])/2))
+    images[:,:,zstart:zstart + h] = data
+    return images 
+def normalize_MRIs(image):
+    mean = np.mean(image)
+    std = np.std(image)
+    image -= mean
+    #image -= 95.09
+    image /= std
+    #image /= 86.38
+    return image
+
+tf.app.flags.DEFINE_string('model_path','/gpfs/data/denizlab/Users/hrr288/Radiology_test/TCBmodelv1_400_add_final_arch/Dnetv1/add_ch32/', 'Folder with the models')
+tf.app.flags.DEFINE_string('val_csv_path', '/gpfs/data/denizlab/Users/hrr288/Tianyu_dat/TestSets/', 'Folder with the fold splits')
+
+tf.app.flags.DEFINE_string('test_csv_path1', '/gpfs/data/denizlab/Users/hrr288/data/OAI_SAG_TSE_test.csv', 'Folder with IW TSE test csv')
+tf.app.flags.DEFINE_string('test_csv_path2', '/gpfs/data/denizlab/Users/hrr288/data/OAI_SAG_DESS_test.csv', 'Folder with DESS test csv')
+
+tf.app.flags.DEFINE_string('result_path', './', 'Folder to save output csv with preds')
+tf.app.flags.DEFINE_bool('vote', False, 'Choice to generate binary predictions for each model to compute final sensitivity/specificity')
+tf.app.flags.DEFINE_string('file_folder1','/gpfs/data/denizlab/Datasets/OAI/SAG_IW_TSE/', 'Path to IW TSE HDF5 radiographs of test set')
+tf.app.flags.DEFINE_string('file_folder2','/gpfs/data/denizlab/Datasets/OAI/SAG_3D_DESS/', 'Path to DESS  HDF5 radiographs of test set')
+tf.app.flags.DEFINE_string('IWdataset_csv','/gpfs/data/denizlab/Datasets/OAI/SAG_IW_TSE/HDF5_00_cohort_2_prime.csv', 'Path to HDF5_00_cohort_2_prime.csv')
+
+tf.app.flags.DEFINE_string('DESSdataset_csv','/gpfs/data/denizlab/Datasets/OAI/SAG_3D_DESS/HDF5_00_SAG_3D_DESScohort_2_prime.csv', 'Path to HDF5_00_SAG_3D_DESScohort_2_prime.csv')
+
+
+
+FLAGS = tf.app.flags.FLAGS
+def main(argv=None):
+
+
+
+
+    val_params = {'dim1': (384,384,36),
+                  'dim2': (352,352,144),
+                  'batch_size': 1,
+                  'n_classes': 1,
+                  'n_channels': 1,
+                  'shuffle': False,
+                  'normalize' : True,
+                  'randomCrop' : False,
+                  'randomFlip' : False,
+                  'flipProbability' : -1,
+                     }
+
+
+
+    validation_generator = DataGenerator(directory1 = FLAGS.test_csv_path1,directory2 = FLAGS.test_csv_path2,file_folder1=FLAGS.file_folder1,file_folder2=FLAGS.file_folder2,  **val_params)
+    df = pd.read_csv(FLAGS.test_csv_path2,index_col=0)
+
+    base_path = FLAGS.model_path
+
+    models= {'fold_1':[],'fold_2':[],'fold_3':[],'fold_4':[],'fold_5':[],'fold_6':[],'fold_7':[]}
+    for fold in np.arange(1,8):
+        tmp_mod_list = []
+        for cv in np.arange(1,7):
+            dir_1 = 'Fold_'+str(fold)+'/CV_'+str(cv)+'/'
+            files_avai =  os.listdir(base_path+dir_1)
+            cands = []
+            cands_score = []
+            for fs in files_avai:
+                if 'weights' not in fs:
+                    continue
+                else:
+                    
+                    cands_score.append(float(fs.split('-')[2]))
+                    cands.append(dir_1+fs)
+            ind_c = int(np.argmin(cands_score))
+            
+            tmp_mod_list.append(cands[ind_c])
+        models['fold_'+str(fold)]=tmp_mod_list
+    AUCS = []
+    preds = []
+    dfs = []
+    pred_arr = np.zeros(df.shape[0])
+    for i in np.arange(1,8):
+        
+        for j in np.arange(1,7):
+            model = load_model(base_path+'/'+models['fold_'+str(i)][j-1])
+            if FLAGS.vote:
+                test_df = pd.read_csv(FLAGS.val_csv_path+'Fold_'+str(i)+'/CV_'+str(j)+'_val.csv')
+                test_generator = DG(directory = FLAGS.val_csv_path+'Fold_'+str(i)+'/CV_'+str(j)+'_val.csv',file_folder1=FLAGS.file_folder1,file_folder2=FLAGS.file_folder2,IWdataset_csv=FLAGS.IWdataset_csv,DESSdataset_csv=FLAGS.DESSdataset_csv,  **val_params)
+
+                test_pred = model.predict_generator(test_generator)
+                test_df["Pred"] = test_pred
+                fpr, tpr, thresholds = metrics.roc_curve(test_df["Label"], test_df["Pred"])
+                opt_ind = np.argmax(tpr-fpr)
+                opt_thresh = thresholds[int(opt_ind)]
+                s = model.predict_generator(validation_generator)
+                
+                pred_arr += (np.squeeze(s)>=opt_thresh)
+            else:
+                s = model.predict_generator(validation_generator)
+                
+                pred_arr += np.squeeze(s)
+
+
+        
+        #AUCS.append(roc_auc_score(df['Label'],pred_arr))
+        
+        #preds.extend(list(pred_arr))
+        
+            
+    pred_arr = pred_arr/42
+
+
+
+
+    # In[ ]:
+
+
+    df["Preds"] = pred_arr
+    if  FLAGS.vote:
+        df.to_csv(FLAGS.result_path+"OAI_results_vote.csv")
+    else:
+        df.to_csv(FLAGS.result_path+"OAI_results.csv")
+if __name__ == "__main__":
+    tf.app.run()
+