aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Graph.cpp
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-01-24 14:06:23 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-01-30 14:03:28 +0000
commitadddddb6cbcb777d92a8c464c9ad0cb9aecc76a3 (patch)
treeb15de32bf9f8612f66e1ae23d2f8009e80e7d0e6 /src/armnn/Graph.cpp
parentd089b74bebbcc8518fb0f4eacb7e6569ae170199 (diff)
downloadarmnn-adddddb6cbcb777d92a8c464c9ad0cb9aecc76a3.tar.gz
IVGCVSW-2458 Refactor the Optimize function (Network.cpp) so that
subgraphs are optimized by the backends * Added a new method OptimizeSubGraph to the backend interface * Refactored the Optimize function so that the backend-specific optimization is performed by the backend itself (through the new OptimizeSubGraph interface method) * Added a new ApplyBackendOptimizations function to apply the new changes * Added some new convenient constructors to the SubGraph class * Added AddLayer method and a pointer to the parent graph to the SubGraph class * Updated the sub-graph unit tests to match the changes * Added SelectSubGraphs and ReplaceSubGraphConnections overloads that work with sub-graphs * Removed unused code and minor refactoring where necessary Change-Id: I46181794c6a9e3b10558944f804e06a8f693a6d0
Diffstat (limited to 'src/armnn/Graph.cpp')
-rw-r--r--src/armnn/Graph.cpp63
1 files changed, 46 insertions, 17 deletions
diff --git a/src/armnn/Graph.cpp b/src/armnn/Graph.cpp
index 831d85e404..1bd4fbd85a 100644
--- a/src/armnn/Graph.cpp
+++ b/src/armnn/Graph.cpp
@@ -2,7 +2,9 @@
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
+
#include "Graph.hpp"
+#include "SubGraph.hpp"
#include "LayersFwd.hpp"
#include <armnn/Utils.hpp>
@@ -17,7 +19,6 @@
#include <DotSerializer.hpp>
#include <sstream>
-
namespace armnn
{
@@ -238,7 +239,7 @@ const Graph& Graph::TopologicalSort() const
it->ResetPriority();
}
- auto compareLayerPriority = [](const LayersList::value_type& layerA, const LayersList::value_type& layerB)
+ auto compareLayerPriority = [](const LayerList::value_type& layerA, const LayerList::value_type& layerB)
{
return layerA->GetPriority() < layerB->GetPriority();
};
@@ -306,44 +307,72 @@ void Graph::SubstituteSubGraph(std::unique_ptr<SubGraph> subGraph, IConnectableL
EraseSubGraphLayers(*subGraph);
}
+void Graph::SubstituteSubGraph(std::unique_ptr<SubGraph> subGraph, const SubGraph& substituteSubGraph)
+{
+ BOOST_ASSERT(subGraph);
+
+ ReplaceSubGraphConnections(*subGraph, substituteSubGraph);
+ EraseSubGraphLayers(*subGraph);
+}
+
void Graph::ReplaceSubGraphConnections(const SubGraph& subGraph, IConnectableLayer* substituteLayer)
{
BOOST_ASSERT(substituteLayer != nullptr);
BOOST_ASSERT_MSG(std::find(m_Layers.begin(), m_Layers.end(), substituteLayer) != m_Layers.end(),
- "Substitue layer is not a member of graph");
+ "Substitute layer is not a member of graph");
+
+ SubGraph substituteSubGraph(subGraph, substituteLayer);
+ ReplaceSubGraphConnections(subGraph, substituteSubGraph);
+}
+
+void Graph::ReplaceSubGraphConnections(const SubGraph& subGraph, const SubGraph& substituteSubGraph)
+{
+ BOOST_ASSERT_MSG(!substituteSubGraph.GetLayers().empty(), "New sub-graph used for substitution must not be empty");
+
+ const SubGraph::Layers& substituteSubGraphLayers = substituteSubGraph.GetLayers();
+ std::for_each(substituteSubGraphLayers.begin(), substituteSubGraphLayers.end(), [&](Layer* layer)
+ {
+ BOOST_ASSERT_MSG(std::find(m_Layers.begin(), m_Layers.end(), layer) != m_Layers.end(),
+ "Substitute layer is not a member of graph");
+ });
const SubGraph::InputSlots& subGraphInputSlots = subGraph.GetInputSlots();
const SubGraph::OutputSlots& subGraphOutputSlots = subGraph.GetOutputSlots();
- const unsigned int numInputSlots = boost::numeric_cast<unsigned int>(subGraphInputSlots.size());
- const unsigned int numOutputSlots = boost::numeric_cast<unsigned int>(subGraphOutputSlots.size());
+ unsigned int subGraphNumInputSlots = boost::numeric_cast<unsigned int>(subGraphInputSlots.size());
+ unsigned int subGraphNumOutputSlots = boost::numeric_cast<unsigned int>(subGraphOutputSlots.size());
+
+ const SubGraph::InputSlots& substituteSubGraphInputSlots = substituteSubGraph.GetInputSlots();
+ const SubGraph::OutputSlots& substituteSubGraphOutputSlots = substituteSubGraph.GetOutputSlots();
+
+ BOOST_ASSERT(subGraphNumInputSlots == substituteSubGraphInputSlots.size());
+ BOOST_ASSERT(subGraphNumOutputSlots == substituteSubGraphOutputSlots.size());
- BOOST_ASSERT(numInputSlots == substituteLayer->GetNumInputSlots());
- BOOST_ASSERT(numOutputSlots == substituteLayer->GetNumOutputSlots());
+ // Disconnect the sub-graph and replace it with the substitute sub-graph
- // Disconnect the sub-graph and replace it with the substitute layer
// Step 1: process input slots
- for(unsigned int inputSlotIdx = 0u; inputSlotIdx < numInputSlots; ++inputSlotIdx)
+ for (unsigned int inputSlotIdx = 0; inputSlotIdx < subGraphNumInputSlots; ++inputSlotIdx)
{
InputSlot* subGraphInputSlot = subGraphInputSlots.at(inputSlotIdx);
- BOOST_ASSERT(subGraphInputSlot != nullptr);
+ BOOST_ASSERT(subGraphInputSlot);
IOutputSlot* connectedOutputSlot = subGraphInputSlot->GetConnection();
- BOOST_ASSERT(connectedOutputSlot != nullptr);
+ BOOST_ASSERT(connectedOutputSlot);
connectedOutputSlot->Disconnect(*subGraphInputSlot);
- IInputSlot& substituteInputSlot = substituteLayer->GetInputSlot(inputSlotIdx);
- connectedOutputSlot->Connect(substituteInputSlot);
+ IInputSlot* substituteInputSlot = substituteSubGraphInputSlots.at(inputSlotIdx);
+ BOOST_ASSERT(substituteInputSlot);
+ connectedOutputSlot->Connect(*substituteInputSlot);
}
// Step 2: process output slots
- for(unsigned int outputSlotIdx = 0u; outputSlotIdx < numOutputSlots; ++outputSlotIdx)
+ for(unsigned int outputSlotIdx = 0; outputSlotIdx < subGraphNumOutputSlots; ++outputSlotIdx)
{
OutputSlot* subGraphOutputSlot = subGraphOutputSlots.at(outputSlotIdx);
- BOOST_ASSERT(subGraphOutputSlot != nullptr);
+ BOOST_ASSERT(subGraphOutputSlot);
- OutputSlot* substituteOutputSlot = boost::polymorphic_downcast<OutputSlot*>(
- &substituteLayer->GetOutputSlot(outputSlotIdx));
+ OutputSlot* substituteOutputSlot = substituteSubGraphOutputSlots.at(outputSlotIdx);
+ BOOST_ASSERT(substituteOutputSlot);
subGraphOutputSlot->MoveAllConnections(*substituteOutputSlot);
}
}