Diff of /examples/ERP.py [000000] .. [195f5e]

Switch to side-by-side view

--- a
+++ b/examples/ERP.py
@@ -0,0 +1,243 @@
+"""
+ Sample script using EEGNet to classify Event-Related Potential (ERP) EEG data
+ from a four-class classification task, using the sample dataset provided in
+ the MNE [1, 2] package:
+     https://martinos.org/mne/stable/manual/sample_dataset.html#ch-sample-data
+   
+ The four classes used from this dataset are:
+     LA: Left-ear auditory stimulation
+     RA: Right-ear auditory stimulation
+     LV: Left visual field stimulation
+     RV: Right visual field stimulation
+
+ The code to process, filter and epoch the data are originally from Alexandre
+ Barachant's PyRiemann [3] package, released under the BSD 3-clause. A copy of 
+ the BSD 3-clause license has been provided together with this software to 
+ comply with software licensing requirements. 
+ 
+ When you first run this script, MNE will download the dataset and prompt you
+ to confirm the download location (defaults to ~/mne_data). Follow the prompts
+ to continue. The dataset size is approx. 1.5GB download. 
+ 
+ For comparative purposes you can also compare EEGNet performance to using 
+ Riemannian geometric approaches with xDAWN spatial filtering [4-8] using 
+ PyRiemann (code provided below).
+
+ [1] A. Gramfort, M. Luessi, E. Larson, D. Engemann, D. Strohmeier, C. Brodbeck,
+     L. Parkkonen, M. Hämäläinen, MNE software for processing MEG and EEG data, 
+     NeuroImage, Volume 86, 1 February 2014, Pages 446-460, ISSN 1053-8119.
+
+ [2] A. Gramfort, M. Luessi, E. Larson, D. Engemann, D. Strohmeier, C. Brodbeck, 
+     R. Goj, M. Jas, T. Brooks, L. Parkkonen, M. Hämäläinen, MEG and EEG data 
+     analysis with MNE-Python, Frontiers in Neuroscience, Volume 7, 2013.
+
+ [3] https://github.com/alexandrebarachant/pyRiemann. 
+
+ [4] A. Barachant, M. Congedo ,"A Plug&Play P300 BCI Using Information Geometry"
+     arXiv:1409.0107. link
+
+ [5] M. Congedo, A. Barachant, A. Andreev ,"A New generation of Brain-Computer 
+     Interface Based on Riemannian Geometry", arXiv: 1310.8115.
+
+ [6] A. Barachant and S. Bonnet, "Channel selection procedure using riemannian 
+     distance for BCI applications," in 2011 5th International IEEE/EMBS 
+     Conference on Neural Engineering (NER), 2011, 348-351.
+
+ [7] A. Barachant, S. Bonnet, M. Congedo and C. Jutten, “Multiclass 
+     Brain-Computer Interface Classification by Riemannian Geometry,” in IEEE 
+     Transactions on Biomedical Engineering, vol. 59, no. 4, p. 920-928, 2012.
+
+ [8] A. Barachant, S. Bonnet, M. Congedo and C. Jutten, “Classification of 
+     covariance matrices using a Riemannian-based kernel for BCI applications“, 
+     in NeuroComputing, vol. 112, p. 172-178, 2013.
+
+
+ Portions of this project are works of the United States Government and are not
+ subject to domestic copyright protection under 17 USC Sec. 105.  Those 
+ portions are released world-wide under the terms of the Creative Commons Zero 
+ 1.0 (CC0) license.  
+ 
+ Other portions of this project are subject to domestic copyright protection 
+ under 17 USC Sec. 105.  Those portions are licensed under the Apache 2.0 
+ license.  The complete text of the license governing this material is in 
+ the file labeled LICENSE.TXT that is a part of this project's official 
+ distribution. 
+"""
+
+import numpy as np
+
+# mne imports
+import mne
+from mne import io
+from mne.datasets import sample
+
+# EEGNet-specific imports
+from EEGModels import EEGNet
+from tensorflow.keras import utils as np_utils
+from tensorflow.keras.callbacks import ModelCheckpoint
+from tensorflow.keras import backend as K
+
+# PyRiemann imports
+from pyriemann.estimation import XdawnCovariances
+from pyriemann.tangentspace import TangentSpace
+from pyriemann.utils.viz import plot_confusion_matrix
+from sklearn.pipeline import make_pipeline
+from sklearn.linear_model import LogisticRegression
+
+# tools for plotting confusion matrices
+from matplotlib import pyplot as plt
+
+# while the default tensorflow ordering is 'channels_last' we set it here
+# to be explicit in case if the user has changed the default ordering
+K.set_image_data_format('channels_last')
+
+##################### Process, filter and epoch the data ######################
+data_path = sample.data_path()
+
+# Set parameters and read data
+raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
+event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
+tmin, tmax = -0., 1
+event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)
+
+# Setup for reading the raw data
+raw = io.Raw(raw_fname, preload=True, verbose=False)
+raw.filter(2, None, method='iir')  # replace baselining with high-pass
+events = mne.read_events(event_fname)
+
+raw.info['bads'] = ['MEG 2443']  # set bad channels
+picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
+                       exclude='bads')
+
+# Read epochs
+epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False,
+                    picks=picks, baseline=None, preload=True, verbose=False)
+labels = epochs.events[:, -1]
+
+# extract raw data. scale by 1000 due to scaling sensitivity in deep learning
+X = epochs.get_data()*1000 # format is in (trials, channels, samples)
+y = labels
+
+kernels, chans, samples = 1, 60, 151
+
+# take 50/25/25 percent of the data to train/validate/test
+X_train      = X[0:144,]
+Y_train      = y[0:144]
+X_validate   = X[144:216,]
+Y_validate   = y[144:216]
+X_test       = X[216:,]
+Y_test       = y[216:]
+
+############################# EEGNet portion ##################################
+
+# convert labels to one-hot encodings.
+Y_train      = np_utils.to_categorical(Y_train-1)
+Y_validate   = np_utils.to_categorical(Y_validate-1)
+Y_test       = np_utils.to_categorical(Y_test-1)
+
+# convert data to NHWC (trials, channels, samples, kernels) format. Data 
+# contains 60 channels and 151 time-points. Set the number of kernels to 1.
+X_train      = X_train.reshape(X_train.shape[0], chans, samples, kernels)
+X_validate   = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)
+X_test       = X_test.reshape(X_test.shape[0], chans, samples, kernels)
+   
+print('X_train shape:', X_train.shape)
+print(X_train.shape[0], 'train samples')
+print(X_test.shape[0], 'test samples')
+
+# configure the EEGNet-8,2,16 model with kernel length of 32 samples (other 
+# model configurations may do better, but this is a good starting point)
+model = EEGNet(nb_classes = 4, Chans = chans, Samples = samples, 
+               dropoutRate = 0.5, kernLength = 32, F1 = 8, D = 2, F2 = 16, 
+               dropoutType = 'Dropout')
+
+# compile the model and set the optimizers
+model.compile(loss='categorical_crossentropy', optimizer='adam', 
+              metrics = ['accuracy'])
+
+# count number of parameters in the model
+numParams    = model.count_params()    
+
+# set a valid path for your system to record model checkpoints
+checkpointer = ModelCheckpoint(filepath='/tmp/checkpoint.h5', verbose=1,
+                               save_best_only=True)
+
+###############################################################################
+# if the classification task was imbalanced (significantly more trials in one
+# class versus the others) you can assign a weight to each class during 
+# optimization to balance it out. This data is approximately balanced so we 
+# don't need to do this, but is shown here for illustration/completeness. 
+###############################################################################
+
+# the syntax is {class_1:weight_1, class_2:weight_2,...}. Here just setting
+# the weights all to be 1
+class_weights = {0:1, 1:1, 2:1, 3:1}
+
+################################################################################
+# fit the model. Due to very small sample sizes this can get
+# pretty noisy run-to-run, but most runs should be comparable to xDAWN + 
+# Riemannian geometry classification (below)
+################################################################################
+fittedModel = model.fit(X_train, Y_train, batch_size = 16, epochs = 300, 
+                        verbose = 2, validation_data=(X_validate, Y_validate),
+                        callbacks=[checkpointer], class_weight = class_weights)
+
+# load optimal weights
+model.load_weights('/tmp/checkpoint.h5')
+
+###############################################################################
+# can alternatively used the weights provided in the repo. If so it should get
+# you 93% accuracy. Change the WEIGHTS_PATH variable to wherever it is on your
+# system.
+###############################################################################
+
+# WEIGHTS_PATH = /path/to/EEGNet-8-2-weights.h5 
+# model.load_weights(WEIGHTS_PATH)
+
+###############################################################################
+# make prediction on test set.
+###############################################################################
+
+probs       = model.predict(X_test)
+preds       = probs.argmax(axis = -1)  
+acc         = np.mean(preds == Y_test.argmax(axis=-1))
+print("Classification accuracy: %f " % (acc))
+
+
+############################# PyRiemann Portion ##############################
+
+# code is taken from PyRiemann's ERP sample script, which is decoding in 
+# the tangent space with a logistic regression
+
+n_components = 2  # pick some components
+
+# set up sklearn pipeline
+clf = make_pipeline(XdawnCovariances(n_components),
+                    TangentSpace(metric='riemann'),
+                    LogisticRegression())
+
+preds_rg     = np.zeros(len(Y_test))
+
+# reshape back to (trials, channels, samples)
+X_train      = X_train.reshape(X_train.shape[0], chans, samples)
+X_test       = X_test.reshape(X_test.shape[0], chans, samples)
+
+# train a classifier with xDAWN spatial filtering + Riemannian Geometry (RG)
+# labels need to be back in single-column format
+clf.fit(X_train, Y_train.argmax(axis = -1))
+preds_rg     = clf.predict(X_test)
+
+# Printing the results
+acc2         = np.mean(preds_rg == Y_test.argmax(axis = -1))
+print("Classification accuracy: %f " % (acc2))
+
+# plot the confusion matrices for both classifiers
+names        = ['audio left', 'audio right', 'vis left', 'vis right']
+plt.figure(0)
+plot_confusion_matrix(preds, Y_test.argmax(axis = -1), names, title = 'EEGNet-8,2')
+
+plt.figure(1)
+plot_confusion_matrix(preds_rg, Y_test.argmax(axis = -1), names, title = 'xDAWN + RG')
+
+
+