Switch to unified view

a b/Feature Extraction/learnCSP.m
1
function CSPMatrix = learnCSP(EEGSignals,classLabels)
2
%
3
%Input:
4
%EEGSignals: the training EEG signals, composed of 2 classes. These signals
5
%are a structure such that:
6
%   EEGSignals.x: the EEG signals as a [Ns * Nc * Nt] Matrix where
7
%       Ns: number of EEG samples per trial
8
%       Nc: number of channels (EEG electrodes)
9
%       nT: number of trials
10
%   EEGSignals.y: a [1 * Nt] vector containing the class labels for each trial
11
%   EEGSignals.s: the sampling frequency (in Hz)
12
%
13
%Output:
14
%CSPMatrix: the learnt CSP filters (a [Nc*Nc] matrix with the filters as rows)
15
%
16
%See also: extractCSPFeatures
17
18
%check and initializations
19
nbChannels = size(EEGSignals.x,2);
20
nbTrials = size(EEGSignals.x,3);
21
nbClasses = length(classLabels);
22
23
if nbClasses ~= 2
24
    disp('ERROR! CSP can only be used for two classes');
25
    return;
26
end
27
28
covMatrices = cell(nbClasses,1); %the covariance matrices for each class
29
30
%% Computing the normalized covariance matrices for each trial
31
trialCov = zeros(nbChannels,nbChannels,nbTrials);
32
for t=1:nbTrials
33
    E = EEGSignals.x(:,:,t)';                       %note the transpose
34
    EE = E * E';
35
    trialCov(:,:,t) = EE ./ trace(EE);
36
end
37
clear E;
38
clear EE;
39
40
%computing the covariance matrix for each class
41
for c=1:nbClasses      
42
    covMatrices{c} = mean(trialCov(:,:,EEGSignals.y == classLabels(c)),3); %EEGSignals.y==classLabels(c) returns the indeces corresponding to the class labels  
43
end
44
45
%the total covariance matrix
46
covTotal = covMatrices{1} + covMatrices{2};
47
48
%whitening transform of total covariance matrix
49
[Ut Dt] = eig(covTotal); %caution: the eigenvalues are initially in increasing order
50
eigenvalues = diag(Dt);
51
[eigenvalues egIndex] = sort(eigenvalues, 'descend');
52
Ut = Ut(:,egIndex);
53
P = diag(sqrt(1./eigenvalues)) * Ut';
54
55
%transforming covariance matrix of first class using P
56
transformedCov1 =  P * covMatrices{1} * P';
57
58
%EVD of the transformed covariance matrix
59
[U1 D1] = eig(transformedCov1);
60
eigenvalues = diag(D1);
61
[eigenvalues egIndex] = sort(eigenvalues, 'descend');
62
U1 = U1(:, egIndex);
63
CSPMatrix = U1' * P;