|
a |
|
b/preprocessing/covs_alex.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
""" |
|
|
3 |
Covariance model by alex. |
|
|
4 |
|
|
|
5 |
@author: Alexandre Barachant |
|
|
6 |
""" |
|
|
7 |
import numpy as np |
|
|
8 |
from pyriemann.classification import MDM |
|
|
9 |
from pyriemann.utils.mean import mean_covariance |
|
|
10 |
from sklearn.base import BaseEstimator, TransformerMixin |
|
|
11 |
from multiprocessing import Pool |
|
|
12 |
from functools import partial |
|
|
13 |
from mne.event import _find_stim_steps |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
def create_sequence(events): |
|
|
17 |
"""create sequence from events. |
|
|
18 |
|
|
|
19 |
Create a sequence of non-overlapped States from labels. |
|
|
20 |
""" |
|
|
21 |
# init variable |
|
|
22 |
sequence = np.zeros((events.shape[1], 1)) |
|
|
23 |
|
|
|
24 |
# get hand start |
|
|
25 |
handStart = np.int64(_find_stim_steps(np.atleast_2d(events[0]), 0)[::2, 0]) |
|
|
26 |
|
|
|
27 |
# lift |
|
|
28 |
lift_on = np.int64(_find_stim_steps(np.atleast_2d(events[1]), 0)[::2, 0]) |
|
|
29 |
lift_off = np.int64(_find_stim_steps(np.atleast_2d(events[3]), 0)[1::2, 0]) |
|
|
30 |
|
|
|
31 |
# replace |
|
|
32 |
replace_on = np.int64(_find_stim_steps(np.atleast_2d(events[4]), 0) |
|
|
33 |
[::2, 0]) |
|
|
34 |
replace_off = np.int64(_find_stim_steps(np.atleast_2d(events[5]), 0) |
|
|
35 |
[1::2, 0]) |
|
|
36 |
|
|
|
37 |
for i in range(len(handStart)): |
|
|
38 |
j = 1 |
|
|
39 |
sequence[(handStart[i] - 250):handStart[i]] = j |
|
|
40 |
j += 1 |
|
|
41 |
sequence[handStart[i]:lift_on[i]] = j |
|
|
42 |
j += 1 |
|
|
43 |
sequence[lift_on[i]:lift_off[i]] = j |
|
|
44 |
j += 1 |
|
|
45 |
sequence[lift_off[i]:replace_on[i]] = j |
|
|
46 |
j += 1 |
|
|
47 |
sequence[replace_on[i]:replace_off[i]] = j |
|
|
48 |
j += 1 |
|
|
49 |
sequence[replace_off[i]:(replace_off[i] + 250)] = j |
|
|
50 |
j += 1 |
|
|
51 |
|
|
|
52 |
return sequence |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
class DistanceCalculatorAlex(BaseEstimator, TransformerMixin): |
|
|
56 |
|
|
|
57 |
"""Distance Calulator Based on MDM.""" |
|
|
58 |
|
|
|
59 |
def __init__(self, metric_mean='logeuclid', metric_dist=['riemann'], |
|
|
60 |
n_jobs=7, subsample=10): |
|
|
61 |
"""Init.""" |
|
|
62 |
self.metric_mean = metric_mean |
|
|
63 |
self.metric_dist = metric_dist |
|
|
64 |
self.n_jobs = n_jobs |
|
|
65 |
self.subsample = subsample |
|
|
66 |
|
|
|
67 |
def fit(self, X, y): |
|
|
68 |
"""Fit.""" |
|
|
69 |
self.mdm = MDM(metric=self.metric_mean, n_jobs=self.n_jobs) |
|
|
70 |
labels = np.squeeze(create_sequence(y.T)[::self.subsample]) |
|
|
71 |
self.mdm.fit(X, labels) |
|
|
72 |
return self |
|
|
73 |
|
|
|
74 |
def transform(self, X, y=None): |
|
|
75 |
"""Transform.""" |
|
|
76 |
feattr = [] |
|
|
77 |
for metric in self.metric_dist: |
|
|
78 |
self.mdm.metric_dist = metric |
|
|
79 |
feat = self.mdm.transform(X) |
|
|
80 |
# substract distance of the class 0 |
|
|
81 |
feat = feat[:, 1:] - np.atleast_2d(feat[:, 0]).T |
|
|
82 |
feattr.append(feat) |
|
|
83 |
feattr = np.concatenate(feattr, axis=1) |
|
|
84 |
feattr[np.isnan(feattr)] = 0 |
|
|
85 |
return feattr |
|
|
86 |
|
|
|
87 |
def fit_transform(self, X, y): |
|
|
88 |
"""Fit and transform.""" |
|
|
89 |
self.fit(X, y) |
|
|
90 |
return self.transform(X) |
|
|
91 |
|
|
|
92 |
|
|
|
93 |
class DistanceCalculatorRafal(BaseEstimator, TransformerMixin): |
|
|
94 |
|
|
|
95 |
"""Distance Calulator Based on MDM Rafal style.""" |
|
|
96 |
|
|
|
97 |
def __init__(self, metric_mean='logeuclid', metric_dist=['riemann'], |
|
|
98 |
n_jobs=12, subsample=10): |
|
|
99 |
"""Init.""" |
|
|
100 |
self.metric_mean = metric_mean |
|
|
101 |
self.metric_dist = metric_dist |
|
|
102 |
self.n_jobs = n_jobs |
|
|
103 |
self.subsample = subsample |
|
|
104 |
|
|
|
105 |
def fit(self, X, y): |
|
|
106 |
"""Fit.""" |
|
|
107 |
self.mdm = MDM(metric=self.metric_mean, n_jobs=self.n_jobs) |
|
|
108 |
labels = y[::self.subsample] |
|
|
109 |
pCalcMeans = partial(mean_covariance, metric=self.metric_mean) |
|
|
110 |
pool = Pool(processes=6) |
|
|
111 |
mc1 = pool.map(pCalcMeans, [X[labels[:, i] == 1] for i in range(6)]) |
|
|
112 |
pool.close() |
|
|
113 |
pool = Pool(processes=6) |
|
|
114 |
mc0 = pool.map(pCalcMeans, [X[labels[:, i] == 0] for i in range(6)]) |
|
|
115 |
pool.close() |
|
|
116 |
self.mdm.covmeans = mc1 + mc0 |
|
|
117 |
return self |
|
|
118 |
|
|
|
119 |
def transform(self, X, y=None): |
|
|
120 |
"""Transform.""" |
|
|
121 |
feattr = [] |
|
|
122 |
for metric in self.metric_dist: |
|
|
123 |
self.mdm.metric_dist = metric |
|
|
124 |
feat = self.mdm.transform(X) |
|
|
125 |
# substract distance of the class 0 |
|
|
126 |
feat = feat[:, 0:6] - feat[:, 6:] |
|
|
127 |
feattr.append(feat) |
|
|
128 |
feattr = np.concatenate(feattr, axis=1) |
|
|
129 |
feattr[np.isnan(feattr)] = 0 |
|
|
130 |
return feattr |
|
|
131 |
|
|
|
132 |
def fit_transform(self, X, y): |
|
|
133 |
"""Fit and transform.""" |
|
|
134 |
self.fit(X, y) |
|
|
135 |
return self.transform(X) |