--- a +++ b/main_TrainValTest.py @@ -0,0 +1,426 @@ +""" +Copyright (C) 2022 King Saud University, Saudi Arabia +SPDX-License-Identifier: Apache-2.0 + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use +this file except in compliance with the License. You may obtain a copy of the +License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software distributed +under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. + +Author: Hamdi Altaheri +""" + +#%% +import os +import sys +import shutil +import time +import numpy as np +import matplotlib.pyplot as plt +import tensorflow as tf + +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.losses import CategoricalCrossentropy +from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau +from sklearn.metrics import confusion_matrix, accuracy_score, ConfusionMatrixDisplay +from sklearn.metrics import cohen_kappa_score +from sklearn.model_selection import train_test_split + +import models +from preprocess import get_data +# from keras.utils.vis_utils import plot_model + + +#%% +def draw_learning_curves(history, sub): + plt.plot(history.history['accuracy']) + plt.plot(history.history['val_accuracy']) + plt.title('Model accuracy - subject: ' + str(sub)) + plt.ylabel('Accuracy') + plt.xlabel('Epoch') + plt.legend(['Train', 'val'], loc='upper left') + plt.show() + plt.plot(history.history['loss']) + plt.plot(history.history['val_loss']) + plt.title('Model loss - subject: ' + str(sub)) + plt.ylabel('Loss') + plt.xlabel('Epoch') + plt.legend(['Train', 'val'], loc='upper left') + plt.show() + plt.close() + +def draw_confusion_matrix(cf_matrix, sub, results_path, classes_labels): + # Generate confusion matrix plot + display_labels = classes_labels + disp = ConfusionMatrixDisplay(confusion_matrix=cf_matrix, + display_labels=display_labels) + disp.plot() + disp.ax_.set_xticklabels(display_labels, rotation=12) + plt.title('Confusion Matrix of Subject: ' + sub ) + plt.savefig(results_path + '/subject_' + sub + '.png') + plt.show() + +def draw_performance_barChart(num_sub, metric, label): + fig, ax = plt.subplots() + x = list(range(1, num_sub+1)) + ax.bar(x, metric, 0.5, label=label) + ax.set_ylabel(label) + ax.set_xlabel("Subject") + ax.set_xticks(x) + ax.set_title('Model '+ label + ' per subject') + ax.set_ylim([0,1]) + + +#%% Training +def train(dataset_conf, train_conf, results_path): + + # remove the 'result' folder before training + if os.path.exists(results_path): + # Remove the folder and its contents + shutil.rmtree(results_path) + os.makedirs(results_path) + + # Get the current 'IN' time to calculate the overall training time + in_exp = time.time() + # Create a file to store the path of the best model among several runs + best_models = open(results_path + "/best models.txt", "w") + # Create a file to store performance during training + log_write = open(results_path + "/log.txt", "w") + + # Get dataset parameters + dataset = dataset_conf.get('name') + n_sub = dataset_conf.get('n_sub') + data_path = dataset_conf.get('data_path') + isStandard = dataset_conf.get('isStandard') + LOSO = dataset_conf.get('LOSO') + # Get training hyperparamters + batch_size = train_conf.get('batch_size') + epochs = train_conf.get('epochs') + patience = train_conf.get('patience') + lr = train_conf.get('lr') + LearnCurves = train_conf.get('LearnCurves') # Plot Learning Curves? + n_train = train_conf.get('n_train') + model_name = train_conf.get('model') + from_logits = train_conf.get('from_logits') + + # Initialize variables + acc = np.zeros((n_sub, n_train)) + kappa = np.zeros((n_sub, n_train)) + + # Iteration over subjects + # for sub in range(n_sub-1, n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject. + for sub in range(n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject. + + print('\nTraining on subject ', sub+1) + log_write.write( '\nTraining on subject '+ str(sub+1) +'\n') + # Initiating variables to save the best subject accuracy among multiple runs. + BestSubjAcc = 0 + bestTrainingHistory = [] + + # Get training and validation data + X_train, _, y_train_onehot, _, _, _ = get_data( + data_path, sub, dataset, LOSO = LOSO, isStandard = isStandard) + + # Divide the training data into training and validation + X_train, X_val, y_train_onehot, y_val_onehot = train_test_split(X_train, y_train_onehot, test_size=0.2, random_state=42) + + # Iteration over multiple runs + for train in range(n_train): # How many repetitions of training for subject i. + # Set the random seed for TensorFlow and NumPy random number generator. + # The purpose of setting a seed is to ensure reproducibility in random operations. + tf.random.set_seed(train+1) + np.random.seed(train+1) + + # Get the current 'IN' time to calculate the 'run' training time + in_run = time.time() + + # Create folders and files to save trained models for all runs + filepath = results_path + '/saved models/run-{}'.format(train+1) + if not os.path.exists(filepath): + os.makedirs(filepath) + filepath = filepath + '/subject-{}.h5'.format(sub+1) + + # Create the model + model = getModel(model_name, dataset_conf, from_logits) + # Compile and train the model + model.compile(loss=CategoricalCrossentropy(from_logits=from_logits), optimizer=Adam(learning_rate=lr), metrics=['accuracy']) + + # model.summary() + # plot_model(model, to_file='plot_model.png', show_shapes=True, show_layer_names=True) + + callbacks = [ + ModelCheckpoint(filepath, monitor='val_loss', verbose=0, + save_best_only=True, save_weights_only=True, mode='min'), + ReduceLROnPlateau(monitor="val_loss", factor=0.90, patience=20, verbose=0, min_lr=0.0001), + # EarlyStopping(monitor='val_loss', verbose=1, mode='min', patience=patience) + ] + history = model.fit(X_train, y_train_onehot, validation_data=(X_val, y_val_onehot), + epochs=epochs, batch_size=batch_size, callbacks=callbacks, verbose=0) + + # Evaluate the performance of the trained model based on the validation data + # Here we load the Trained weights from the file saved in the hard + # disk, which should be the same as the weights of the current model. + model.load_weights(filepath) + y_pred = model.predict(X_val) + + if from_logits: + y_pred = tf.nn.softmax(y_pred).numpy().argmax(axis=-1) + else: + y_pred = y_pred.argmax(axis=-1) + + labels = y_val_onehot.argmax(axis=-1) + acc[sub, train] = accuracy_score(labels, y_pred) + kappa[sub, train] = cohen_kappa_score(labels, y_pred) + + # Get the current 'OUT' time to calculate the 'run' training time + out_run = time.time() + # Print & write performance measures for each run + info = 'Subject: {} seed {} time: {:.1f} m '.format(sub+1, train+1, ((out_run-in_run)/60)) + info = info + 'valid_acc: {:.4f} valid_loss: {:.3f}'.format(acc[sub, train], min(history.history['val_loss'])) + print(info) + log_write.write(info +'\n') + # If current training run is better than previous runs, save the history. + if(BestSubjAcc < acc[sub, train]): + BestSubjAcc = acc[sub, train] + bestTrainingHistory = history + + # Store the path of the best model among several runs + best_run = np.argmax(acc[sub,:]) + filepath = '/saved models/run-{}/subject-{}.h5'.format(best_run+1, sub+1)+'\n' + best_models.write(filepath) + + # Plot Learning curves + if (LearnCurves == True): + print('Plot Learning Curves ....... ') + draw_learning_curves(bestTrainingHistory, sub+1) + + # Get the current 'OUT' time to calculate the overall training time + out_exp = time.time() + + # Print & write the validation performance using all seeds + head1 = head2 = ' ' + for sub in range(n_sub): + head1 = head1 + 'sub_{} '.format(sub+1) + head2 = head2 + '----- ' + head1 = head1 + ' average' + head2 = head2 + ' -------' + info = '\n---------------------------------\nValidation performance (acc %):' + info = info + '\n---------------------------------\n' + head1 +'\n'+ head2 + for run in range(n_train): + info = info + '\nSeed {}: '.format(run+1) + for sub in range(n_sub): + info = info + '{:.2f} '.format(acc[sub, run]*100) + info = info + ' {:.2f} '.format(np.average(acc[:, run])*100) + info = info + '\n---------------------------------\nAverage acc - all seeds: ' + info = info + '{:.2f} %\n\nTrain Time - all seeds: {:.1f}'.format(np.average(acc)*100, (out_exp-in_exp)/(60)) + info = info + ' min\n---------------------------------\n' + print(info) + log_write.write(info+'\n') + + # Close open files + best_models.close() + log_write.close() + + +#%% Evaluation +def test(model, dataset_conf, results_path, allRuns = True): + # Open the "Log" file to write the evaluation results + log_write = open(results_path + "/log.txt", "a") + + # Get dataset paramters + dataset = dataset_conf.get('name') + n_classes = dataset_conf.get('n_classes') + n_sub = dataset_conf.get('n_sub') + data_path = dataset_conf.get('data_path') + isStandard = dataset_conf.get('isStandard') + LOSO = dataset_conf.get('LOSO') + classes_labels = dataset_conf.get('cl_labels') + + # Test the performance based on several runs (seeds) + runs = os.listdir(results_path+"/saved models") + # Initialize variables + acc = np.zeros((n_sub, len(runs))) + kappa = np.zeros((n_sub, len(runs))) + cf_matrix = np.zeros([n_sub, len(runs), n_classes, n_classes]) + + # Iteration over subjects + # for sub in range(n_sub-1, n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject. + inference_time = 0 # inference_time: classification time for one trial + for sub in range(n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject. + # Load data + _, _, _, X_test, _, y_test_onehot = get_data(data_path, sub, dataset, LOSO = LOSO, isStandard = isStandard) + + # Iteration over runs (seeds) + for seed in range(len(runs)): + # Load the model of the seed. + model.load_weights('{}/saved models/{}/subject-{}.h5'.format(results_path, runs[seed], sub+1)) + + inference_time = time.time() + # Predict MI task + y_pred = model.predict(X_test).argmax(axis=-1) + inference_time = (time.time() - inference_time)/X_test.shape[0] + # Calculate accuracy and K-score + labels = y_test_onehot.argmax(axis=-1) + acc[sub, seed] = accuracy_score(labels, y_pred) + kappa[sub, seed] = cohen_kappa_score(labels, y_pred) + # Calculate and draw confusion matrix + cf_matrix[sub, seed, :, :] = confusion_matrix(labels, y_pred, normalize='true') + # draw_confusion_matrix(cf_matrix[sub, seed, :, :], str(sub+1), results_path, classes_labels) + + # Print & write the average performance measures for all subjects + head1 = head2 = ' ' + for sub in range(n_sub): + head1 = head1 + 'sub_{} '.format(sub+1) + head2 = head2 + '----- ' + head1 = head1 + ' average' + head2 = head2 + ' -------' + info = '\n' + head1 +'\n'+ head2 + info = '\n---------------------------------\nTest performance (acc & k-score):\n' + info = info + '---------------------------------\n' + head1 +'\n'+ head2 + for run in range(len(runs)): + info = info + '\nSeed {}: '.format(run+1) + info_acc = '(acc %) ' + info_k = ' (k-sco) ' + for sub in range(n_sub): + info_acc = info_acc + '{:.2f} '.format(acc[sub, run]*100) + info_k = info_k + '{:.3f} '.format(kappa[sub, run]) + info_acc = info_acc + ' {:.2f} '.format(np.average(acc[:, run])*100) + info_k = info_k + ' {:.3f} '.format(np.average(kappa[:, run])) + info = info + info_acc + '\n' + info_k + info = info + '\n----------------------------------\nAverage - all seeds (acc %): ' + info = info + '{:.2f}\n (k-sco): '.format(np.average(acc)*100) + info = info + '{:.3f}\n\nInference time: {:.2f}'.format(np.average(kappa), inference_time * 1000) + info = info + ' ms per trial\n----------------------------------\n' + print(info) + log_write.write(info+'\n') + + # Draw a performance bar chart for all subjects + draw_performance_barChart(n_sub, acc.mean(1), 'Accuracy') + draw_performance_barChart(n_sub, kappa.mean(1), 'k-score') + # Draw confusion matrix for all subjects (average) + draw_confusion_matrix(cf_matrix.mean((0,1)), 'All', results_path, classes_labels) + # Close opened file + log_write.close() + + +#%% +def getModel(model_name, dataset_conf, from_logits = False): + + n_classes = dataset_conf.get('n_classes') + n_channels = dataset_conf.get('n_channels') + in_samples = dataset_conf.get('in_samples') + + # Select the model + if(model_name == 'ATCNet'): + # Train using the proposed ATCNet model: https://ieeexplore.ieee.org/document/9852687 + model = models.ATCNet_( + # Dataset parameters + n_classes = n_classes, + in_chans = n_channels, + in_samples = in_samples, + # Sliding window (SW) parameter + n_windows = 5, + # Attention (AT) block parameter + attention = 'mha', # Options: None, 'mha','mhla', 'cbam', 'se' + # Convolutional (CV) block parameters + eegn_F1 = 16, + eegn_D = 2, + eegn_kernelSize = 64, + eegn_poolSize = 7, + eegn_dropout = 0.3, + # Temporal convolutional (TC) block parameters + tcn_depth = 2, + tcn_kernelSize = 4, + tcn_filters = 32, + tcn_dropout = 0.3, + tcn_activation='elu', + ) + elif(model_name == 'TCNet_Fusion'): + # Train using TCNet_Fusion: https://doi.org/10.1016/j.bspc.2021.102826 + model = models.TCNet_Fusion(n_classes = n_classes, Chans=n_channels, Samples=in_samples) + elif(model_name == 'EEGTCNet'): + # Train using EEGTCNet: https://arxiv.org/abs/2006.00622 + model = models.EEGTCNet(n_classes = n_classes, Chans=n_channels, Samples=in_samples) + elif(model_name == 'EEGNet'): + # Train using EEGNet: https://arxiv.org/abs/1611.08024 + model = models.EEGNet_classifier(n_classes = n_classes, Chans=n_channels, Samples=in_samples) + elif(model_name == 'EEGNeX'): + # Train using EEGNeX: https://arxiv.org/abs/2207.12369 + model = models.EEGNeX_8_32(n_timesteps = in_samples , n_features = n_channels, n_outputs = n_classes) + elif(model_name == 'DeepConvNet'): + # Train using DeepConvNet: https://doi.org/10.1002/hbm.23730 + model = models.DeepConvNet(nb_classes = n_classes , Chans = n_channels, Samples = in_samples) + elif(model_name == 'ShallowConvNet'): + # Train using ShallowConvNet: https://doi.org/10.1002/hbm.23730 + model = models.ShallowConvNet(nb_classes = n_classes , Chans = n_channels, Samples = in_samples) + elif(model_name == 'MBEEG_SENet'): + # Train using MBEEG_SENet: https://www.mdpi.com/2075-4418/12/4/995 + model = models.MBEEG_SENet(nb_classes = n_classes , Chans = n_channels, Samples = in_samples) + + else: + raise Exception("'{}' model is not supported yet!".format(model_name)) + + return model + +#%% +def run(): + # Define dataset parameters + dataset = 'HGD' # Options: 'BCI2a','HGD', 'CS2R' + + if dataset == 'BCI2a': + in_samples = 1125 + n_channels = 22 + n_sub = 9 + n_classes = 4 + classes_labels = ['Left hand', 'Right hand','Foot','Tongue'] + data_path = os.path.expanduser('~') + '/BCI Competition IV/BCI Competition IV-2a/BCI Competition IV 2a mat/' + elif dataset == 'HGD': + in_samples = 1125 + n_channels = 44 + n_sub = 14 + n_classes = 4 + classes_labels = ['Right Hand', 'Left Hand','Rest','Feet'] + data_path = os.path.expanduser('~') + '/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/' + elif dataset == 'CS2R': + in_samples = 1125 + # in_samples = 576 + n_channels = 32 + n_sub = 18 + n_classes = 3 + # classes_labels = ['Fingers', 'Wrist','Elbow','Rest'] + classes_labels = ['Fingers', 'Wrist','Elbow'] + # classes_labels = ['Fingers', 'Elbow'] + data_path = os.path.expanduser('~') + '/CS2R MI EEG dataset/all/EDF - Cleaned - phase one (remove extra runs)/two sessions/' + else: + raise Exception("'{}' dataset is not supported yet!".format(dataset)) + + # Create a folder to store the results of the experiment + results_path = os.getcwd() + "/results" + if not os.path.exists(results_path): + os.makedirs(results_path) # Create a new directory if it does not exist + + # Set dataset paramters + dataset_conf = { 'name': dataset, 'n_classes': n_classes, 'cl_labels': classes_labels, + 'n_sub': n_sub, 'n_channels': n_channels, 'in_samples': in_samples, + 'data_path': data_path, 'isStandard': True, 'LOSO': False} + # Set training hyperparamters + train_conf = { 'batch_size': 64, 'epochs': 500, 'patience': 100, 'lr': 0.001,'n_train': 1, + 'LearnCurves': True, 'from_logits': False, 'model':'ATCNet'} + + # Train the model + # train(dataset_conf, train_conf, results_path) + + # Evaluate the model based on the weights saved in the '/results' folder + model = getModel(train_conf.get('model'), dataset_conf) + test(model, dataset_conf, results_path) + +#%% +if __name__ == "__main__": + run() +