Diff of /functions/replaceLayers.m [000000] .. [1422d3]

Switch to side-by-side view

--- a
+++ b/functions/replaceLayers.m
@@ -0,0 +1,30 @@
+function lgraph = replaceLayers(net, numClasses)
+
+%https://it.mathworks.com/help/deeplearning/ug/train-deep-learning-network-to-classify-new-images.html
+
+if isa(net,'SeriesNetwork') 
+  lgraph = layerGraph(net.Layers); 
+else
+  lgraph = layerGraph(net);
+end 
+
+[learnableLayer, classLayer] = findLayersToReplace(lgraph);
+
+if isa(learnableLayer, 'nnet.cnn.layer.FullyConnectedLayer')
+    newLearnableLayer = fullyConnectedLayer(numClasses, ...
+        'Name', 'new_fc', ...
+        'WeightLearnRateFactor', 20, ...
+        'BiasLearnRateFactor', 20);
+    
+elseif isa(learnableLayer, 'nnet.cnn.layer.Convolution2DLayer')
+    newLearnableLayer = convolution2dLayer(1, numClasses, ...
+        'Name', 'new_conv', ...
+        'WeightLearnRateFactor', 20, ...
+        'BiasLearnRateFactor', 20);
+end
+
+lgraph = replaceLayer(lgraph, learnableLayer.Name, newLearnableLayer);
+
+newClassLayer = classificationLayer('Name', 'new_classoutput');
+lgraph = replaceLayer(lgraph, classLayer.Name, newClassLayer);
+