Switch to unified view

a b/functions/findLayersToReplace.m
1
% findLayersToReplace(lgraph) finds the single classification layer and the
2
% preceding learnable (fully connected or convolutional) layer of the layer
3
% graph lgraph.
4
function [learnableLayer,classLayer] = findLayersToReplace(lgraph)
5
6
if ~isa(lgraph,'nnet.cnn.LayerGraph')
7
    error('Argument must be a LayerGraph object.')
8
end
9
10
% Get source, destination, and layer names.
11
src = string(lgraph.Connections.Source);
12
dst = string(lgraph.Connections.Destination);
13
layerNames = string({lgraph.Layers.Name}');
14
15
% Find the classification layer. The layer graph must have a single
16
% classification layer.
17
isClassificationLayer = arrayfun(@(l) ...
18
    (isa(l,'nnet.cnn.layer.ClassificationOutputLayer')|isa(l,'nnet.layer.ClassificationLayer')), ...
19
    lgraph.Layers);
20
21
if sum(isClassificationLayer) ~= 1
22
    error('Layer graph must have a single classification layer.')
23
end
24
classLayer = lgraph.Layers(isClassificationLayer);
25
26
27
% Traverse the layer graph in reverse starting from the classification
28
% layer. If the network branches, throw an error.
29
currentLayerIdx = find(isClassificationLayer);
30
while true
31
    
32
    if numel(currentLayerIdx) ~= 1
33
        error('Layer graph must have a single learnable layer preceding the classification layer.')
34
    end
35
    
36
    currentLayerType = class(lgraph.Layers(currentLayerIdx));
37
    isLearnableLayer = ismember(currentLayerType, ...
38
        ['nnet.cnn.layer.FullyConnectedLayer','nnet.cnn.layer.Convolution2DLayer']);
39
    
40
    if isLearnableLayer
41
        learnableLayer =  lgraph.Layers(currentLayerIdx);
42
        return
43
    end
44
    
45
    currentDstIdx = find(layerNames(currentLayerIdx) == dst);
46
    currentLayerIdx = find(src(currentDstIdx) == layerNames);
47
    
48
end
49
50
end
51