Diff of /main_TrainTest.py [000000] .. [56a69a]

Switch to side-by-side view

--- a
+++ b/main_TrainTest.py
@@ -0,0 +1,405 @@
+""" 
+Copyright (C) 2022 King Saud University, Saudi Arabia 
+SPDX-License-Identifier: Apache-2.0 
+
+Licensed under the Apache License, Version 2.0 (the "License"); you may not use
+this file except in compliance with the License. You may obtain a copy of the 
+License at
+
+http://www.apache.org/licenses/LICENSE-2.0  
+
+Unless required by applicable law or agreed to in writing, software distributed
+under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 
+CONDITIONS OF ANY KIND, either express or implied. See the License for the
+specific language governing permissions and limitations under the License. 
+
+Author:  Hamdi Altaheri 
+"""
+
+#%%
+import os
+import time
+import numpy as np
+import matplotlib.pyplot as plt
+import tensorflow as tf
+
+from tensorflow.keras.optimizers import Adam
+from tensorflow.keras.losses import categorical_crossentropy
+from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
+from sklearn.metrics import confusion_matrix, accuracy_score, ConfusionMatrixDisplay
+from sklearn.metrics import cohen_kappa_score
+
+import models 
+from preprocess import get_data
+# from keras.utils.vis_utils import plot_model
+
+
+#%%
+def draw_learning_curves(history):
+    plt.plot(history.history['accuracy'])
+    plt.plot(history.history['val_accuracy'])
+    plt.title('Model accuracy')
+    plt.ylabel('Accuracy')
+    plt.xlabel('Epoch')
+    plt.legend(['Train', 'val'], loc='upper left')
+    plt.show()
+    plt.plot(history.history['loss'])
+    plt.plot(history.history['val_loss'])
+    plt.title('Model loss')
+    plt.ylabel('Loss')
+    plt.xlabel('Epoch')
+    plt.legend(['Train', 'val'], loc='upper left')
+    plt.show()
+    plt.close()
+
+def draw_confusion_matrix(cf_matrix, sub, results_path, classes_labels):
+    # Generate confusion matrix plot
+    display_labels = classes_labels
+    disp = ConfusionMatrixDisplay(confusion_matrix=cf_matrix, 
+                                display_labels=display_labels)
+    disp.plot()
+    disp.ax_.set_xticklabels(display_labels, rotation=12)
+    plt.title('Confusion Matrix of Subject: ' + sub )
+    plt.savefig(results_path + '/subject_' + sub + '.png')
+    plt.show()
+
+def draw_performance_barChart(num_sub, metric, label):
+    fig, ax = plt.subplots()
+    x = list(range(1, num_sub+1))
+    ax.bar(x, metric, 0.5, label=label)
+    ax.set_ylabel(label)
+    ax.set_xlabel("Subject")
+    ax.set_xticks(x)
+    ax.set_title('Model '+ label + ' per subject')
+    ax.set_ylim([0,1])
+    
+    
+#%% Training 
+def train(dataset_conf, train_conf, results_path):
+    # Get the current 'IN' time to calculate the overall training time
+    in_exp = time.time()
+    # Create a file to store the path of the best model among several runs
+    best_models = open(results_path + "/best models.txt", "w")
+    # Create a file to store performance during training
+    log_write = open(results_path + "/log.txt", "w")
+    # Create a .npz file (zipped archive) to store the accuracy and kappa metrics 
+    # for all runs (to calculate average accuracy/kappa over all runs)
+    perf_allRuns = open(results_path + "/perf_allRuns.npz", 'wb')
+    
+    # Get dataset paramters
+    dataset = dataset_conf.get('name')
+    n_sub = dataset_conf.get('n_sub')
+    data_path = dataset_conf.get('data_path')
+    isStandard = dataset_conf.get('isStandard')
+    LOSO = dataset_conf.get('LOSO')
+    # Get training hyperparamters
+    batch_size = train_conf.get('batch_size')
+    epochs = train_conf.get('epochs')
+    patience = train_conf.get('patience')
+    lr = train_conf.get('lr')
+    LearnCurves = train_conf.get('LearnCurves') # Plot Learning Curves?
+    n_train = train_conf.get('n_train')
+    model_name = train_conf.get('model')
+
+    # Initialize variables
+    acc = np.zeros((n_sub, n_train))
+    kappa = np.zeros((n_sub, n_train))
+    
+    # Iteration over subjects 
+    # for sub in range(n_sub-1, n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject.
+    for sub in range(n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject.
+        # Get the current 'IN' time to calculate the subject training time
+        in_sub = time.time()
+        print('\nTraining on subject ', sub+1)
+        log_write.write( '\nTraining on subject '+ str(sub+1) +'\n')
+        # Initiating variables to save the best subject accuracy among multiple runs.
+        BestSubjAcc = 0 
+        bestTrainingHistory = [] 
+        # Get training and test data
+        X_train, _, y_train_onehot, X_test, _, y_test_onehot = get_data(
+            data_path, sub, dataset, LOSO = LOSO, isStandard = isStandard)  
+        
+        # Iteration over multiple runs 
+        for train in range(n_train): # How many repetitions of training for subject i.
+            # Get the current 'IN' time to calculate the 'run' training time
+            tf.random.set_seed(train+1)
+            np.random.seed(train+1)
+
+            in_run = time.time()
+            # Create folders and files to save trained models for all runs
+            filepath = results_path + '/saved models/run-{}'.format(train+1)
+            if not os.path.exists(filepath):
+                os.makedirs(filepath)        
+            filepath = filepath + '/subject-{}.h5'.format(sub+1)
+            
+            # Create the model
+            model = getModel(model_name, dataset_conf)
+            # Compile and train the model
+            model.compile(loss=categorical_crossentropy, optimizer=Adam(learning_rate=lr), metrics=['accuracy'])          
+            # model.summary()
+            # plot_model(model, to_file='plot_model.png', show_shapes=True, show_layer_names=True)
+            
+            callbacks = [
+                ModelCheckpoint(filepath, monitor='val_accuracy', verbose=0, 
+                                save_best_only=True, save_weights_only=True, mode='max'),
+                
+                ReduceLROnPlateau(monitor="val_loss", factor=0.90, patience=20, verbose=1, min_lr=0.0001),  
+                
+                EarlyStopping(monitor='val_accuracy', verbose=1, mode='max', patience=patience)
+            ]
+            history = model.fit(X_train, y_train_onehot, validation_data=(X_test, y_test_onehot), 
+                                epochs=epochs, batch_size=batch_size, callbacks=callbacks, verbose=0)
+
+            # Evaluate the performance of the trained model. 
+            # Here we load the Trained weights from the file saved in the hard 
+            # disk, which should be the same as the weights of the current model.
+            model.load_weights(filepath)
+            y_pred = model.predict(X_test).argmax(axis=-1)
+            labels = y_test_onehot.argmax(axis=-1)
+            acc[sub, train]  = accuracy_score(labels, y_pred)
+            kappa[sub, train] = cohen_kappa_score(labels, y_pred)
+              
+            # Get the current 'OUT' time to calculate the 'run' training time
+            out_run = time.time()
+            # Print & write performance measures for each run
+            info = 'Subject: {}   Train no. {}   Time: {:.1f} m   '.format(sub+1, train+1, ((out_run-in_run)/60))
+            info = info + 'Test_acc: {:.4f}   Test_kappa: {:.4f}'.format(acc[sub, train], kappa[sub, train])
+            print(info)
+            log_write.write(info +'\n')
+            # If current training run is better than previous runs, save the history.
+            if(BestSubjAcc < acc[sub, train]):
+                 BestSubjAcc = acc[sub, train]
+                 bestTrainingHistory = history
+        
+        # Store the path of the best model among several runs
+        best_run = np.argmax(acc[sub,:])
+        filepath = '/saved models/run-{}/subject-{}.h5'.format(best_run+1, sub+1)+'\n'
+        best_models.write(filepath)
+        # Get the current 'OUT' time to calculate the subject training time
+        out_sub = time.time()
+        # Print & write the best subject performance among multiple runs
+        info = '----------\n'
+        info = info + 'Subject: {}   best_run: {}   Time: {:.1f} m   '.format(sub+1, best_run+1, ((out_sub-in_sub)/60))
+        info = info + 'acc: {:.4f}   avg_acc: {:.4f} +- {:.4f}   '.format(acc[sub, best_run], np.average(acc[sub, :]), acc[sub,:].std() )
+        info = info + 'kappa: {:.4f}   avg_kappa: {:.4f} +- {:.4f}'.format(kappa[sub, best_run], np.average(kappa[sub, :]), kappa[sub,:].std())
+        info = info + '\n----------'
+        print(info)
+        log_write.write(info+'\n')
+        # Plot Learning curves 
+        if (LearnCurves == True):
+            print('Plot Learning Curves ....... ')
+            draw_learning_curves(bestTrainingHistory)
+          
+    # Get the current 'OUT' time to calculate the overall training time
+    out_exp = time.time()
+    info = '\nTime: {:.1f} h   '.format( (out_exp-in_exp)/(60*60) )
+    print(info)
+    log_write.write(info+'\n')
+    
+    # Store the accuracy and kappa metrics as arrays for all runs into a .npz 
+    # file format, which is an uncompressed zipped archive, to calculate average
+    # accuracy/kappa over all runs.
+    np.savez(perf_allRuns, acc = acc, kappa = kappa)
+    
+    # Close open files 
+    best_models.close()   
+    log_write.close() 
+    perf_allRuns.close() 
+
+
+#%% Evaluation 
+def test(model, dataset_conf, results_path, allRuns = True):
+    # Open the  "Log" file to write the evaluation results 
+    log_write = open(results_path + "/log.txt", "a")
+    # Open the file that stores the path of the best models among several random runs.
+    best_models = open(results_path + "/best models.txt", "r")   
+    
+    # Get dataset paramters
+    dataset = dataset_conf.get('name')
+    n_classes = dataset_conf.get('n_classes')
+    n_sub = dataset_conf.get('n_sub')
+    data_path = dataset_conf.get('data_path')
+    isStandard = dataset_conf.get('isStandard')
+    LOSO = dataset_conf.get('LOSO')
+    classes_labels = dataset_conf.get('cl_labels')
+    
+    # Initialize variables
+    acc_bestRun = np.zeros(n_sub)
+    kappa_bestRun = np.zeros(n_sub)  
+    cf_matrix = np.zeros([n_sub, n_classes, n_classes])
+
+    # Calculate the average performance (average accuracy and K-score) for 
+    # all runs (experiments) for each subject.
+    if(allRuns): 
+        # Load the test accuracy and kappa metrics as arrays for all runs from a .npz 
+        # file format, which is an uncompressed zipped archive, to calculate average
+        # accuracy/kappa over all runs.
+        perf_allRuns = open(results_path + "/perf_allRuns.npz", 'rb')
+        perf_arrays = np.load(perf_allRuns)
+        acc_allRuns = perf_arrays['acc']
+        kappa_allRuns = perf_arrays['kappa']
+    
+    # Iteration over subjects 
+    # for sub in range(n_sub-1, n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject.
+    for sub in range(n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject.
+        # Load data
+        _, _, _, X_test, _, y_test_onehot = get_data(data_path, sub, dataset, LOSO, isStandard)      
+        
+        # Load the best model out of multiple random runs (experiments).
+        filepath = best_models.readline()
+        model.load_weights(results_path + filepath[:-1])
+        # Predict MI task
+        y_pred = model.predict(X_test).argmax(axis=-1)
+        # Calculate accuracy and K-score
+        labels = y_test_onehot.argmax(axis=-1)
+        acc_bestRun[sub] = accuracy_score(labels, y_pred)
+        kappa_bestRun[sub] = cohen_kappa_score(labels, y_pred)
+        # Calculate and draw confusion matrix
+        cf_matrix[sub, :, :] = confusion_matrix(labels, y_pred, normalize='true')
+        draw_confusion_matrix(cf_matrix[sub, :, :], str(sub+1), results_path, classes_labels)
+        
+        # Print & write performance measures for each subject
+        info = 'Subject: {}   best_run: {:2}  '.format(sub+1, (filepath[filepath.find('run-')+4:filepath.find('/sub')]) )
+        info = info + 'acc: {:.4f}   kappa: {:.4f}   '.format(acc_bestRun[sub], kappa_bestRun[sub] )
+        if(allRuns): 
+            info = info + 'avg_acc: {:.4f} +- {:.4f}   avg_kappa: {:.4f} +- {:.4f}'.format(
+                np.average(acc_allRuns[sub, :]), acc_allRuns[sub,:].std(),
+                np.average(kappa_allRuns[sub, :]), kappa_allRuns[sub,:].std() )
+        print(info)
+        log_write.write('\n'+info)
+      
+    # Print & write the average performance measures for all subjects     
+    info = '\nAverage of {} subjects - best runs:\nAccuracy = {:.4f}   Kappa = {:.4f}\n'.format(
+        n_sub, np.average(acc_bestRun), np.average(kappa_bestRun)) 
+    if(allRuns): 
+        info = info + '\nAverage of {} subjects x {} runs (average of {} experiments):\nAccuracy = {:.4f}   Kappa = {:.4f}'.format(
+            n_sub, acc_allRuns.shape[1], (n_sub * acc_allRuns.shape[1]),
+            np.average(acc_allRuns), np.average(kappa_allRuns)) 
+    print(info)
+    log_write.write(info)
+    
+    # Draw a performance bar chart for all subjects 
+    draw_performance_barChart(n_sub, acc_bestRun, 'Accuracy')
+    draw_performance_barChart(n_sub, kappa_bestRun, 'K-score')
+    # Draw confusion matrix for all subjects (average)
+    draw_confusion_matrix(cf_matrix.mean(0), 'All', results_path, classes_labels)
+    # Close open files     
+    log_write.close() 
+    
+    
+#%%
+def getModel(model_name, dataset_conf):
+    
+    n_classes = dataset_conf.get('n_classes')
+    n_channels = dataset_conf.get('n_channels')
+    in_samples = dataset_conf.get('in_samples')
+
+    # Select the model
+    if(model_name == 'ATCNet'):
+        # Train using the proposed ATCNet model: https://doi.org/10.1109/TII.2022.3197419
+        model = models.ATCNet_( 
+            # Dataset parameters
+            n_classes = n_classes, 
+            in_chans = n_channels, 
+            in_samples = in_samples, 
+            # Sliding window (SW) parameter
+            n_windows = 5, 
+            # Attention (AT) block parameter
+            attention = 'mha', # Options: None, 'mha','mhla', 'cbam', 'se'
+            # Convolutional (CV) block parameters
+            eegn_F1 = 16,
+            eegn_D = 2, 
+            eegn_kernelSize = 64,
+            eegn_poolSize = 7,
+            eegn_dropout = 0.3,
+            # Temporal convolutional (TC) block parameters
+            tcn_depth = 2, 
+            tcn_kernelSize = 4,
+            tcn_filters = 32,
+            tcn_dropout = 0.3, 
+            tcn_activation='elu'
+            )     
+    elif(model_name == 'TCNet_Fusion'):
+        # Train using TCNet_Fusion: https://doi.org/10.1016/j.bspc.2021.102826
+        model = models.TCNet_Fusion(n_classes = n_classes, Chans=n_channels, Samples=in_samples)      
+    elif(model_name == 'EEGTCNet'):
+        # Train using EEGTCNet: https://arxiv.org/abs/2006.00622
+        model = models.EEGTCNet(n_classes = n_classes, Chans=n_channels, Samples=in_samples)          
+    elif(model_name == 'EEGNet'):
+        # Train using EEGNet: https://arxiv.org/abs/1611.08024
+        model = models.EEGNet_classifier(n_classes = n_classes, Chans=n_channels, Samples=in_samples) 
+    elif(model_name == 'EEGNeX'):
+        # Train using EEGNeX: https://arxiv.org/abs/2207.12369
+        model = models.EEGNeX_8_32(n_timesteps = in_samples , n_features = n_channels, n_outputs = n_classes)
+    elif(model_name == 'DeepConvNet'):
+        # Train using DeepConvNet: https://doi.org/10.1002/hbm.23730
+        model = models.DeepConvNet(nb_classes = n_classes , Chans = n_channels, Samples = in_samples)
+    elif(model_name == 'ShallowConvNet'):
+        # Train using ShallowConvNet: https://doi.org/10.1002/hbm.23730
+        model = models.ShallowConvNet(nb_classes = n_classes , Chans = n_channels, Samples = in_samples)
+    elif(model_name == 'MBEEG_SENet'):
+        # Train using MBEEG_SENet: https://www.mdpi.com/2075-4418/12/4/995
+        model = models.MBEEG_SENet(nb_classes = n_classes , Chans = n_channels, Samples = in_samples)
+
+    else:
+        raise Exception("'{}' model is not supported yet!".format(model_name))
+
+    return model
+    
+    
+#%%
+def run():
+    # Define dataset parameters
+    dataset = 'BCI2a' # Options: 'BCI2a','HGD', 'CS2R'
+    
+    if dataset == 'BCI2a': 
+        in_samples = 1125
+        n_channels = 22
+        n_sub = 9
+        n_classes = 4
+        classes_labels = ['Left hand', 'Right hand','Foot','Tongue']
+        data_path = os.path.expanduser('~') + '/BCI Competition IV/BCI Competition IV-2a/BCI Competition IV 2a mat/'
+    elif dataset == 'HGD': 
+        in_samples = 1125
+        n_channels = 44
+        n_sub = 14
+        n_classes = 4
+        classes_labels = ['Right Hand', 'Left Hand','Rest','Feet']     
+        data_path = os.path.expanduser('~') + '/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/'
+    elif dataset == 'CS2R': 
+        in_samples = 1125
+        # in_samples = 576
+        n_channels = 32
+        n_sub = 18
+        n_classes = 3
+        # classes_labels = ['Fingers', 'Wrist','Elbow','Rest']     
+        classes_labels = ['Fingers', 'Wrist','Elbow']     
+        # classes_labels = ['Fingers', 'Elbow']     
+        data_path = os.path.expanduser('~') + '/CS2R MI EEG dataset/all/EDF - Cleaned - phase one (remove extra runs)/two sessions/'
+    else:
+        raise Exception("'{}' dataset is not supported yet!".format(dataset))
+        
+    # Create a folder to store the results of the experiment
+    results_path = os.getcwd() + "/results"
+    if not  os.path.exists(results_path):
+      os.makedirs(results_path)   # Create a new directory if it does not exist 
+      
+    # Set dataset paramters 
+    dataset_conf = { 'name': dataset, 'n_classes': n_classes, 'cl_labels': classes_labels,
+                    'n_sub': n_sub, 'n_channels': n_channels, 'in_samples': in_samples,
+                    'data_path': data_path, 'isStandard': True, 'LOSO': False}
+    # Set training hyperparamters
+    train_conf = { 'batch_size': 64, 'epochs': 1000, 'patience': 300, 'lr': 0.001,
+                  'LearnCurves': True, 'n_train': 10, 'model':'ATCNet'}
+           
+    # Train the model
+    # train(dataset_conf, train_conf, results_path)
+
+    # Evaluate the model based on the weights saved in the '/results' folder
+    model = getModel(train_conf.get('model'), dataset_conf)
+    test(model, dataset_conf, results_path)    
+    
+#%%
+if __name__ == "__main__":
+    run()
+