--- a +++ b/CNN.m @@ -0,0 +1,54 @@ +%% Setup of CNN based classifier +% Fist option direct input to 2D CNN +close all + +rootFolder = fullfile('C:\Users\tripats\Documents\Biomedical Signal Analysis\Grad Single Project\ECG'); + +categories = {'SinusTachycardia',... + 'SinusRhythm','SinusBradycardia',... + 'AtrialFibrilation',}; +% categories = {'SupraventricularTachycardia','SinusTachycardia',... +% 'SinusRhythm','SinusIrregularity','SinusBradycardia',... +% 'SinusAtriumtoAtrialWanderingRhythm', 'AtrioventricularReentrantTachycardia',... +% 'AtrioventricularNodeReentrantTachycardia','AtrialTachycardia',... +% 'AtrialFlutter','AtrialFibrilation',}; + + +imds = imageDatastore(fullfile(rootFolder,categories),'LabelSource','foldernames'); +imds.ReadFcn = @dsresize; + +label_count = countEachLabel(imds) +minCount = min(label_count{:,2}) + +imds = splitEachLabel(imds, minCount, "randomized") + +net = alexnet(); + +[trainingData,testData] = splitEachLabel(imds, 0.2, 'randomized'); + +layersTransfer = net.Layers(2:end-3); +layers = [ + imageInputLayer([256 256 3]); + layersTransfer + fullyConnectedLayer(numel(categories)) + softmaxLayer + classificationLayer]; + +plot(layerGraph(layers)); + +options = trainingOptions("sgdm",... + "ExecutionEnvironment","parallel",... + "InitialLearnRate",1e-3,... + "MaxEpochs", 40,... + "Shuffle","every-epoch",... + "Plots","training-progress",... + "ValidationData",testData,... + "ValidationFrequency",5); + +[net, traininfo] = trainNetwork(trainingData, layers, options); + +function data = dsresize(filename) + data = imread(filename); + data = data(:,:,min(1:3, end)); + data = imresize(data, [256 256]); +end \ No newline at end of file