|
a |
|
b/main_TrainTest.py |
|
|
1 |
""" |
|
|
2 |
Copyright (C) 2022 King Saud University, Saudi Arabia |
|
|
3 |
SPDX-License-Identifier: Apache-2.0 |
|
|
4 |
|
|
|
5 |
Licensed under the Apache License, Version 2.0 (the "License"); you may not use |
|
|
6 |
this file except in compliance with the License. You may obtain a copy of the |
|
|
7 |
License at |
|
|
8 |
|
|
|
9 |
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
10 |
|
|
|
11 |
Unless required by applicable law or agreed to in writing, software distributed |
|
|
12 |
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR |
|
|
13 |
CONDITIONS OF ANY KIND, either express or implied. See the License for the |
|
|
14 |
specific language governing permissions and limitations under the License. |
|
|
15 |
|
|
|
16 |
Author: Hamdi Altaheri |
|
|
17 |
""" |
|
|
18 |
|
|
|
19 |
#%% |
|
|
20 |
import os |
|
|
21 |
import time |
|
|
22 |
import numpy as np |
|
|
23 |
import matplotlib.pyplot as plt |
|
|
24 |
import tensorflow as tf |
|
|
25 |
|
|
|
26 |
from tensorflow.keras.optimizers import Adam |
|
|
27 |
from tensorflow.keras.losses import categorical_crossentropy |
|
|
28 |
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau |
|
|
29 |
from sklearn.metrics import confusion_matrix, accuracy_score, ConfusionMatrixDisplay |
|
|
30 |
from sklearn.metrics import cohen_kappa_score |
|
|
31 |
|
|
|
32 |
import models |
|
|
33 |
from preprocess import get_data |
|
|
34 |
# from keras.utils.vis_utils import plot_model |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
#%% |
|
|
38 |
def draw_learning_curves(history): |
|
|
39 |
plt.plot(history.history['accuracy']) |
|
|
40 |
plt.plot(history.history['val_accuracy']) |
|
|
41 |
plt.title('Model accuracy') |
|
|
42 |
plt.ylabel('Accuracy') |
|
|
43 |
plt.xlabel('Epoch') |
|
|
44 |
plt.legend(['Train', 'val'], loc='upper left') |
|
|
45 |
plt.show() |
|
|
46 |
plt.plot(history.history['loss']) |
|
|
47 |
plt.plot(history.history['val_loss']) |
|
|
48 |
plt.title('Model loss') |
|
|
49 |
plt.ylabel('Loss') |
|
|
50 |
plt.xlabel('Epoch') |
|
|
51 |
plt.legend(['Train', 'val'], loc='upper left') |
|
|
52 |
plt.show() |
|
|
53 |
plt.close() |
|
|
54 |
|
|
|
55 |
def draw_confusion_matrix(cf_matrix, sub, results_path, classes_labels): |
|
|
56 |
# Generate confusion matrix plot |
|
|
57 |
display_labels = classes_labels |
|
|
58 |
disp = ConfusionMatrixDisplay(confusion_matrix=cf_matrix, |
|
|
59 |
display_labels=display_labels) |
|
|
60 |
disp.plot() |
|
|
61 |
disp.ax_.set_xticklabels(display_labels, rotation=12) |
|
|
62 |
plt.title('Confusion Matrix of Subject: ' + sub ) |
|
|
63 |
plt.savefig(results_path + '/subject_' + sub + '.png') |
|
|
64 |
plt.show() |
|
|
65 |
|
|
|
66 |
def draw_performance_barChart(num_sub, metric, label): |
|
|
67 |
fig, ax = plt.subplots() |
|
|
68 |
x = list(range(1, num_sub+1)) |
|
|
69 |
ax.bar(x, metric, 0.5, label=label) |
|
|
70 |
ax.set_ylabel(label) |
|
|
71 |
ax.set_xlabel("Subject") |
|
|
72 |
ax.set_xticks(x) |
|
|
73 |
ax.set_title('Model '+ label + ' per subject') |
|
|
74 |
ax.set_ylim([0,1]) |
|
|
75 |
|
|
|
76 |
|
|
|
77 |
#%% Training |
|
|
78 |
def train(dataset_conf, train_conf, results_path): |
|
|
79 |
# Get the current 'IN' time to calculate the overall training time |
|
|
80 |
in_exp = time.time() |
|
|
81 |
# Create a file to store the path of the best model among several runs |
|
|
82 |
best_models = open(results_path + "/best models.txt", "w") |
|
|
83 |
# Create a file to store performance during training |
|
|
84 |
log_write = open(results_path + "/log.txt", "w") |
|
|
85 |
# Create a .npz file (zipped archive) to store the accuracy and kappa metrics |
|
|
86 |
# for all runs (to calculate average accuracy/kappa over all runs) |
|
|
87 |
perf_allRuns = open(results_path + "/perf_allRuns.npz", 'wb') |
|
|
88 |
|
|
|
89 |
# Get dataset paramters |
|
|
90 |
dataset = dataset_conf.get('name') |
|
|
91 |
n_sub = dataset_conf.get('n_sub') |
|
|
92 |
data_path = dataset_conf.get('data_path') |
|
|
93 |
isStandard = dataset_conf.get('isStandard') |
|
|
94 |
LOSO = dataset_conf.get('LOSO') |
|
|
95 |
# Get training hyperparamters |
|
|
96 |
batch_size = train_conf.get('batch_size') |
|
|
97 |
epochs = train_conf.get('epochs') |
|
|
98 |
patience = train_conf.get('patience') |
|
|
99 |
lr = train_conf.get('lr') |
|
|
100 |
LearnCurves = train_conf.get('LearnCurves') # Plot Learning Curves? |
|
|
101 |
n_train = train_conf.get('n_train') |
|
|
102 |
model_name = train_conf.get('model') |
|
|
103 |
|
|
|
104 |
# Initialize variables |
|
|
105 |
acc = np.zeros((n_sub, n_train)) |
|
|
106 |
kappa = np.zeros((n_sub, n_train)) |
|
|
107 |
|
|
|
108 |
# Iteration over subjects |
|
|
109 |
# for sub in range(n_sub-1, n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject. |
|
|
110 |
for sub in range(n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject. |
|
|
111 |
# Get the current 'IN' time to calculate the subject training time |
|
|
112 |
in_sub = time.time() |
|
|
113 |
print('\nTraining on subject ', sub+1) |
|
|
114 |
log_write.write( '\nTraining on subject '+ str(sub+1) +'\n') |
|
|
115 |
# Initiating variables to save the best subject accuracy among multiple runs. |
|
|
116 |
BestSubjAcc = 0 |
|
|
117 |
bestTrainingHistory = [] |
|
|
118 |
# Get training and test data |
|
|
119 |
X_train, _, y_train_onehot, X_test, _, y_test_onehot = get_data( |
|
|
120 |
data_path, sub, dataset, LOSO = LOSO, isStandard = isStandard) |
|
|
121 |
|
|
|
122 |
# Iteration over multiple runs |
|
|
123 |
for train in range(n_train): # How many repetitions of training for subject i. |
|
|
124 |
# Get the current 'IN' time to calculate the 'run' training time |
|
|
125 |
tf.random.set_seed(train+1) |
|
|
126 |
np.random.seed(train+1) |
|
|
127 |
|
|
|
128 |
in_run = time.time() |
|
|
129 |
# Create folders and files to save trained models for all runs |
|
|
130 |
filepath = results_path + '/saved models/run-{}'.format(train+1) |
|
|
131 |
if not os.path.exists(filepath): |
|
|
132 |
os.makedirs(filepath) |
|
|
133 |
filepath = filepath + '/subject-{}.h5'.format(sub+1) |
|
|
134 |
|
|
|
135 |
# Create the model |
|
|
136 |
model = getModel(model_name, dataset_conf) |
|
|
137 |
# Compile and train the model |
|
|
138 |
model.compile(loss=categorical_crossentropy, optimizer=Adam(learning_rate=lr), metrics=['accuracy']) |
|
|
139 |
# model.summary() |
|
|
140 |
# plot_model(model, to_file='plot_model.png', show_shapes=True, show_layer_names=True) |
|
|
141 |
|
|
|
142 |
callbacks = [ |
|
|
143 |
ModelCheckpoint(filepath, monitor='val_accuracy', verbose=0, |
|
|
144 |
save_best_only=True, save_weights_only=True, mode='max'), |
|
|
145 |
|
|
|
146 |
ReduceLROnPlateau(monitor="val_loss", factor=0.90, patience=20, verbose=1, min_lr=0.0001), |
|
|
147 |
|
|
|
148 |
EarlyStopping(monitor='val_accuracy', verbose=1, mode='max', patience=patience) |
|
|
149 |
] |
|
|
150 |
history = model.fit(X_train, y_train_onehot, validation_data=(X_test, y_test_onehot), |
|
|
151 |
epochs=epochs, batch_size=batch_size, callbacks=callbacks, verbose=0) |
|
|
152 |
|
|
|
153 |
# Evaluate the performance of the trained model. |
|
|
154 |
# Here we load the Trained weights from the file saved in the hard |
|
|
155 |
# disk, which should be the same as the weights of the current model. |
|
|
156 |
model.load_weights(filepath) |
|
|
157 |
y_pred = model.predict(X_test).argmax(axis=-1) |
|
|
158 |
labels = y_test_onehot.argmax(axis=-1) |
|
|
159 |
acc[sub, train] = accuracy_score(labels, y_pred) |
|
|
160 |
kappa[sub, train] = cohen_kappa_score(labels, y_pred) |
|
|
161 |
|
|
|
162 |
# Get the current 'OUT' time to calculate the 'run' training time |
|
|
163 |
out_run = time.time() |
|
|
164 |
# Print & write performance measures for each run |
|
|
165 |
info = 'Subject: {} Train no. {} Time: {:.1f} m '.format(sub+1, train+1, ((out_run-in_run)/60)) |
|
|
166 |
info = info + 'Test_acc: {:.4f} Test_kappa: {:.4f}'.format(acc[sub, train], kappa[sub, train]) |
|
|
167 |
print(info) |
|
|
168 |
log_write.write(info +'\n') |
|
|
169 |
# If current training run is better than previous runs, save the history. |
|
|
170 |
if(BestSubjAcc < acc[sub, train]): |
|
|
171 |
BestSubjAcc = acc[sub, train] |
|
|
172 |
bestTrainingHistory = history |
|
|
173 |
|
|
|
174 |
# Store the path of the best model among several runs |
|
|
175 |
best_run = np.argmax(acc[sub,:]) |
|
|
176 |
filepath = '/saved models/run-{}/subject-{}.h5'.format(best_run+1, sub+1)+'\n' |
|
|
177 |
best_models.write(filepath) |
|
|
178 |
# Get the current 'OUT' time to calculate the subject training time |
|
|
179 |
out_sub = time.time() |
|
|
180 |
# Print & write the best subject performance among multiple runs |
|
|
181 |
info = '----------\n' |
|
|
182 |
info = info + 'Subject: {} best_run: {} Time: {:.1f} m '.format(sub+1, best_run+1, ((out_sub-in_sub)/60)) |
|
|
183 |
info = info + 'acc: {:.4f} avg_acc: {:.4f} +- {:.4f} '.format(acc[sub, best_run], np.average(acc[sub, :]), acc[sub,:].std() ) |
|
|
184 |
info = info + 'kappa: {:.4f} avg_kappa: {:.4f} +- {:.4f}'.format(kappa[sub, best_run], np.average(kappa[sub, :]), kappa[sub,:].std()) |
|
|
185 |
info = info + '\n----------' |
|
|
186 |
print(info) |
|
|
187 |
log_write.write(info+'\n') |
|
|
188 |
# Plot Learning curves |
|
|
189 |
if (LearnCurves == True): |
|
|
190 |
print('Plot Learning Curves ....... ') |
|
|
191 |
draw_learning_curves(bestTrainingHistory) |
|
|
192 |
|
|
|
193 |
# Get the current 'OUT' time to calculate the overall training time |
|
|
194 |
out_exp = time.time() |
|
|
195 |
info = '\nTime: {:.1f} h '.format( (out_exp-in_exp)/(60*60) ) |
|
|
196 |
print(info) |
|
|
197 |
log_write.write(info+'\n') |
|
|
198 |
|
|
|
199 |
# Store the accuracy and kappa metrics as arrays for all runs into a .npz |
|
|
200 |
# file format, which is an uncompressed zipped archive, to calculate average |
|
|
201 |
# accuracy/kappa over all runs. |
|
|
202 |
np.savez(perf_allRuns, acc = acc, kappa = kappa) |
|
|
203 |
|
|
|
204 |
# Close open files |
|
|
205 |
best_models.close() |
|
|
206 |
log_write.close() |
|
|
207 |
perf_allRuns.close() |
|
|
208 |
|
|
|
209 |
|
|
|
210 |
#%% Evaluation |
|
|
211 |
def test(model, dataset_conf, results_path, allRuns = True): |
|
|
212 |
# Open the "Log" file to write the evaluation results |
|
|
213 |
log_write = open(results_path + "/log.txt", "a") |
|
|
214 |
# Open the file that stores the path of the best models among several random runs. |
|
|
215 |
best_models = open(results_path + "/best models.txt", "r") |
|
|
216 |
|
|
|
217 |
# Get dataset paramters |
|
|
218 |
dataset = dataset_conf.get('name') |
|
|
219 |
n_classes = dataset_conf.get('n_classes') |
|
|
220 |
n_sub = dataset_conf.get('n_sub') |
|
|
221 |
data_path = dataset_conf.get('data_path') |
|
|
222 |
isStandard = dataset_conf.get('isStandard') |
|
|
223 |
LOSO = dataset_conf.get('LOSO') |
|
|
224 |
classes_labels = dataset_conf.get('cl_labels') |
|
|
225 |
|
|
|
226 |
# Initialize variables |
|
|
227 |
acc_bestRun = np.zeros(n_sub) |
|
|
228 |
kappa_bestRun = np.zeros(n_sub) |
|
|
229 |
cf_matrix = np.zeros([n_sub, n_classes, n_classes]) |
|
|
230 |
|
|
|
231 |
# Calculate the average performance (average accuracy and K-score) for |
|
|
232 |
# all runs (experiments) for each subject. |
|
|
233 |
if(allRuns): |
|
|
234 |
# Load the test accuracy and kappa metrics as arrays for all runs from a .npz |
|
|
235 |
# file format, which is an uncompressed zipped archive, to calculate average |
|
|
236 |
# accuracy/kappa over all runs. |
|
|
237 |
perf_allRuns = open(results_path + "/perf_allRuns.npz", 'rb') |
|
|
238 |
perf_arrays = np.load(perf_allRuns) |
|
|
239 |
acc_allRuns = perf_arrays['acc'] |
|
|
240 |
kappa_allRuns = perf_arrays['kappa'] |
|
|
241 |
|
|
|
242 |
# Iteration over subjects |
|
|
243 |
# for sub in range(n_sub-1, n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject. |
|
|
244 |
for sub in range(n_sub): # (num_sub): for all subjects, (i-1,i): for the ith subject. |
|
|
245 |
# Load data |
|
|
246 |
_, _, _, X_test, _, y_test_onehot = get_data(data_path, sub, dataset, LOSO, isStandard) |
|
|
247 |
|
|
|
248 |
# Load the best model out of multiple random runs (experiments). |
|
|
249 |
filepath = best_models.readline() |
|
|
250 |
model.load_weights(results_path + filepath[:-1]) |
|
|
251 |
# Predict MI task |
|
|
252 |
y_pred = model.predict(X_test).argmax(axis=-1) |
|
|
253 |
# Calculate accuracy and K-score |
|
|
254 |
labels = y_test_onehot.argmax(axis=-1) |
|
|
255 |
acc_bestRun[sub] = accuracy_score(labels, y_pred) |
|
|
256 |
kappa_bestRun[sub] = cohen_kappa_score(labels, y_pred) |
|
|
257 |
# Calculate and draw confusion matrix |
|
|
258 |
cf_matrix[sub, :, :] = confusion_matrix(labels, y_pred, normalize='true') |
|
|
259 |
draw_confusion_matrix(cf_matrix[sub, :, :], str(sub+1), results_path, classes_labels) |
|
|
260 |
|
|
|
261 |
# Print & write performance measures for each subject |
|
|
262 |
info = 'Subject: {} best_run: {:2} '.format(sub+1, (filepath[filepath.find('run-')+4:filepath.find('/sub')]) ) |
|
|
263 |
info = info + 'acc: {:.4f} kappa: {:.4f} '.format(acc_bestRun[sub], kappa_bestRun[sub] ) |
|
|
264 |
if(allRuns): |
|
|
265 |
info = info + 'avg_acc: {:.4f} +- {:.4f} avg_kappa: {:.4f} +- {:.4f}'.format( |
|
|
266 |
np.average(acc_allRuns[sub, :]), acc_allRuns[sub,:].std(), |
|
|
267 |
np.average(kappa_allRuns[sub, :]), kappa_allRuns[sub,:].std() ) |
|
|
268 |
print(info) |
|
|
269 |
log_write.write('\n'+info) |
|
|
270 |
|
|
|
271 |
# Print & write the average performance measures for all subjects |
|
|
272 |
info = '\nAverage of {} subjects - best runs:\nAccuracy = {:.4f} Kappa = {:.4f}\n'.format( |
|
|
273 |
n_sub, np.average(acc_bestRun), np.average(kappa_bestRun)) |
|
|
274 |
if(allRuns): |
|
|
275 |
info = info + '\nAverage of {} subjects x {} runs (average of {} experiments):\nAccuracy = {:.4f} Kappa = {:.4f}'.format( |
|
|
276 |
n_sub, acc_allRuns.shape[1], (n_sub * acc_allRuns.shape[1]), |
|
|
277 |
np.average(acc_allRuns), np.average(kappa_allRuns)) |
|
|
278 |
print(info) |
|
|
279 |
log_write.write(info) |
|
|
280 |
|
|
|
281 |
# Draw a performance bar chart for all subjects |
|
|
282 |
draw_performance_barChart(n_sub, acc_bestRun, 'Accuracy') |
|
|
283 |
draw_performance_barChart(n_sub, kappa_bestRun, 'K-score') |
|
|
284 |
# Draw confusion matrix for all subjects (average) |
|
|
285 |
draw_confusion_matrix(cf_matrix.mean(0), 'All', results_path, classes_labels) |
|
|
286 |
# Close open files |
|
|
287 |
log_write.close() |
|
|
288 |
|
|
|
289 |
|
|
|
290 |
#%% |
|
|
291 |
def getModel(model_name, dataset_conf): |
|
|
292 |
|
|
|
293 |
n_classes = dataset_conf.get('n_classes') |
|
|
294 |
n_channels = dataset_conf.get('n_channels') |
|
|
295 |
in_samples = dataset_conf.get('in_samples') |
|
|
296 |
|
|
|
297 |
# Select the model |
|
|
298 |
if(model_name == 'ATCNet'): |
|
|
299 |
# Train using the proposed ATCNet model: https://doi.org/10.1109/TII.2022.3197419 |
|
|
300 |
model = models.ATCNet_( |
|
|
301 |
# Dataset parameters |
|
|
302 |
n_classes = n_classes, |
|
|
303 |
in_chans = n_channels, |
|
|
304 |
in_samples = in_samples, |
|
|
305 |
# Sliding window (SW) parameter |
|
|
306 |
n_windows = 5, |
|
|
307 |
# Attention (AT) block parameter |
|
|
308 |
attention = 'mha', # Options: None, 'mha','mhla', 'cbam', 'se' |
|
|
309 |
# Convolutional (CV) block parameters |
|
|
310 |
eegn_F1 = 16, |
|
|
311 |
eegn_D = 2, |
|
|
312 |
eegn_kernelSize = 64, |
|
|
313 |
eegn_poolSize = 7, |
|
|
314 |
eegn_dropout = 0.3, |
|
|
315 |
# Temporal convolutional (TC) block parameters |
|
|
316 |
tcn_depth = 2, |
|
|
317 |
tcn_kernelSize = 4, |
|
|
318 |
tcn_filters = 32, |
|
|
319 |
tcn_dropout = 0.3, |
|
|
320 |
tcn_activation='elu' |
|
|
321 |
) |
|
|
322 |
elif(model_name == 'TCNet_Fusion'): |
|
|
323 |
# Train using TCNet_Fusion: https://doi.org/10.1016/j.bspc.2021.102826 |
|
|
324 |
model = models.TCNet_Fusion(n_classes = n_classes, Chans=n_channels, Samples=in_samples) |
|
|
325 |
elif(model_name == 'EEGTCNet'): |
|
|
326 |
# Train using EEGTCNet: https://arxiv.org/abs/2006.00622 |
|
|
327 |
model = models.EEGTCNet(n_classes = n_classes, Chans=n_channels, Samples=in_samples) |
|
|
328 |
elif(model_name == 'EEGNet'): |
|
|
329 |
# Train using EEGNet: https://arxiv.org/abs/1611.08024 |
|
|
330 |
model = models.EEGNet_classifier(n_classes = n_classes, Chans=n_channels, Samples=in_samples) |
|
|
331 |
elif(model_name == 'EEGNeX'): |
|
|
332 |
# Train using EEGNeX: https://arxiv.org/abs/2207.12369 |
|
|
333 |
model = models.EEGNeX_8_32(n_timesteps = in_samples , n_features = n_channels, n_outputs = n_classes) |
|
|
334 |
elif(model_name == 'DeepConvNet'): |
|
|
335 |
# Train using DeepConvNet: https://doi.org/10.1002/hbm.23730 |
|
|
336 |
model = models.DeepConvNet(nb_classes = n_classes , Chans = n_channels, Samples = in_samples) |
|
|
337 |
elif(model_name == 'ShallowConvNet'): |
|
|
338 |
# Train using ShallowConvNet: https://doi.org/10.1002/hbm.23730 |
|
|
339 |
model = models.ShallowConvNet(nb_classes = n_classes , Chans = n_channels, Samples = in_samples) |
|
|
340 |
elif(model_name == 'MBEEG_SENet'): |
|
|
341 |
# Train using MBEEG_SENet: https://www.mdpi.com/2075-4418/12/4/995 |
|
|
342 |
model = models.MBEEG_SENet(nb_classes = n_classes , Chans = n_channels, Samples = in_samples) |
|
|
343 |
|
|
|
344 |
else: |
|
|
345 |
raise Exception("'{}' model is not supported yet!".format(model_name)) |
|
|
346 |
|
|
|
347 |
return model |
|
|
348 |
|
|
|
349 |
|
|
|
350 |
#%% |
|
|
351 |
def run(): |
|
|
352 |
# Define dataset parameters |
|
|
353 |
dataset = 'BCI2a' # Options: 'BCI2a','HGD', 'CS2R' |
|
|
354 |
|
|
|
355 |
if dataset == 'BCI2a': |
|
|
356 |
in_samples = 1125 |
|
|
357 |
n_channels = 22 |
|
|
358 |
n_sub = 9 |
|
|
359 |
n_classes = 4 |
|
|
360 |
classes_labels = ['Left hand', 'Right hand','Foot','Tongue'] |
|
|
361 |
data_path = os.path.expanduser('~') + '/BCI Competition IV/BCI Competition IV-2a/BCI Competition IV 2a mat/' |
|
|
362 |
elif dataset == 'HGD': |
|
|
363 |
in_samples = 1125 |
|
|
364 |
n_channels = 44 |
|
|
365 |
n_sub = 14 |
|
|
366 |
n_classes = 4 |
|
|
367 |
classes_labels = ['Right Hand', 'Left Hand','Rest','Feet'] |
|
|
368 |
data_path = os.path.expanduser('~') + '/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/' |
|
|
369 |
elif dataset == 'CS2R': |
|
|
370 |
in_samples = 1125 |
|
|
371 |
# in_samples = 576 |
|
|
372 |
n_channels = 32 |
|
|
373 |
n_sub = 18 |
|
|
374 |
n_classes = 3 |
|
|
375 |
# classes_labels = ['Fingers', 'Wrist','Elbow','Rest'] |
|
|
376 |
classes_labels = ['Fingers', 'Wrist','Elbow'] |
|
|
377 |
# classes_labels = ['Fingers', 'Elbow'] |
|
|
378 |
data_path = os.path.expanduser('~') + '/CS2R MI EEG dataset/all/EDF - Cleaned - phase one (remove extra runs)/two sessions/' |
|
|
379 |
else: |
|
|
380 |
raise Exception("'{}' dataset is not supported yet!".format(dataset)) |
|
|
381 |
|
|
|
382 |
# Create a folder to store the results of the experiment |
|
|
383 |
results_path = os.getcwd() + "/results" |
|
|
384 |
if not os.path.exists(results_path): |
|
|
385 |
os.makedirs(results_path) # Create a new directory if it does not exist |
|
|
386 |
|
|
|
387 |
# Set dataset paramters |
|
|
388 |
dataset_conf = { 'name': dataset, 'n_classes': n_classes, 'cl_labels': classes_labels, |
|
|
389 |
'n_sub': n_sub, 'n_channels': n_channels, 'in_samples': in_samples, |
|
|
390 |
'data_path': data_path, 'isStandard': True, 'LOSO': False} |
|
|
391 |
# Set training hyperparamters |
|
|
392 |
train_conf = { 'batch_size': 64, 'epochs': 1000, 'patience': 300, 'lr': 0.001, |
|
|
393 |
'LearnCurves': True, 'n_train': 10, 'model':'ATCNet'} |
|
|
394 |
|
|
|
395 |
# Train the model |
|
|
396 |
# train(dataset_conf, train_conf, results_path) |
|
|
397 |
|
|
|
398 |
# Evaluate the model based on the weights saved in the '/results' folder |
|
|
399 |
model = getModel(train_conf.get('model'), dataset_conf) |
|
|
400 |
test(model, dataset_conf, results_path) |
|
|
401 |
|
|
|
402 |
#%% |
|
|
403 |
if __name__ == "__main__": |
|
|
404 |
run() |
|
|
405 |
|