Diff of /CNN.m [000000] .. [24d692]

Switch to side-by-side view

--- a
+++ b/CNN.m
@@ -0,0 +1,54 @@
+%% Setup of CNN based classifier 
+% Fist option direct input to 2D CNN
+close all
+
+rootFolder = fullfile('C:\Users\tripats\Documents\Biomedical Signal Analysis\Grad Single Project\ECG');
+
+categories = {'SinusTachycardia',...
+    'SinusRhythm','SinusBradycardia',...
+     'AtrialFibrilation',};
+% categories = {'SupraventricularTachycardia','SinusTachycardia',...
+%     'SinusRhythm','SinusIrregularity','SinusBradycardia',...
+%     'SinusAtriumtoAtrialWanderingRhythm', 'AtrioventricularReentrantTachycardia',...
+%     'AtrioventricularNodeReentrantTachycardia','AtrialTachycardia',...
+%      'AtrialFlutter','AtrialFibrilation',};
+
+
+imds = imageDatastore(fullfile(rootFolder,categories),'LabelSource','foldernames');
+imds.ReadFcn = @dsresize;
+
+label_count = countEachLabel(imds)
+minCount = min(label_count{:,2})
+
+imds = splitEachLabel(imds, minCount, "randomized")
+
+net = alexnet();
+
+[trainingData,testData] = splitEachLabel(imds, 0.2, 'randomized');
+
+layersTransfer = net.Layers(2:end-3);
+layers = [
+    imageInputLayer([256 256 3]);
+    layersTransfer
+    fullyConnectedLayer(numel(categories))
+    softmaxLayer
+    classificationLayer];
+
+plot(layerGraph(layers));
+
+options = trainingOptions("sgdm",...
+    "ExecutionEnvironment","parallel",...
+    "InitialLearnRate",1e-3,...
+    "MaxEpochs", 40,...
+    "Shuffle","every-epoch",...
+    "Plots","training-progress",...
+    "ValidationData",testData,...
+    "ValidationFrequency",5);
+
+[net, traininfo] = trainNetwork(trainingData, layers, options);
+
+function data = dsresize(filename)
+    data = imread(filename);
+    data = data(:,:,min(1:3, end));
+    data = imresize(data, [256 256]);
+end
\ No newline at end of file