Diff of /IW-TSE/train.py [000000] .. [6a4082]

Switch to unified view

a b/IW-TSE/train.py
1
# ==============================================================================
2
# Copyright (C) 2023 Haresh Rengaraj Rajamohan, Tianyu Wang, Kevin Leung, 
3
# Gregory Chang, Kyunghyun Cho, Richard Kijowski & Cem M. Deniz 
4
#
5
# This file is part of OAI-MRI-TKR
6
#
7
# This program is free software: you can redistribute it and/or modify
8
# it under the terms of the GNU Affero General Public License as published
9
# by the Free Software Foundation, either version 3 of the License, or
10
# (at your option) any later version.
11
12
# This program is distributed in the hope that it will be useful,
13
# but WITHOUT ANY WARRANTY; without even the implied warranty of
14
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15
# GNU Affero General Public License for more details.
16
17
# You should have received a copy of the GNU Affero General Public License
18
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
19
# ==============================================================================
20
#!/usr/bin/env python3
21
import h5py
22
import os.path
23
import numpy as np
24
import pandas as pd
25
import math
26
import matplotlib
27
matplotlib.use('Agg')
28
29
import matplotlib.pyplot as plt
30
import tensorflow as tf
31
#from sklearn.model_selection import StratifiedKFold
32
from ModelResnet3D import generate_model
33
from DataGenerator import DataGenerator
34
35
#from keras.models import Sequential
36
#from keras.optimizers import SGD, Adam
37
#from keras.layers import Dropout, Dense, Conv3D, MaxPooling3D, GlobalAveragePooling3D, Activation, BatchNormalization,Flatten
38
from keras.callbacks import LearningRateScheduler, TensorBoard, EarlyStopping, ModelCheckpoint, Callback
39
from sklearn.metrics import roc_auc_score
40
41
tf.app.flags.DEFINE_boolean('batch_norm', True, 'Use BN or not')
42
tf.app.flags.DEFINE_float('lr', 0.0001, 'Initial learning rate.')
43
tf.app.flags.DEFINE_integer('filters_in_last', 128, 'Number of filters on the last layer')
44
tf.app.flags.DEFINE_float('dr',0.0, 'Dropout rate when training.')
45
tf.app.flags.DEFINE_string('input_file','/gpfs/data/denizlab/Datasets/OAI/SAG_IW_TSE/dataBefore201807/data_SAG_IW_TSE.hdf5', 'File to read data')
46
tf.app.flags.DEFINE_string('file_path', '/gpfs/data/denizlab/Users/hrr288/Radiology_test/', 'Main Folder to Save outputs')
47
tf.app.flags.DEFINE_integer('val_fold', 1, 'Fold for cv')
48
tf.app.flags.DEFINE_string('file_folder','/gpfs/data/denizlab/Datasets/OAI/SAG_IW_TSE/', 'Path to HDF5 radiographs')
49
tf.app.flags.DEFINE_string('csv_path', '/gpfs/data/denizlab/Users/hrr288/TSE_dataset/', 'Folder with the fold splits')
50
51
FLAGS = tf.app.flags.FLAGS
52
53
class auc_Histories(Callback):
54
    def on_train_begin(self, logs={}):
55
        self.aucs = []
56
        self.losses = []
57
 
58
    def on_train_end(self, logs={}):
59
        return
60
 
61
    def on_epoch_begin(self, epoch, logs={}):
62
        return
63
 
64
    def on_epoch_end(self, epoch, logs={}):
65
        self.losses.append(logs.get('loss'))
66
        y_pred = self.model.predict(self.model.validation_data)
67
        self.aucs.append(roc_auc_score(self.model.validation_data[1], y_pred))
68
        return
69
 
70
    def on_batch_begin(self, batch, logs={}):
71
        return
72
 
73
    def on_batch_end(self, batch, logs={}):
74
        return
75
76
77
78
class roc_callback(Callback):
79
    def __init__(self,index,val_fold):
80
        _params = {'dim': (384,384,36),
81
              'batch_size': 4,
82
              'n_classes': 2,
83
              'n_channels': 1,
84
              'shuffle': False,
85
              'normalize' : True,
86
              'randomCrop' : False,
87
              'randomFlip' : False,
88
              'flipProbability' : -1}
89
        self.x = DataGenerator(directory = FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(index)+'_train.csv',file_folder=FLAGS.file_folder,  **_params)
90
        self.x_val = DataGenerator(directory = FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(index)+'_val.csv',file_folder=FLAGS.file_folder,  **_params)
91
        self.y = pd.read_csv(FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(index)+'_train.csv').Label
92
        self.y_val = pd.read_csv(FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(index)+'_val.csv').Label
93
        self.auc = []
94
        self.val_auc = []
95
        self.losses = []
96
        self.val_losses = []
97
    
98
    def on_train_begin(self, logs={}):
99
        return
100
 
101
    def on_train_end(self, logs={}):
102
        return
103
 
104
    def on_epoch_begin(self, epoch, logs={}):
105
        return
106
    
107
    def on_epoch_end(self, epoch, logs={}):        
108
        self.losses.append(logs.get('loss'))
109
        self.val_losses.append(logs.get('val_loss'))
110
        y_pred = self.model.predict_generator(self.x)
111
        y_true = self.y[:len(y_pred)]
112
        roc = roc_auc_score(y_true, y_pred)      
113
        
114
        y_pred_val = self.model.predict_generator(self.x_val)
115
        y_true_val = self.y_val[:len(y_pred_val)]
116
        roc_val = roc_auc_score(y_true_val, y_pred_val)      
117
        self.auc.append(roc)
118
        self.val_auc.append(roc_val)
119
        #print(len(y_true),len(y_true_val))
120
        print('\rroc-auc: %s - roc-auc_val: %s' % (str(round(roc,4)),str(round(roc_val,4))),end=100*' '+'\n')
121
        return
122
 
123
    def on_batch_begin(self, batch, logs={}):
124
        return
125
 
126
    def on_batch_end(self, batch, logs={}):
127
        return   
128
    
129
'''
130
    Def: Code to plot loss curves
131
    Params: history = keras output from training
132
            loss_path = path to save curve
133
'''
134
def plot_loss_curves(history, loss_path): #, i):
135
    f = plt.figure()
136
    plt.plot(history.history['loss'])
137
    plt.plot(history.history['val_loss'])
138
    plt.title('model loss')
139
    plt.ylabel('loss')
140
    plt.xlabel('epoch')
141
    plt.legend(['train', 'validation'], loc='upper left')
142
    #plt.show()    
143
    #path = '/data/kl2596/curves/loss/' + loss_path + '.jpeg'
144
    f.savefig(loss_path)
145
146
147
'''
148
    Def: Code to plot accuracy curves
149
    Params: history = keras output from training
150
            acc_path = path to save curve
151
'''
152
def plot_accuracy_curves(history, acc_path): #, i):
153
    f = plt.figure()
154
    plt.plot(history.history['acc'])
155
    plt.plot(history.history['val_acc'])
156
    plt.title('model accuracy')
157
    plt.ylabel('acc')
158
    plt.xlabel('epoch')
159
    plt.legend(['train', 'validation'], loc='upper left')
160
    #plt.show() 
161
    #path = '/data/kl2596/curves/accuracy/' + acc_path + '.jpeg'
162
    f.savefig(acc_path)
163
164
165
def plot_auc_curves(auc_history, acc_path): #, i):
166
    f = plt.figure()
167
    plt.plot(auc_history.auc)
168
    plt.plot(auc_history.val_auc)
169
    plt.title('model AUC')
170
    plt.ylabel('auc')
171
    plt.xlabel('epoch')
172
    plt.legend(['train', 'validation'], loc='upper left')
173
    #plt.show() 
174
    #path = '/data/kl2596/curves/accuracy/' + acc_path + '.jpeg'
175
    f.savefig(acc_path)
176
177
def train_model(model, train_data, val_data, path, index,val_fold):
178
    #model.summary()
179
    
180
    # Early Stopping callback that can be found on Keras website
181
    
182
    # Create path to save weights with model checkpoint
183
    weights_path = path + 'weights-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}-{loss:.2f}-{acc:.2f}.hdf5'
184
    model_checkpoint = ModelCheckpoint(weights_path, monitor = 'val_loss', save_best_only = True, 
185
                                       verbose=0, period=1)
186
    
187
    # Save loss and accuracy curves using Tensorboard
188
    tensorboard_callback = TensorBoard(log_dir = path, 
189
                                       histogram_freq = 0, 
190
                                       write_graph = False, 
191
                                       write_grads = False, 
192
                                       write_images = False)
193
    
194
    auc_history = roc_callback(index,val_fold)
195
    #es = EarlyStopping(monitor='val_auc', mode='max', verbose=1, patience=50)
196
    callbacks_list = [model_checkpoint, tensorboard_callback, auc_history]
197
    #es = EarlyStopping(monitor='val_auc', mode='max', verbose=1, patience=150)
198
    history = model.fit_generator(generator = train_data, validation_data = val_data, epochs=10, 
199
                        #use_multiprocessing=True, workers=6, 
200
                        callbacks = callbacks_list)
201
    
202
    accuracy = auc_history.val_auc
203
    print('*****************************')
204
    print('best auc:',np.max(accuracy))
205
    print('average auc:',np.mean(accuracy))
206
    print('*****************************') 
207
208
    accuracy = history.history['val_acc']
209
    print('*****************************')
210
    print('best acc:',np.max(accuracy))
211
    print('average acc:',np.mean(accuracy))
212
    print('*****************************')
213
    loss_path = path + 'loss_curve.jpeg'
214
    acc_path = path + 'acc_curve.jpeg'
215
    auc_path = path + 'auc_curve.jpeg'
216
    plot_loss_curves(history, loss_path)
217
    plot_accuracy_curves(history, acc_path)
218
    plot_auc_curves(auc_history, auc_path)
219
    #model.save_weights(weights_path)
220
   
221
    
222
'''
223
    Def: Code to run stratified cross validation to train my network
224
    Params: num_of_folds = number of folds to cross validate
225
            lr = learning rate
226
            dr = dropout rate
227
            filters_in_last = number of filters in last convolutional layer (we tested 64 and 128)
228
            batch_norm = True or False for batch norm in model
229
            data = MRI images
230
            labels = labels corresponding to MRI images
231
            file_path = path to save network weights, curves, and tensorboard callbacks
232
'''
233
def cross_validation(val_fold, lr, dr, filters_in_last, file_path):
234
    train_params = {'dim': (384,384,36),
235
          'batch_size': 4,
236
          'n_classes': 2,
237
          'n_channels': 1,
238
          'shuffle': True,
239
          'normalize' : True,
240
          'randomCrop' : True,
241
          'randomFlip' : True,
242
          'flipProbability' : -1}
243
    
244
    val_params = {'dim': (384,384,36),
245
          'batch_size': 4,
246
          'n_classes': 2,
247
          'n_channels': 1,
248
          'shuffle': False,
249
          'normalize' : True,
250
          'randomCrop' : False,
251
          'randomFlip' : False,
252
          'flipProbability' : -1} 
253
    
254
    model_path = file_path + 'Tnetres_Best/lr24ch32kerne773773_strde222_new_arch/'
255
    if not os.path.exists(model_path):
256
        os.makedirs(model_path)
257
            
258
259
    num_of_folds = 6
260
    for i in range(num_of_folds):
261
        model = generate_model(learning_rate = 2 * 10 **(-4))
262
        model.summary()
263
        print(train_params)
264
        #print(train_index, test_index)
265
        print('Running Fold', i+1, '/', num_of_folds)   
266
        fold_path = model_path + 'Fold_' + str(val_fold) + '/CV_'+str(i+1)+'/'
267
        print(fold_path)
268
        
269
        if not os.path.exists(fold_path):
270
            os.makedirs(fold_path)    
271
        
272
        training_generator = DataGenerator(directory = FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(i+1)+'_train.csv',file_folder=FLAGS.file_folder,  **train_params)
273
        validation_generator = DataGenerator(directory = FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(i+1)+'_val.csv',file_folder=FLAGS.file_folder,  **val_params)
274
        
275
        train_model(model=model, 
276
                    train_data = training_generator,
277
                    val_data = validation_generator,
278
                    path = fold_path, index = i+1,val_fold=val_fold)
279
280
281
282
def main(argv=None):
283
    print('Begin training for fold ',FLAGS.val_fold)
284
    cross_validation(val_fold=FLAGS.val_fold, 
285
                     lr=FLAGS.lr, dr=FLAGS.dr, filters_in_last=FLAGS.filters_in_last,   
286
                     file_path = FLAGS.file_path)
287
288
if __name__ == "__main__":
289
    tf.app.run()
290
291