Switch to side-by-side view

--- a
+++ b/functions/findLayersToReplace.m
@@ -0,0 +1,51 @@
+% findLayersToReplace(lgraph) finds the single classification layer and the
+% preceding learnable (fully connected or convolutional) layer of the layer
+% graph lgraph.
+function [learnableLayer,classLayer] = findLayersToReplace(lgraph)
+
+if ~isa(lgraph,'nnet.cnn.LayerGraph')
+    error('Argument must be a LayerGraph object.')
+end
+
+% Get source, destination, and layer names.
+src = string(lgraph.Connections.Source);
+dst = string(lgraph.Connections.Destination);
+layerNames = string({lgraph.Layers.Name}');
+
+% Find the classification layer. The layer graph must have a single
+% classification layer.
+isClassificationLayer = arrayfun(@(l) ...
+    (isa(l,'nnet.cnn.layer.ClassificationOutputLayer')|isa(l,'nnet.layer.ClassificationLayer')), ...
+    lgraph.Layers);
+
+if sum(isClassificationLayer) ~= 1
+    error('Layer graph must have a single classification layer.')
+end
+classLayer = lgraph.Layers(isClassificationLayer);
+
+
+% Traverse the layer graph in reverse starting from the classification
+% layer. If the network branches, throw an error.
+currentLayerIdx = find(isClassificationLayer);
+while true
+    
+    if numel(currentLayerIdx) ~= 1
+        error('Layer graph must have a single learnable layer preceding the classification layer.')
+    end
+    
+    currentLayerType = class(lgraph.Layers(currentLayerIdx));
+    isLearnableLayer = ismember(currentLayerType, ...
+        ['nnet.cnn.layer.FullyConnectedLayer','nnet.cnn.layer.Convolution2DLayer']);
+    
+    if isLearnableLayer
+        learnableLayer =  lgraph.Layers(currentLayerIdx);
+        return
+    end
+    
+    currentDstIdx = find(layerNames(currentLayerIdx) == dst);
+    currentLayerIdx = find(src(currentDstIdx) == layerNames);
+    
+end
+
+end
+