|
a |
|
b/CNN.m |
|
|
1 |
%% Setup of CNN based classifier |
|
|
2 |
% Fist option direct input to 2D CNN |
|
|
3 |
close all |
|
|
4 |
|
|
|
5 |
rootFolder = fullfile('C:\Users\tripats\Documents\Biomedical Signal Analysis\Grad Single Project\ECG'); |
|
|
6 |
|
|
|
7 |
categories = {'SinusTachycardia',... |
|
|
8 |
'SinusRhythm','SinusBradycardia',... |
|
|
9 |
'AtrialFibrilation',}; |
|
|
10 |
% categories = {'SupraventricularTachycardia','SinusTachycardia',... |
|
|
11 |
% 'SinusRhythm','SinusIrregularity','SinusBradycardia',... |
|
|
12 |
% 'SinusAtriumtoAtrialWanderingRhythm', 'AtrioventricularReentrantTachycardia',... |
|
|
13 |
% 'AtrioventricularNodeReentrantTachycardia','AtrialTachycardia',... |
|
|
14 |
% 'AtrialFlutter','AtrialFibrilation',}; |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
imds = imageDatastore(fullfile(rootFolder,categories),'LabelSource','foldernames'); |
|
|
18 |
imds.ReadFcn = @dsresize; |
|
|
19 |
|
|
|
20 |
label_count = countEachLabel(imds) |
|
|
21 |
minCount = min(label_count{:,2}) |
|
|
22 |
|
|
|
23 |
imds = splitEachLabel(imds, minCount, "randomized") |
|
|
24 |
|
|
|
25 |
net = alexnet(); |
|
|
26 |
|
|
|
27 |
[trainingData,testData] = splitEachLabel(imds, 0.2, 'randomized'); |
|
|
28 |
|
|
|
29 |
layersTransfer = net.Layers(2:end-3); |
|
|
30 |
layers = [ |
|
|
31 |
imageInputLayer([256 256 3]); |
|
|
32 |
layersTransfer |
|
|
33 |
fullyConnectedLayer(numel(categories)) |
|
|
34 |
softmaxLayer |
|
|
35 |
classificationLayer]; |
|
|
36 |
|
|
|
37 |
plot(layerGraph(layers)); |
|
|
38 |
|
|
|
39 |
options = trainingOptions("sgdm",... |
|
|
40 |
"ExecutionEnvironment","parallel",... |
|
|
41 |
"InitialLearnRate",1e-3,... |
|
|
42 |
"MaxEpochs", 40,... |
|
|
43 |
"Shuffle","every-epoch",... |
|
|
44 |
"Plots","training-progress",... |
|
|
45 |
"ValidationData",testData,... |
|
|
46 |
"ValidationFrequency",5); |
|
|
47 |
|
|
|
48 |
[net, traininfo] = trainNetwork(trainingData, layers, options); |
|
|
49 |
|
|
|
50 |
function data = dsresize(filename) |
|
|
51 |
data = imread(filename); |
|
|
52 |
data = data(:,:,min(1:3, end)); |
|
|
53 |
data = imresize(data, [256 256]); |
|
|
54 |
end |