aboutsummaryrefslogtreecommitdiff
path: root/src/backends/aclCommon
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/aclCommon')
-rw-r--r--src/backends/aclCommon/ArmComputeSubgraphUtils.hpp21
1 files changed, 21 insertions, 0 deletions
diff --git a/src/backends/aclCommon/ArmComputeSubgraphUtils.hpp b/src/backends/aclCommon/ArmComputeSubgraphUtils.hpp
index 90c0fd5890..a44acb0f54 100644
--- a/src/backends/aclCommon/ArmComputeSubgraphUtils.hpp
+++ b/src/backends/aclCommon/ArmComputeSubgraphUtils.hpp
@@ -356,4 +356,25 @@ void ReplaceLayers(OptimizationViews& optimizationViews,
optimizationViews.AddSubstitution({substitutionSubgraph, replacementSubgraph});
}
+//
+// Substitute a multi-layer subgraph with one new layer
+//
+template<typename LayerType>
+void ReplaceMultipleLayers(OptimizationViews& optimizationViews,
+ std::vector<IConnectableLayer*>& originalLayers,
+ LayerType* baseLayer,
+ const std::vector<SlotList> inputLayersSlotLists,
+ const std::vector<SlotList> outputLayersSlotLists)
+{
+ std::list<IConnectableLayer*> originalLayerList(originalLayers.begin(), originalLayers.end());
+
+ SubgraphView substitutionSubgraph(
+ std::move(originalLayerList),
+ CreateIInputsFromSlotLists<armnn::IConnectableLayer>(originalLayers, inputLayersSlotLists),
+ CreateIOutputsFromSlotLists<armnn::IConnectableLayer>(originalLayers, outputLayersSlotLists));
+ SubgraphView replacementSubgraph(baseLayer);
+
+ optimizationViews.AddSubstitution({substitutionSubgraph, replacementSubgraph});
+}
+
} // namespace armnn