Diff of /main_TrainValTest.py [000000] .. [56a69a]

Switch to unified view

a b/main_TrainValTest.py
1
""" 
2
Copyright (C) 2022 King Saud University, Saudi Arabia 
3
SPDX-License-Identifier: Apache-2.0 
4
5
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
6
this file except in compliance with the License. You may obtain a copy of the 
7
License at
8
9
http://www.apache.org/licenses/LICENSE-2.0  
10
11
Unless required by applicable law or agreed to in writing, software distributed
12
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 
13
CONDITIONS OF ANY KIND, either express or implied. See the License for the
14
specific language governing permissions and limitations under the License. 
15
16
Author:  Hamdi Altaheri 
17
"""
18
19
#%%
20
import os
21
import sys
22
import shutil
23
import time
24
import numpy as np
25
import matplotlib.pyplot as plt
26
import tensorflow as tf
27
28
from tensorflow.keras.optimizers import Adam
29
from tensorflow.keras.losses import CategoricalCrossentropy
30
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
31
from sklearn.metrics import confusion_matrix, accuracy_score, ConfusionMatrixDisplay
32
from sklearn.metrics import cohen_kappa_score
33
from sklearn.model_selection import train_test_split
34
35
import models 
36
from preprocess import get_data
37
# from keras.utils.vis_utils import plot_model
38
39
40
#%%
41
def draw_learning_curves(history, sub):
42
    plt.plot(history.history['accuracy'])
43
    plt.plot(history.history['val_accuracy'])
44
    plt.title('Model accuracy - subject: ' + str(sub))
45
    plt.ylabel('Accuracy')
46
    plt.xlabel('Epoch')
47
    plt.legend(['Train', 'val'], loc='upper left')
48
    plt.show()
49
    plt.plot(history.history['loss'])
50
    plt.plot(history.history['val_loss'])
51
    plt.title('Model loss - subject: ' + str(sub))
52
    plt.ylabel('Loss')
53
    plt.xlabel('Epoch')
54
    plt.legend(['Train', 'val'], loc='upper left')
55
    plt.show()
56
    plt.close()
57
58
def draw_confusion_matrix(cf_matrix, sub, results_path, classes_labels):
59
    # Generate confusion matrix plot
60
    display_labels = classes_labels
61
    disp = ConfusionMatrixDisplay(confusion_matrix=cf_matrix, 
62
                                display_labels=display_labels)
63
    disp.plot()
64
    disp.ax_.set_xticklabels(display_labels, rotation=12)
65
    plt.title('Confusion Matrix of Subject: ' + sub )
66
    plt.savefig(results_path + '/subject_' + sub + '.png')
67
    plt.show()
68
69
def draw_performance_barChart(num_sub, metric, label):
70
    fig, ax = plt.subplots()
71
    x = list(range(1, num_sub+1))
72
    ax.bar(x, metric, 0.5, label=label)
73
    ax.set_ylabel(label)
74
    ax.set_xlabel("Subject")
75
    ax.set_xticks(x)
76
    ax.set_title('Model '+ label + ' per subject')
77
    ax.set_ylim([0,1])
78
    
79
    
80
#%% Training 
81
def train(dataset_conf, train_conf, results_path):
82
    
83
    # remove the 'result' folder before training
84
    if os.path.exists(results_path):
85
        # Remove the folder and its contents
86
        shutil.rmtree(results_path)
87
        os.makedirs(results_path)        
88
89
    # Get the current 'IN' time to calculate the overall training time
90
    in_exp = time.time()
91
    # Create a file to store the path of the best model among several runs
92
    best_models = open(results_path + "/best models.txt", "w")
93
    # Create a file to store performance during training
94
    log_write = open(results_path + "/log.txt", "w")
95
    
96
    # Get dataset parameters
97
    dataset = dataset_conf.get('name')
98
    n_sub = dataset_conf.get('n_sub')
99
    data_path = dataset_conf.get('data_path')
100
    isStandard = dataset_conf.get('isStandard')
101
    LOSO = dataset_conf.get('LOSO')
102
    # Get training hyperparamters
103
    batch_size = train_conf.get('batch_size')
104
    epochs = train_conf.get('epochs')
105
    patience = train_conf.get('patience')
106
    lr = train_conf.get('lr')
107
    LearnCurves = train_conf.get('LearnCurves') # Plot Learning Curves?
108
    n_train = train_conf.get('n_train')
109
    model_name = train_conf.get('model')
110
    from_logits = train_conf.get('from_logits') 
111
112
    # Initialize variables
113
    acc = np.zeros((n_sub, n_train))
114
    kappa = np.zeros((n_sub, n_train))
115
    
116
    # Iteration over subjects 
117
    # for sub in range(n_sub-1, n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject.
118
    for sub in range(n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject.
119
120
        print('\nTraining on subject ', sub+1)
121
        log_write.write( '\nTraining on subject '+ str(sub+1) +'\n')
122
        # Initiating variables to save the best subject accuracy among multiple runs.
123
        BestSubjAcc = 0 
124
        bestTrainingHistory = [] 
125
        
126
        # Get training and validation data
127
        X_train, _, y_train_onehot, _, _, _ = get_data(
128
            data_path, sub, dataset, LOSO = LOSO, isStandard = isStandard)
129
         
130
        # Divide the training data into training and validation
131
        X_train, X_val, y_train_onehot, y_val_onehot = train_test_split(X_train, y_train_onehot, test_size=0.2, random_state=42)       
132
        
133
        # Iteration over multiple runs 
134
        for train in range(n_train): # How many repetitions of training for subject i.
135
            # Set the random seed for TensorFlow and NumPy random number generator. 
136
            # The purpose of setting a seed is to ensure reproducibility in random operations. 
137
            tf.random.set_seed(train+1)
138
            np.random.seed(train+1)
139
            
140
            # Get the current 'IN' time to calculate the 'run' training time
141
            in_run = time.time()
142
            
143
            # Create folders and files to save trained models for all runs
144
            filepath = results_path + '/saved models/run-{}'.format(train+1)
145
            if not os.path.exists(filepath):
146
                os.makedirs(filepath)        
147
            filepath = filepath + '/subject-{}.h5'.format(sub+1)
148
            
149
            # Create the model
150
            model = getModel(model_name, dataset_conf, from_logits)
151
            # Compile and train the model
152
            model.compile(loss=CategoricalCrossentropy(from_logits=from_logits), optimizer=Adam(learning_rate=lr), metrics=['accuracy'])          
153
154
            # model.summary()
155
            # plot_model(model, to_file='plot_model.png', show_shapes=True, show_layer_names=True)
156
            
157
            callbacks = [
158
                ModelCheckpoint(filepath, monitor='val_loss', verbose=0, 
159
                                save_best_only=True, save_weights_only=True, mode='min'),
160
                ReduceLROnPlateau(monitor="val_loss", factor=0.90, patience=20, verbose=0, min_lr=0.0001),  
161
                # EarlyStopping(monitor='val_loss', verbose=1, mode='min', patience=patience)
162
            ]
163
            history = model.fit(X_train, y_train_onehot, validation_data=(X_val, y_val_onehot), 
164
                                epochs=epochs, batch_size=batch_size, callbacks=callbacks, verbose=0)
165
           
166
            # Evaluate the performance of the trained model based on the validation data
167
            # Here we load the Trained weights from the file saved in the hard 
168
            # disk, which should be the same as the weights of the current model.
169
            model.load_weights(filepath)
170
            y_pred = model.predict(X_val)
171
172
            if from_logits:
173
                y_pred = tf.nn.softmax(y_pred).numpy().argmax(axis=-1)
174
            else:
175
                y_pred = y_pred.argmax(axis=-1)
176
                
177
            labels = y_val_onehot.argmax(axis=-1)
178
            acc[sub, train]  = accuracy_score(labels, y_pred)
179
            kappa[sub, train] = cohen_kappa_score(labels, y_pred)
180
                        
181
            # Get the current 'OUT' time to calculate the 'run' training time
182
            out_run = time.time()
183
            # Print & write performance measures for each run
184
            info = 'Subject: {}   seed {}   time: {:.1f} m   '.format(sub+1, train+1, ((out_run-in_run)/60))
185
            info = info + 'valid_acc: {:.4f}   valid_loss: {:.3f}'.format(acc[sub, train], min(history.history['val_loss']))
186
            print(info)
187
            log_write.write(info +'\n')
188
            # If current training run is better than previous runs, save the history.
189
            if(BestSubjAcc < acc[sub, train]):
190
                 BestSubjAcc = acc[sub, train]
191
                 bestTrainingHistory = history
192
        
193
        # Store the path of the best model among several runs
194
        best_run = np.argmax(acc[sub,:])
195
        filepath = '/saved models/run-{}/subject-{}.h5'.format(best_run+1, sub+1)+'\n'
196
        best_models.write(filepath)
197
198
        # Plot Learning curves 
199
        if (LearnCurves == True):
200
            print('Plot Learning Curves ....... ')
201
            draw_learning_curves(bestTrainingHistory, sub+1)
202
          
203
    # Get the current 'OUT' time to calculate the overall training time
204
    out_exp = time.time()
205
           
206
    # Print & write the validation performance using all seeds
207
    head1 = head2 = '         '
208
    for sub in range(n_sub): 
209
        head1 = head1 + 'sub_{}   '.format(sub+1)
210
        head2 = head2 + '-----   '
211
    head1 = head1 + '  average'
212
    head2 = head2 + '  -------'
213
    info = '\n---------------------------------\nValidation performance (acc %):'
214
    info = info + '\n---------------------------------\n' + head1 +'\n'+ head2
215
    for run in range(n_train): 
216
        info = info + '\nSeed {}:  '.format(run+1)
217
        for sub in range(n_sub): 
218
            info = info + '{:.2f}   '.format(acc[sub, run]*100)
219
        info = info + '  {:.2f}   '.format(np.average(acc[:, run])*100)
220
    info = info + '\n---------------------------------\nAverage acc - all seeds: '
221
    info = info + '{:.2f} %\n\nTrain Time  - all seeds: {:.1f}'.format(np.average(acc)*100, (out_exp-in_exp)/(60))
222
    info = info + ' min\n---------------------------------\n'
223
    print(info)
224
    log_write.write(info+'\n')
225
226
    # Close open files 
227
    best_models.close()   
228
    log_write.close() 
229
    
230
    
231
#%% Evaluation 
232
def test(model, dataset_conf, results_path, allRuns = True):
233
    # Open the  "Log" file to write the evaluation results 
234
    log_write = open(results_path + "/log.txt", "a")
235
    
236
    # Get dataset paramters
237
    dataset = dataset_conf.get('name')
238
    n_classes = dataset_conf.get('n_classes')
239
    n_sub = dataset_conf.get('n_sub')
240
    data_path = dataset_conf.get('data_path')
241
    isStandard = dataset_conf.get('isStandard')
242
    LOSO = dataset_conf.get('LOSO')
243
    classes_labels = dataset_conf.get('cl_labels')
244
     
245
    # Test the performance based on several runs (seeds)
246
    runs = os.listdir(results_path+"/saved models") 
247
    # Initialize variables
248
    acc = np.zeros((n_sub, len(runs)))
249
    kappa = np.zeros((n_sub, len(runs)))
250
    cf_matrix = np.zeros([n_sub, len(runs), n_classes, n_classes])
251
252
    # Iteration over subjects 
253
    # for sub in range(n_sub-1, n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject.
254
    inference_time = 0 #  inference_time: classification time for one trial
255
    for sub in range(n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject.
256
        # Load data
257
        _, _, _, X_test, _, y_test_onehot = get_data(data_path, sub, dataset, LOSO = LOSO, isStandard = isStandard)     
258
259
        # Iteration over runs (seeds) 
260
        for seed in range(len(runs)): 
261
            # Load the model of the seed.
262
            model.load_weights('{}/saved models/{}/subject-{}.h5'.format(results_path, runs[seed], sub+1))
263
            
264
            inference_time = time.time()
265
            # Predict MI task
266
            y_pred = model.predict(X_test).argmax(axis=-1)
267
            inference_time = (time.time() - inference_time)/X_test.shape[0]
268
            # Calculate accuracy and K-score          
269
            labels = y_test_onehot.argmax(axis=-1)
270
            acc[sub, seed]  = accuracy_score(labels, y_pred)
271
            kappa[sub, seed] = cohen_kappa_score(labels, y_pred)
272
            # Calculate and draw confusion matrix
273
            cf_matrix[sub, seed, :, :] = confusion_matrix(labels, y_pred, normalize='true')
274
            # draw_confusion_matrix(cf_matrix[sub, seed, :, :], str(sub+1), results_path, classes_labels)
275
        
276
    # Print & write the average performance measures for all subjects     
277
    head1 = head2 = '                  '
278
    for sub in range(n_sub): 
279
        head1 = head1 + 'sub_{}   '.format(sub+1)
280
        head2 = head2 + '-----   '
281
    head1 = head1 + '  average'
282
    head2 = head2 + '  -------'
283
    info = '\n' + head1 +'\n'+ head2
284
    info = '\n---------------------------------\nTest performance (acc & k-score):\n'
285
    info = info + '---------------------------------\n' + head1 +'\n'+ head2
286
    for run in range(len(runs)): 
287
        info = info + '\nSeed {}: '.format(run+1)
288
        info_acc = '(acc %)   '
289
        info_k = '        (k-sco)   '
290
        for sub in range(n_sub): 
291
            info_acc = info_acc + '{:.2f}   '.format(acc[sub, run]*100)
292
            info_k = info_k + '{:.3f}   '.format(kappa[sub, run])
293
        info_acc = info_acc + '  {:.2f}   '.format(np.average(acc[:, run])*100)
294
        info_k = info_k + '  {:.3f}   '.format(np.average(kappa[:, run]))
295
        info = info + info_acc + '\n' + info_k
296
    info = info + '\n----------------------------------\nAverage - all seeds (acc %): '
297
    info = info + '{:.2f}\n                    (k-sco): '.format(np.average(acc)*100)
298
    info = info + '{:.3f}\n\nInference time: {:.2f}'.format(np.average(kappa), inference_time * 1000)
299
    info = info + ' ms per trial\n----------------------------------\n'
300
    print(info)
301
    log_write.write(info+'\n')
302
         
303
    # Draw a performance bar chart for all subjects 
304
    draw_performance_barChart(n_sub, acc.mean(1), 'Accuracy')
305
    draw_performance_barChart(n_sub, kappa.mean(1), 'k-score')
306
    # Draw confusion matrix for all subjects (average)
307
    draw_confusion_matrix(cf_matrix.mean((0,1)), 'All', results_path, classes_labels)
308
    # Close opened file    
309
    log_write.close() 
310
    
311
    
312
#%%
313
def getModel(model_name, dataset_conf, from_logits = False):
314
    
315
    n_classes = dataset_conf.get('n_classes')
316
    n_channels = dataset_conf.get('n_channels')
317
    in_samples = dataset_conf.get('in_samples')
318
319
    # Select the model
320
    if(model_name == 'ATCNet'):
321
        # Train using the proposed ATCNet model: https://ieeexplore.ieee.org/document/9852687
322
        model = models.ATCNet_( 
323
            # Dataset parameters
324
            n_classes = n_classes, 
325
            in_chans = n_channels, 
326
            in_samples = in_samples, 
327
            # Sliding window (SW) parameter
328
            n_windows = 5, 
329
            # Attention (AT) block parameter
330
            attention = 'mha', # Options: None, 'mha','mhla', 'cbam', 'se'
331
            # Convolutional (CV) block parameters
332
            eegn_F1 = 16,
333
            eegn_D = 2, 
334
            eegn_kernelSize = 64,
335
            eegn_poolSize = 7,
336
            eegn_dropout = 0.3,
337
            # Temporal convolutional (TC) block parameters
338
            tcn_depth = 2, 
339
            tcn_kernelSize = 4,
340
            tcn_filters = 32,
341
            tcn_dropout = 0.3, 
342
            tcn_activation='elu',
343
            )     
344
    elif(model_name == 'TCNet_Fusion'):
345
        # Train using TCNet_Fusion: https://doi.org/10.1016/j.bspc.2021.102826
346
        model = models.TCNet_Fusion(n_classes = n_classes, Chans=n_channels, Samples=in_samples)      
347
    elif(model_name == 'EEGTCNet'):
348
        # Train using EEGTCNet: https://arxiv.org/abs/2006.00622
349
        model = models.EEGTCNet(n_classes = n_classes, Chans=n_channels, Samples=in_samples)          
350
    elif(model_name == 'EEGNet'):
351
        # Train using EEGNet: https://arxiv.org/abs/1611.08024
352
        model = models.EEGNet_classifier(n_classes = n_classes, Chans=n_channels, Samples=in_samples) 
353
    elif(model_name == 'EEGNeX'):
354
        # Train using EEGNeX: https://arxiv.org/abs/2207.12369
355
        model = models.EEGNeX_8_32(n_timesteps = in_samples , n_features = n_channels, n_outputs = n_classes)
356
    elif(model_name == 'DeepConvNet'):
357
        # Train using DeepConvNet: https://doi.org/10.1002/hbm.23730
358
        model = models.DeepConvNet(nb_classes = n_classes , Chans = n_channels, Samples = in_samples)
359
    elif(model_name == 'ShallowConvNet'):
360
        # Train using ShallowConvNet: https://doi.org/10.1002/hbm.23730
361
        model = models.ShallowConvNet(nb_classes = n_classes , Chans = n_channels, Samples = in_samples)
362
    elif(model_name == 'MBEEG_SENet'):
363
        # Train using MBEEG_SENet: https://www.mdpi.com/2075-4418/12/4/995
364
        model = models.MBEEG_SENet(nb_classes = n_classes , Chans = n_channels, Samples = in_samples)
365
366
    else:
367
        raise Exception("'{}' model is not supported yet!".format(model_name))
368
369
    return model
370
    
371
#%%
372
def run():
373
    # Define dataset parameters
374
    dataset = 'HGD' # Options: 'BCI2a','HGD', 'CS2R'
375
    
376
    if dataset == 'BCI2a': 
377
        in_samples = 1125
378
        n_channels = 22
379
        n_sub = 9
380
        n_classes = 4
381
        classes_labels = ['Left hand', 'Right hand','Foot','Tongue']
382
        data_path = os.path.expanduser('~') + '/BCI Competition IV/BCI Competition IV-2a/BCI Competition IV 2a mat/'
383
    elif dataset == 'HGD': 
384
        in_samples = 1125
385
        n_channels = 44
386
        n_sub = 14
387
        n_classes = 4
388
        classes_labels = ['Right Hand', 'Left Hand','Rest','Feet']     
389
        data_path = os.path.expanduser('~') + '/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/'
390
    elif dataset == 'CS2R': 
391
        in_samples = 1125
392
        # in_samples = 576
393
        n_channels = 32
394
        n_sub = 18
395
        n_classes = 3
396
        # classes_labels = ['Fingers', 'Wrist','Elbow','Rest']     
397
        classes_labels = ['Fingers', 'Wrist','Elbow']     
398
        # classes_labels = ['Fingers', 'Elbow']     
399
        data_path = os.path.expanduser('~') + '/CS2R MI EEG dataset/all/EDF - Cleaned - phase one (remove extra runs)/two sessions/'
400
    else:
401
        raise Exception("'{}' dataset is not supported yet!".format(dataset))
402
        
403
    # Create a folder to store the results of the experiment
404
    results_path = os.getcwd() + "/results"
405
    if not  os.path.exists(results_path):
406
      os.makedirs(results_path)   # Create a new directory if it does not exist 
407
      
408
    # Set dataset paramters 
409
    dataset_conf = { 'name': dataset, 'n_classes': n_classes, 'cl_labels': classes_labels,
410
                    'n_sub': n_sub, 'n_channels': n_channels, 'in_samples': in_samples,
411
                    'data_path': data_path, 'isStandard': True, 'LOSO': False}
412
    # Set training hyperparamters
413
    train_conf = { 'batch_size': 64, 'epochs': 500, 'patience': 100, 'lr': 0.001,'n_train': 1,
414
                  'LearnCurves': True, 'from_logits': False, 'model':'ATCNet'}
415
           
416
    # Train the model
417
    # train(dataset_conf, train_conf, results_path)
418
419
    # Evaluate the model based on the weights saved in the '/results' folder
420
    model = getModel(train_conf.get('model'), dataset_conf)
421
    test(model, dataset_conf, results_path)    
422
423
#%%
424
if __name__ == "__main__":
425
    run()
426