--- a +++ b/functions/replaceLayers.m @@ -0,0 +1,30 @@ +function lgraph = replaceLayers(net, numClasses) + +%https://it.mathworks.com/help/deeplearning/ug/train-deep-learning-network-to-classify-new-images.html + +if isa(net,'SeriesNetwork') + lgraph = layerGraph(net.Layers); +else + lgraph = layerGraph(net); +end + +[learnableLayer, classLayer] = findLayersToReplace(lgraph); + +if isa(learnableLayer, 'nnet.cnn.layer.FullyConnectedLayer') + newLearnableLayer = fullyConnectedLayer(numClasses, ... + 'Name', 'new_fc', ... + 'WeightLearnRateFactor', 20, ... + 'BiasLearnRateFactor', 20); + +elseif isa(learnableLayer, 'nnet.cnn.layer.Convolution2DLayer') + newLearnableLayer = convolution2dLayer(1, numClasses, ... + 'Name', 'new_conv', ... + 'WeightLearnRateFactor', 20, ... + 'BiasLearnRateFactor', 20); +end + +lgraph = replaceLayer(lgraph, learnableLayer.Name, newLearnableLayer); + +newClassLayer = classificationLayer('Name', 'new_classoutput'); +lgraph = replaceLayer(lgraph, classLayer.Name, newClassLayer); +