|
a |
|
b/examples/ERP.py |
|
|
1 |
""" |
|
|
2 |
Sample script using EEGNet to classify Event-Related Potential (ERP) EEG data |
|
|
3 |
from a four-class classification task, using the sample dataset provided in |
|
|
4 |
the MNE [1, 2] package: |
|
|
5 |
https://martinos.org/mne/stable/manual/sample_dataset.html#ch-sample-data |
|
|
6 |
|
|
|
7 |
The four classes used from this dataset are: |
|
|
8 |
LA: Left-ear auditory stimulation |
|
|
9 |
RA: Right-ear auditory stimulation |
|
|
10 |
LV: Left visual field stimulation |
|
|
11 |
RV: Right visual field stimulation |
|
|
12 |
|
|
|
13 |
The code to process, filter and epoch the data are originally from Alexandre |
|
|
14 |
Barachant's PyRiemann [3] package, released under the BSD 3-clause. A copy of |
|
|
15 |
the BSD 3-clause license has been provided together with this software to |
|
|
16 |
comply with software licensing requirements. |
|
|
17 |
|
|
|
18 |
When you first run this script, MNE will download the dataset and prompt you |
|
|
19 |
to confirm the download location (defaults to ~/mne_data). Follow the prompts |
|
|
20 |
to continue. The dataset size is approx. 1.5GB download. |
|
|
21 |
|
|
|
22 |
For comparative purposes you can also compare EEGNet performance to using |
|
|
23 |
Riemannian geometric approaches with xDAWN spatial filtering [4-8] using |
|
|
24 |
PyRiemann (code provided below). |
|
|
25 |
|
|
|
26 |
[1] A. Gramfort, M. Luessi, E. Larson, D. Engemann, D. Strohmeier, C. Brodbeck, |
|
|
27 |
L. Parkkonen, M. Hämäläinen, MNE software for processing MEG and EEG data, |
|
|
28 |
NeuroImage, Volume 86, 1 February 2014, Pages 446-460, ISSN 1053-8119. |
|
|
29 |
|
|
|
30 |
[2] A. Gramfort, M. Luessi, E. Larson, D. Engemann, D. Strohmeier, C. Brodbeck, |
|
|
31 |
R. Goj, M. Jas, T. Brooks, L. Parkkonen, M. Hämäläinen, MEG and EEG data |
|
|
32 |
analysis with MNE-Python, Frontiers in Neuroscience, Volume 7, 2013. |
|
|
33 |
|
|
|
34 |
[3] https://github.com/alexandrebarachant/pyRiemann. |
|
|
35 |
|
|
|
36 |
[4] A. Barachant, M. Congedo ,"A Plug&Play P300 BCI Using Information Geometry" |
|
|
37 |
arXiv:1409.0107. link |
|
|
38 |
|
|
|
39 |
[5] M. Congedo, A. Barachant, A. Andreev ,"A New generation of Brain-Computer |
|
|
40 |
Interface Based on Riemannian Geometry", arXiv: 1310.8115. |
|
|
41 |
|
|
|
42 |
[6] A. Barachant and S. Bonnet, "Channel selection procedure using riemannian |
|
|
43 |
distance for BCI applications," in 2011 5th International IEEE/EMBS |
|
|
44 |
Conference on Neural Engineering (NER), 2011, 348-351. |
|
|
45 |
|
|
|
46 |
[7] A. Barachant, S. Bonnet, M. Congedo and C. Jutten, “Multiclass |
|
|
47 |
Brain-Computer Interface Classification by Riemannian Geometry,” in IEEE |
|
|
48 |
Transactions on Biomedical Engineering, vol. 59, no. 4, p. 920-928, 2012. |
|
|
49 |
|
|
|
50 |
[8] A. Barachant, S. Bonnet, M. Congedo and C. Jutten, “Classification of |
|
|
51 |
covariance matrices using a Riemannian-based kernel for BCI applications“, |
|
|
52 |
in NeuroComputing, vol. 112, p. 172-178, 2013. |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
Portions of this project are works of the United States Government and are not |
|
|
56 |
subject to domestic copyright protection under 17 USC Sec. 105. Those |
|
|
57 |
portions are released world-wide under the terms of the Creative Commons Zero |
|
|
58 |
1.0 (CC0) license. |
|
|
59 |
|
|
|
60 |
Other portions of this project are subject to domestic copyright protection |
|
|
61 |
under 17 USC Sec. 105. Those portions are licensed under the Apache 2.0 |
|
|
62 |
license. The complete text of the license governing this material is in |
|
|
63 |
the file labeled LICENSE.TXT that is a part of this project's official |
|
|
64 |
distribution. |
|
|
65 |
""" |
|
|
66 |
|
|
|
67 |
import numpy as np |
|
|
68 |
|
|
|
69 |
# mne imports |
|
|
70 |
import mne |
|
|
71 |
from mne import io |
|
|
72 |
from mne.datasets import sample |
|
|
73 |
|
|
|
74 |
# EEGNet-specific imports |
|
|
75 |
from EEGModels import EEGNet |
|
|
76 |
from tensorflow.keras import utils as np_utils |
|
|
77 |
from tensorflow.keras.callbacks import ModelCheckpoint |
|
|
78 |
from tensorflow.keras import backend as K |
|
|
79 |
|
|
|
80 |
# PyRiemann imports |
|
|
81 |
from pyriemann.estimation import XdawnCovariances |
|
|
82 |
from pyriemann.tangentspace import TangentSpace |
|
|
83 |
from pyriemann.utils.viz import plot_confusion_matrix |
|
|
84 |
from sklearn.pipeline import make_pipeline |
|
|
85 |
from sklearn.linear_model import LogisticRegression |
|
|
86 |
|
|
|
87 |
# tools for plotting confusion matrices |
|
|
88 |
from matplotlib import pyplot as plt |
|
|
89 |
|
|
|
90 |
# while the default tensorflow ordering is 'channels_last' we set it here |
|
|
91 |
# to be explicit in case if the user has changed the default ordering |
|
|
92 |
K.set_image_data_format('channels_last') |
|
|
93 |
|
|
|
94 |
##################### Process, filter and epoch the data ###################### |
|
|
95 |
data_path = sample.data_path() |
|
|
96 |
|
|
|
97 |
# Set parameters and read data |
|
|
98 |
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif' |
|
|
99 |
event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif' |
|
|
100 |
tmin, tmax = -0., 1 |
|
|
101 |
event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4) |
|
|
102 |
|
|
|
103 |
# Setup for reading the raw data |
|
|
104 |
raw = io.Raw(raw_fname, preload=True, verbose=False) |
|
|
105 |
raw.filter(2, None, method='iir') # replace baselining with high-pass |
|
|
106 |
events = mne.read_events(event_fname) |
|
|
107 |
|
|
|
108 |
raw.info['bads'] = ['MEG 2443'] # set bad channels |
|
|
109 |
picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, |
|
|
110 |
exclude='bads') |
|
|
111 |
|
|
|
112 |
# Read epochs |
|
|
113 |
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False, |
|
|
114 |
picks=picks, baseline=None, preload=True, verbose=False) |
|
|
115 |
labels = epochs.events[:, -1] |
|
|
116 |
|
|
|
117 |
# extract raw data. scale by 1000 due to scaling sensitivity in deep learning |
|
|
118 |
X = epochs.get_data()*1000 # format is in (trials, channels, samples) |
|
|
119 |
y = labels |
|
|
120 |
|
|
|
121 |
kernels, chans, samples = 1, 60, 151 |
|
|
122 |
|
|
|
123 |
# take 50/25/25 percent of the data to train/validate/test |
|
|
124 |
X_train = X[0:144,] |
|
|
125 |
Y_train = y[0:144] |
|
|
126 |
X_validate = X[144:216,] |
|
|
127 |
Y_validate = y[144:216] |
|
|
128 |
X_test = X[216:,] |
|
|
129 |
Y_test = y[216:] |
|
|
130 |
|
|
|
131 |
############################# EEGNet portion ################################## |
|
|
132 |
|
|
|
133 |
# convert labels to one-hot encodings. |
|
|
134 |
Y_train = np_utils.to_categorical(Y_train-1) |
|
|
135 |
Y_validate = np_utils.to_categorical(Y_validate-1) |
|
|
136 |
Y_test = np_utils.to_categorical(Y_test-1) |
|
|
137 |
|
|
|
138 |
# convert data to NHWC (trials, channels, samples, kernels) format. Data |
|
|
139 |
# contains 60 channels and 151 time-points. Set the number of kernels to 1. |
|
|
140 |
X_train = X_train.reshape(X_train.shape[0], chans, samples, kernels) |
|
|
141 |
X_validate = X_validate.reshape(X_validate.shape[0], chans, samples, kernels) |
|
|
142 |
X_test = X_test.reshape(X_test.shape[0], chans, samples, kernels) |
|
|
143 |
|
|
|
144 |
print('X_train shape:', X_train.shape) |
|
|
145 |
print(X_train.shape[0], 'train samples') |
|
|
146 |
print(X_test.shape[0], 'test samples') |
|
|
147 |
|
|
|
148 |
# configure the EEGNet-8,2,16 model with kernel length of 32 samples (other |
|
|
149 |
# model configurations may do better, but this is a good starting point) |
|
|
150 |
model = EEGNet(nb_classes = 4, Chans = chans, Samples = samples, |
|
|
151 |
dropoutRate = 0.5, kernLength = 32, F1 = 8, D = 2, F2 = 16, |
|
|
152 |
dropoutType = 'Dropout') |
|
|
153 |
|
|
|
154 |
# compile the model and set the optimizers |
|
|
155 |
model.compile(loss='categorical_crossentropy', optimizer='adam', |
|
|
156 |
metrics = ['accuracy']) |
|
|
157 |
|
|
|
158 |
# count number of parameters in the model |
|
|
159 |
numParams = model.count_params() |
|
|
160 |
|
|
|
161 |
# set a valid path for your system to record model checkpoints |
|
|
162 |
checkpointer = ModelCheckpoint(filepath='/tmp/checkpoint.h5', verbose=1, |
|
|
163 |
save_best_only=True) |
|
|
164 |
|
|
|
165 |
############################################################################### |
|
|
166 |
# if the classification task was imbalanced (significantly more trials in one |
|
|
167 |
# class versus the others) you can assign a weight to each class during |
|
|
168 |
# optimization to balance it out. This data is approximately balanced so we |
|
|
169 |
# don't need to do this, but is shown here for illustration/completeness. |
|
|
170 |
############################################################################### |
|
|
171 |
|
|
|
172 |
# the syntax is {class_1:weight_1, class_2:weight_2,...}. Here just setting |
|
|
173 |
# the weights all to be 1 |
|
|
174 |
class_weights = {0:1, 1:1, 2:1, 3:1} |
|
|
175 |
|
|
|
176 |
################################################################################ |
|
|
177 |
# fit the model. Due to very small sample sizes this can get |
|
|
178 |
# pretty noisy run-to-run, but most runs should be comparable to xDAWN + |
|
|
179 |
# Riemannian geometry classification (below) |
|
|
180 |
################################################################################ |
|
|
181 |
fittedModel = model.fit(X_train, Y_train, batch_size = 16, epochs = 300, |
|
|
182 |
verbose = 2, validation_data=(X_validate, Y_validate), |
|
|
183 |
callbacks=[checkpointer], class_weight = class_weights) |
|
|
184 |
|
|
|
185 |
# load optimal weights |
|
|
186 |
model.load_weights('/tmp/checkpoint.h5') |
|
|
187 |
|
|
|
188 |
############################################################################### |
|
|
189 |
# can alternatively used the weights provided in the repo. If so it should get |
|
|
190 |
# you 93% accuracy. Change the WEIGHTS_PATH variable to wherever it is on your |
|
|
191 |
# system. |
|
|
192 |
############################################################################### |
|
|
193 |
|
|
|
194 |
# WEIGHTS_PATH = /path/to/EEGNet-8-2-weights.h5 |
|
|
195 |
# model.load_weights(WEIGHTS_PATH) |
|
|
196 |
|
|
|
197 |
############################################################################### |
|
|
198 |
# make prediction on test set. |
|
|
199 |
############################################################################### |
|
|
200 |
|
|
|
201 |
probs = model.predict(X_test) |
|
|
202 |
preds = probs.argmax(axis = -1) |
|
|
203 |
acc = np.mean(preds == Y_test.argmax(axis=-1)) |
|
|
204 |
print("Classification accuracy: %f " % (acc)) |
|
|
205 |
|
|
|
206 |
|
|
|
207 |
############################# PyRiemann Portion ############################## |
|
|
208 |
|
|
|
209 |
# code is taken from PyRiemann's ERP sample script, which is decoding in |
|
|
210 |
# the tangent space with a logistic regression |
|
|
211 |
|
|
|
212 |
n_components = 2 # pick some components |
|
|
213 |
|
|
|
214 |
# set up sklearn pipeline |
|
|
215 |
clf = make_pipeline(XdawnCovariances(n_components), |
|
|
216 |
TangentSpace(metric='riemann'), |
|
|
217 |
LogisticRegression()) |
|
|
218 |
|
|
|
219 |
preds_rg = np.zeros(len(Y_test)) |
|
|
220 |
|
|
|
221 |
# reshape back to (trials, channels, samples) |
|
|
222 |
X_train = X_train.reshape(X_train.shape[0], chans, samples) |
|
|
223 |
X_test = X_test.reshape(X_test.shape[0], chans, samples) |
|
|
224 |
|
|
|
225 |
# train a classifier with xDAWN spatial filtering + Riemannian Geometry (RG) |
|
|
226 |
# labels need to be back in single-column format |
|
|
227 |
clf.fit(X_train, Y_train.argmax(axis = -1)) |
|
|
228 |
preds_rg = clf.predict(X_test) |
|
|
229 |
|
|
|
230 |
# Printing the results |
|
|
231 |
acc2 = np.mean(preds_rg == Y_test.argmax(axis = -1)) |
|
|
232 |
print("Classification accuracy: %f " % (acc2)) |
|
|
233 |
|
|
|
234 |
# plot the confusion matrices for both classifiers |
|
|
235 |
names = ['audio left', 'audio right', 'vis left', 'vis right'] |
|
|
236 |
plt.figure(0) |
|
|
237 |
plot_confusion_matrix(preds, Y_test.argmax(axis = -1), names, title = 'EEGNet-8,2') |
|
|
238 |
|
|
|
239 |
plt.figure(1) |
|
|
240 |
plot_confusion_matrix(preds_rg, Y_test.argmax(axis = -1), names, title = 'xDAWN + RG') |
|
|
241 |
|
|
|
242 |
|
|
|
243 |
|