diff options
Diffstat (limited to 'src/armnn/Graph.cpp')
-rw-r--r-- | src/armnn/Graph.cpp | 63 |
1 files changed, 46 insertions, 17 deletions
diff --git a/src/armnn/Graph.cpp b/src/armnn/Graph.cpp index 831d85e404..1bd4fbd85a 100644 --- a/src/armnn/Graph.cpp +++ b/src/armnn/Graph.cpp @@ -2,7 +2,9 @@ // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // + #include "Graph.hpp" +#include "SubGraph.hpp" #include "LayersFwd.hpp" #include <armnn/Utils.hpp> @@ -17,7 +19,6 @@ #include <DotSerializer.hpp> #include <sstream> - namespace armnn { @@ -238,7 +239,7 @@ const Graph& Graph::TopologicalSort() const it->ResetPriority(); } - auto compareLayerPriority = [](const LayersList::value_type& layerA, const LayersList::value_type& layerB) + auto compareLayerPriority = [](const LayerList::value_type& layerA, const LayerList::value_type& layerB) { return layerA->GetPriority() < layerB->GetPriority(); }; @@ -306,44 +307,72 @@ void Graph::SubstituteSubGraph(std::unique_ptr<SubGraph> subGraph, IConnectableL EraseSubGraphLayers(*subGraph); } +void Graph::SubstituteSubGraph(std::unique_ptr<SubGraph> subGraph, const SubGraph& substituteSubGraph) +{ + BOOST_ASSERT(subGraph); + + ReplaceSubGraphConnections(*subGraph, substituteSubGraph); + EraseSubGraphLayers(*subGraph); +} + void Graph::ReplaceSubGraphConnections(const SubGraph& subGraph, IConnectableLayer* substituteLayer) { BOOST_ASSERT(substituteLayer != nullptr); BOOST_ASSERT_MSG(std::find(m_Layers.begin(), m_Layers.end(), substituteLayer) != m_Layers.end(), - "Substitue layer is not a member of graph"); + "Substitute layer is not a member of graph"); + + SubGraph substituteSubGraph(subGraph, substituteLayer); + ReplaceSubGraphConnections(subGraph, substituteSubGraph); +} + +void Graph::ReplaceSubGraphConnections(const SubGraph& subGraph, const SubGraph& substituteSubGraph) +{ + BOOST_ASSERT_MSG(!substituteSubGraph.GetLayers().empty(), "New sub-graph used for substitution must not be empty"); + + const SubGraph::Layers& substituteSubGraphLayers = substituteSubGraph.GetLayers(); + std::for_each(substituteSubGraphLayers.begin(), substituteSubGraphLayers.end(), [&](Layer* layer) + { + BOOST_ASSERT_MSG(std::find(m_Layers.begin(), m_Layers.end(), layer) != m_Layers.end(), + "Substitute layer is not a member of graph"); + }); const SubGraph::InputSlots& subGraphInputSlots = subGraph.GetInputSlots(); const SubGraph::OutputSlots& subGraphOutputSlots = subGraph.GetOutputSlots(); - const unsigned int numInputSlots = boost::numeric_cast<unsigned int>(subGraphInputSlots.size()); - const unsigned int numOutputSlots = boost::numeric_cast<unsigned int>(subGraphOutputSlots.size()); + unsigned int subGraphNumInputSlots = boost::numeric_cast<unsigned int>(subGraphInputSlots.size()); + unsigned int subGraphNumOutputSlots = boost::numeric_cast<unsigned int>(subGraphOutputSlots.size()); + + const SubGraph::InputSlots& substituteSubGraphInputSlots = substituteSubGraph.GetInputSlots(); + const SubGraph::OutputSlots& substituteSubGraphOutputSlots = substituteSubGraph.GetOutputSlots(); + + BOOST_ASSERT(subGraphNumInputSlots == substituteSubGraphInputSlots.size()); + BOOST_ASSERT(subGraphNumOutputSlots == substituteSubGraphOutputSlots.size()); - BOOST_ASSERT(numInputSlots == substituteLayer->GetNumInputSlots()); - BOOST_ASSERT(numOutputSlots == substituteLayer->GetNumOutputSlots()); + // Disconnect the sub-graph and replace it with the substitute sub-graph - // Disconnect the sub-graph and replace it with the substitute layer // Step 1: process input slots - for(unsigned int inputSlotIdx = 0u; inputSlotIdx < numInputSlots; ++inputSlotIdx) + for (unsigned int inputSlotIdx = 0; inputSlotIdx < subGraphNumInputSlots; ++inputSlotIdx) { InputSlot* subGraphInputSlot = subGraphInputSlots.at(inputSlotIdx); - BOOST_ASSERT(subGraphInputSlot != nullptr); + BOOST_ASSERT(subGraphInputSlot); IOutputSlot* connectedOutputSlot = subGraphInputSlot->GetConnection(); - BOOST_ASSERT(connectedOutputSlot != nullptr); + BOOST_ASSERT(connectedOutputSlot); connectedOutputSlot->Disconnect(*subGraphInputSlot); - IInputSlot& substituteInputSlot = substituteLayer->GetInputSlot(inputSlotIdx); - connectedOutputSlot->Connect(substituteInputSlot); + IInputSlot* substituteInputSlot = substituteSubGraphInputSlots.at(inputSlotIdx); + BOOST_ASSERT(substituteInputSlot); + connectedOutputSlot->Connect(*substituteInputSlot); } // Step 2: process output slots - for(unsigned int outputSlotIdx = 0u; outputSlotIdx < numOutputSlots; ++outputSlotIdx) + for(unsigned int outputSlotIdx = 0; outputSlotIdx < subGraphNumOutputSlots; ++outputSlotIdx) { OutputSlot* subGraphOutputSlot = subGraphOutputSlots.at(outputSlotIdx); - BOOST_ASSERT(subGraphOutputSlot != nullptr); + BOOST_ASSERT(subGraphOutputSlot); - OutputSlot* substituteOutputSlot = boost::polymorphic_downcast<OutputSlot*>( - &substituteLayer->GetOutputSlot(outputSlotIdx)); + OutputSlot* substituteOutputSlot = substituteSubGraphOutputSlots.at(outputSlotIdx); + BOOST_ASSERT(substituteOutputSlot); subGraphOutputSlot->MoveAllConnections(*substituteOutputSlot); } } |