aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Graph.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Graph.cpp')
-rw-r--r--src/armnn/Graph.cpp59
1 files changed, 59 insertions, 0 deletions
diff --git a/src/armnn/Graph.cpp b/src/armnn/Graph.cpp
index 83d82a5ffe..831d85e404 100644
--- a/src/armnn/Graph.cpp
+++ b/src/armnn/Graph.cpp
@@ -297,6 +297,65 @@ void Graph::AddCopyLayers()
}
}
+void Graph::SubstituteSubGraph(std::unique_ptr<SubGraph> subGraph, IConnectableLayer* substituteLayer)
+{
+ BOOST_ASSERT(subGraph != nullptr);
+ BOOST_ASSERT(substituteLayer != nullptr);
+
+ ReplaceSubGraphConnections(*subGraph, substituteLayer);
+ EraseSubGraphLayers(*subGraph);
+}
+
+void Graph::ReplaceSubGraphConnections(const SubGraph& subGraph, IConnectableLayer* substituteLayer)
+{
+ BOOST_ASSERT(substituteLayer != nullptr);
+ BOOST_ASSERT_MSG(std::find(m_Layers.begin(), m_Layers.end(), substituteLayer) != m_Layers.end(),
+ "Substitue layer is not a member of graph");
+
+ const SubGraph::InputSlots& subGraphInputSlots = subGraph.GetInputSlots();
+ const SubGraph::OutputSlots& subGraphOutputSlots = subGraph.GetOutputSlots();
+
+ const unsigned int numInputSlots = boost::numeric_cast<unsigned int>(subGraphInputSlots.size());
+ const unsigned int numOutputSlots = boost::numeric_cast<unsigned int>(subGraphOutputSlots.size());
+
+ BOOST_ASSERT(numInputSlots == substituteLayer->GetNumInputSlots());
+ BOOST_ASSERT(numOutputSlots == substituteLayer->GetNumOutputSlots());
+
+ // Disconnect the sub-graph and replace it with the substitute layer
+ // Step 1: process input slots
+ for(unsigned int inputSlotIdx = 0u; inputSlotIdx < numInputSlots; ++inputSlotIdx)
+ {
+ InputSlot* subGraphInputSlot = subGraphInputSlots.at(inputSlotIdx);
+ BOOST_ASSERT(subGraphInputSlot != nullptr);
+
+ IOutputSlot* connectedOutputSlot = subGraphInputSlot->GetConnection();
+ BOOST_ASSERT(connectedOutputSlot != nullptr);
+ connectedOutputSlot->Disconnect(*subGraphInputSlot);
+
+ IInputSlot& substituteInputSlot = substituteLayer->GetInputSlot(inputSlotIdx);
+ connectedOutputSlot->Connect(substituteInputSlot);
+ }
+
+ // Step 2: process output slots
+ for(unsigned int outputSlotIdx = 0u; outputSlotIdx < numOutputSlots; ++outputSlotIdx)
+ {
+ OutputSlot* subGraphOutputSlot = subGraphOutputSlots.at(outputSlotIdx);
+ BOOST_ASSERT(subGraphOutputSlot != nullptr);
+
+ OutputSlot* substituteOutputSlot = boost::polymorphic_downcast<OutputSlot*>(
+ &substituteLayer->GetOutputSlot(outputSlotIdx));
+ subGraphOutputSlot->MoveAllConnections(*substituteOutputSlot);
+ }
+}
+
+void Graph::EraseSubGraphLayers(const SubGraph &subGraph)
+{
+ for (auto layer : subGraph.GetLayers())
+ {
+ EraseLayer(layer);
+ }
+}
+
void Graph::InferTensorInfos()
{
for (auto&& layer : TopologicalSort())