Diff of /IW-TSE/evaluate.py [000000] .. [6a4082]

Switch to unified view

a b/IW-TSE/evaluate.py
1
# ==============================================================================
2
# Copyright (C) 2023 Haresh Rengaraj Rajamohan, Tianyu Wang, Kevin Leung, 
3
# Gregory Chang, Kyunghyun Cho, Richard Kijowski & Cem M. Deniz 
4
#
5
# This file is part of OAI-MRI-TKR
6
#
7
# This program is free software: you can redistribute it and/or modify
8
# it under the terms of the GNU Affero General Public License as published
9
# by the Free Software Foundation, either version 3 of the License, or
10
# (at your option) any later version.
11
12
# This program is distributed in the hope that it will be useful,
13
# but WITHOUT ANY WARRANTY; without even the implied warranty of
14
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15
# GNU Affero General Public License for more details.
16
17
# You should have received a copy of the GNU Affero General Public License
18
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
19
# ==============================================================================
20
import numpy as np
21
import pandas as pd
22
import h5py
23
import nibabel as nib
24
import keras
25
from mpl_toolkits.mplot3d import Axes3D
26
import numpy as np
27
import matplotlib.pyplot as plt
28
import matplotlib.cm
29
import matplotlib.colorbar
30
import matplotlib.colors
31
import pandas as pd
32
import numpy as np
33
from sklearn import metrics
34
35
36
import os
37
import tensorflow as tf
38
39
40
from keras.models import load_model
41
42
43
44
45
from sklearn.metrics import roc_auc_score,auc,roc_curve,average_precision_score
46
47
48
49
50
from Augmentation import RandomCrop, CenterCrop, RandomFlip
51
52
from DataGenerator import DataGenerator
53
tf.app.flags.DEFINE_string('model_path', '/gpfs/data/denizlab/Users/hrr288/Radiology_test/Tnetres_Best/lr24ch32kerne773773_strde222_new_arch/', 'Folder with the models')
54
tf.app.flags.DEFINE_string('csv_path', '/gpfs/data/denizlab/Users/hrr288/TSE_dataset/', 'Folder with the fold splits')
55
tf.app.flags.DEFINE_string('result_path', './', 'Folder to save output csv with preds')
56
tf.app.flags.DEFINE_string('file_folder','/gpfs/data/denizlab/Datasets/OAI/SAG_IW_TSE/', 'Path to IW TSE HDF5 radiographs')
57
58
59
60
FLAGS = tf.app.flags.FLAGS
61
62
def main(argv=None):
63
64
65
    base_path = FLAGS.model_path
66
    csv_path = FLAGS.csv_path
67
68
    # Choosing the model in each folder with lowest val loss
69
70
    models= {'fold_1':[],'fold_2':[],'fold_3':[],'fold_4':[],'fold_5':[],'fold_6':[],'fold_7':[]}
71
    for fold in np.arange(1,8):
72
        tmp_mod_list = []
73
        for cv in np.arange(1,7):
74
            dir_1 = 'Fold_'+str(fold)+'/CV_'+str(cv)+'/'
75
            files_avai =  os.listdir(base_path+dir_1)
76
            cands = []
77
            cands_score = []
78
            for fs in files_avai:
79
                if 'weights' not in fs:
80
                    continue
81
                else:
82
                    
83
                    cands_score.append(float(fs.split('-')[2]))
84
                    cands.append(dir_1+fs)
85
            ind_c = int(np.argmin(cands_score))
86
            
87
            tmp_mod_list.append(cands[ind_c])
88
        models['fold_'+str(fold)]=tmp_mod_list
89
90
    val_params = {'dim': (384,384,36),
91
              'batch_size': 1,
92
              'n_classes': 2,
93
              'n_channels': 1,
94
              'shuffle': False,
95
              'normalize' : True,
96
              'randomCrop' : False,
97
              'randomFlip' : False,
98
              'flipProbability' : -1}
99
    
100
101
102
    dfs = []
103
    
104
    for i in np.arange(1,8):
105
        print("Fold_"+str(i))
106
        validation_generator = DataGenerator(directory = csv_path+'Fold_'+str(i)+'/Fold_'+str(i)+'_test.csv', file_folder=FLAGS.file_folder, **val_params)
107
        df = pd.read_csv(csv_path+'Fold_'+str(i)+'/Fold_'+str(i)+'_test.csv')
108
        pred_arr = np.zeros(df.shape[0])
109
        
110
        for j in np.arange(1,7):
111
            model = load_model(base_path+'/'+models['fold_'+str(i)][j-1])
112
            
113
            s = model.predict_generator(validation_generator)
114
            
115
            pred_arr += np.squeeze(s)
116
        pred_arr = pred_arr/6
117
        df["Preds"] = pred_arr
118
        dfs.append(df)
119
120
        
121
        
122
        
123
            
124
        
125
126
    full_df = pd.concat(dfs)
127
    full_df.to_csv(FLAGS.result_path+"OAI_DESS_results.csv")
128
if __name__ == "__main__":
129
    tf.app.run()
130
131