--- a
+++ b/DESS/Eval_OAI_DESS.py
@@ -0,0 +1,255 @@
+# ==============================================================================
+# Copyright (C) 2023 Haresh Rengaraj Rajamohan, Tianyu Wang, Kevin Leung, 
+# Gregory Chang, Kyunghyun Cho, Richard Kijowski & Cem M. Deniz 
+#
+# This file is part of OAI-MRI-TKR
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+# ==============================================================================
+import numpy as np
+import pandas as pd
+import h5py
+import nibabel as nib
+import keras
+from mpl_toolkits.mplot3d import Axes3D
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib.cm
+import matplotlib.colorbar
+import matplotlib.colors
+import pandas as pd
+import numpy as np
+
+from sklearn import metrics
+
+import os
+from keras.models import load_model
+
+from Augmentation import RandomCrop, CenterCrop, RandomFlip
+
+
+
+from sklearn.metrics import roc_auc_score,auc,roc_curve,average_precision_score
+
+import tensorflow as tf
+
+
+from DataGenerator import DataGenerator as DG
+
+class DataGenerator(keras.utils.Sequence):
+    'Generates data for Keras'
+    def __init__(self, directory, file_folder,batch_size=8, dim=(384,384,160), n_channels=1,
+                 n_classes=10, shuffle=True,normalize = True, randomCrop = True, randomFlip = True, 
+                 flipProbability = -1, cropDim = (384,384,160)):
+        'Initialization'
+        self.dim = dim
+        self.batch_size = batch_size
+        self.dataset = pd.read_csv(directory)
+        #self.list_IDs = list_IDs
+        self.list_IDs = pd.read_csv(directory)['h5Name']
+        self.n_channels = n_channels
+        self.n_classes = n_classes
+        self.shuffle = shuffle
+        self.on_epoch_end()
+        self.file_folder = file_folder+"00m/"
+        self.normalize = normalize
+        self.randomCrop = randomCrop
+        self.randomFlip = randomFlip
+        self.flipProbability = flipProbability
+        self.cropDim = cropDim
+    
+    def __len__(self):
+        'Denotes the number of batches per epoch'
+        return int(np.floor(len(self.list_IDs) / self.batch_size))
+    
+    def __getitem__(self, index):
+        'Generate one batch of data'
+        # Generate indexes of the batch
+        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
+        # Find list of IDs
+        list_IDs_temp = [self.list_IDs[k] for k in indexes]
+        # Generate data
+        X, y = self.__data_generation(list_IDs_temp)
+        return X, y
+    
+    def on_epoch_end(self):
+        'Updates indexes after each epoch'
+        self.indexes = np.arange(len(self.list_IDs))
+        if self.shuffle == True:
+            np.random.shuffle(self.indexes)
+    
+    def __data_generation(self, list_IDs_temp):
+        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
+        # Initialization
+        X = np.empty((self.batch_size, *self.dim, self.n_channels))
+        y = np.empty((self.batch_size), dtype=int)
+        # Generate data
+        for i, ID in enumerate(list_IDs_temp):
+            # Store sample
+            
+            pre_image = h5py.File(self.file_folder + ID, "r")['data/'].value.astype('float64')
+            #pre_image = padding_image(data = image,shape = [448,448,48])
+            #pre_image = np.zeros(image.shape)
+            #pre_image = image
+            if pre_image.shape[2]<144:
+                pre_image = padding_image2(data = pre_image)
+
+            # normalize
+            if self.normalize:
+                pre_image = normalize_MRIs(pre_image)
+            # Augmentation
+            if self.randomFlip:
+                pre_image = RandomFlip(image=pre_image,p=0.5).horizontal_flip(p=self.flipProbability)
+            if self.randomCrop:
+                pre_image = RandomCrop(pre_image).crop_along_hieght_width_depth(self.cropDim)
+            else:
+                pre_image = CenterCrop(image=pre_image).crop(size = self.cropDim)
+            #print(ID,pre_image.shape)
+            X[i,:,:,:,0] = pre_image
+            # Store class
+            y[i] = self.dataset[self.dataset.h5Name == ID].Label
+        
+        return X, y
+    
+    def getXvalue(self,index):
+        return self.__getitem__(index)
+    
+def padding_image2(data):
+    l,w,h = data.shape
+    images = np.zeros((l,w,144))
+    zstart = int(np.ceil((144-data.shape[2])/2))
+    images[:,:,zstart:zstart + h] = data
+    return images 
+def padding_image(data, shape):
+    images = np.zeros(shape)
+    candi = data
+    candi_shape = data.shape
+    
+    xstart = int(np.ceil((448-candi_shape[0])/2))
+    ystart = int(np.ceil((448-candi_shape[1])/2))
+    zstart = int(np.ceil((48-candi_shape[2])/2))
+    
+    images[xstart:xstart+candi_shape[0],ystart:ystart + candi_shape[1],zstart:zstart+candi_shape[2]] = candi
+    return images
+
+def normalize_MRIs(image):
+    mean = np.mean(image)
+    std = np.std(image)
+    image -= mean
+    #image -= 95.09
+    image /= std
+    #image /= 86.38
+    return image
+
+tf.app.flags.DEFINE_string('model_path', '/gpfs/data/denizlab/Users/hrr288/Radiology_test/SAG3D_lr24_18_stride221_kernel777773/', 'Folder with the models')
+tf.app.flags.DEFINE_string('val_csv_path', '/gpfs/data/denizlab/Users/hrr288/Tianyu_dat/TestSets/', 'Folder with the fold splits')
+
+tf.app.flags.DEFINE_string('test_csv_path', '/gpfs/data/denizlab/Users/hrr288/data/OAI_SAG_DESS_test.csv', 'Folder with the  test csv')
+tf.app.flags.DEFINE_string('result_path', './', 'Folder to save output csv with preds')
+tf.app.flags.DEFINE_bool('vote', False, 'Choice to generate binary predictions for each model to compute final sensitivity/specificity')
+tf.app.flags.DEFINE_string('file_folder','/gpfs/data/denizlab/Datasets/OAI/SAG_3D_DESS/', 'Path to DESS HDF5 radiographs of test set')
+tf.app.flags.DEFINE_string('train_file_folder','/gpfs/data/denizlab/Datasets/OAI/SAG_3D_DESS/', 'Path to DESS HDF5 radiographs of OAI train/val set')
+
+
+FLAGS = tf.app.flags.FLAGS
+def main(argv=None):
+
+
+
+
+    val_params = {'dim': (352,352,144),
+              'batch_size': 1,
+              'n_classes': 2,
+              'n_channels': 1,
+              'shuffle': False,
+              'normalize' : True,
+              'randomCrop' : False,
+              'randomFlip' : False,
+              'flipProbability' : -1,
+              'cropDim' : (352,352,144)}
+
+
+
+    validation_generator = DataGenerator(directory = FLAGS.test_csv_path,file_folder=FLAGS.file_folder,  **val_params)
+    df = pd.read_csv(FLAGS.test_csv_path,index_col=0)
+
+    base_path = FLAGS.model_path
+
+    models= {'fold_1':[],'fold_2':[],'fold_3':[],'fold_4':[],'fold_5':[],'fold_6':[],'fold_7':[]}
+    for fold in np.arange(1,8):
+        tmp_mod_list = []
+        for cv in np.arange(1,7):
+            dir_1 = 'Fold_'+str(fold)+'/CV_'+str(cv)+'/'
+            files_avai =  os.listdir(base_path+dir_1)
+            cands = []
+            cands_score = []
+            for fs in files_avai:
+                if 'weights' not in fs:
+                    continue
+                else:
+                    
+                    cands_score.append(float(fs.split('-')[2]))
+                    cands.append(dir_1+fs)
+            ind_c = int(np.argmin(cands_score))
+            
+            tmp_mod_list.append(cands[ind_c])
+        models['fold_'+str(fold)]=tmp_mod_list
+    AUCS = []
+    preds = []
+    dfs = []
+    pred_arr = np.zeros(df.shape[0])
+    for i in np.arange(1,8):
+        
+        for j in np.arange(1,7):
+            model = load_model(base_path+'/'+models['fold_'+str(i)][j-1])
+            if FLAGS.vote:
+                test_df = pd.read_csv(FLAGS.val_csv_path+'Fold_'+str(i)+'/CV_'+str(j)+'_val.csv')
+                test_generator = DG(directory = FLAGS.val_csv_path+'Fold_'+str(i)+'/CV_'+str(j)+'_val.csv',file_folder=FLAGS.train_file_folder,  **val_params)
+
+                test_pred = model.predict_generator(test_generator)
+                test_df["Pred"] = test_pred
+                fpr, tpr, thresholds = metrics.roc_curve(test_df["Label"], test_df["Pred"])
+                opt_ind = np.argmax(tpr-fpr)
+                opt_thresh = thresholds[int(opt_ind)]
+                s = model.predict_generator(validation_generator)
+                
+                pred_arr += (np.squeeze(s)>=opt_thresh)
+            else:
+                s = model.predict_generator(validation_generator)
+                
+                pred_arr += np.squeeze(s)
+
+
+        
+        #AUCS.append(roc_auc_score(df['Label'],pred_arr))
+        
+        #preds.extend(list(pred_arr))
+        
+            
+    pred_arr = pred_arr/42
+
+
+
+
+    # In[ ]:
+
+
+    df["Preds"] = pred_arr
+    if  FLAGS.vote:
+        df.to_csv(FLAGS.result_path+"OAI_results_vote.csv")
+    else:
+        df.to_csv(FLAGS.result_path+"OAI_results.csv")
+if __name__ == "__main__":
+    tf.app.run()
+