Diff of /main_TrainTest.py [000000] .. [56a69a]

Switch to unified view

a b/main_TrainTest.py
1
""" 
2
Copyright (C) 2022 King Saud University, Saudi Arabia 
3
SPDX-License-Identifier: Apache-2.0 
4
5
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
6
this file except in compliance with the License. You may obtain a copy of the 
7
License at
8
9
http://www.apache.org/licenses/LICENSE-2.0  
10
11
Unless required by applicable law or agreed to in writing, software distributed
12
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 
13
CONDITIONS OF ANY KIND, either express or implied. See the License for the
14
specific language governing permissions and limitations under the License. 
15
16
Author:  Hamdi Altaheri 
17
"""
18
19
#%%
20
import os
21
import time
22
import numpy as np
23
import matplotlib.pyplot as plt
24
import tensorflow as tf
25
26
from tensorflow.keras.optimizers import Adam
27
from tensorflow.keras.losses import categorical_crossentropy
28
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
29
from sklearn.metrics import confusion_matrix, accuracy_score, ConfusionMatrixDisplay
30
from sklearn.metrics import cohen_kappa_score
31
32
import models 
33
from preprocess import get_data
34
# from keras.utils.vis_utils import plot_model
35
36
37
#%%
38
def draw_learning_curves(history):
39
    plt.plot(history.history['accuracy'])
40
    plt.plot(history.history['val_accuracy'])
41
    plt.title('Model accuracy')
42
    plt.ylabel('Accuracy')
43
    plt.xlabel('Epoch')
44
    plt.legend(['Train', 'val'], loc='upper left')
45
    plt.show()
46
    plt.plot(history.history['loss'])
47
    plt.plot(history.history['val_loss'])
48
    plt.title('Model loss')
49
    plt.ylabel('Loss')
50
    plt.xlabel('Epoch')
51
    plt.legend(['Train', 'val'], loc='upper left')
52
    plt.show()
53
    plt.close()
54
55
def draw_confusion_matrix(cf_matrix, sub, results_path, classes_labels):
56
    # Generate confusion matrix plot
57
    display_labels = classes_labels
58
    disp = ConfusionMatrixDisplay(confusion_matrix=cf_matrix, 
59
                                display_labels=display_labels)
60
    disp.plot()
61
    disp.ax_.set_xticklabels(display_labels, rotation=12)
62
    plt.title('Confusion Matrix of Subject: ' + sub )
63
    plt.savefig(results_path + '/subject_' + sub + '.png')
64
    plt.show()
65
66
def draw_performance_barChart(num_sub, metric, label):
67
    fig, ax = plt.subplots()
68
    x = list(range(1, num_sub+1))
69
    ax.bar(x, metric, 0.5, label=label)
70
    ax.set_ylabel(label)
71
    ax.set_xlabel("Subject")
72
    ax.set_xticks(x)
73
    ax.set_title('Model '+ label + ' per subject')
74
    ax.set_ylim([0,1])
75
    
76
    
77
#%% Training 
78
def train(dataset_conf, train_conf, results_path):
79
    # Get the current 'IN' time to calculate the overall training time
80
    in_exp = time.time()
81
    # Create a file to store the path of the best model among several runs
82
    best_models = open(results_path + "/best models.txt", "w")
83
    # Create a file to store performance during training
84
    log_write = open(results_path + "/log.txt", "w")
85
    # Create a .npz file (zipped archive) to store the accuracy and kappa metrics 
86
    # for all runs (to calculate average accuracy/kappa over all runs)
87
    perf_allRuns = open(results_path + "/perf_allRuns.npz", 'wb')
88
    
89
    # Get dataset paramters
90
    dataset = dataset_conf.get('name')
91
    n_sub = dataset_conf.get('n_sub')
92
    data_path = dataset_conf.get('data_path')
93
    isStandard = dataset_conf.get('isStandard')
94
    LOSO = dataset_conf.get('LOSO')
95
    # Get training hyperparamters
96
    batch_size = train_conf.get('batch_size')
97
    epochs = train_conf.get('epochs')
98
    patience = train_conf.get('patience')
99
    lr = train_conf.get('lr')
100
    LearnCurves = train_conf.get('LearnCurves') # Plot Learning Curves?
101
    n_train = train_conf.get('n_train')
102
    model_name = train_conf.get('model')
103
104
    # Initialize variables
105
    acc = np.zeros((n_sub, n_train))
106
    kappa = np.zeros((n_sub, n_train))
107
    
108
    # Iteration over subjects 
109
    # for sub in range(n_sub-1, n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject.
110
    for sub in range(n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject.
111
        # Get the current 'IN' time to calculate the subject training time
112
        in_sub = time.time()
113
        print('\nTraining on subject ', sub+1)
114
        log_write.write( '\nTraining on subject '+ str(sub+1) +'\n')
115
        # Initiating variables to save the best subject accuracy among multiple runs.
116
        BestSubjAcc = 0 
117
        bestTrainingHistory = [] 
118
        # Get training and test data
119
        X_train, _, y_train_onehot, X_test, _, y_test_onehot = get_data(
120
            data_path, sub, dataset, LOSO = LOSO, isStandard = isStandard)  
121
        
122
        # Iteration over multiple runs 
123
        for train in range(n_train): # How many repetitions of training for subject i.
124
            # Get the current 'IN' time to calculate the 'run' training time
125
            tf.random.set_seed(train+1)
126
            np.random.seed(train+1)
127
128
            in_run = time.time()
129
            # Create folders and files to save trained models for all runs
130
            filepath = results_path + '/saved models/run-{}'.format(train+1)
131
            if not os.path.exists(filepath):
132
                os.makedirs(filepath)        
133
            filepath = filepath + '/subject-{}.h5'.format(sub+1)
134
            
135
            # Create the model
136
            model = getModel(model_name, dataset_conf)
137
            # Compile and train the model
138
            model.compile(loss=categorical_crossentropy, optimizer=Adam(learning_rate=lr), metrics=['accuracy'])          
139
            # model.summary()
140
            # plot_model(model, to_file='plot_model.png', show_shapes=True, show_layer_names=True)
141
            
142
            callbacks = [
143
                ModelCheckpoint(filepath, monitor='val_accuracy', verbose=0, 
144
                                save_best_only=True, save_weights_only=True, mode='max'),
145
                
146
                ReduceLROnPlateau(monitor="val_loss", factor=0.90, patience=20, verbose=1, min_lr=0.0001),  
147
                
148
                EarlyStopping(monitor='val_accuracy', verbose=1, mode='max', patience=patience)
149
            ]
150
            history = model.fit(X_train, y_train_onehot, validation_data=(X_test, y_test_onehot), 
151
                                epochs=epochs, batch_size=batch_size, callbacks=callbacks, verbose=0)
152
153
            # Evaluate the performance of the trained model. 
154
            # Here we load the Trained weights from the file saved in the hard 
155
            # disk, which should be the same as the weights of the current model.
156
            model.load_weights(filepath)
157
            y_pred = model.predict(X_test).argmax(axis=-1)
158
            labels = y_test_onehot.argmax(axis=-1)
159
            acc[sub, train]  = accuracy_score(labels, y_pred)
160
            kappa[sub, train] = cohen_kappa_score(labels, y_pred)
161
              
162
            # Get the current 'OUT' time to calculate the 'run' training time
163
            out_run = time.time()
164
            # Print & write performance measures for each run
165
            info = 'Subject: {}   Train no. {}   Time: {:.1f} m   '.format(sub+1, train+1, ((out_run-in_run)/60))
166
            info = info + 'Test_acc: {:.4f}   Test_kappa: {:.4f}'.format(acc[sub, train], kappa[sub, train])
167
            print(info)
168
            log_write.write(info +'\n')
169
            # If current training run is better than previous runs, save the history.
170
            if(BestSubjAcc < acc[sub, train]):
171
                 BestSubjAcc = acc[sub, train]
172
                 bestTrainingHistory = history
173
        
174
        # Store the path of the best model among several runs
175
        best_run = np.argmax(acc[sub,:])
176
        filepath = '/saved models/run-{}/subject-{}.h5'.format(best_run+1, sub+1)+'\n'
177
        best_models.write(filepath)
178
        # Get the current 'OUT' time to calculate the subject training time
179
        out_sub = time.time()
180
        # Print & write the best subject performance among multiple runs
181
        info = '----------\n'
182
        info = info + 'Subject: {}   best_run: {}   Time: {:.1f} m   '.format(sub+1, best_run+1, ((out_sub-in_sub)/60))
183
        info = info + 'acc: {:.4f}   avg_acc: {:.4f} +- {:.4f}   '.format(acc[sub, best_run], np.average(acc[sub, :]), acc[sub,:].std() )
184
        info = info + 'kappa: {:.4f}   avg_kappa: {:.4f} +- {:.4f}'.format(kappa[sub, best_run], np.average(kappa[sub, :]), kappa[sub,:].std())
185
        info = info + '\n----------'
186
        print(info)
187
        log_write.write(info+'\n')
188
        # Plot Learning curves 
189
        if (LearnCurves == True):
190
            print('Plot Learning Curves ....... ')
191
            draw_learning_curves(bestTrainingHistory)
192
          
193
    # Get the current 'OUT' time to calculate the overall training time
194
    out_exp = time.time()
195
    info = '\nTime: {:.1f} h   '.format( (out_exp-in_exp)/(60*60) )
196
    print(info)
197
    log_write.write(info+'\n')
198
    
199
    # Store the accuracy and kappa metrics as arrays for all runs into a .npz 
200
    # file format, which is an uncompressed zipped archive, to calculate average
201
    # accuracy/kappa over all runs.
202
    np.savez(perf_allRuns, acc = acc, kappa = kappa)
203
    
204
    # Close open files 
205
    best_models.close()   
206
    log_write.close() 
207
    perf_allRuns.close() 
208
209
210
#%% Evaluation 
211
def test(model, dataset_conf, results_path, allRuns = True):
212
    # Open the  "Log" file to write the evaluation results 
213
    log_write = open(results_path + "/log.txt", "a")
214
    # Open the file that stores the path of the best models among several random runs.
215
    best_models = open(results_path + "/best models.txt", "r")   
216
    
217
    # Get dataset paramters
218
    dataset = dataset_conf.get('name')
219
    n_classes = dataset_conf.get('n_classes')
220
    n_sub = dataset_conf.get('n_sub')
221
    data_path = dataset_conf.get('data_path')
222
    isStandard = dataset_conf.get('isStandard')
223
    LOSO = dataset_conf.get('LOSO')
224
    classes_labels = dataset_conf.get('cl_labels')
225
    
226
    # Initialize variables
227
    acc_bestRun = np.zeros(n_sub)
228
    kappa_bestRun = np.zeros(n_sub)  
229
    cf_matrix = np.zeros([n_sub, n_classes, n_classes])
230
231
    # Calculate the average performance (average accuracy and K-score) for 
232
    # all runs (experiments) for each subject.
233
    if(allRuns): 
234
        # Load the test accuracy and kappa metrics as arrays for all runs from a .npz 
235
        # file format, which is an uncompressed zipped archive, to calculate average
236
        # accuracy/kappa over all runs.
237
        perf_allRuns = open(results_path + "/perf_allRuns.npz", 'rb')
238
        perf_arrays = np.load(perf_allRuns)
239
        acc_allRuns = perf_arrays['acc']
240
        kappa_allRuns = perf_arrays['kappa']
241
    
242
    # Iteration over subjects 
243
    # for sub in range(n_sub-1, n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject.
244
    for sub in range(n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject.
245
        # Load data
246
        _, _, _, X_test, _, y_test_onehot = get_data(data_path, sub, dataset, LOSO, isStandard)      
247
        
248
        # Load the best model out of multiple random runs (experiments).
249
        filepath = best_models.readline()
250
        model.load_weights(results_path + filepath[:-1])
251
        # Predict MI task
252
        y_pred = model.predict(X_test).argmax(axis=-1)
253
        # Calculate accuracy and K-score
254
        labels = y_test_onehot.argmax(axis=-1)
255
        acc_bestRun[sub] = accuracy_score(labels, y_pred)
256
        kappa_bestRun[sub] = cohen_kappa_score(labels, y_pred)
257
        # Calculate and draw confusion matrix
258
        cf_matrix[sub, :, :] = confusion_matrix(labels, y_pred, normalize='true')
259
        draw_confusion_matrix(cf_matrix[sub, :, :], str(sub+1), results_path, classes_labels)
260
        
261
        # Print & write performance measures for each subject
262
        info = 'Subject: {}   best_run: {:2}  '.format(sub+1, (filepath[filepath.find('run-')+4:filepath.find('/sub')]) )
263
        info = info + 'acc: {:.4f}   kappa: {:.4f}   '.format(acc_bestRun[sub], kappa_bestRun[sub] )
264
        if(allRuns): 
265
            info = info + 'avg_acc: {:.4f} +- {:.4f}   avg_kappa: {:.4f} +- {:.4f}'.format(
266
                np.average(acc_allRuns[sub, :]), acc_allRuns[sub,:].std(),
267
                np.average(kappa_allRuns[sub, :]), kappa_allRuns[sub,:].std() )
268
        print(info)
269
        log_write.write('\n'+info)
270
      
271
    # Print & write the average performance measures for all subjects     
272
    info = '\nAverage of {} subjects - best runs:\nAccuracy = {:.4f}   Kappa = {:.4f}\n'.format(
273
        n_sub, np.average(acc_bestRun), np.average(kappa_bestRun)) 
274
    if(allRuns): 
275
        info = info + '\nAverage of {} subjects x {} runs (average of {} experiments):\nAccuracy = {:.4f}   Kappa = {:.4f}'.format(
276
            n_sub, acc_allRuns.shape[1], (n_sub * acc_allRuns.shape[1]),
277
            np.average(acc_allRuns), np.average(kappa_allRuns)) 
278
    print(info)
279
    log_write.write(info)
280
    
281
    # Draw a performance bar chart for all subjects 
282
    draw_performance_barChart(n_sub, acc_bestRun, 'Accuracy')
283
    draw_performance_barChart(n_sub, kappa_bestRun, 'K-score')
284
    # Draw confusion matrix for all subjects (average)
285
    draw_confusion_matrix(cf_matrix.mean(0), 'All', results_path, classes_labels)
286
    # Close open files     
287
    log_write.close() 
288
    
289
    
290
#%%
291
def getModel(model_name, dataset_conf):
292
    
293
    n_classes = dataset_conf.get('n_classes')
294
    n_channels = dataset_conf.get('n_channels')
295
    in_samples = dataset_conf.get('in_samples')
296
297
    # Select the model
298
    if(model_name == 'ATCNet'):
299
        # Train using the proposed ATCNet model: https://doi.org/10.1109/TII.2022.3197419
300
        model = models.ATCNet_( 
301
            # Dataset parameters
302
            n_classes = n_classes, 
303
            in_chans = n_channels, 
304
            in_samples = in_samples, 
305
            # Sliding window (SW) parameter
306
            n_windows = 5, 
307
            # Attention (AT) block parameter
308
            attention = 'mha', # Options: None, 'mha','mhla', 'cbam', 'se'
309
            # Convolutional (CV) block parameters
310
            eegn_F1 = 16,
311
            eegn_D = 2, 
312
            eegn_kernelSize = 64,
313
            eegn_poolSize = 7,
314
            eegn_dropout = 0.3,
315
            # Temporal convolutional (TC) block parameters
316
            tcn_depth = 2, 
317
            tcn_kernelSize = 4,
318
            tcn_filters = 32,
319
            tcn_dropout = 0.3, 
320
            tcn_activation='elu'
321
            )     
322
    elif(model_name == 'TCNet_Fusion'):
323
        # Train using TCNet_Fusion: https://doi.org/10.1016/j.bspc.2021.102826
324
        model = models.TCNet_Fusion(n_classes = n_classes, Chans=n_channels, Samples=in_samples)      
325
    elif(model_name == 'EEGTCNet'):
326
        # Train using EEGTCNet: https://arxiv.org/abs/2006.00622
327
        model = models.EEGTCNet(n_classes = n_classes, Chans=n_channels, Samples=in_samples)          
328
    elif(model_name == 'EEGNet'):
329
        # Train using EEGNet: https://arxiv.org/abs/1611.08024
330
        model = models.EEGNet_classifier(n_classes = n_classes, Chans=n_channels, Samples=in_samples) 
331
    elif(model_name == 'EEGNeX'):
332
        # Train using EEGNeX: https://arxiv.org/abs/2207.12369
333
        model = models.EEGNeX_8_32(n_timesteps = in_samples , n_features = n_channels, n_outputs = n_classes)
334
    elif(model_name == 'DeepConvNet'):
335
        # Train using DeepConvNet: https://doi.org/10.1002/hbm.23730
336
        model = models.DeepConvNet(nb_classes = n_classes , Chans = n_channels, Samples = in_samples)
337
    elif(model_name == 'ShallowConvNet'):
338
        # Train using ShallowConvNet: https://doi.org/10.1002/hbm.23730
339
        model = models.ShallowConvNet(nb_classes = n_classes , Chans = n_channels, Samples = in_samples)
340
    elif(model_name == 'MBEEG_SENet'):
341
        # Train using MBEEG_SENet: https://www.mdpi.com/2075-4418/12/4/995
342
        model = models.MBEEG_SENet(nb_classes = n_classes , Chans = n_channels, Samples = in_samples)
343
344
    else:
345
        raise Exception("'{}' model is not supported yet!".format(model_name))
346
347
    return model
348
    
349
    
350
#%%
351
def run():
352
    # Define dataset parameters
353
    dataset = 'BCI2a' # Options: 'BCI2a','HGD', 'CS2R'
354
    
355
    if dataset == 'BCI2a': 
356
        in_samples = 1125
357
        n_channels = 22
358
        n_sub = 9
359
        n_classes = 4
360
        classes_labels = ['Left hand', 'Right hand','Foot','Tongue']
361
        data_path = os.path.expanduser('~') + '/BCI Competition IV/BCI Competition IV-2a/BCI Competition IV 2a mat/'
362
    elif dataset == 'HGD': 
363
        in_samples = 1125
364
        n_channels = 44
365
        n_sub = 14
366
        n_classes = 4
367
        classes_labels = ['Right Hand', 'Left Hand','Rest','Feet']     
368
        data_path = os.path.expanduser('~') + '/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/'
369
    elif dataset == 'CS2R': 
370
        in_samples = 1125
371
        # in_samples = 576
372
        n_channels = 32
373
        n_sub = 18
374
        n_classes = 3
375
        # classes_labels = ['Fingers', 'Wrist','Elbow','Rest']     
376
        classes_labels = ['Fingers', 'Wrist','Elbow']     
377
        # classes_labels = ['Fingers', 'Elbow']     
378
        data_path = os.path.expanduser('~') + '/CS2R MI EEG dataset/all/EDF - Cleaned - phase one (remove extra runs)/two sessions/'
379
    else:
380
        raise Exception("'{}' dataset is not supported yet!".format(dataset))
381
        
382
    # Create a folder to store the results of the experiment
383
    results_path = os.getcwd() + "/results"
384
    if not  os.path.exists(results_path):
385
      os.makedirs(results_path)   # Create a new directory if it does not exist 
386
      
387
    # Set dataset paramters 
388
    dataset_conf = { 'name': dataset, 'n_classes': n_classes, 'cl_labels': classes_labels,
389
                    'n_sub': n_sub, 'n_channels': n_channels, 'in_samples': in_samples,
390
                    'data_path': data_path, 'isStandard': True, 'LOSO': False}
391
    # Set training hyperparamters
392
    train_conf = { 'batch_size': 64, 'epochs': 1000, 'patience': 300, 'lr': 0.001,
393
                  'LearnCurves': True, 'n_train': 10, 'model':'ATCNet'}
394
           
395
    # Train the model
396
    # train(dataset_conf, train_conf, results_path)
397
398
    # Evaluate the model based on the weights saved in the '/results' folder
399
    model = getModel(train_conf.get('model'), dataset_conf)
400
    test(model, dataset_conf, results_path)    
401
    
402
#%%
403
if __name__ == "__main__":
404
    run()
405