|
a |
|
b/test_cnn.m |
|
|
1 |
close all |
|
|
2 |
|
|
|
3 |
rootFolder = fullfile('C:\Users\tripats\Documents\Biomedical Signal Analysis\Grad Single Project\ECG'); |
|
|
4 |
|
|
|
5 |
categories = {'SinusTachycardia',... |
|
|
6 |
'SinusRhythm','SinusBradycardia',... |
|
|
7 |
'AtrialFlutter','AtrialFibrilation'}; |
|
|
8 |
|
|
|
9 |
imds = imageDatastore(fullfile(rootFolder,categories),'LabelSource','foldernames'); |
|
|
10 |
imds.ReadFcn = @dsresize; |
|
|
11 |
|
|
|
12 |
label_count = countEachLabel(imds) |
|
|
13 |
minCount = min(label_count{:,2}); |
|
|
14 |
|
|
|
15 |
imds = splitEachLabel(imds, minCount, "randomized") |
|
|
16 |
|
|
|
17 |
net = googlenet(); |
|
|
18 |
|
|
|
19 |
[trainingData,testData] = splitEachLabel(imds, 0.2, 'randomized'); |
|
|
20 |
|
|
|
21 |
layersTransfer = net.Layers(2:end-3); |
|
|
22 |
layers = [ |
|
|
23 |
imageInputLayer([256 256 3]); |
|
|
24 |
layersTransfer |
|
|
25 |
fullyConnectedLayer(numel(categories)) |
|
|
26 |
softmaxLayer |
|
|
27 |
classificationLayer]; |
|
|
28 |
|
|
|
29 |
plot(layerGraph(layers)); |
|
|
30 |
|
|
|
31 |
options = trainingOptions("sgdm",... |
|
|
32 |
"ExecutionEnvironment","parallel",... |
|
|
33 |
"InitialLearnRate",1e-3,... |
|
|
34 |
"MaxEpochs", 40,... |
|
|
35 |
"Shuffle","every-epoch",... |
|
|
36 |
"Plots","training-progress",... |
|
|
37 |
"ValidationData",testData); |
|
|
38 |
|
|
|
39 |
[net, traininfo] = trainNetwork(trainingData, layers, opts); |
|
|
40 |
|
|
|
41 |
function data = dsresize(filename) |
|
|
42 |
data = imread(filename); |
|
|
43 |
data = data(:,:,min(1:3, end)); |
|
|
44 |
data = imresize(data, [256 256]); |
|
|
45 |
end |