|
a |
|
b/featurebased-approach/TrainClassifier.m |
|
|
1 |
function TrainClassifier(feature_file) |
|
|
2 |
% This function extracts features for each record present in a folder |
|
|
3 |
% |
|
|
4 |
% Input: |
|
|
5 |
% - feature_file: file containing table with extracted |
|
|
6 |
% features in different records |
|
|
7 |
% |
|
|
8 |
% -- |
|
|
9 |
% ECG classification from single-lead segments using Deep Convolutional Neural |
|
|
10 |
% Networks and Feature-Based Approaches - December 2017 |
|
|
11 |
% |
|
|
12 |
% Released under the GNU General Public License |
|
|
13 |
% |
|
|
14 |
% Copyright (C) 2017 Fernando Andreotti, Oliver Carr |
|
|
15 |
% University of Oxford, Insitute of Biomedical Engineering, CIBIM Lab - Oxford 2017 |
|
|
16 |
% fernando.andreotti@eng.ox.ac.uk |
|
|
17 |
% |
|
|
18 |
% |
|
|
19 |
% For more information visit: https://github.com/fernandoandreotti/cinc-challenge2017 |
|
|
20 |
% |
|
|
21 |
% Referencing this work |
|
|
22 |
% |
|
|
23 |
% Andreotti, F., Carr, O., Pimentel, M.A.F., Mahdi, A., & De Vos, M. (2017). |
|
|
24 |
% Comparing Feature Based Classifiers and Convolutional Neural Networks to Detect |
|
|
25 |
% Arrhythmia from Short Segments of ECG. In Computing in Cardiology. Rennes (France). |
|
|
26 |
% |
|
|
27 |
% Last updated : December 2017 |
|
|
28 |
% |
|
|
29 |
% This program is free software: you can redistribute it and/or modify |
|
|
30 |
% it under the terms of the GNU General Public License as published by |
|
|
31 |
% the Free Software Foundation, either version 3 of the License, or |
|
|
32 |
% (at your option) any later version. |
|
|
33 |
% |
|
|
34 |
% This program is distributed in the hope that it will be useful, |
|
|
35 |
% but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
|
36 |
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
|
37 |
% GNU General Public License for more details. |
|
|
38 |
% |
|
|
39 |
% You should have received a copy of the GNU General Public License |
|
|
40 |
% along with this program. If not, see <http://www.gnu.org/licenses/>. |
|
|
41 |
|
|
|
42 |
load(feature_file) |
|
|
43 |
NFEAT=size(allfeats,2); |
|
|
44 |
NFEAT=NFEAT-2; |
|
|
45 |
|
|
|
46 |
% Get summary statistics on the distribution of the features in each |
|
|
47 |
% signal, using following: |
|
|
48 |
% - median |
|
|
49 |
% - inter quartile range |
|
|
50 |
% - range |
|
|
51 |
% - min value |
|
|
52 |
% - max value |
|
|
53 |
% - 25% perctile |
|
|
54 |
% - 50% perctile |
|
|
55 |
% - 75% percentile |
|
|
56 |
% - Real coefficients of Hilbert transform |
|
|
57 |
% - Absolute values of Hilbert transform |
|
|
58 |
% - Skewness |
|
|
59 |
% - Kurtosis |
|
|
60 |
% |
|
|
61 |
feat = zeros(max(allfeats.rec_number),16*NFEAT); |
|
|
62 |
for i=1:max(allfeats.rec_number) |
|
|
63 |
fprintf('Processing record %d .. \n',i) |
|
|
64 |
ind=find(table2array(allfeats(:,1))==i); |
|
|
65 |
feat(i,1:NFEAT)=nanmean(table2array(allfeats(ind,3:end))); |
|
|
66 |
feat(i,1*NFEAT+1:2*NFEAT)=nanstd(table2array(allfeats(ind,3:end))); |
|
|
67 |
if length(ind)>2 |
|
|
68 |
PCAn=pca(table2array(allfeats(ind,3:end))); |
|
|
69 |
feat(i,2*NFEAT+1:3*NFEAT)=PCAn(:,1); |
|
|
70 |
feat(i,3*NFEAT+1:4*NFEAT)=PCAn(:,2); |
|
|
71 |
else |
|
|
72 |
feat(i,2*NFEAT+1:3*NFEAT)=NaN; |
|
|
73 |
feat(i,3*NFEAT+1:4*NFEAT)=NaN; |
|
|
74 |
end |
|
|
75 |
feat(i,4*NFEAT+1:5*NFEAT)=nanmedian(table2array(allfeats(ind,3:end))); |
|
|
76 |
feat(i,5*NFEAT+1:6*NFEAT)=iqr(table2array(allfeats(ind,3:end))); |
|
|
77 |
feat(i,6*NFEAT+1:7*NFEAT)=range(table2array(allfeats(ind,3:end))); |
|
|
78 |
feat(i,7*NFEAT+1:8*NFEAT)=min(table2array(allfeats(ind,3:end))); |
|
|
79 |
feat(i,8*NFEAT+1:9*NFEAT)=max(table2array(allfeats(ind,3:end))); |
|
|
80 |
feat(i,9*NFEAT+1:10*NFEAT)=prctile(table2array(allfeats(ind,3:end)),25); |
|
|
81 |
feat(i,10*NFEAT+1:11*NFEAT)=prctile(table2array(allfeats(ind,3:end)),50); |
|
|
82 |
feat(i,11*NFEAT+1:12*NFEAT)=prctile(table2array(allfeats(ind,3:end)),75); |
|
|
83 |
HIL=hilbert(table2array(allfeats(ind,3:end))); |
|
|
84 |
feat(i,12*NFEAT+1:13*NFEAT)=real(HIL(1,:)); |
|
|
85 |
feat(i,13*NFEAT+1:14*NFEAT)=abs(HIL(1,:)); |
|
|
86 |
feat(i,14*NFEAT+1:15*NFEAT)=skewness(table2array(allfeats(ind,3:end))); |
|
|
87 |
feat(i,15*NFEAT+1:16*NFEAT)=kurtosis(table2array(allfeats(ind,3:end))); |
|
|
88 |
end |
|
|
89 |
|
|
|
90 |
In = feat; |
|
|
91 |
Ntrain = size(In,1); |
|
|
92 |
In(isnan(In)) = 0; |
|
|
93 |
% Standardizing input |
|
|
94 |
In = In - mean(In); |
|
|
95 |
In = In./std(In); |
|
|
96 |
|
|
|
97 |
labels = {'A' 'N' 'O' '~'}; |
|
|
98 |
Out = reference_tab{:,2}; |
|
|
99 |
Outbi = cell2mat(cellfun(@(x) strcmp(x,labels),Out,'UniformOutput',0)); |
|
|
100 |
Outde = bi2de(Outbi); |
|
|
101 |
Outde(Outde == 4) = 3; |
|
|
102 |
Outde(Outde == 8) = 4; |
|
|
103 |
clear Out |
|
|
104 |
rng(1); % For reproducibility |
|
|
105 |
%% Perform cross-validation |
|
|
106 |
%== Subset sampling |
|
|
107 |
k = 5; |
|
|
108 |
cv = cvpartition(Outde,'kfold',k); |
|
|
109 |
confusion = zeros(4,4,k); |
|
|
110 |
F1save = zeros(k,4); |
|
|
111 |
F1_best = 0; |
|
|
112 |
for i=1:k |
|
|
113 |
fprintf('Cross-validation loop %d \n',i) |
|
|
114 |
trainidx = find(training(cv,i)); |
|
|
115 |
trainidx = trainidx(randperm(length(trainidx))); |
|
|
116 |
testidx = find(test(cv,i)); |
|
|
117 |
%% Bagged trees (oversampled) |
|
|
118 |
ens = fitensemble(In(trainidx,:),Outde(trainidx),'Bag',50,'Tree','type','classification'); |
|
|
119 |
[~,probTree] = predict(ens,In(testidx,:)); |
|
|
120 |
|
|
|
121 |
%% Neural networks |
|
|
122 |
net = patternnet(10); |
|
|
123 |
net = train(net,In(trainidx,:)',Outbi(trainidx,:)'); |
|
|
124 |
probNN = net(In(testidx,:)')'; |
|
|
125 |
|
|
|
126 |
%% Combining methods |
|
|
127 |
C = cat(3,probTree,probNN); |
|
|
128 |
C = mean(C,3); |
|
|
129 |
estimate = zeros(size(C,1),1); |
|
|
130 |
for r = 1:size(C,1) |
|
|
131 |
[~,estimate(r)] = max(C(r,:)); |
|
|
132 |
end |
|
|
133 |
confmat = confusionmat(Outde(testidx),estimate); |
|
|
134 |
confusion(:,:,i) = confmat; |
|
|
135 |
F1 = zeros(1,4); |
|
|
136 |
for j = 1:4 |
|
|
137 |
F1(j)=2*confmat(j,j)/(sum(confmat(j,:))+sum(confmat(:,j))); |
|
|
138 |
fprintf('F1 measure for %s rhythm: %1.4f \n',labels{j},F1(j)) |
|
|
139 |
end |
|
|
140 |
F1save(i,:) = F1; |
|
|
141 |
|
|
|
142 |
if F1 > F1_best |
|
|
143 |
F1_best = F1; |
|
|
144 |
ensTree_best = compact(ens); |
|
|
145 |
nnet_best = net; |
|
|
146 |
end |
|
|
147 |
end |
|
|
148 |
%% Producing statistics |
|
|
149 |
confusion = sum(confusion,3); |
|
|
150 |
F1 = zeros(1,4); |
|
|
151 |
for i = 1:4 |
|
|
152 |
F1(i)=2*confusion(i,i)/(sum(confusion(i,:))+sum(confusion(:,i))); |
|
|
153 |
fprintf('F1 measure for %s rhythm: %1.4f \n',labels{i},F1(i)) |
|
|
154 |
end |
|
|
155 |
fprintf('Final F1 measure: %1.4f\n',mean(F1)) |
|
|
156 |
|
|
|
157 |
%% Save output |
|
|
158 |
save('results_allfeat.mat','F1save','F1_best') |
|
|
159 |
save('ensTree.mat','ensTree_best') |
|
|
160 |
save('nNets.mat','nnet_best') |
|
|
161 |
|
|
|
162 |
|
|
|
163 |
|
|
|
164 |
|
|
|
165 |
|
|
|
166 |
|
|
|
167 |
|
|
|
168 |
|