[9c6ad1]: / common_spatial_pattern.py

Download this file

77 lines (59 with data), 2.3 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
Used to calculate the common spatial pattern filter for four-class classification
"""
import numpy as np
import matplotlib.pyplot as plt
from numpy.linalg import eig
def csp(data_train, label_train):
idx_0 = np.squeeze(np.where(label_train == 0))
idx_1 = np.squeeze(np.where(label_train == 1))
idx_2 = np.squeeze(np.where(label_train == 2))
idx_3 = np.squeeze(np.where(label_train == 3))
W = []
for n_class in range(4):
if n_class == 0:
idx_L = idx_0
idx_R = np.concatenate((idx_1, idx_2, idx_3))
elif n_class == 1:
idx_L = idx_1
idx_R = np.concatenate((idx_0, idx_2, idx_3))
elif n_class == 2:
idx_L = idx_2
idx_R = np.concatenate((idx_0, idx_1, idx_3))
elif n_class == 3:
idx_L = idx_3
idx_R = np.concatenate((idx_0, idx_1, idx_2))
idx_R = np.sort(idx_R)
Cov_L = np.zeros([22, 22, len(idx_L)])
Cov_R = np.zeros([22, 22, len(idx_R)])
for nL in range(len(idx_L)):
E = data_train[idx_L[nL], :, :]
EE = np.dot(E.transpose(), E)
Cov_L[:, :, nL] = EE / np.trace(EE)
for nR in range(len(idx_R)):
E = data_train[idx_R[nR], :, :]
EE = np.dot(E.transpose(), E)
Cov_R[:, :, nR] = EE / np.trace(EE)
Cov_L = np.mean(Cov_L, axis=2)
Cov_R = np.mean(Cov_R, axis=2)
CovTotal = Cov_L + Cov_R
lam, Uc = eig(CovTotal)
eigorder = np.argsort(lam)
eigorder = eigorder[::-1]
lam = lam[eigorder]
Ut = Uc[:, eigorder]
Ptmp = np.sqrt(np.diag(np.power(lam, -1)))
P = np.dot(Ptmp, Ut.transpose())
SL = np.dot(P, Cov_L)
SLL = np.dot(SL, P.transpose())
SR = np.dot(P, Cov_R)
SRR = np.dot(SR, P.transpose())
lam_R, BR = eig(SRR)
erorder = np.argsort(lam_R)
B = BR[:, erorder]
w = np.dot(P.transpose(), B)
W.append(w)
Wb = np.concatenate((W[0][:, 0:4], W[1][:, 0:4], W[2][:, 0:4], W[3][:, 0:4]), axis=1)
# The original one is two use the first and last r row, I just use the first 2r.
# Not significant difference, 2r could be better.
return Wb