Diff of /functions/replaceLayers.m [000000] .. [1422d3]

Switch to unified view

a b/functions/replaceLayers.m
1
function lgraph = replaceLayers(net, numClasses)
2
3
%https://it.mathworks.com/help/deeplearning/ug/train-deep-learning-network-to-classify-new-images.html
4
5
if isa(net,'SeriesNetwork') 
6
  lgraph = layerGraph(net.Layers); 
7
else
8
  lgraph = layerGraph(net);
9
end 
10
11
[learnableLayer, classLayer] = findLayersToReplace(lgraph);
12
13
if isa(learnableLayer, 'nnet.cnn.layer.FullyConnectedLayer')
14
    newLearnableLayer = fullyConnectedLayer(numClasses, ...
15
        'Name', 'new_fc', ...
16
        'WeightLearnRateFactor', 20, ...
17
        'BiasLearnRateFactor', 20);
18
    
19
elseif isa(learnableLayer, 'nnet.cnn.layer.Convolution2DLayer')
20
    newLearnableLayer = convolution2dLayer(1, numClasses, ...
21
        'Name', 'new_conv', ...
22
        'WeightLearnRateFactor', 20, ...
23
        'BiasLearnRateFactor', 20);
24
end
25
26
lgraph = replaceLayer(lgraph, learnableLayer.Name, newLearnableLayer);
27
28
newClassLayer = classificationLayer('Name', 'new_classoutput');
29
lgraph = replaceLayer(lgraph, classLayer.Name, newClassLayer);
30