a b/preprocessing/covs_alex.py
1
# -*- coding: utf-8 -*-
2
"""
3
Covariance model by alex.
4
5
@author: Alexandre Barachant
6
"""
7
import numpy as np
8
from pyriemann.classification import MDM
9
from pyriemann.utils.mean import mean_covariance
10
from sklearn.base import BaseEstimator, TransformerMixin
11
from multiprocessing import Pool
12
from functools import partial
13
from mne.event import _find_stim_steps
14
15
16
def create_sequence(events):
17
    """create sequence from events.
18
19
    Create a sequence of non-overlapped States from labels.
20
    """
21
    # init variable
22
    sequence = np.zeros((events.shape[1], 1))
23
24
    # get hand start
25
    handStart = np.int64(_find_stim_steps(np.atleast_2d(events[0]), 0)[::2, 0])
26
27
    # lift
28
    lift_on = np.int64(_find_stim_steps(np.atleast_2d(events[1]), 0)[::2, 0])
29
    lift_off = np.int64(_find_stim_steps(np.atleast_2d(events[3]), 0)[1::2, 0])
30
31
    # replace
32
    replace_on = np.int64(_find_stim_steps(np.atleast_2d(events[4]), 0)
33
                          [::2, 0])
34
    replace_off = np.int64(_find_stim_steps(np.atleast_2d(events[5]), 0)
35
                           [1::2, 0])
36
37
    for i in range(len(handStart)):
38
        j = 1
39
        sequence[(handStart[i] - 250):handStart[i]] = j
40
        j += 1
41
        sequence[handStart[i]:lift_on[i]] = j
42
        j += 1
43
        sequence[lift_on[i]:lift_off[i]] = j
44
        j += 1
45
        sequence[lift_off[i]:replace_on[i]] = j
46
        j += 1
47
        sequence[replace_on[i]:replace_off[i]] = j
48
        j += 1
49
        sequence[replace_off[i]:(replace_off[i] + 250)] = j
50
        j += 1
51
52
    return sequence
53
54
55
class DistanceCalculatorAlex(BaseEstimator, TransformerMixin):
56
57
    """Distance Calulator Based on MDM."""
58
59
    def __init__(self, metric_mean='logeuclid', metric_dist=['riemann'],
60
                 n_jobs=7, subsample=10):
61
        """Init."""
62
        self.metric_mean = metric_mean
63
        self.metric_dist = metric_dist
64
        self.n_jobs = n_jobs
65
        self.subsample = subsample
66
67
    def fit(self, X, y):
68
        """Fit."""
69
        self.mdm = MDM(metric=self.metric_mean, n_jobs=self.n_jobs)
70
        labels = np.squeeze(create_sequence(y.T)[::self.subsample])
71
        self.mdm.fit(X, labels)
72
        return self
73
74
    def transform(self, X, y=None):
75
        """Transform."""
76
        feattr = []
77
        for metric in self.metric_dist:
78
            self.mdm.metric_dist = metric
79
            feat = self.mdm.transform(X)
80
            # substract distance of the class 0
81
            feat = feat[:, 1:] - np.atleast_2d(feat[:, 0]).T
82
            feattr.append(feat)
83
        feattr = np.concatenate(feattr, axis=1)
84
        feattr[np.isnan(feattr)] = 0
85
        return feattr
86
87
    def fit_transform(self, X, y):
88
        """Fit and transform."""
89
        self.fit(X, y)
90
        return self.transform(X)
91
92
93
class DistanceCalculatorRafal(BaseEstimator, TransformerMixin):
94
95
    """Distance Calulator Based on MDM Rafal style."""
96
97
    def __init__(self, metric_mean='logeuclid', metric_dist=['riemann'],
98
                 n_jobs=12, subsample=10):
99
        """Init."""
100
        self.metric_mean = metric_mean
101
        self.metric_dist = metric_dist
102
        self.n_jobs = n_jobs
103
        self.subsample = subsample
104
105
    def fit(self, X, y):
106
        """Fit."""
107
        self.mdm = MDM(metric=self.metric_mean, n_jobs=self.n_jobs)
108
        labels = y[::self.subsample]
109
        pCalcMeans = partial(mean_covariance, metric=self.metric_mean)
110
        pool = Pool(processes=6)
111
        mc1 = pool.map(pCalcMeans, [X[labels[:, i] == 1] for i in range(6)])
112
        pool.close()
113
        pool = Pool(processes=6)
114
        mc0 = pool.map(pCalcMeans, [X[labels[:, i] == 0] for i in range(6)])
115
        pool.close()
116
        self.mdm.covmeans = mc1 + mc0
117
        return self
118
119
    def transform(self, X, y=None):
120
        """Transform."""
121
        feattr = []
122
        for metric in self.metric_dist:
123
            self.mdm.metric_dist = metric
124
            feat = self.mdm.transform(X)
125
            # substract distance of the class 0
126
            feat = feat[:, 0:6] - feat[:, 6:]
127
            feattr.append(feat)
128
        feattr = np.concatenate(feattr, axis=1)
129
        feattr[np.isnan(feattr)] = 0
130
        return feattr
131
132
    def fit_transform(self, X, y):
133
        """Fit and transform."""
134
        self.fit(X, y)
135
        return self.transform(X)