[1422d3]: / functions / replaceLayers.m

Download this file

31 lines (22 with data), 992 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
function lgraph = replaceLayers(net, numClasses)
%https://it.mathworks.com/help/deeplearning/ug/train-deep-learning-network-to-classify-new-images.html
if isa(net,'SeriesNetwork')
lgraph = layerGraph(net.Layers);
else
lgraph = layerGraph(net);
end
[learnableLayer, classLayer] = findLayersToReplace(lgraph);
if isa(learnableLayer, 'nnet.cnn.layer.FullyConnectedLayer')
newLearnableLayer = fullyConnectedLayer(numClasses, ...
'Name', 'new_fc', ...
'WeightLearnRateFactor', 20, ...
'BiasLearnRateFactor', 20);
elseif isa(learnableLayer, 'nnet.cnn.layer.Convolution2DLayer')
newLearnableLayer = convolution2dLayer(1, numClasses, ...
'Name', 'new_conv', ...
'WeightLearnRateFactor', 20, ...
'BiasLearnRateFactor', 20);
end
lgraph = replaceLayer(lgraph, learnableLayer.Name, newLearnableLayer);
newClassLayer = classificationLayer('Name', 'new_classoutput');
lgraph = replaceLayer(lgraph, classLayer.Name, newClassLayer);