a b/preprocessing/erp.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Wed Jul  8 22:00:08 2015.
4
5
@author: rc, alexandre
6
"""
7
import numpy as np
8
from sklearn.base import BaseEstimator, TransformerMixin
9
from mne.io import RawArray
10
from mne.channels import read_montage
11
from mne import create_info
12
from mne import find_events, Epochs
13
from mne.preprocessing import Xdawn
14
from mne import compute_raw_data_covariance
15
16
from pyriemann.utils.covariance import _lwf
17
from pyriemann.classification import MDM
18
19
from preprocessing.aux import getChannelNames, getEventNames, sliding_window
20
21
22
def toMNE(X, y=None):
23
    """Tranform array into MNE for epoching."""
24
    ch_names = getChannelNames()
25
    montage = read_montage('standard_1005', ch_names)
26
    ch_type = ['eeg']*len(ch_names)
27
    data = X.T
28
    if y is not None:
29
        y = y.transpose()
30
        ch_type.extend(['stim']*6)
31
        event_names = getEventNames()
32
        ch_names.extend(event_names)
33
        # concatenate event file and data
34
        data = np.concatenate((data, y))
35
    info = create_info(ch_names, sfreq=500.0, ch_types=ch_type,
36
                       montage=montage)
37
    raw = RawArray(data, info, verbose=False)
38
    return raw
39
40
41
def get_epochs_and_cov(X, y, window=500):
42
    """return epochs from array."""
43
    raw_train = toMNE(X, y)
44
    picks = range(len(getChannelNames()))
45
46
    events = list()
47
    events_id = dict()
48
    for j, eid in enumerate(getEventNames()):
49
        tmp = find_events(raw_train, stim_channel=eid, verbose=False)
50
        tmp[:, -1] = j + 1
51
        events.append(tmp)
52
        events_id[eid] = j + 1
53
54
    # concatenate and sort events
55
    events = np.concatenate(events, axis=0)
56
    order_ev = np.argsort(events[:, 0])
57
    events = events[order_ev]
58
59
    epochs = Epochs(raw_train, events, events_id,
60
                    tmin=-(window / 500.0) + 1 / 500.0 + 0.150,
61
                    tmax=0.150, proj=False, picks=picks, baseline=None,
62
                    preload=True, add_eeg_ref=False, verbose=False)
63
64
    cov_signal = compute_raw_data_covariance(raw_train, verbose=False)
65
    return epochs, cov_signal
66
67
68
class ERP(BaseEstimator, TransformerMixin):
69
70
    """ERP cov estimator.
71
72
    This is a transformer for estimating special form covariance matrices in
73
    the context of ERP detection [1,2]. For each class, the ERP is estimated by
74
    averaging all the epochs of a given class. The dimentionality of the ERP is
75
    reduced using a XDAWN algorithm and then concatenated with the epochs
76
    before estimation of the covariance matrices.
77
78
    References :
79
    [1] A. Barachant, M. Congedo ,"A Plug&Play P300 BCI Using Information
80
    Geometry", arXiv:1409.0107
81
    [2] M. Congedo, A. Barachant, A. Andreev ,"A New generation of
82
    Brain-Computer Interface Based on Riemannian Geometry", arXiv: 1310.8115.
83
    """
84
85
    def __init__(self, window=500, nfilters=3, subsample=1):
86
        """Init."""
87
        self.window = window
88
        self.nfilters = nfilters
89
        self.subsample = subsample
90
91
    def fit(self, X, y):
92
        """fit."""
93
        self._fit(X, y)
94
        return self
95
96
    def _fit(self, X, y):
97
        """fit and return epochs."""
98
        epochs, cov_signal = get_epochs_and_cov(X, y, self.window)
99
100
        xd = Xdawn(n_components=self.nfilters, signal_cov=cov_signal,
101
                   correct_overlap=False)
102
        xd.fit(epochs)
103
104
        P = []
105
        for eid in getEventNames():
106
            P.append(np.dot(xd.filters_[eid][:, 0:self.nfilters].T,
107
                            xd.evokeds_[eid].data))
108
        self.P = np.concatenate(P, axis=0)
109
        self.labels_train = epochs.events[:, -1]
110
        return epochs
111
112
    def transform(self, X, y=None):
113
        """Transform."""
114
        test_cov = sliding_window(X.T, window=self.window,
115
                                  subsample=self.subsample,
116
                                  estimator=self.erp_cov)
117
        return test_cov
118
119
    def fit_transform(self, X, y):
120
        """Fit and transform."""
121
        epochs = self._fit(X, y)
122
        train_cov = np.array([self.erp_cov(ep) for ep in epochs.get_data()])
123
        return train_cov
124
125
    def erp_cov(self, X):
126
        """Compute ERP covariances."""
127
        data = np.concatenate((self.P, X), axis=0)
128
        return _lwf(data)
129
130
    def update_subsample(self, old_sub, new_sub):
131
        """update subsampling."""
132
        self.subsample = new_sub
133
134
135
class ERPDistance(BaseEstimator, TransformerMixin):
136
137
    """ERP distance cov estimator.
138
139
    This transformer estimates Riemannian distance for ERP covariance matrices.
140
    After estimation of special form ERP covariance matrices using the ERP
141
    transformer, a MDM [1] algorithm is used to compute Riemannian distance.
142
143
    References:
144
    [1] A. Barachant, S. Bonnet, M. Congedo and C. Jutten, "Multiclass
145
    Brain-Computer Interface Classification by Riemannian Geometry," in IEEE
146
    Transactions on Biomedical Engineering, vol. 59, no. 4, p. 920-928, 2012
147
    """
148
149
    def __init__(self, window=500, nfilters=3, subsample=1, metric='riemann',
150
                 n_jobs=1):
151
        """Init."""
152
        self.window = window
153
        self.nfilters = nfilters
154
        self.subsample = subsample
155
        self.metric = metric
156
        self.n_jobs = n_jobs
157
        self._fitted = False
158
159
    def fit(self, X, y):
160
        """fit."""
161
        # Create ERP and get cov mat
162
        self.ERP = ERP(self.window, self.nfilters, self.subsample)
163
        train_cov = self.ERP.fit_transform(X, y)
164
        labels_train = self.ERP.labels_train
165
166
        # Add rest epochs
167
        rest_cov = self._get_rest_cov(X, y)
168
        train_cov = np.concatenate((train_cov, rest_cov), axis=0)
169
        labels_train = np.concatenate((labels_train, [0] * len(rest_cov)))
170
171
        # fit MDM
172
        self.MDM = MDM(metric=self.metric, n_jobs=self.n_jobs)
173
        self.MDM.fit(train_cov, labels_train)
174
        self._fitted = True
175
        return self
176
177
    def transform(self, X, y=None):
178
        """Transform."""
179
        test_cov = self.ERP.transform(X)
180
        dist = self.MDM.transform(test_cov)
181
        dist = dist[:, 1:] - np.atleast_2d(dist[:, 0]).T
182
        return dist
183
184
    def update_subsample(self, old_sub, new_sub):
185
        """update subsampling."""
186
        if self._fitted:
187
            self.ERP.update_subsample(old_sub, new_sub)
188
189
    def _get_rest_cov(self, X, y):
190
        """Sample rest epochs from data and compute the cov mat."""
191
        ix = np.where(np.diff(y[:, 0]) == 1)[0]
192
        rest = []
193
        offset = - self.window
194
        for i in ix:
195
            start = i + offset - self.window
196
            stop = i + offset
197
            rest.append(self.ERP.erp_cov(X[slice(start, stop)].T))
198
        return np.array(rest)