|
a |
|
b/preprocessing/erp.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
""" |
|
|
3 |
Created on Wed Jul 8 22:00:08 2015. |
|
|
4 |
|
|
|
5 |
@author: rc, alexandre |
|
|
6 |
""" |
|
|
7 |
import numpy as np |
|
|
8 |
from sklearn.base import BaseEstimator, TransformerMixin |
|
|
9 |
from mne.io import RawArray |
|
|
10 |
from mne.channels import read_montage |
|
|
11 |
from mne import create_info |
|
|
12 |
from mne import find_events, Epochs |
|
|
13 |
from mne.preprocessing import Xdawn |
|
|
14 |
from mne import compute_raw_data_covariance |
|
|
15 |
|
|
|
16 |
from pyriemann.utils.covariance import _lwf |
|
|
17 |
from pyriemann.classification import MDM |
|
|
18 |
|
|
|
19 |
from preprocessing.aux import getChannelNames, getEventNames, sliding_window |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
def toMNE(X, y=None): |
|
|
23 |
"""Tranform array into MNE for epoching.""" |
|
|
24 |
ch_names = getChannelNames() |
|
|
25 |
montage = read_montage('standard_1005', ch_names) |
|
|
26 |
ch_type = ['eeg']*len(ch_names) |
|
|
27 |
data = X.T |
|
|
28 |
if y is not None: |
|
|
29 |
y = y.transpose() |
|
|
30 |
ch_type.extend(['stim']*6) |
|
|
31 |
event_names = getEventNames() |
|
|
32 |
ch_names.extend(event_names) |
|
|
33 |
# concatenate event file and data |
|
|
34 |
data = np.concatenate((data, y)) |
|
|
35 |
info = create_info(ch_names, sfreq=500.0, ch_types=ch_type, |
|
|
36 |
montage=montage) |
|
|
37 |
raw = RawArray(data, info, verbose=False) |
|
|
38 |
return raw |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
def get_epochs_and_cov(X, y, window=500): |
|
|
42 |
"""return epochs from array.""" |
|
|
43 |
raw_train = toMNE(X, y) |
|
|
44 |
picks = range(len(getChannelNames())) |
|
|
45 |
|
|
|
46 |
events = list() |
|
|
47 |
events_id = dict() |
|
|
48 |
for j, eid in enumerate(getEventNames()): |
|
|
49 |
tmp = find_events(raw_train, stim_channel=eid, verbose=False) |
|
|
50 |
tmp[:, -1] = j + 1 |
|
|
51 |
events.append(tmp) |
|
|
52 |
events_id[eid] = j + 1 |
|
|
53 |
|
|
|
54 |
# concatenate and sort events |
|
|
55 |
events = np.concatenate(events, axis=0) |
|
|
56 |
order_ev = np.argsort(events[:, 0]) |
|
|
57 |
events = events[order_ev] |
|
|
58 |
|
|
|
59 |
epochs = Epochs(raw_train, events, events_id, |
|
|
60 |
tmin=-(window / 500.0) + 1 / 500.0 + 0.150, |
|
|
61 |
tmax=0.150, proj=False, picks=picks, baseline=None, |
|
|
62 |
preload=True, add_eeg_ref=False, verbose=False) |
|
|
63 |
|
|
|
64 |
cov_signal = compute_raw_data_covariance(raw_train, verbose=False) |
|
|
65 |
return epochs, cov_signal |
|
|
66 |
|
|
|
67 |
|
|
|
68 |
class ERP(BaseEstimator, TransformerMixin): |
|
|
69 |
|
|
|
70 |
"""ERP cov estimator. |
|
|
71 |
|
|
|
72 |
This is a transformer for estimating special form covariance matrices in |
|
|
73 |
the context of ERP detection [1,2]. For each class, the ERP is estimated by |
|
|
74 |
averaging all the epochs of a given class. The dimentionality of the ERP is |
|
|
75 |
reduced using a XDAWN algorithm and then concatenated with the epochs |
|
|
76 |
before estimation of the covariance matrices. |
|
|
77 |
|
|
|
78 |
References : |
|
|
79 |
[1] A. Barachant, M. Congedo ,"A Plug&Play P300 BCI Using Information |
|
|
80 |
Geometry", arXiv:1409.0107 |
|
|
81 |
[2] M. Congedo, A. Barachant, A. Andreev ,"A New generation of |
|
|
82 |
Brain-Computer Interface Based on Riemannian Geometry", arXiv: 1310.8115. |
|
|
83 |
""" |
|
|
84 |
|
|
|
85 |
def __init__(self, window=500, nfilters=3, subsample=1): |
|
|
86 |
"""Init.""" |
|
|
87 |
self.window = window |
|
|
88 |
self.nfilters = nfilters |
|
|
89 |
self.subsample = subsample |
|
|
90 |
|
|
|
91 |
def fit(self, X, y): |
|
|
92 |
"""fit.""" |
|
|
93 |
self._fit(X, y) |
|
|
94 |
return self |
|
|
95 |
|
|
|
96 |
def _fit(self, X, y): |
|
|
97 |
"""fit and return epochs.""" |
|
|
98 |
epochs, cov_signal = get_epochs_and_cov(X, y, self.window) |
|
|
99 |
|
|
|
100 |
xd = Xdawn(n_components=self.nfilters, signal_cov=cov_signal, |
|
|
101 |
correct_overlap=False) |
|
|
102 |
xd.fit(epochs) |
|
|
103 |
|
|
|
104 |
P = [] |
|
|
105 |
for eid in getEventNames(): |
|
|
106 |
P.append(np.dot(xd.filters_[eid][:, 0:self.nfilters].T, |
|
|
107 |
xd.evokeds_[eid].data)) |
|
|
108 |
self.P = np.concatenate(P, axis=0) |
|
|
109 |
self.labels_train = epochs.events[:, -1] |
|
|
110 |
return epochs |
|
|
111 |
|
|
|
112 |
def transform(self, X, y=None): |
|
|
113 |
"""Transform.""" |
|
|
114 |
test_cov = sliding_window(X.T, window=self.window, |
|
|
115 |
subsample=self.subsample, |
|
|
116 |
estimator=self.erp_cov) |
|
|
117 |
return test_cov |
|
|
118 |
|
|
|
119 |
def fit_transform(self, X, y): |
|
|
120 |
"""Fit and transform.""" |
|
|
121 |
epochs = self._fit(X, y) |
|
|
122 |
train_cov = np.array([self.erp_cov(ep) for ep in epochs.get_data()]) |
|
|
123 |
return train_cov |
|
|
124 |
|
|
|
125 |
def erp_cov(self, X): |
|
|
126 |
"""Compute ERP covariances.""" |
|
|
127 |
data = np.concatenate((self.P, X), axis=0) |
|
|
128 |
return _lwf(data) |
|
|
129 |
|
|
|
130 |
def update_subsample(self, old_sub, new_sub): |
|
|
131 |
"""update subsampling.""" |
|
|
132 |
self.subsample = new_sub |
|
|
133 |
|
|
|
134 |
|
|
|
135 |
class ERPDistance(BaseEstimator, TransformerMixin): |
|
|
136 |
|
|
|
137 |
"""ERP distance cov estimator. |
|
|
138 |
|
|
|
139 |
This transformer estimates Riemannian distance for ERP covariance matrices. |
|
|
140 |
After estimation of special form ERP covariance matrices using the ERP |
|
|
141 |
transformer, a MDM [1] algorithm is used to compute Riemannian distance. |
|
|
142 |
|
|
|
143 |
References: |
|
|
144 |
[1] A. Barachant, S. Bonnet, M. Congedo and C. Jutten, "Multiclass |
|
|
145 |
Brain-Computer Interface Classification by Riemannian Geometry," in IEEE |
|
|
146 |
Transactions on Biomedical Engineering, vol. 59, no. 4, p. 920-928, 2012 |
|
|
147 |
""" |
|
|
148 |
|
|
|
149 |
def __init__(self, window=500, nfilters=3, subsample=1, metric='riemann', |
|
|
150 |
n_jobs=1): |
|
|
151 |
"""Init.""" |
|
|
152 |
self.window = window |
|
|
153 |
self.nfilters = nfilters |
|
|
154 |
self.subsample = subsample |
|
|
155 |
self.metric = metric |
|
|
156 |
self.n_jobs = n_jobs |
|
|
157 |
self._fitted = False |
|
|
158 |
|
|
|
159 |
def fit(self, X, y): |
|
|
160 |
"""fit.""" |
|
|
161 |
# Create ERP and get cov mat |
|
|
162 |
self.ERP = ERP(self.window, self.nfilters, self.subsample) |
|
|
163 |
train_cov = self.ERP.fit_transform(X, y) |
|
|
164 |
labels_train = self.ERP.labels_train |
|
|
165 |
|
|
|
166 |
# Add rest epochs |
|
|
167 |
rest_cov = self._get_rest_cov(X, y) |
|
|
168 |
train_cov = np.concatenate((train_cov, rest_cov), axis=0) |
|
|
169 |
labels_train = np.concatenate((labels_train, [0] * len(rest_cov))) |
|
|
170 |
|
|
|
171 |
# fit MDM |
|
|
172 |
self.MDM = MDM(metric=self.metric, n_jobs=self.n_jobs) |
|
|
173 |
self.MDM.fit(train_cov, labels_train) |
|
|
174 |
self._fitted = True |
|
|
175 |
return self |
|
|
176 |
|
|
|
177 |
def transform(self, X, y=None): |
|
|
178 |
"""Transform.""" |
|
|
179 |
test_cov = self.ERP.transform(X) |
|
|
180 |
dist = self.MDM.transform(test_cov) |
|
|
181 |
dist = dist[:, 1:] - np.atleast_2d(dist[:, 0]).T |
|
|
182 |
return dist |
|
|
183 |
|
|
|
184 |
def update_subsample(self, old_sub, new_sub): |
|
|
185 |
"""update subsampling.""" |
|
|
186 |
if self._fitted: |
|
|
187 |
self.ERP.update_subsample(old_sub, new_sub) |
|
|
188 |
|
|
|
189 |
def _get_rest_cov(self, X, y): |
|
|
190 |
"""Sample rest epochs from data and compute the cov mat.""" |
|
|
191 |
ix = np.where(np.diff(y[:, 0]) == 1)[0] |
|
|
192 |
rest = [] |
|
|
193 |
offset = - self.window |
|
|
194 |
for i in ix: |
|
|
195 |
start = i + offset - self.window |
|
|
196 |
stop = i + offset |
|
|
197 |
rest.append(self.ERP.erp_cov(X[slice(start, stop)].T)) |
|
|
198 |
return np.array(rest) |