Diff of /T1-TSE/evaluate.py [000000] .. [6a4082]

Switch to unified view

a b/T1-TSE/evaluate.py
1
# ==============================================================================
2
# Copyright (C) 2023 Haresh Rengaraj Rajamohan, Tianyu Wang, Kevin Leung, 
3
# Gregory Chang, Kyunghyun Cho, Richard Kijowski & Cem M. Deniz 
4
#
5
# This file is part of OAI-MRI-TKR
6
#
7
# This program is free software: you can redistribute it and/or modify
8
# it under the terms of the GNU Affero General Public License as published
9
# by the Free Software Foundation, either version 3 of the License, or
10
# (at your option) any later version.
11
12
# This program is distributed in the hope that it will be useful,
13
# but WITHOUT ANY WARRANTY; without even the implied warranty of
14
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15
# GNU Affero General Public License for more details.
16
17
# You should have received a copy of the GNU Affero General Public License
18
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
19
# ==============================================================================
20
import numpy as np
21
import pandas as pd
22
import h5py
23
import nibabel as nib
24
import keras
25
from mpl_toolkits.mplot3d import Axes3D
26
import numpy as np
27
import matplotlib.pyplot as plt
28
import matplotlib.cm
29
import matplotlib.colorbar
30
import matplotlib.colors
31
import pandas as pd
32
import numpy as np
33
from sklearn import metrics
34
35
36
import os
37
import tensorflow as tf
38
39
40
from keras.models import load_model
41
42
43
44
45
from sklearn.metrics import roc_auc_score,auc,roc_curve,average_precision_score
46
47
48
49
50
from Augmentation import RandomCrop, CenterCrop, RandomFlip
51
52
from DataGenerator import DataGenerator
53
tf.app.flags.DEFINE_string('model_path', '/gpfs/data/denizlab/Users/hrr288/Radiology_test/COR_IW_TSE/', 'Folder with the model')
54
tf.app.flags.DEFINE_string('csv_path', '/gpfs/data/denizlab/Users/hrr288/Tianyu_dat/COR_TSE/', 'Folder with the fold splits')
55
tf.app.flags.DEFINE_string('result_path', './', 'Folder to save output csv with preds')
56
tf.app.flags.DEFINE_string('file_folder','/gpfs/data/denizlab/Datasets/OAI/COR_IW_TSE/', 'Path to HDF5 radiographs of test set')
57
58
59
FLAGS = tf.app.flags.FLAGS
60
61
def main(argv=None):
62
63
64
    base_path = FLAGS.model_path
65
    csv_path = FLAGS.csv_path
66
67
    # Choosing the model in each folder with lowest val loss
68
69
    models= {'fold_1':[],'fold_2':[],'fold_3':[],'fold_4':[],'fold_5':[],'fold_6':[],'fold_7':[]}
70
    for fold in np.arange(1,8):
71
        tmp_mod_list = []
72
        for cv in np.arange(1,7):
73
            dir_1 = 'Fold_'+str(fold)+'/CV_'+str(cv)+'/'
74
            files_avai =  os.listdir(base_path+dir_1)
75
            cands = []
76
            cands_score = []
77
            for fs in files_avai:
78
                if 'weights' not in fs:
79
                    continue
80
                else:
81
                    
82
                    cands_score.append(float(fs.split('-')[2]))
83
                    cands.append(dir_1+fs)
84
            ind_c = int(np.argmin(cands_score))
85
            
86
            tmp_mod_list.append(cands[ind_c])
87
        models['fold_'+str(fold)]=tmp_mod_list
88
89
    val_params = {'dim': (352,352,35),
90
              'batch_size': 1,
91
              'n_classes': 2,
92
              'n_channels': 1,
93
              'shuffle': False,
94
              'normalize' : True,
95
              'randomCrop' : False,
96
              'randomFlip' : False,
97
              'flipProbability' : -1,
98
              'cropDim' : (352,352,35)}
99
100
    
101
102
103
    dfs = []
104
    
105
    for i in np.arange(1,8):
106
        print("Fold_"+str(i))
107
        validation_generator = DataGenerator(directory = csv_path+'Fold_'+str(i)+'/Fold_'+str(i)+'_test.csv',file_folder=FLAGS.file_folder,  **val_params)
108
        df = pd.read_csv(csv_path+'Fold_'+str(i)+'/Fold_'+str(i)+'_test.csv')
109
        pred_arr = np.zeros(df.shape[0])
110
        
111
        for j in np.arange(1,7):
112
            model = load_model(base_path+'/'+models['fold_'+str(i)][j-1])
113
            
114
            s = model.predict_generator(validation_generator)
115
            
116
            pred_arr += np.squeeze(s)
117
        pred_arr = pred_arr/6
118
        df["Preds"] = pred_arr
119
        dfs.append(df)
120
121
        
122
        
123
        
124
            
125
        
126
127
    full_df = pd.concat(dfs)
128
    full_df.to_csv(FLAGS.result_path+"OAI_T1TSE_results.csv")
129
if __name__ == "__main__":
130
    tf.app.run()
131
132