|
a |
|
b/functions/replaceLayers.m |
|
|
1 |
function lgraph = replaceLayers(net, numClasses) |
|
|
2 |
|
|
|
3 |
%https://it.mathworks.com/help/deeplearning/ug/train-deep-learning-network-to-classify-new-images.html |
|
|
4 |
|
|
|
5 |
if isa(net,'SeriesNetwork') |
|
|
6 |
lgraph = layerGraph(net.Layers); |
|
|
7 |
else |
|
|
8 |
lgraph = layerGraph(net); |
|
|
9 |
end |
|
|
10 |
|
|
|
11 |
[learnableLayer, classLayer] = findLayersToReplace(lgraph); |
|
|
12 |
|
|
|
13 |
if isa(learnableLayer, 'nnet.cnn.layer.FullyConnectedLayer') |
|
|
14 |
newLearnableLayer = fullyConnectedLayer(numClasses, ... |
|
|
15 |
'Name', 'new_fc', ... |
|
|
16 |
'WeightLearnRateFactor', 20, ... |
|
|
17 |
'BiasLearnRateFactor', 20); |
|
|
18 |
|
|
|
19 |
elseif isa(learnableLayer, 'nnet.cnn.layer.Convolution2DLayer') |
|
|
20 |
newLearnableLayer = convolution2dLayer(1, numClasses, ... |
|
|
21 |
'Name', 'new_conv', ... |
|
|
22 |
'WeightLearnRateFactor', 20, ... |
|
|
23 |
'BiasLearnRateFactor', 20); |
|
|
24 |
end |
|
|
25 |
|
|
|
26 |
lgraph = replaceLayer(lgraph, learnableLayer.Name, newLearnableLayer); |
|
|
27 |
|
|
|
28 |
newClassLayer = classificationLayer('Name', 'new_classoutput'); |
|
|
29 |
lgraph = replaceLayer(lgraph, classLayer.Name, newClassLayer); |
|
|
30 |
|