--- a +++ b/featurebased-approach/TrainClassifier.m @@ -0,0 +1,168 @@ +function TrainClassifier(feature_file) +% This function extracts features for each record present in a folder +% +% Input: +% - feature_file: file containing table with extracted +% features in different records +% +% -- +% ECG classification from single-lead segments using Deep Convolutional Neural +% Networks and Feature-Based Approaches - December 2017 +% +% Released under the GNU General Public License +% +% Copyright (C) 2017 Fernando Andreotti, Oliver Carr +% University of Oxford, Insitute of Biomedical Engineering, CIBIM Lab - Oxford 2017 +% fernando.andreotti@eng.ox.ac.uk +% +% +% For more information visit: https://github.com/fernandoandreotti/cinc-challenge2017 +% +% Referencing this work +% +% Andreotti, F., Carr, O., Pimentel, M.A.F., Mahdi, A., & De Vos, M. (2017). +% Comparing Feature Based Classifiers and Convolutional Neural Networks to Detect +% Arrhythmia from Short Segments of ECG. In Computing in Cardiology. Rennes (France). +% +% Last updated : December 2017 +% +% This program is free software: you can redistribute it and/or modify +% it under the terms of the GNU General Public License as published by +% the Free Software Foundation, either version 3 of the License, or +% (at your option) any later version. +% +% This program is distributed in the hope that it will be useful, +% but WITHOUT ANY WARRANTY; without even the implied warranty of +% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +% GNU General Public License for more details. +% +% You should have received a copy of the GNU General Public License +% along with this program. If not, see <http://www.gnu.org/licenses/>. + +load(feature_file) +NFEAT=size(allfeats,2); +NFEAT=NFEAT-2; + +% Get summary statistics on the distribution of the features in each +% signal, using following: +% - median +% - inter quartile range +% - range +% - min value +% - max value +% - 25% perctile +% - 50% perctile +% - 75% percentile +% - Real coefficients of Hilbert transform +% - Absolute values of Hilbert transform +% - Skewness +% - Kurtosis +% +feat = zeros(max(allfeats.rec_number),16*NFEAT); +for i=1:max(allfeats.rec_number) + fprintf('Processing record %d .. \n',i) + ind=find(table2array(allfeats(:,1))==i); + feat(i,1:NFEAT)=nanmean(table2array(allfeats(ind,3:end))); + feat(i,1*NFEAT+1:2*NFEAT)=nanstd(table2array(allfeats(ind,3:end))); + if length(ind)>2 + PCAn=pca(table2array(allfeats(ind,3:end))); + feat(i,2*NFEAT+1:3*NFEAT)=PCAn(:,1); + feat(i,3*NFEAT+1:4*NFEAT)=PCAn(:,2); + else + feat(i,2*NFEAT+1:3*NFEAT)=NaN; + feat(i,3*NFEAT+1:4*NFEAT)=NaN; + end + feat(i,4*NFEAT+1:5*NFEAT)=nanmedian(table2array(allfeats(ind,3:end))); + feat(i,5*NFEAT+1:6*NFEAT)=iqr(table2array(allfeats(ind,3:end))); + feat(i,6*NFEAT+1:7*NFEAT)=range(table2array(allfeats(ind,3:end))); + feat(i,7*NFEAT+1:8*NFEAT)=min(table2array(allfeats(ind,3:end))); + feat(i,8*NFEAT+1:9*NFEAT)=max(table2array(allfeats(ind,3:end))); + feat(i,9*NFEAT+1:10*NFEAT)=prctile(table2array(allfeats(ind,3:end)),25); + feat(i,10*NFEAT+1:11*NFEAT)=prctile(table2array(allfeats(ind,3:end)),50); + feat(i,11*NFEAT+1:12*NFEAT)=prctile(table2array(allfeats(ind,3:end)),75); + HIL=hilbert(table2array(allfeats(ind,3:end))); + feat(i,12*NFEAT+1:13*NFEAT)=real(HIL(1,:)); + feat(i,13*NFEAT+1:14*NFEAT)=abs(HIL(1,:)); + feat(i,14*NFEAT+1:15*NFEAT)=skewness(table2array(allfeats(ind,3:end))); + feat(i,15*NFEAT+1:16*NFEAT)=kurtosis(table2array(allfeats(ind,3:end))); +end + +In = feat; +Ntrain = size(In,1); +In(isnan(In)) = 0; +% Standardizing input +In = In - mean(In); +In = In./std(In); + +labels = {'A' 'N' 'O' '~'}; +Out = reference_tab{:,2}; +Outbi = cell2mat(cellfun(@(x) strcmp(x,labels),Out,'UniformOutput',0)); +Outde = bi2de(Outbi); +Outde(Outde == 4) = 3; +Outde(Outde == 8) = 4; +clear Out +rng(1); % For reproducibility +%% Perform cross-validation +%== Subset sampling +k = 5; +cv = cvpartition(Outde,'kfold',k); +confusion = zeros(4,4,k); +F1save = zeros(k,4); +F1_best = 0; +for i=1:k + fprintf('Cross-validation loop %d \n',i) + trainidx = find(training(cv,i)); + trainidx = trainidx(randperm(length(trainidx))); + testidx = find(test(cv,i)); + %% Bagged trees (oversampled) + ens = fitensemble(In(trainidx,:),Outde(trainidx),'Bag',50,'Tree','type','classification'); + [~,probTree] = predict(ens,In(testidx,:)); + + %% Neural networks + net = patternnet(10); + net = train(net,In(trainidx,:)',Outbi(trainidx,:)'); + probNN = net(In(testidx,:)')'; + + %% Combining methods + C = cat(3,probTree,probNN); + C = mean(C,3); + estimate = zeros(size(C,1),1); + for r = 1:size(C,1) + [~,estimate(r)] = max(C(r,:)); + end + confmat = confusionmat(Outde(testidx),estimate); + confusion(:,:,i) = confmat; + F1 = zeros(1,4); + for j = 1:4 + F1(j)=2*confmat(j,j)/(sum(confmat(j,:))+sum(confmat(:,j))); + fprintf('F1 measure for %s rhythm: %1.4f \n',labels{j},F1(j)) + end + F1save(i,:) = F1; + + if F1 > F1_best + F1_best = F1; + ensTree_best = compact(ens); + nnet_best = net; + end +end +%% Producing statistics +confusion = sum(confusion,3); +F1 = zeros(1,4); +for i = 1:4 + F1(i)=2*confusion(i,i)/(sum(confusion(i,:))+sum(confusion(:,i))); + fprintf('F1 measure for %s rhythm: %1.4f \n',labels{i},F1(i)) +end +fprintf('Final F1 measure: %1.4f\n',mean(F1)) + +%% Save output +save('results_allfeat.mat','F1save','F1_best') +save('ensTree.mat','ensTree_best') +save('nNets.mat','nnet_best') + + + + + + + +