Diff of /MI-DESS_IWTSE/train.py [000000] .. [6a4082]

Switch to unified view

a b/MI-DESS_IWTSE/train.py
1
# ==============================================================================
2
# Copyright (C) 2023 Haresh Rengaraj Rajamohan, Tianyu Wang, Kevin Leung, 
3
# Gregory Chang, Kyunghyun Cho, Richard Kijowski & Cem M. Deniz 
4
#
5
# This file is part of OAI-MRI-TKR
6
#
7
# This program is free software: you can redistribute it and/or modify
8
# it under the terms of the GNU Affero General Public License as published
9
# by the Free Software Foundation, either version 3 of the License, or
10
# (at your option) any later version.
11
12
# This program is distributed in the hope that it will be useful,
13
# but WITHOUT ANY WARRANTY; without even the implied warranty of
14
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15
# GNU Affero General Public License for more details.
16
17
# You should have received a copy of the GNU Affero General Public License
18
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
19
# ==============================================================================
20
#!/usr/bin/env python3
21
import h5py
22
import os.path
23
import numpy as np
24
import pandas as pd
25
import math
26
import matplotlib
27
matplotlib.use('Agg')
28
29
import matplotlib.pyplot as plt
30
import tensorflow as tf
31
#from sklearn.model_selection import StratifiedKFold
32
#from ModelCB_onelayer import generate_model
33
from ModelDCB import generate_model
34
from DataGenerator import DataGenerator
35
36
#from keras.models import Sequential
37
#from keras.optimizers import SGD, Adam
38
#from keras.layers import Dropout, Dense, Conv3D, MaxPooling3D, GlobalAveragePooling3D, Activation, BatchNormalization,Flatten
39
from keras.callbacks import LearningRateScheduler, TensorBoard, EarlyStopping, ModelCheckpoint, Callback
40
from sklearn.metrics import roc_auc_score
41
42
43
tf.app.flags.DEFINE_boolean('batch_norm', True, 'Use BN or not')
44
tf.app.flags.DEFINE_float('lr', 0.0002, 'Initial learning rate.')
45
tf.app.flags.DEFINE_integer('filters_in_last', 128, 'Number of filters on the last layer')
46
tf.app.flags.DEFINE_string('file_path', '/gpfs/data/denizlab/Users/hrr288/Radiology_test/', 'Main Folder to Save outputs')
47
tf.app.flags.DEFINE_integer('val_fold', 1, 'Fold for cv')
48
tf.app.flags.DEFINE_string('file_folder1','/gpfs/data/denizlab/Datasets/OAI/SAG_IW_TSE/', 'Path to HDF5 radiographs of test set')
49
tf.app.flags.DEFINE_string('file_folder2','/gpfs/data/denizlab/Datasets/OAI/SAG_3D_DESS/', 'Path to HDF5 radiographs of test set')
50
tf.app.flags.DEFINE_string('IWdataset_csv','/gpfs/data/denizlab/Datasets/OAI/SAG_IW_TSE/HDF5_00_cohort_2_prime.csv', 'Path to HDF5_00_cohort_2_prime.csv')
51
52
tf.app.flags.DEFINE_string('DESSdataset_csv','/gpfs/data/denizlab/Datasets/OAI/SAG_3D_DESS/HDF5_00_SAG_3D_DESScohort_2_prime.csv', 'Path to HDF5_00_SAG_3D_DESScohort_2_prime.csv')
53
tf.app.flags.DEFINE_string('csv_path', '/gpfs/data/denizlab/Users/hrr288/Tianyu_dat/TestSets/', 'Folder with the fold splits')
54
55
56
FLAGS = tf.app.flags.FLAGS
57
58
59
class roc_callback(Callback):
60
    def __init__(self,index,val_fold):
61
        _params = {'dim1': (384,384,36),
62
          'dim2': (352,352,144),
63
          'batch_size': 4,
64
          'n_classes': 1,
65
          'n_channels': 1,
66
          'shuffle': False,
67
          'normalize' : True,
68
          'randomCrop' : False,
69
          'randomFlip' : False,
70
          'flipProbability' : -1,
71
             }
72
        
73
        self.x = DataGenerator(directory = FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(index)+'_train.csv',file_folder1=FLAGS.file_folder1,file_folder2=FLAGS.file_folder2,IWdataset_csv=FLAGS.IWdataset_csv,DESSdataset_csv=FLAGS.DESSdataset_csv,   **_params)
74
        self.x_val = DataGenerator(directory = FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(index)+'_val.csv',file_folder1=FLAGS.file_folder1,file_folder2=FLAGS.file_folder2,IWdataset_csv=FLAGS.IWdataset_csv,DESSdataset_csv=FLAGS.DESSdataset_csv,   **_params)
75
        self.y = pd.read_csv(FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(index)+'_train.csv').Label
76
        self.y_val = pd.read_csv(FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(index)+'_val.csv').Label
77
        self.auc = []
78
        self.val_auc = []
79
        self.losses = []
80
        self.val_losses = []
81
82
    def on_train_begin(self, logs={}):
83
        return
84
85
    def on_train_end(self, logs={}):
86
        return
87
88
    def on_epoch_begin(self, epoch, logs={}):
89
        return
90
91
    def on_epoch_end(self, epoch, logs={}):
92
        self.losses.append(logs.get('loss'))
93
        self.val_losses.append(logs.get('val_loss'))
94
        y_pred = self.model.predict_generator(self.x)
95
        y_true = self.y[:len(y_pred)]
96
        roc = roc_auc_score(y_true, y_pred)
97
98
        y_pred_val = self.model.predict_generator(self.x_val)
99
        y_true_val = self.y_val[:len(y_pred_val)]
100
        roc_val = roc_auc_score(y_true_val, y_pred_val)
101
        self.auc.append(roc)
102
        self.val_auc.append(roc_val)
103
        #print(len(y_true),len(y_true_val))
104
        print('\rroc-auc: %s - roc-auc_val: %s' % (str(round(roc,4)),str(round(roc_val,4))),end=100*' '+'\n')
105
        return
106
107
    def on_batch_begin(self, batch, logs={}):
108
        return
109
    def on_batch_end(self, batch, logs={}):
110
        return
111
112
113
'''
114
    Def: Code to plot loss curves
115
    Params: history = keras output from training
116
            loss_path = path to save curve
117
'''
118
def plot_loss_curves(history, loss_path): #, i):
119
    f = plt.figure()
120
    plt.plot(history.history['loss'])
121
    plt.plot(history.history['val_loss'])
122
    plt.title('model loss')
123
    plt.ylabel('loss')
124
    plt.xlabel('epoch')
125
    plt.legend(['train', 'validation'], loc='upper left')
126
    #plt.show()    
127
    #path = '/data/kl2596/curves/loss/' + loss_path + '.jpeg'
128
    f.savefig(loss_path)
129
130
131
'''
132
    Def: Code to plot accuracy curves
133
    Params: history = keras output from training
134
            acc_path = path to save curve
135
'''
136
def plot_accuracy_curves(history, acc_path): #, i):
137
    f = plt.figure()
138
    plt.plot(history.history['acc'])
139
    plt.plot(history.history['val_acc'])
140
    plt.title('model accuracy')
141
    plt.ylabel('accuracy')
142
    plt.xlabel('epoch')
143
    plt.legend(['train', 'validation'], loc='upper left')
144
    #plt.show() 
145
    #path = '/data/kl2596/curves/accuracy/' + acc_path + '.jpeg'
146
    f.savefig(acc_path)
147
148
def plot_auc_curves(auc_history, acc_path): #, i):
149
    f = plt.figure()
150
    plt.plot(auc_history.auc)
151
    plt.plot(auc_history.val_auc)
152
    plt.title('model AUC')
153
    plt.ylabel('auc')
154
    plt.xlabel('epoch')
155
    plt.legend(['train', 'validation'], loc='upper left')
156
    #plt.show() 
157
    #path = '/data/kl2596/curves/accuracy/' + acc_path + '.jpeg'
158
    f.savefig(acc_path)
159
160
161
def train_model(model, train_data, val_data, path,index,val_fold):
162
    #model.summary()
163
    
164
    # Early Stopping callback that can be found on Keras website
165
    #early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=15)
166
    
167
    # Create path to save weights with model checkpoint
168
    weights_path = path + 'weights-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}-{loss:.2f}-{acc:.2f}.hdf5'
169
    model_checkpoint = ModelCheckpoint(weights_path, monitor = 'val_loss', save_best_only = False, 
170
                                       verbose=0, period=1)
171
    
172
    # Save loss and accuracy curves using Tensorboard
173
    tensorboard_callback = TensorBoard(log_dir = path, 
174
                                       histogram_freq = 0, 
175
                                       write_graph = False, 
176
                                       write_grads = False, 
177
                                       write_images = False)
178
179
    auc_history = roc_callback(index,val_fold)    
180
    #callbacks_list = [model_checkpoint, tensorboard_callback, auc_history]
181
    #es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=100)
182
    callbacks_list = [model_checkpoint, tensorboard_callback, auc_history]
183
    history = model.fit_generator(generator = train_data, validation_data = val_data, epochs=10, 
184
                        #use_multiprocessing=True, workers=6, 
185
                        callbacks = callbacks_list)
186
    
187
    accuracy = auc_history.val_auc
188
    print('*****************************')
189
    print('best auc:',np.max(accuracy))
190
    print('average auc:',np.mean(accuracy))
191
    print('*****************************')
192
193
    accuracy = history.history['val_acc']
194
    print('*****************************')
195
    print('best accuracy:', np.max(accuracy))
196
    print('average accuracy:', np.mean(accuracy))
197
    print('*****************************')
198
     
199
    loss_path = path + 'loss_curve.jpeg'
200
    acc_path = path + 'acc_curve.jpeg'
201
    auc_path = path + 'auc_curve.jpeg'
202
    plot_loss_curves(history, loss_path)
203
    plot_accuracy_curves(history, acc_path)
204
    plot_auc_curves(auc_history, auc_path)
205
    #model.save_weights(weights_path)
206
   
207
    
208
'''
209
    Def: Code to run stratified cross validation to train my network
210
    Params: num_of_folds = number of folds to cross validate
211
            lr = learning rate
212
            dr = dropout rate
213
            filters_in_last = number of filters in last convolutional layer (we tested 64 and 128)
214
            batch_norm = True or False for batch norm in model
215
            data = MRI images
216
            labels = labels corresponding to MRI images
217
            file_path = path to save network weights, curves, and tensorboard callbacks
218
'''
219
def cross_validation(val_fold, lr, filters_in_last, file_path):
220
    
221
    
222
    model_path = file_path + 'TCBmodelv1_400_add_final_arch/Dnetv1/add_ch32/'
223
    if not os.path.exists(model_path):
224
        os.makedirs(model_path)
225
    
226
    num_of_folds = 6
227
    for i in range(num_of_folds):
228
        train_params = {'dim1': (384,384,36),
229
              'dim2': (352,352,144),
230
              'batch_size': 4,
231
              'n_classes': 1,
232
              'n_channels': 1,
233
              'shuffle': True,
234
              'normalize' : True,
235
              'randomCrop' : True,
236
              'randomFlip' : True,
237
              'flipProbability' : -1,
238
                 }
239
240
        val_params = {'dim1': (384,384,36),
241
              'dim2': (352,352,144),
242
              'batch_size': 4,
243
              'n_classes': 1,
244
              'n_channels': 1,
245
              'shuffle': False,
246
              'normalize' : True,
247
              'randomCrop' : False,
248
              'randomFlip' : False,
249
              'flipProbability' : -1,
250
                 }
251
        model = generate_model(learning_rate = 2 * 10 **(-4))
252
        
253
        model.summary()
254
        print(train_params)
255
        #print(train_index, test_index)
256
        print('Running Fold', i+1, '/', num_of_folds)   
257
        fold_path = model_path + 'Fold_' + str(val_fold) + '/CV_'+str(i+1)+'/'
258
        print(fold_path)
259
        
260
        if not os.path.exists(fold_path):
261
            os.makedirs(fold_path)    
262
        
263
        training_generator = DataGenerator(directory = FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(i+1)+'_train.csv',file_folder1=FLAGS.file_folder1,file_folder2=FLAGS.file_folder2,IWdataset_csv=FLAGS.IWdataset_csv,DESSdataset_csv=FLAGS.DESSdataset_csv,   **train_params)
264
        validation_generator = DataGenerator(directory = FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(i+1)+'_val.csv',file_folder1=FLAGS.file_folder1,file_folder2=FLAGS.file_folder2,IWdataset_csv=FLAGS.IWdataset_csv,DESSdataset_csv=FLAGS.DESSdataset_csv,   **val_params)
265
        
266
        train_model(model=model, 
267
                    train_data = training_generator,
268
                    val_data = validation_generator,
269
                    path = fold_path, index = i+1,val_fold=val_fold)
270
271
272
def main(argv=None):
273
    print('Begin training for fold ',FLAGS.val_fold)
274
    cross_validation(val_fold=FLAGS.val_fold, 
275
                     lr=FLAGS.lr, filters_in_last=FLAGS.filters_in_last,   
276
                     file_path = FLAGS.file_path)
277
278
if __name__ == "__main__":
279
    tf.app.run()