|
a |
|
b/functions/findLayersToReplace.m |
|
|
1 |
% findLayersToReplace(lgraph) finds the single classification layer and the |
|
|
2 |
% preceding learnable (fully connected or convolutional) layer of the layer |
|
|
3 |
% graph lgraph. |
|
|
4 |
function [learnableLayer,classLayer] = findLayersToReplace(lgraph) |
|
|
5 |
|
|
|
6 |
if ~isa(lgraph,'nnet.cnn.LayerGraph') |
|
|
7 |
error('Argument must be a LayerGraph object.') |
|
|
8 |
end |
|
|
9 |
|
|
|
10 |
% Get source, destination, and layer names. |
|
|
11 |
src = string(lgraph.Connections.Source); |
|
|
12 |
dst = string(lgraph.Connections.Destination); |
|
|
13 |
layerNames = string({lgraph.Layers.Name}'); |
|
|
14 |
|
|
|
15 |
% Find the classification layer. The layer graph must have a single |
|
|
16 |
% classification layer. |
|
|
17 |
isClassificationLayer = arrayfun(@(l) ... |
|
|
18 |
(isa(l,'nnet.cnn.layer.ClassificationOutputLayer')|isa(l,'nnet.layer.ClassificationLayer')), ... |
|
|
19 |
lgraph.Layers); |
|
|
20 |
|
|
|
21 |
if sum(isClassificationLayer) ~= 1 |
|
|
22 |
error('Layer graph must have a single classification layer.') |
|
|
23 |
end |
|
|
24 |
classLayer = lgraph.Layers(isClassificationLayer); |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
% Traverse the layer graph in reverse starting from the classification |
|
|
28 |
% layer. If the network branches, throw an error. |
|
|
29 |
currentLayerIdx = find(isClassificationLayer); |
|
|
30 |
while true |
|
|
31 |
|
|
|
32 |
if numel(currentLayerIdx) ~= 1 |
|
|
33 |
error('Layer graph must have a single learnable layer preceding the classification layer.') |
|
|
34 |
end |
|
|
35 |
|
|
|
36 |
currentLayerType = class(lgraph.Layers(currentLayerIdx)); |
|
|
37 |
isLearnableLayer = ismember(currentLayerType, ... |
|
|
38 |
['nnet.cnn.layer.FullyConnectedLayer','nnet.cnn.layer.Convolution2DLayer']); |
|
|
39 |
|
|
|
40 |
if isLearnableLayer |
|
|
41 |
learnableLayer = lgraph.Layers(currentLayerIdx); |
|
|
42 |
return |
|
|
43 |
end |
|
|
44 |
|
|
|
45 |
currentDstIdx = find(layerNames(currentLayerIdx) == dst); |
|
|
46 |
currentLayerIdx = find(src(currentDstIdx) == layerNames); |
|
|
47 |
|
|
|
48 |
end |
|
|
49 |
|
|
|
50 |
end |
|
|
51 |
|