--- a +++ b/functions/findLayersToReplace.m @@ -0,0 +1,51 @@ +% findLayersToReplace(lgraph) finds the single classification layer and the +% preceding learnable (fully connected or convolutional) layer of the layer +% graph lgraph. +function [learnableLayer,classLayer] = findLayersToReplace(lgraph) + +if ~isa(lgraph,'nnet.cnn.LayerGraph') + error('Argument must be a LayerGraph object.') +end + +% Get source, destination, and layer names. +src = string(lgraph.Connections.Source); +dst = string(lgraph.Connections.Destination); +layerNames = string({lgraph.Layers.Name}'); + +% Find the classification layer. The layer graph must have a single +% classification layer. +isClassificationLayer = arrayfun(@(l) ... + (isa(l,'nnet.cnn.layer.ClassificationOutputLayer')|isa(l,'nnet.layer.ClassificationLayer')), ... + lgraph.Layers); + +if sum(isClassificationLayer) ~= 1 + error('Layer graph must have a single classification layer.') +end +classLayer = lgraph.Layers(isClassificationLayer); + + +% Traverse the layer graph in reverse starting from the classification +% layer. If the network branches, throw an error. +currentLayerIdx = find(isClassificationLayer); +while true + + if numel(currentLayerIdx) ~= 1 + error('Layer graph must have a single learnable layer preceding the classification layer.') + end + + currentLayerType = class(lgraph.Layers(currentLayerIdx)); + isLearnableLayer = ismember(currentLayerType, ... + ['nnet.cnn.layer.FullyConnectedLayer','nnet.cnn.layer.Convolution2DLayer']); + + if isLearnableLayer + learnableLayer = lgraph.Layers(currentLayerIdx); + return + end + + currentDstIdx = find(layerNames(currentLayerIdx) == dst); + currentLayerIdx = find(src(currentDstIdx) == layerNames); + +end + +end +