Diff of /IW-TSE/train.py [000000] .. [6a4082]

Switch to side-by-side view

--- a
+++ b/IW-TSE/train.py
@@ -0,0 +1,291 @@
+# ==============================================================================
+# Copyright (C) 2023 Haresh Rengaraj Rajamohan, Tianyu Wang, Kevin Leung, 
+# Gregory Chang, Kyunghyun Cho, Richard Kijowski & Cem M. Deniz 
+#
+# This file is part of OAI-MRI-TKR
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+# ==============================================================================
+#!/usr/bin/env python3
+import h5py
+import os.path
+import numpy as np
+import pandas as pd
+import math
+import matplotlib
+matplotlib.use('Agg')
+
+import matplotlib.pyplot as plt
+import tensorflow as tf
+#from sklearn.model_selection import StratifiedKFold
+from ModelResnet3D import generate_model
+from DataGenerator import DataGenerator
+
+#from keras.models import Sequential
+#from keras.optimizers import SGD, Adam
+#from keras.layers import Dropout, Dense, Conv3D, MaxPooling3D, GlobalAveragePooling3D, Activation, BatchNormalization,Flatten
+from keras.callbacks import LearningRateScheduler, TensorBoard, EarlyStopping, ModelCheckpoint, Callback
+from sklearn.metrics import roc_auc_score
+
+tf.app.flags.DEFINE_boolean('batch_norm', True, 'Use BN or not')
+tf.app.flags.DEFINE_float('lr', 0.0001, 'Initial learning rate.')
+tf.app.flags.DEFINE_integer('filters_in_last', 128, 'Number of filters on the last layer')
+tf.app.flags.DEFINE_float('dr',0.0, 'Dropout rate when training.')
+tf.app.flags.DEFINE_string('input_file','/gpfs/data/denizlab/Datasets/OAI/SAG_IW_TSE/dataBefore201807/data_SAG_IW_TSE.hdf5', 'File to read data')
+tf.app.flags.DEFINE_string('file_path', '/gpfs/data/denizlab/Users/hrr288/Radiology_test/', 'Main Folder to Save outputs')
+tf.app.flags.DEFINE_integer('val_fold', 1, 'Fold for cv')
+tf.app.flags.DEFINE_string('file_folder','/gpfs/data/denizlab/Datasets/OAI/SAG_IW_TSE/', 'Path to HDF5 radiographs')
+tf.app.flags.DEFINE_string('csv_path', '/gpfs/data/denizlab/Users/hrr288/TSE_dataset/', 'Folder with the fold splits')
+
+FLAGS = tf.app.flags.FLAGS
+
+class auc_Histories(Callback):
+    def on_train_begin(self, logs={}):
+        self.aucs = []
+        self.losses = []
+ 
+    def on_train_end(self, logs={}):
+        return
+ 
+    def on_epoch_begin(self, epoch, logs={}):
+        return
+ 
+    def on_epoch_end(self, epoch, logs={}):
+        self.losses.append(logs.get('loss'))
+        y_pred = self.model.predict(self.model.validation_data)
+        self.aucs.append(roc_auc_score(self.model.validation_data[1], y_pred))
+        return
+ 
+    def on_batch_begin(self, batch, logs={}):
+        return
+ 
+    def on_batch_end(self, batch, logs={}):
+        return
+
+
+
+class roc_callback(Callback):
+    def __init__(self,index,val_fold):
+        _params = {'dim': (384,384,36),
+              'batch_size': 4,
+              'n_classes': 2,
+              'n_channels': 1,
+              'shuffle': False,
+              'normalize' : True,
+              'randomCrop' : False,
+              'randomFlip' : False,
+              'flipProbability' : -1}
+        self.x = DataGenerator(directory = FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(index)+'_train.csv',file_folder=FLAGS.file_folder,  **_params)
+        self.x_val = DataGenerator(directory = FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(index)+'_val.csv',file_folder=FLAGS.file_folder,  **_params)
+        self.y = pd.read_csv(FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(index)+'_train.csv').Label
+        self.y_val = pd.read_csv(FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(index)+'_val.csv').Label
+        self.auc = []
+        self.val_auc = []
+        self.losses = []
+        self.val_losses = []
+    
+    def on_train_begin(self, logs={}):
+        return
+ 
+    def on_train_end(self, logs={}):
+        return
+ 
+    def on_epoch_begin(self, epoch, logs={}):
+        return
+    
+    def on_epoch_end(self, epoch, logs={}):        
+        self.losses.append(logs.get('loss'))
+        self.val_losses.append(logs.get('val_loss'))
+        y_pred = self.model.predict_generator(self.x)
+        y_true = self.y[:len(y_pred)]
+        roc = roc_auc_score(y_true, y_pred)      
+        
+        y_pred_val = self.model.predict_generator(self.x_val)
+        y_true_val = self.y_val[:len(y_pred_val)]
+        roc_val = roc_auc_score(y_true_val, y_pred_val)      
+        self.auc.append(roc)
+        self.val_auc.append(roc_val)
+        #print(len(y_true),len(y_true_val))
+        print('\rroc-auc: %s - roc-auc_val: %s' % (str(round(roc,4)),str(round(roc_val,4))),end=100*' '+'\n')
+        return
+ 
+    def on_batch_begin(self, batch, logs={}):
+        return
+ 
+    def on_batch_end(self, batch, logs={}):
+        return   
+    
+'''
+    Def: Code to plot loss curves
+    Params: history = keras output from training
+            loss_path = path to save curve
+'''
+def plot_loss_curves(history, loss_path): #, i):
+    f = plt.figure()
+    plt.plot(history.history['loss'])
+    plt.plot(history.history['val_loss'])
+    plt.title('model loss')
+    plt.ylabel('loss')
+    plt.xlabel('epoch')
+    plt.legend(['train', 'validation'], loc='upper left')
+    #plt.show()    
+    #path = '/data/kl2596/curves/loss/' + loss_path + '.jpeg'
+    f.savefig(loss_path)
+
+
+'''
+    Def: Code to plot accuracy curves
+    Params: history = keras output from training
+            acc_path = path to save curve
+'''
+def plot_accuracy_curves(history, acc_path): #, i):
+    f = plt.figure()
+    plt.plot(history.history['acc'])
+    plt.plot(history.history['val_acc'])
+    plt.title('model accuracy')
+    plt.ylabel('acc')
+    plt.xlabel('epoch')
+    plt.legend(['train', 'validation'], loc='upper left')
+    #plt.show() 
+    #path = '/data/kl2596/curves/accuracy/' + acc_path + '.jpeg'
+    f.savefig(acc_path)
+
+
+def plot_auc_curves(auc_history, acc_path): #, i):
+    f = plt.figure()
+    plt.plot(auc_history.auc)
+    plt.plot(auc_history.val_auc)
+    plt.title('model AUC')
+    plt.ylabel('auc')
+    plt.xlabel('epoch')
+    plt.legend(['train', 'validation'], loc='upper left')
+    #plt.show() 
+    #path = '/data/kl2596/curves/accuracy/' + acc_path + '.jpeg'
+    f.savefig(acc_path)
+
+def train_model(model, train_data, val_data, path, index,val_fold):
+    #model.summary()
+    
+    # Early Stopping callback that can be found on Keras website
+    
+    # Create path to save weights with model checkpoint
+    weights_path = path + 'weights-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}-{loss:.2f}-{acc:.2f}.hdf5'
+    model_checkpoint = ModelCheckpoint(weights_path, monitor = 'val_loss', save_best_only = True, 
+                                       verbose=0, period=1)
+    
+    # Save loss and accuracy curves using Tensorboard
+    tensorboard_callback = TensorBoard(log_dir = path, 
+                                       histogram_freq = 0, 
+                                       write_graph = False, 
+                                       write_grads = False, 
+                                       write_images = False)
+    
+    auc_history = roc_callback(index,val_fold)
+    #es = EarlyStopping(monitor='val_auc', mode='max', verbose=1, patience=50)
+    callbacks_list = [model_checkpoint, tensorboard_callback, auc_history]
+    #es = EarlyStopping(monitor='val_auc', mode='max', verbose=1, patience=150)
+    history = model.fit_generator(generator = train_data, validation_data = val_data, epochs=10, 
+                        #use_multiprocessing=True, workers=6, 
+                        callbacks = callbacks_list)
+    
+    accuracy = auc_history.val_auc
+    print('*****************************')
+    print('best auc:',np.max(accuracy))
+    print('average auc:',np.mean(accuracy))
+    print('*****************************') 
+
+    accuracy = history.history['val_acc']
+    print('*****************************')
+    print('best acc:',np.max(accuracy))
+    print('average acc:',np.mean(accuracy))
+    print('*****************************')
+    loss_path = path + 'loss_curve.jpeg'
+    acc_path = path + 'acc_curve.jpeg'
+    auc_path = path + 'auc_curve.jpeg'
+    plot_loss_curves(history, loss_path)
+    plot_accuracy_curves(history, acc_path)
+    plot_auc_curves(auc_history, auc_path)
+    #model.save_weights(weights_path)
+   
+    
+'''
+    Def: Code to run stratified cross validation to train my network
+    Params: num_of_folds = number of folds to cross validate
+            lr = learning rate
+            dr = dropout rate
+            filters_in_last = number of filters in last convolutional layer (we tested 64 and 128)
+            batch_norm = True or False for batch norm in model
+            data = MRI images
+            labels = labels corresponding to MRI images
+            file_path = path to save network weights, curves, and tensorboard callbacks
+'''
+def cross_validation(val_fold, lr, dr, filters_in_last, file_path):
+    train_params = {'dim': (384,384,36),
+          'batch_size': 4,
+          'n_classes': 2,
+          'n_channels': 1,
+          'shuffle': True,
+          'normalize' : True,
+          'randomCrop' : True,
+          'randomFlip' : True,
+          'flipProbability' : -1}
+    
+    val_params = {'dim': (384,384,36),
+          'batch_size': 4,
+          'n_classes': 2,
+          'n_channels': 1,
+          'shuffle': False,
+          'normalize' : True,
+          'randomCrop' : False,
+          'randomFlip' : False,
+          'flipProbability' : -1} 
+    
+    model_path = file_path + 'Tnetres_Best/lr24ch32kerne773773_strde222_new_arch/'
+    if not os.path.exists(model_path):
+        os.makedirs(model_path)
+            
+
+    num_of_folds = 6
+    for i in range(num_of_folds):
+        model = generate_model(learning_rate = 2 * 10 **(-4))
+        model.summary()
+        print(train_params)
+        #print(train_index, test_index)
+        print('Running Fold', i+1, '/', num_of_folds)   
+        fold_path = model_path + 'Fold_' + str(val_fold) + '/CV_'+str(i+1)+'/'
+        print(fold_path)
+        
+        if not os.path.exists(fold_path):
+            os.makedirs(fold_path)    
+        
+        training_generator = DataGenerator(directory = FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(i+1)+'_train.csv',file_folder=FLAGS.file_folder,  **train_params)
+        validation_generator = DataGenerator(directory = FLAGS.csv_path+'Fold_'+str(val_fold)+'/CV_'+str(i+1)+'_val.csv',file_folder=FLAGS.file_folder,  **val_params)
+        
+        train_model(model=model, 
+                    train_data = training_generator,
+                    val_data = validation_generator,
+                    path = fold_path, index = i+1,val_fold=val_fold)
+
+
+
+def main(argv=None):
+    print('Begin training for fold ',FLAGS.val_fold)
+    cross_validation(val_fold=FLAGS.val_fold, 
+                     lr=FLAGS.lr, dr=FLAGS.dr, filters_in_last=FLAGS.filters_in_last,   
+                     file_path = FLAGS.file_path)
+
+if __name__ == "__main__":
+    tf.app.run()
+
+