Diff of /examples/ERP.py [000000] .. [195f5e]

Switch to unified view

a b/examples/ERP.py
1
"""
2
 Sample script using EEGNet to classify Event-Related Potential (ERP) EEG data
3
 from a four-class classification task, using the sample dataset provided in
4
 the MNE [1, 2] package:
5
     https://martinos.org/mne/stable/manual/sample_dataset.html#ch-sample-data
6
   
7
 The four classes used from this dataset are:
8
     LA: Left-ear auditory stimulation
9
     RA: Right-ear auditory stimulation
10
     LV: Left visual field stimulation
11
     RV: Right visual field stimulation
12
13
 The code to process, filter and epoch the data are originally from Alexandre
14
 Barachant's PyRiemann [3] package, released under the BSD 3-clause. A copy of 
15
 the BSD 3-clause license has been provided together with this software to 
16
 comply with software licensing requirements. 
17
 
18
 When you first run this script, MNE will download the dataset and prompt you
19
 to confirm the download location (defaults to ~/mne_data). Follow the prompts
20
 to continue. The dataset size is approx. 1.5GB download. 
21
 
22
 For comparative purposes you can also compare EEGNet performance to using 
23
 Riemannian geometric approaches with xDAWN spatial filtering [4-8] using 
24
 PyRiemann (code provided below).
25
26
 [1] A. Gramfort, M. Luessi, E. Larson, D. Engemann, D. Strohmeier, C. Brodbeck,
27
     L. Parkkonen, M. Hämäläinen, MNE software for processing MEG and EEG data, 
28
     NeuroImage, Volume 86, 1 February 2014, Pages 446-460, ISSN 1053-8119.
29
30
 [2] A. Gramfort, M. Luessi, E. Larson, D. Engemann, D. Strohmeier, C. Brodbeck, 
31
     R. Goj, M. Jas, T. Brooks, L. Parkkonen, M. Hämäläinen, MEG and EEG data 
32
     analysis with MNE-Python, Frontiers in Neuroscience, Volume 7, 2013.
33
34
 [3] https://github.com/alexandrebarachant/pyRiemann. 
35
36
 [4] A. Barachant, M. Congedo ,"A Plug&Play P300 BCI Using Information Geometry"
37
     arXiv:1409.0107. link
38
39
 [5] M. Congedo, A. Barachant, A. Andreev ,"A New generation of Brain-Computer 
40
     Interface Based on Riemannian Geometry", arXiv: 1310.8115.
41
42
 [6] A. Barachant and S. Bonnet, "Channel selection procedure using riemannian 
43
     distance for BCI applications," in 2011 5th International IEEE/EMBS 
44
     Conference on Neural Engineering (NER), 2011, 348-351.
45
46
 [7] A. Barachant, S. Bonnet, M. Congedo and C. Jutten, “Multiclass 
47
     Brain-Computer Interface Classification by Riemannian Geometry,” in IEEE 
48
     Transactions on Biomedical Engineering, vol. 59, no. 4, p. 920-928, 2012.
49
50
 [8] A. Barachant, S. Bonnet, M. Congedo and C. Jutten, “Classification of 
51
     covariance matrices using a Riemannian-based kernel for BCI applications“, 
52
     in NeuroComputing, vol. 112, p. 172-178, 2013.
53
54
55
 Portions of this project are works of the United States Government and are not
56
 subject to domestic copyright protection under 17 USC Sec. 105.  Those 
57
 portions are released world-wide under the terms of the Creative Commons Zero 
58
 1.0 (CC0) license.  
59
 
60
 Other portions of this project are subject to domestic copyright protection 
61
 under 17 USC Sec. 105.  Those portions are licensed under the Apache 2.0 
62
 license.  The complete text of the license governing this material is in 
63
 the file labeled LICENSE.TXT that is a part of this project's official 
64
 distribution. 
65
"""
66
67
import numpy as np
68
69
# mne imports
70
import mne
71
from mne import io
72
from mne.datasets import sample
73
74
# EEGNet-specific imports
75
from EEGModels import EEGNet
76
from tensorflow.keras import utils as np_utils
77
from tensorflow.keras.callbacks import ModelCheckpoint
78
from tensorflow.keras import backend as K
79
80
# PyRiemann imports
81
from pyriemann.estimation import XdawnCovariances
82
from pyriemann.tangentspace import TangentSpace
83
from pyriemann.utils.viz import plot_confusion_matrix
84
from sklearn.pipeline import make_pipeline
85
from sklearn.linear_model import LogisticRegression
86
87
# tools for plotting confusion matrices
88
from matplotlib import pyplot as plt
89
90
# while the default tensorflow ordering is 'channels_last' we set it here
91
# to be explicit in case if the user has changed the default ordering
92
K.set_image_data_format('channels_last')
93
94
##################### Process, filter and epoch the data ######################
95
data_path = sample.data_path()
96
97
# Set parameters and read data
98
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
99
event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
100
tmin, tmax = -0., 1
101
event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)
102
103
# Setup for reading the raw data
104
raw = io.Raw(raw_fname, preload=True, verbose=False)
105
raw.filter(2, None, method='iir')  # replace baselining with high-pass
106
events = mne.read_events(event_fname)
107
108
raw.info['bads'] = ['MEG 2443']  # set bad channels
109
picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
110
                       exclude='bads')
111
112
# Read epochs
113
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False,
114
                    picks=picks, baseline=None, preload=True, verbose=False)
115
labels = epochs.events[:, -1]
116
117
# extract raw data. scale by 1000 due to scaling sensitivity in deep learning
118
X = epochs.get_data()*1000 # format is in (trials, channels, samples)
119
y = labels
120
121
kernels, chans, samples = 1, 60, 151
122
123
# take 50/25/25 percent of the data to train/validate/test
124
X_train      = X[0:144,]
125
Y_train      = y[0:144]
126
X_validate   = X[144:216,]
127
Y_validate   = y[144:216]
128
X_test       = X[216:,]
129
Y_test       = y[216:]
130
131
############################# EEGNet portion ##################################
132
133
# convert labels to one-hot encodings.
134
Y_train      = np_utils.to_categorical(Y_train-1)
135
Y_validate   = np_utils.to_categorical(Y_validate-1)
136
Y_test       = np_utils.to_categorical(Y_test-1)
137
138
# convert data to NHWC (trials, channels, samples, kernels) format. Data 
139
# contains 60 channels and 151 time-points. Set the number of kernels to 1.
140
X_train      = X_train.reshape(X_train.shape[0], chans, samples, kernels)
141
X_validate   = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)
142
X_test       = X_test.reshape(X_test.shape[0], chans, samples, kernels)
143
   
144
print('X_train shape:', X_train.shape)
145
print(X_train.shape[0], 'train samples')
146
print(X_test.shape[0], 'test samples')
147
148
# configure the EEGNet-8,2,16 model with kernel length of 32 samples (other 
149
# model configurations may do better, but this is a good starting point)
150
model = EEGNet(nb_classes = 4, Chans = chans, Samples = samples, 
151
               dropoutRate = 0.5, kernLength = 32, F1 = 8, D = 2, F2 = 16, 
152
               dropoutType = 'Dropout')
153
154
# compile the model and set the optimizers
155
model.compile(loss='categorical_crossentropy', optimizer='adam', 
156
              metrics = ['accuracy'])
157
158
# count number of parameters in the model
159
numParams    = model.count_params()    
160
161
# set a valid path for your system to record model checkpoints
162
checkpointer = ModelCheckpoint(filepath='/tmp/checkpoint.h5', verbose=1,
163
                               save_best_only=True)
164
165
###############################################################################
166
# if the classification task was imbalanced (significantly more trials in one
167
# class versus the others) you can assign a weight to each class during 
168
# optimization to balance it out. This data is approximately balanced so we 
169
# don't need to do this, but is shown here for illustration/completeness. 
170
###############################################################################
171
172
# the syntax is {class_1:weight_1, class_2:weight_2,...}. Here just setting
173
# the weights all to be 1
174
class_weights = {0:1, 1:1, 2:1, 3:1}
175
176
################################################################################
177
# fit the model. Due to very small sample sizes this can get
178
# pretty noisy run-to-run, but most runs should be comparable to xDAWN + 
179
# Riemannian geometry classification (below)
180
################################################################################
181
fittedModel = model.fit(X_train, Y_train, batch_size = 16, epochs = 300, 
182
                        verbose = 2, validation_data=(X_validate, Y_validate),
183
                        callbacks=[checkpointer], class_weight = class_weights)
184
185
# load optimal weights
186
model.load_weights('/tmp/checkpoint.h5')
187
188
###############################################################################
189
# can alternatively used the weights provided in the repo. If so it should get
190
# you 93% accuracy. Change the WEIGHTS_PATH variable to wherever it is on your
191
# system.
192
###############################################################################
193
194
# WEIGHTS_PATH = /path/to/EEGNet-8-2-weights.h5 
195
# model.load_weights(WEIGHTS_PATH)
196
197
###############################################################################
198
# make prediction on test set.
199
###############################################################################
200
201
probs       = model.predict(X_test)
202
preds       = probs.argmax(axis = -1)  
203
acc         = np.mean(preds == Y_test.argmax(axis=-1))
204
print("Classification accuracy: %f " % (acc))
205
206
207
############################# PyRiemann Portion ##############################
208
209
# code is taken from PyRiemann's ERP sample script, which is decoding in 
210
# the tangent space with a logistic regression
211
212
n_components = 2  # pick some components
213
214
# set up sklearn pipeline
215
clf = make_pipeline(XdawnCovariances(n_components),
216
                    TangentSpace(metric='riemann'),
217
                    LogisticRegression())
218
219
preds_rg     = np.zeros(len(Y_test))
220
221
# reshape back to (trials, channels, samples)
222
X_train      = X_train.reshape(X_train.shape[0], chans, samples)
223
X_test       = X_test.reshape(X_test.shape[0], chans, samples)
224
225
# train a classifier with xDAWN spatial filtering + Riemannian Geometry (RG)
226
# labels need to be back in single-column format
227
clf.fit(X_train, Y_train.argmax(axis = -1))
228
preds_rg     = clf.predict(X_test)
229
230
# Printing the results
231
acc2         = np.mean(preds_rg == Y_test.argmax(axis = -1))
232
print("Classification accuracy: %f " % (acc2))
233
234
# plot the confusion matrices for both classifiers
235
names        = ['audio left', 'audio right', 'vis left', 'vis right']
236
plt.figure(0)
237
plot_confusion_matrix(preds, Y_test.argmax(axis = -1), names, title = 'EEGNet-8,2')
238
239
plt.figure(1)
240
plot_confusion_matrix(preds_rg, Y_test.argmax(axis = -1), names, title = 'xDAWN + RG')
241
242
243