--- a
+++ b/common_spatial_pattern.py
@@ -0,0 +1,76 @@
+"""
+Used to calculate the common spatial pattern filter for four-class classification
+"""
+
+
+import numpy as np
+import matplotlib.pyplot as plt
+from numpy.linalg import eig
+
+
+def csp(data_train, label_train):
+    idx_0 = np.squeeze(np.where(label_train == 0))
+    idx_1 = np.squeeze(np.where(label_train == 1))
+    idx_2 = np.squeeze(np.where(label_train == 2))
+    idx_3 = np.squeeze(np.where(label_train == 3))
+
+    W = []
+    for n_class in range(4):
+        if n_class == 0:
+            idx_L = idx_0
+            idx_R = np.concatenate((idx_1, idx_2, idx_3))
+        elif n_class == 1:
+            idx_L = idx_1
+            idx_R = np.concatenate((idx_0, idx_2, idx_3))
+        elif n_class == 2:
+            idx_L = idx_2
+            idx_R = np.concatenate((idx_0, idx_1, idx_3))
+        elif n_class == 3:
+            idx_L = idx_3
+            idx_R = np.concatenate((idx_0, idx_1, idx_2))
+
+        idx_R = np.sort(idx_R)
+        Cov_L = np.zeros([22, 22, len(idx_L)])
+        Cov_R = np.zeros([22, 22, len(idx_R)])
+
+        for nL in range(len(idx_L)):
+            E = data_train[idx_L[nL], :, :]
+            EE = np.dot(E.transpose(), E)
+            Cov_L[:, :, nL] = EE / np.trace(EE)
+        for nR in range(len(idx_R)):
+            E = data_train[idx_R[nR], :, :]
+            EE = np.dot(E.transpose(), E)
+            Cov_R[:, :, nR] = EE / np.trace(EE)
+
+        Cov_L = np.mean(Cov_L, axis=2)
+        Cov_R = np.mean(Cov_R, axis=2)
+        CovTotal = Cov_L + Cov_R
+
+        lam, Uc = eig(CovTotal)
+        eigorder = np.argsort(lam)
+        eigorder = eigorder[::-1]
+        lam = lam[eigorder]
+        Ut = Uc[:, eigorder]
+
+        Ptmp = np.sqrt(np.diag(np.power(lam, -1)))
+        P = np.dot(Ptmp, Ut.transpose())
+
+        SL = np.dot(P, Cov_L)
+        SLL = np.dot(SL, P.transpose())
+        SR = np.dot(P, Cov_R)
+        SRR = np.dot(SR, P.transpose())
+
+        lam_R, BR = eig(SRR)
+        erorder = np.argsort(lam_R)
+        B = BR[:, erorder]
+
+        w = np.dot(P.transpose(), B)
+        W.append(w)
+
+    Wb = np.concatenate((W[0][:, 0:4], W[1][:, 0:4], W[2][:, 0:4], W[3][:, 0:4]), axis=1)
+    # The original one is two use the first and last r row, I just use the first 2r.
+    # Not significant difference, 2r could be better.
+
+    return Wb
+
+