diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-01-24 14:06:23 +0000 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-01-30 14:03:28 +0000 |
commit | adddddb6cbcb777d92a8c464c9ad0cb9aecc76a3 (patch) | |
tree | b15de32bf9f8612f66e1ae23d2f8009e80e7d0e6 /src/armnn/Graph.cpp | |
parent | d089b74bebbcc8518fb0f4eacb7e6569ae170199 (diff) | |
download | armnn-adddddb6cbcb777d92a8c464c9ad0cb9aecc76a3.tar.gz |
IVGCVSW-2458 Refactor the Optimize function (Network.cpp) so that
subgraphs are optimized by the backends
* Added a new method OptimizeSubGraph to the backend interface
* Refactored the Optimize function so that the backend-specific
optimization is performed by the backend itself (through the new
OptimizeSubGraph interface method)
* Added a new ApplyBackendOptimizations function to apply the new
changes
* Added some new convenient constructors to the SubGraph class
* Added AddLayer method and a pointer to the parent graph to the
SubGraph class
* Updated the sub-graph unit tests to match the changes
* Added SelectSubGraphs and ReplaceSubGraphConnections overloads
that work with sub-graphs
* Removed unused code and minor refactoring where necessary
Change-Id: I46181794c6a9e3b10558944f804e06a8f693a6d0
Diffstat (limited to 'src/armnn/Graph.cpp')
-rw-r--r-- | src/armnn/Graph.cpp | 63 |
1 files changed, 46 insertions, 17 deletions
diff --git a/src/armnn/Graph.cpp b/src/armnn/Graph.cpp index 831d85e404..1bd4fbd85a 100644 --- a/src/armnn/Graph.cpp +++ b/src/armnn/Graph.cpp @@ -2,7 +2,9 @@ // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // + #include "Graph.hpp" +#include "SubGraph.hpp" #include "LayersFwd.hpp" #include <armnn/Utils.hpp> @@ -17,7 +19,6 @@ #include <DotSerializer.hpp> #include <sstream> - namespace armnn { @@ -238,7 +239,7 @@ const Graph& Graph::TopologicalSort() const it->ResetPriority(); } - auto compareLayerPriority = [](const LayersList::value_type& layerA, const LayersList::value_type& layerB) + auto compareLayerPriority = [](const LayerList::value_type& layerA, const LayerList::value_type& layerB) { return layerA->GetPriority() < layerB->GetPriority(); }; @@ -306,44 +307,72 @@ void Graph::SubstituteSubGraph(std::unique_ptr<SubGraph> subGraph, IConnectableL EraseSubGraphLayers(*subGraph); } +void Graph::SubstituteSubGraph(std::unique_ptr<SubGraph> subGraph, const SubGraph& substituteSubGraph) +{ + BOOST_ASSERT(subGraph); + + ReplaceSubGraphConnections(*subGraph, substituteSubGraph); + EraseSubGraphLayers(*subGraph); +} + void Graph::ReplaceSubGraphConnections(const SubGraph& subGraph, IConnectableLayer* substituteLayer) { BOOST_ASSERT(substituteLayer != nullptr); BOOST_ASSERT_MSG(std::find(m_Layers.begin(), m_Layers.end(), substituteLayer) != m_Layers.end(), - "Substitue layer is not a member of graph"); + "Substitute layer is not a member of graph"); + + SubGraph substituteSubGraph(subGraph, substituteLayer); + ReplaceSubGraphConnections(subGraph, substituteSubGraph); +} + +void Graph::ReplaceSubGraphConnections(const SubGraph& subGraph, const SubGraph& substituteSubGraph) +{ + BOOST_ASSERT_MSG(!substituteSubGraph.GetLayers().empty(), "New sub-graph used for substitution must not be empty"); + + const SubGraph::Layers& substituteSubGraphLayers = substituteSubGraph.GetLayers(); + std::for_each(substituteSubGraphLayers.begin(), substituteSubGraphLayers.end(), [&](Layer* layer) + { + BOOST_ASSERT_MSG(std::find(m_Layers.begin(), m_Layers.end(), layer) != m_Layers.end(), + "Substitute layer is not a member of graph"); + }); const SubGraph::InputSlots& subGraphInputSlots = subGraph.GetInputSlots(); const SubGraph::OutputSlots& subGraphOutputSlots = subGraph.GetOutputSlots(); - const unsigned int numInputSlots = boost::numeric_cast<unsigned int>(subGraphInputSlots.size()); - const unsigned int numOutputSlots = boost::numeric_cast<unsigned int>(subGraphOutputSlots.size()); + unsigned int subGraphNumInputSlots = boost::numeric_cast<unsigned int>(subGraphInputSlots.size()); + unsigned int subGraphNumOutputSlots = boost::numeric_cast<unsigned int>(subGraphOutputSlots.size()); + + const SubGraph::InputSlots& substituteSubGraphInputSlots = substituteSubGraph.GetInputSlots(); + const SubGraph::OutputSlots& substituteSubGraphOutputSlots = substituteSubGraph.GetOutputSlots(); + + BOOST_ASSERT(subGraphNumInputSlots == substituteSubGraphInputSlots.size()); + BOOST_ASSERT(subGraphNumOutputSlots == substituteSubGraphOutputSlots.size()); - BOOST_ASSERT(numInputSlots == substituteLayer->GetNumInputSlots()); - BOOST_ASSERT(numOutputSlots == substituteLayer->GetNumOutputSlots()); + // Disconnect the sub-graph and replace it with the substitute sub-graph - // Disconnect the sub-graph and replace it with the substitute layer // Step 1: process input slots - for(unsigned int inputSlotIdx = 0u; inputSlotIdx < numInputSlots; ++inputSlotIdx) + for (unsigned int inputSlotIdx = 0; inputSlotIdx < subGraphNumInputSlots; ++inputSlotIdx) { InputSlot* subGraphInputSlot = subGraphInputSlots.at(inputSlotIdx); - BOOST_ASSERT(subGraphInputSlot != nullptr); + BOOST_ASSERT(subGraphInputSlot); IOutputSlot* connectedOutputSlot = subGraphInputSlot->GetConnection(); - BOOST_ASSERT(connectedOutputSlot != nullptr); + BOOST_ASSERT(connectedOutputSlot); connectedOutputSlot->Disconnect(*subGraphInputSlot); - IInputSlot& substituteInputSlot = substituteLayer->GetInputSlot(inputSlotIdx); - connectedOutputSlot->Connect(substituteInputSlot); + IInputSlot* substituteInputSlot = substituteSubGraphInputSlots.at(inputSlotIdx); + BOOST_ASSERT(substituteInputSlot); + connectedOutputSlot->Connect(*substituteInputSlot); } // Step 2: process output slots - for(unsigned int outputSlotIdx = 0u; outputSlotIdx < numOutputSlots; ++outputSlotIdx) + for(unsigned int outputSlotIdx = 0; outputSlotIdx < subGraphNumOutputSlots; ++outputSlotIdx) { OutputSlot* subGraphOutputSlot = subGraphOutputSlots.at(outputSlotIdx); - BOOST_ASSERT(subGraphOutputSlot != nullptr); + BOOST_ASSERT(subGraphOutputSlot); - OutputSlot* substituteOutputSlot = boost::polymorphic_downcast<OutputSlot*>( - &substituteLayer->GetOutputSlot(outputSlotIdx)); + OutputSlot* substituteOutputSlot = substituteSubGraphOutputSlots.at(outputSlotIdx); + BOOST_ASSERT(substituteOutputSlot); subGraphOutputSlot->MoveAllConnections(*substituteOutputSlot); } } |