a b/test_cnn.m
1
close all
2
3
rootFolder = fullfile('C:\Users\tripats\Documents\Biomedical Signal Analysis\Grad Single Project\ECG');
4
5
categories = {'SinusTachycardia',...
6
    'SinusRhythm','SinusBradycardia',...
7
     'AtrialFlutter','AtrialFibrilation'};
8
9
imds = imageDatastore(fullfile(rootFolder,categories),'LabelSource','foldernames');
10
imds.ReadFcn = @dsresize;
11
12
label_count = countEachLabel(imds)
13
minCount = min(label_count{:,2});
14
15
imds = splitEachLabel(imds, minCount, "randomized")
16
17
net = googlenet();
18
19
[trainingData,testData] = splitEachLabel(imds, 0.2, 'randomized');
20
21
layersTransfer = net.Layers(2:end-3);
22
layers = [
23
    imageInputLayer([256 256 3]);
24
    layersTransfer
25
    fullyConnectedLayer(numel(categories))
26
    softmaxLayer
27
    classificationLayer];
28
29
plot(layerGraph(layers));
30
31
options = trainingOptions("sgdm",...
32
    "ExecutionEnvironment","parallel",...
33
    "InitialLearnRate",1e-3,...
34
    "MaxEpochs", 40,...
35
    "Shuffle","every-epoch",...
36
    "Plots","training-progress",...
37
    "ValidationData",testData);
38
39
[net, traininfo] = trainNetwork(trainingData, layers, opts);
40
41
function data = dsresize(filename)
42
    data = imread(filename);
43
    data = data(:,:,min(1:3, end));
44
    data = imresize(data, [256 256]);
45
end