[24d692]: / CNN.m

Download this file

54 lines (41 with data), 1.7 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
%% Setup of CNN based classifier
% Fist option direct input to 2D CNN
close all
rootFolder = fullfile('C:\Users\tripats\Documents\Biomedical Signal Analysis\Grad Single Project\ECG');
categories = {'SinusTachycardia',...
'SinusRhythm','SinusBradycardia',...
'AtrialFibrilation',};
% categories = {'SupraventricularTachycardia','SinusTachycardia',...
% 'SinusRhythm','SinusIrregularity','SinusBradycardia',...
% 'SinusAtriumtoAtrialWanderingRhythm', 'AtrioventricularReentrantTachycardia',...
% 'AtrioventricularNodeReentrantTachycardia','AtrialTachycardia',...
% 'AtrialFlutter','AtrialFibrilation',};
imds = imageDatastore(fullfile(rootFolder,categories),'LabelSource','foldernames');
imds.ReadFcn = @dsresize;
label_count = countEachLabel(imds)
minCount = min(label_count{:,2})
imds = splitEachLabel(imds, minCount, "randomized")
net = alexnet();
[trainingData,testData] = splitEachLabel(imds, 0.2, 'randomized');
layersTransfer = net.Layers(2:end-3);
layers = [
imageInputLayer([256 256 3]);
layersTransfer
fullyConnectedLayer(numel(categories))
softmaxLayer
classificationLayer];
plot(layerGraph(layers));
options = trainingOptions("sgdm",...
"ExecutionEnvironment","parallel",...
"InitialLearnRate",1e-3,...
"MaxEpochs", 40,...
"Shuffle","every-epoch",...
"Plots","training-progress",...
"ValidationData",testData,...
"ValidationFrequency",5);
[net, traininfo] = trainNetwork(trainingData, layers, options);
function data = dsresize(filename)
data = imread(filename);
data = data(:,:,min(1:3, end));
data = imresize(data, [256 256]);
end