In [1]:
"""
Format dataset, we read the file for the desired subject, and parse the data to extract:
- samplingRate
- trialLength
- X, a M x N x K matrix, which stands for trial x chan x samples
                         the actual values are 160 x 15 x 2560
- y, a M vector containing the labels {0,1}

ref:
Dataset description: https://lampx.tugraz.at/~bci/database/002-2014/description.pdf
"""

import scipy.io as sio
import numpy as np
import sys 
import mne

# prepare data containers
y = []
X = []

trainingFileList = ['../../BBCIData/S14T.mat', 
                    '../../BBCIData/S13T.mat', 
                    '../../BBCIData/S12T.mat', 
                    '../../BBCIData/S11T.mat', 
                    '../../BBCIData/S10T.mat', 
                    '../../BBCIData/S09T.mat', 
                    '../../BBCIData/S08T.mat', 
                    '../../BBCIData/S07T.mat', 
                    '../../BBCIData/S06T.mat', 
                    '../../BBCIData/S05T.mat', 
                    '../../BBCIData/S04T.mat', 
                    '../../BBCIData/S03T.mat', 
                    '../../BBCIData/S02T.mat', 
                    '../../BBCIData/S01T.mat']

validationFileList = ['../../BBCIData/S14E.mat', 
                      '../../BBCIData/S13E.mat', 
                      '../../BBCIData/S12E.mat', 
                      '../../BBCIData/S11E.mat', 
                      '../../BBCIData/S10E.mat', 
                      '../../BBCIData/S09E.mat', 
                      '../../BBCIData/S08E.mat', 
                      '../../BBCIData/S07E.mat', 
                      '../../BBCIData/S06E.mat', 
                      '../../BBCIData/S05E.mat', 
                      '../../BBCIData/S04E.mat', 
                      '../../BBCIData/S03E.mat', 
                      '../../BBCIData/S02E.mat', 
                      '../../BBCIData/S01E.mat']



#trainingFileList = ['BBCIData/S14T.mat']

#validationFileList = ['BBCIData/S14E.mat']

def filterData(rawData, samplingRate):
    print(rawData.shape)
    
    #filteredData = mne.filter.filter_data()
    
    filterHandle = mne.filter.create_filter(rawData[0,:], samplingRate, l_freq=7., h_freq=35.)  
    filterHalfLenght = len(filterHandle) // 2

    filteredData = np.zeros([rawData.shape[0],rawData.shape[1]+filterHalfLenght])

    # for each channel
    for i in range(rawData.shape[0]):
        #filteredData[i,:] = mne.filter.filter_data(rawData[i,:],sfreq=samplingRate ,l_freq=7., h_freq=35.,method='iir')
        filteredData[i,:] = np.convolve(filterHandle, rawData[i,:])[len(filterHandle) // 2:]

    filteredData = filteredData[:,filterHalfLenght:-filterHalfLenght]
    return filteredData
    


tStart = 3.5
tStop = 5.

for i in range(len(trainingFileList)):
    # read file
    d1T = sio.loadmat(trainingFileList[i])
    d1E = sio.loadmat(validationFileList[i])
    
    samplingRate = d1T['data'][0][0][0][0][3][0][0]
    trialLength = (int)(tStop*samplingRate)


    # run through all training runs
    for run in range(5):
        y.append(d1T['data'][0][run][0][0][2][0]) # labels
        timestamps = d1T['data'][0][run][0][0][1][0] # timestamps
        rawData = d1T['data'][0][run][0][0][0].transpose() # chan x data
        
        rawData = filterData(rawData,samplingRate)
        
        # parse out data based on timestamps
        for start in timestamps:
            dstart = (int)(start + tStart*samplingRate)
            end = start + trialLength
            X.append(rawData[:,dstart:end]) #15 x 2560


    # run through all validation runs (we do not discriminate at this point)
    for run in range(3):
        y.append(d1E['data'][0][run][0][0][2][0]) # labels
        timestamps = d1E['data'][0][run][0][0][1][0] # timestamps
        rawData = d1E['data'][0][run][0][0][0].transpose() # chan x data

        filterData(rawData,samplingRate)
        
        # parse out data based on timestamps
        for start in timestamps:
            dstart = (int)(start + tStart*samplingRate)
            end = start + trialLength
            X.append(rawData[:,dstart:end]) #15 x 2557

    del rawData
    del d1T
    del d1E

# arrange data into numpy arrays
# also torch expect float32 for samples
# and int64 for labels {0,1}
X = np.array(X).astype(np.float32)
y = (np.array(y).flatten()-1).astype(np.int64)
print(X.shape)
print(y.shape)

# erase unused references
d1T = []
d1E = []



(15, 114176)
Setting up band-pass filter from 7 - 35 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 8.8 Hz
Filter length of 845 samples (1.650 sec) selected
(15, 112640)
Setting up band-pass filter from 7 - 35 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 8.8 Hz
Filter length of 845 samples (1.650 sec) selected
(15, 114176)
Setting up band-pass filter from 7 - 35 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 8.8 Hz
Filter length of 845 samples (1.650 sec) selected
(15, 113152)
Setting up band-pass filter from 7 - 35 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 8.8 Hz
Filter length of 845 samples (1.650 sec) selected
(15, 112128)
Setting up band-pass filter from 7 - 35 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 8.8 Hz
Filter length of 845 samples (1.650 sec) selected
(15, 112640)
Setting up band-pass filter from 7 - 35 Hz
l_trans_bandwidth chosen to be 2.0

l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 8.8 Hz
Filter length of 845 samples (1.650 sec) selected
(15, 112640)
Setting up band-pass filter from 7 - 35 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 8.8 Hz
Filter length of 845 samples (1.650 sec) selected
(15, 112128)
Setting up band-pass filter from 7 - 35 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 8.8 Hz
Filter length of 845 samples (1.650 sec) selected
(15, 113664)
Setting up band-pass filter from 7 - 35 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 8.8 Hz
Filter length of 845 samples (1.650 sec) selected
(15, 113152)
Setting up band-pass filter from 7 - 35 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 8.8 Hz
Filter length of 845 samples (1.650 sec) selected
(15, 112640)
Setting up band-pass filter from 7 - 35 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 8.8 Hz
Filter length 

h_trans_bandwidth chosen to be 8.8 Hz
Filter length of 845 samples (1.650 sec) selected
(15, 113664)
Setting up band-pass filter from 7 - 35 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 8.8 Hz
Filter length of 845 samples (1.650 sec) selected
(15, 113664)
Setting up band-pass filter from 7 - 35 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 8.8 Hz
Filter length of 845 samples (1.650 sec) selected
(15, 112640)
Setting up band-pass filter from 7 - 35 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 8.8 Hz
Filter length of 845 samples (1.650 sec) selected
(15, 113664)
Setting up band-pass filter from 7 - 35 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 8.8 Hz
Filter length of 845 samples (1.650 sec) selected
(15, 112640)
Setting up band-pass filter from 7 - 35 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 8.8 Hz
Filter length of 845 samples (1.650 sec) selected
(1

In [4]:
"""
====================================================================
ERP EEG decoding in Tangent space.
====================================================================
Decoding applied to EEG data in sensor space decomposed using Xdawn.
After spatial filtering, covariances matrices are estimated, then projected in
the tangent space and classified with a logistic regression.
"""
# Authors: Alexandre Barachant <alexandre.barachant@gmail.com>
#
# License: BSD (3-clause)

import numpy as np

from pyriemann.estimation import XdawnCovariances
#from pyriemann.tangentspace import TangentSpace
from pyriemann.utils.viz import plot_confusion_matrix
from pyriemann.estimation import Covariances
from pyriemann.classification import TSclassifier, MDM

from sklearn.pipeline import make_pipeline
from sklearn.cross_validation import KFold
from sklearn.linear_model import LogisticRegression
from random import randint
from matplotlib import pyplot as plt

###############################################################################
# Decoding in tangent space with a logistic regression

n_components = 2  # pick some components

labels = y
epochs_data = X


# Define a monte-carlo cross-validation generator (reduce variance):
cv = KFold(len(labels), 10, shuffle=True, random_state=randint(1,5000))

print("epoch data:")
print(epochs_data.shape)




clf = make_pipeline(Covariances('oas'), TSclassifier())

preds = np.zeros(len(labels))

print("labels:")
print(labels.shape)

for train_idx, test_idx in cv:
    y_train, y_test = labels[train_idx], labels[test_idx]

    clf.fit(epochs_data[train_idx], y_train)
    preds[test_idx] = clf.predict(epochs_data[test_idx])

# Printing the results
acc = np.mean(preds == labels)
print("Classification accuracy: %f " % (acc))

#names = ['audio left', 'audio right', 'vis left', 'vis right']
#plot_confusion_matrix(preds, labels, names)
#plt.show()



epoch data:
(2240, 15, 768)
labels:
(2240,)
Classification accuracy: 0.676339 


In [3]:

# subject 1
# Classification accuracy: 0.618750

# subject 2
# Classification accuracy: 0.768750 

# subject 3
# Classification accuracy: 0.950000 

# subject 4
# Classification accuracy: 0.787500 

# subject 5
# Classification accuracy: 0.725000

# subject 6
# Classification accuracy:  0.675000   

# subject 7
# Classification accuracy: 0.856250 

# subject 8
# Classification accuracy: 0.762500 

# subject 9
# Classification accuracy: 0.931250 
-
# subject 10
# Classification accuracy: 0.681250

# subject 11
# Classification accuracy: 0.812500 

# subject 12
# Classification accuracy: 0.581250 

# subject 13
# Classification accuracy: 0.506250 

# subject 14
# Classification accuracy: 0.500000 

# subject 1-14
# Classification accuracy: 0.695536 

SyntaxError: invalid syntax (<ipython-input-3-7c5a99a2478d>, line 28)