a b/featurebased-approach/TrainClassifier.m
1
function TrainClassifier(feature_file)
2
% This function extracts features for each record present  in a folder
3
%
4
%  Input:
5
%       - feature_file:         file containing table with extracted
6
%                               features in different records
7
%       
8
% --
9
% ECG classification from single-lead segments using Deep Convolutional Neural 
10
% Networks and Feature-Based Approaches - December 2017
11
% 
12
% Released under the GNU General Public License
13
%
14
% Copyright (C) 2017  Fernando Andreotti, Oliver Carr
15
% University of Oxford, Insitute of Biomedical Engineering, CIBIM Lab - Oxford 2017
16
% fernando.andreotti@eng.ox.ac.uk
17
%
18
% 
19
% For more information visit: https://github.com/fernandoandreotti/cinc-challenge2017
20
% 
21
% Referencing this work
22
%
23
% Andreotti, F., Carr, O., Pimentel, M.A.F., Mahdi, A., & De Vos, M. (2017). 
24
% Comparing Feature Based Classifiers and Convolutional Neural Networks to Detect 
25
% Arrhythmia from Short Segments of ECG. In Computing in Cardiology. Rennes (France).
26
%
27
% Last updated : December 2017
28
% 
29
% This program is free software: you can redistribute it and/or modify
30
% it under the terms of the GNU General Public License as published by
31
% the Free Software Foundation, either version 3 of the License, or
32
% (at your option) any later version.
33
% 
34
% This program is distributed in the hope that it will be useful,
35
% but WITHOUT ANY WARRANTY; without even the implied warranty of
36
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
37
% GNU General Public License for more details.
38
% 
39
% You should have received a copy of the GNU General Public License
40
% along with this program.  If not, see <http://www.gnu.org/licenses/>.
41
42
load(feature_file)
43
NFEAT=size(allfeats,2);
44
NFEAT=NFEAT-2;
45
46
% Get summary statistics on the distribution of the features in each
47
% signal, using following:
48
% - median
49
% - inter quartile range
50
% - range
51
% - min value
52
% - max value
53
% - 25% perctile
54
% - 50% perctile
55
% - 75% percentile
56
% - Real coefficients of Hilbert transform 
57
% - Absolute values of Hilbert transform
58
% - Skewness
59
% - Kurtosis
60
% 
61
feat = zeros(max(allfeats.rec_number),16*NFEAT);
62
for i=1:max(allfeats.rec_number)
63
    fprintf('Processing record %d .. \n',i)
64
    ind=find(table2array(allfeats(:,1))==i);
65
    feat(i,1:NFEAT)=nanmean(table2array(allfeats(ind,3:end)));
66
    feat(i,1*NFEAT+1:2*NFEAT)=nanstd(table2array(allfeats(ind,3:end)));
67
    if length(ind)>2
68
        PCAn=pca(table2array(allfeats(ind,3:end)));
69
        feat(i,2*NFEAT+1:3*NFEAT)=PCAn(:,1);
70
        feat(i,3*NFEAT+1:4*NFEAT)=PCAn(:,2);
71
    else
72
        feat(i,2*NFEAT+1:3*NFEAT)=NaN;
73
        feat(i,3*NFEAT+1:4*NFEAT)=NaN;
74
    end
75
    feat(i,4*NFEAT+1:5*NFEAT)=nanmedian(table2array(allfeats(ind,3:end)));
76
    feat(i,5*NFEAT+1:6*NFEAT)=iqr(table2array(allfeats(ind,3:end)));
77
    feat(i,6*NFEAT+1:7*NFEAT)=range(table2array(allfeats(ind,3:end)));
78
    feat(i,7*NFEAT+1:8*NFEAT)=min(table2array(allfeats(ind,3:end)));
79
    feat(i,8*NFEAT+1:9*NFEAT)=max(table2array(allfeats(ind,3:end)));
80
    feat(i,9*NFEAT+1:10*NFEAT)=prctile(table2array(allfeats(ind,3:end)),25);
81
    feat(i,10*NFEAT+1:11*NFEAT)=prctile(table2array(allfeats(ind,3:end)),50);
82
    feat(i,11*NFEAT+1:12*NFEAT)=prctile(table2array(allfeats(ind,3:end)),75);
83
    HIL=hilbert(table2array(allfeats(ind,3:end)));
84
    feat(i,12*NFEAT+1:13*NFEAT)=real(HIL(1,:));
85
    feat(i,13*NFEAT+1:14*NFEAT)=abs(HIL(1,:));
86
    feat(i,14*NFEAT+1:15*NFEAT)=skewness(table2array(allfeats(ind,3:end)));
87
    feat(i,15*NFEAT+1:16*NFEAT)=kurtosis(table2array(allfeats(ind,3:end))); 
88
end
89
90
In = feat;
91
Ntrain = size(In,1);
92
In(isnan(In)) = 0;
93
% Standardizing input
94
In = In - mean(In);
95
In = In./std(In);
96
97
labels = {'A' 'N' 'O' '~'};
98
Out = reference_tab{:,2};
99
Outbi = cell2mat(cellfun(@(x) strcmp(x,labels),Out,'UniformOutput',0));
100
Outde = bi2de(Outbi);
101
Outde(Outde == 4) = 3;
102
Outde(Outde == 8) = 4;
103
clear Out
104
rng(1); % For reproducibility
105
%% Perform cross-validation
106
%== Subset sampling
107
k = 5;
108
cv = cvpartition(Outde,'kfold',k);
109
confusion = zeros(4,4,k);
110
F1save = zeros(k,4);
111
F1_best = 0;
112
for i=1:k
113
    fprintf('Cross-validation loop %d \n',i)
114
    trainidx = find(training(cv,i));
115
    trainidx = trainidx(randperm(length(trainidx)));
116
    testidx  = find(test(cv,i));
117
    %% Bagged trees (oversampled)
118
    ens = fitensemble(In(trainidx,:),Outde(trainidx),'Bag',50,'Tree','type','classification');
119
    [~,probTree] = predict(ens,In(testidx,:));
120
    
121
    %% Neural networks
122
    net = patternnet(10);
123
    net = train(net,In(trainidx,:)',Outbi(trainidx,:)');            
124
    probNN = net(In(testidx,:)')';    
125
    
126
    %% Combining methods
127
    C = cat(3,probTree,probNN);
128
    C = mean(C,3);
129
    estimate = zeros(size(C,1),1);
130
    for r = 1:size(C,1)
131
        [~,estimate(r)] = max(C(r,:));
132
    end
133
    confmat = confusionmat(Outde(testidx),estimate);
134
    confusion(:,:,i) = confmat;    
135
    F1 = zeros(1,4);
136
    for j = 1:4
137
        F1(j)=2*confmat(j,j)/(sum(confmat(j,:))+sum(confmat(:,j)));
138
        fprintf('F1 measure for %s rhythm: %1.4f \n',labels{j},F1(j))
139
    end
140
    F1save(i,:) = F1;
141
    
142
    if F1 > F1_best
143
        F1_best = F1;
144
        ensTree_best = compact(ens);
145
        nnet_best = net;
146
    end
147
end
148
%% Producing statistics
149
confusion = sum(confusion,3);
150
F1 = zeros(1,4);
151
for i = 1:4
152
    F1(i)=2*confusion(i,i)/(sum(confusion(i,:))+sum(confusion(:,i)));
153
    fprintf('F1 measure for %s rhythm: %1.4f \n',labels{i},F1(i))
154
end
155
fprintf('Final F1 measure:  %1.4f\n',mean(F1))
156
157
%% Save output
158
save('results_allfeat.mat','F1save','F1_best')
159
save('ensTree.mat','ensTree_best')
160
save('nNets.mat','nnet_best')
161
162
163
164
165
166
167
168