Diff of /CNN.m [000000] .. [24d692]

Switch to unified view

a b/CNN.m
1
%% Setup of CNN based classifier 
2
% Fist option direct input to 2D CNN
3
close all
4
5
rootFolder = fullfile('C:\Users\tripats\Documents\Biomedical Signal Analysis\Grad Single Project\ECG');
6
7
categories = {'SinusTachycardia',...
8
    'SinusRhythm','SinusBradycardia',...
9
     'AtrialFibrilation',};
10
% categories = {'SupraventricularTachycardia','SinusTachycardia',...
11
%     'SinusRhythm','SinusIrregularity','SinusBradycardia',...
12
%     'SinusAtriumtoAtrialWanderingRhythm', 'AtrioventricularReentrantTachycardia',...
13
%     'AtrioventricularNodeReentrantTachycardia','AtrialTachycardia',...
14
%      'AtrialFlutter','AtrialFibrilation',};
15
16
17
imds = imageDatastore(fullfile(rootFolder,categories),'LabelSource','foldernames');
18
imds.ReadFcn = @dsresize;
19
20
label_count = countEachLabel(imds)
21
minCount = min(label_count{:,2})
22
23
imds = splitEachLabel(imds, minCount, "randomized")
24
25
net = alexnet();
26
27
[trainingData,testData] = splitEachLabel(imds, 0.2, 'randomized');
28
29
layersTransfer = net.Layers(2:end-3);
30
layers = [
31
    imageInputLayer([256 256 3]);
32
    layersTransfer
33
    fullyConnectedLayer(numel(categories))
34
    softmaxLayer
35
    classificationLayer];
36
37
plot(layerGraph(layers));
38
39
options = trainingOptions("sgdm",...
40
    "ExecutionEnvironment","parallel",...
41
    "InitialLearnRate",1e-3,...
42
    "MaxEpochs", 40,...
43
    "Shuffle","every-epoch",...
44
    "Plots","training-progress",...
45
    "ValidationData",testData,...
46
    "ValidationFrequency",5);
47
48
[net, traininfo] = trainNetwork(trainingData, layers, options);
49
50
function data = dsresize(filename)
51
    data = imread(filename);
52
    data = data(:,:,min(1:3, end));
53
    data = imresize(data, [256 256]);
54
end