Diff of /T1-TSE/train.py [000000] .. [6a4082]

Switch to unified view

a b/T1-TSE/train.py
1
# ==============================================================================
2
# Copyright (C) 2023 Haresh Rengaraj Rajamohan, Tianyu Wang, Kevin Leung, 
3
# Gregory Chang, Kyunghyun Cho, Richard Kijowski & Cem M. Deniz 
4
#
5
# This file is part of OAI-MRI-TKR
6
#
7
# This program is free software: you can redistribute it and/or modify
8
# it under the terms of the GNU Affero General Public License as published
9
# by the Free Software Foundation, either version 3 of the License, or
10
# (at your option) any later version.
11
12
# This program is distributed in the hope that it will be useful,
13
# but WITHOUT ANY WARRANTY; without even the implied warranty of
14
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15
# GNU Affero General Public License for more details.
16
17
# You should have received a copy of the GNU Affero General Public License
18
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
19
# ==============================================================================
20
#!/usr/bin/env python3
21
import h5py
22
import os.path
23
import numpy as np
24
import pandas as pd
25
import math
26
import matplotlib
27
matplotlib.use('Agg')
28
29
import matplotlib.pyplot as plt
30
import tensorflow as tf
31
#from sklearn.model_selection import StratifiedKFold
32
from ModelResnet3D import generate_model
33
from DataGenerator import DataGenerator
34
35
#from keras.models import Sequential
36
#from keras.optimizers import SGD, Adam
37
#from keras.layers import Dropout, Dense, Conv3D, MaxPooling3D, GlobalAveragePooling3D, Activation, BatchNormalization,Flatten
38
from keras.callbacks import LearningRateScheduler, TensorBoard, EarlyStopping, ModelCheckpoint, Callback
39
from sklearn.metrics import roc_auc_score
40
41
42
tf.app.flags.DEFINE_boolean('batch_norm', True, 'Use BN or not')
43
tf.app.flags.DEFINE_float('lr', 0.0001, 'Initial learning rate.')
44
tf.app.flags.DEFINE_integer('filters_in_last', 128, 'Number of filters on the last layer')
45
tf.app.flags.DEFINE_string('file_path', '/gpfs/data/denizlab/Users/hrr288/Radiology_test/', 'Main Folder to Save outputs')
46
tf.app.flags.DEFINE_integer('val_fold', 1, 'Fold fo cross-validation')
47
tf.app.flags.DEFINE_string('file_folder','/gpfs/data/denizlab/Datasets/OAI/COR_IW_TSE/', 'Path to HDF5 radiographs of test set')
48
tf.app.flags.DEFINE_string('csv_path', '/gpfs/data/denizlab/Users/hrr288/Tianyu_dat/COR_TSE/', 'Folder with the fold splits')
49
50
FLAGS = tf.app.flags.FLAGS
51
52
53
class roc_callback(Callback):
54
    def __init__(self,index,val_fold):
55
        _params = {'dim': (352,352,35),
56
              'batch_size': 4,
57
              'n_classes': 2,
58
              'n_channels': 1,
59
              'shuffle': False,
60
              'normalize' : True,
61
              'randomCrop' : False,
62
              'randomFlip' : False,
63
              'flipProbability' : -1,
64
              'cropDim' : (352,352,35)}
65
        self.x = DataGenerator(directory = FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(index)+'_train.csv', file_folder=FLAGS.file_folder, **_params)
66
        self.x_val = DataGenerator(directory = FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(index)+'_val.csv',file_folder=FLAGS.file_folder,  **_params)
67
        self.y = pd.read_csv(FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(index)+'_train.csv').Label
68
        self.y_val = pd.read_csv(FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(index)+'_val.csv').Label
69
        self.auc = []
70
        self.val_auc = []
71
        self.losses = []
72
        self.val_losses = []
73
74
    def on_train_begin(self, logs={}):
75
        return
76
77
    def on_train_end(self, logs={}):
78
        return
79
80
    def on_epoch_begin(self, epoch, logs={}):
81
        return
82
83
    def on_epoch_end(self, epoch, logs={}):
84
        self.losses.append(logs.get('loss'))
85
        self.val_losses.append(logs.get('val_loss'))
86
        y_pred = self.model.predict_generator(self.x)
87
        y_true = self.y[:len(y_pred)]
88
        roc = roc_auc_score(y_true, y_pred)
89
90
        y_pred_val = self.model.predict_generator(self.x_val)
91
        y_true_val = self.y_val[:len(y_pred_val)]
92
        roc_val = roc_auc_score(y_true_val, y_pred_val)
93
        self.auc.append(roc)
94
        self.val_auc.append(roc_val)
95
        #print(len(y_true),len(y_true_val))
96
        print('\rroc-auc: %s - roc-auc_val: %s' % (str(round(roc,4)),str(round(roc_val,4))),end=100*' '+'\n')
97
        return
98
99
    def on_batch_begin(self, batch, logs={}):
100
        return
101
    def on_batch_end(self, batch, logs={}):
102
        return
103
104
105
'''
106
    Def: Code to plot loss curves
107
    Params: history = keras output from training
108
            loss_path = path to save curve
109
'''
110
def plot_loss_curves(history, loss_path): #, i):
111
    f = plt.figure()
112
    plt.plot(history.history['loss'])
113
    plt.plot(history.history['val_loss'])
114
    plt.title('model loss')
115
    plt.ylabel('loss')
116
    plt.xlabel('epoch')
117
    plt.legend(['train', 'validation'], loc='upper left')
118
    #plt.show()    
119
    #path = '/data/kl2596/curves/loss/' + loss_path + '.jpeg'
120
    f.savefig(loss_path)
121
122
123
'''
124
    Def: Code to plot accuracy curves
125
    Params: history = keras output from training
126
            acc_path = path to save curve
127
'''
128
def plot_accuracy_curves(history, acc_path): #, i):
129
    f = plt.figure()
130
    plt.plot(history.history['acc'])
131
    plt.plot(history.history['val_acc'])
132
    plt.title('model accuracy')
133
    plt.ylabel('accuracy')
134
    plt.xlabel('epoch')
135
    plt.legend(['train', 'validation'], loc='upper left')
136
    #plt.show() 
137
    #path = '/data/kl2596/curves/accuracy/' + acc_path + '.jpeg'
138
    f.savefig(acc_path)
139
140
def plot_auc_curves(auc_history, acc_path): #, i):
141
    f = plt.figure()
142
    plt.plot(auc_history.auc)
143
    plt.plot(auc_history.val_auc)
144
    plt.title('model AUC')
145
    plt.ylabel('auc')
146
    plt.xlabel('epoch')
147
    plt.legend(['train', 'validation'], loc='upper left')
148
    #plt.show() 
149
    #path = '/data/kl2596/curves/accuracy/' + acc_path + '.jpeg'
150
    f.savefig(acc_path)
151
152
153
def train_model(model, train_data, val_data, path, index,val_fold):
154
    #model.summary()
155
    
156
    # Early Stopping callback that can be found on Keras website
157
    #early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=15)
158
    
159
    # Create path to save weights with model checkpoint
160
    weights_path = path + 'weights-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}-{loss:.2f}-{acc:.2f}.hdf5'
161
    model_checkpoint = ModelCheckpoint(weights_path, monitor = 'val_loss', save_best_only = True, 
162
                                       verbose=1)
163
    
164
    # Save loss and accuracy curves using Tensorboard
165
    tensorboard_callback = TensorBoard(log_dir = path, 
166
                                       histogram_freq = 0, 
167
                                       write_graph = False, 
168
                                       write_grads = False, 
169
                                       write_images = False)
170
171
    auc_history = roc_callback(index,val_fold)    
172
    #callbacks_list = [model_checkpoint, tensorboard_callback, auc_history]
173
    #es = EarlyStopping(monitor='val_auc', mode='max', verbose=1, patience=50)
174
    callbacks_list = [model_checkpoint, tensorboard_callback, auc_history]
175
    history = model.fit_generator(generator = train_data, validation_data = val_data, epochs=10, 
176
                        #use_multiprocessing=True, workers=6, 
177
                        callbacks = callbacks_list)
178
    
179
    accuracy = auc_history.val_auc
180
    print('*****************************')
181
    print('best auc:',np.max(accuracy))
182
    print('average auc:',np.mean(accuracy))
183
    print('*****************************')
184
185
    accuracy = history.history['val_acc']
186
    print('*****************************')
187
    print('best accuracy:', np.max(accuracy))
188
    print('average accuracy:', np.mean(accuracy))
189
    print('*****************************')
190
     
191
    loss_path = path + 'loss_curve.jpeg'
192
    acc_path = path + 'acc_curve.jpeg'
193
    auc_path = path + 'auc_curve.jpeg'
194
    plot_loss_curves(history, loss_path)
195
    plot_accuracy_curves(history, acc_path)
196
    plot_auc_curves(auc_history, auc_path)
197
    #model.save_weights(weights_path)
198
   
199
    
200
'''
201
    Def: Code to run stratified cross validation to train my network
202
    Params: num_of_folds = number of folds to cross validate
203
            lr = learning rate
204
            dr = dropout rate
205
            filters_in_last = number of filters in last convolutional layer (we tested 64 and 128)
206
            batch_norm = True or False for batch norm in model
207
            data = MRI images
208
            labels = labels corresponding to MRI images
209
            file_path = path to save network weights, curves, and tensorboard callbacks
210
'''
211
def cross_validation(val_fold, lr, filters_in_last, file_path):
212
    train_params = {'dim': (352,352,35),
213
          'batch_size': 4,
214
          'n_classes': 2,
215
          'n_channels': 1,
216
          'shuffle': True,
217
          'normalize' : True,
218
          'randomCrop' : True,
219
          'randomFlip' : True,
220
          'flipProbability' : -1,
221
          'cropDim' : (352,352,35)}
222
    
223
    val_params = {'dim': (352,352,35),
224
          'batch_size': 4,
225
          'n_classes': 2,
226
          'n_channels': 1,
227
          'shuffle': False,
228
          'normalize' : True,
229
          'randomCrop' : False,
230
          'randomFlip' : False,
231
          'flipProbability' : -1,
232
          'cropDim' : (352,352,35)} 
233
    
234
    model_path = file_path + 'COR_IW_TSE/'
235
    if not os.path.exists(model_path):
236
        os.makedirs(model_path)
237
            
238
    
239
    #all_folds = [1,2,3,4,5,6,7]
240
    #train_folds  = all_folds.remove(val_fold)
241
    
242
    num_of_folds = 6
243
    for i in range(num_of_folds):
244
        model = generate_model(learning_rate = 2 * 10 **(-4))
245
        model.summary()
246
        print(train_params)
247
        #print(train_index, test_index)
248
        print('Running Fold', i+1, '/', num_of_folds)   
249
        fold_path = model_path + 'Fold_' + str(val_fold) + '/CV_'+str(i+1)+'/'
250
        print(fold_path)
251
        
252
        if not os.path.exists(fold_path):
253
            os.makedirs(fold_path)    
254
        
255
        training_generator = DataGenerator(directory = FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(i+1)+'_train.csv',file_folder=FLAGS.file_folder,  **train_params)
256
        validation_generator = DataGenerator(directory = FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(i+1)+'_val.csv',file_folder=FLAGS.file_folder,  **val_params)
257
        
258
        train_model(model=model, 
259
                    train_data = training_generator,
260
                    val_data = validation_generator,
261
                    path = fold_path, index = i+1,val_fold=val_fold)
262
263
264
def main(argv=None):
265
    print('Begin training for fold ',FLAGS.val_fold)
266
    cross_validation(val_fold=FLAGS.val_fold, 
267
                     lr=FLAGS.lr, filters_in_last=FLAGS.filters_in_last,   
268
                     file_path = FLAGS.file_path)
269
270
if __name__ == "__main__":
271
    tf.app.run()