aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/SubGraph.cpp
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-01-24 14:06:23 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-01-30 14:03:28 +0000
commitadddddb6cbcb777d92a8c464c9ad0cb9aecc76a3 (patch)
treeb15de32bf9f8612f66e1ae23d2f8009e80e7d0e6 /src/armnn/SubGraph.cpp
parentd089b74bebbcc8518fb0f4eacb7e6569ae170199 (diff)
downloadarmnn-adddddb6cbcb777d92a8c464c9ad0cb9aecc76a3.tar.gz
IVGCVSW-2458 Refactor the Optimize function (Network.cpp) so that
subgraphs are optimized by the backends * Added a new method OptimizeSubGraph to the backend interface * Refactored the Optimize function so that the backend-specific optimization is performed by the backend itself (through the new OptimizeSubGraph interface method) * Added a new ApplyBackendOptimizations function to apply the new changes * Added some new convenient constructors to the SubGraph class * Added AddLayer method and a pointer to the parent graph to the SubGraph class * Updated the sub-graph unit tests to match the changes * Added SelectSubGraphs and ReplaceSubGraphConnections overloads that work with sub-graphs * Removed unused code and minor refactoring where necessary Change-Id: I46181794c6a9e3b10558944f804e06a8f693a6d0
Diffstat (limited to 'src/armnn/SubGraph.cpp')
-rw-r--r--src/armnn/SubGraph.cpp144
1 files changed, 129 insertions, 15 deletions
diff --git a/src/armnn/SubGraph.cpp b/src/armnn/SubGraph.cpp
index 74a1838ef0..d0fc760c15 100644
--- a/src/armnn/SubGraph.cpp
+++ b/src/armnn/SubGraph.cpp
@@ -3,33 +3,147 @@
// SPDX-License-Identifier: MIT
//
-#include "Layer.hpp"
#include "SubGraph.hpp"
+#include "Graph.hpp"
#include <boost/numeric/conversion/cast.hpp>
+#include <utility>
+
namespace armnn
{
-SubGraph::SubGraph()
+namespace
+{
+
+template <class C>
+void AssertIfNullsOrDuplicates(const C& container, const std::string& errorMessage)
+{
+ using T = typename C::value_type;
+ std::unordered_set<T> duplicateSet;
+ std::for_each(container.begin(), container.end(), [&duplicateSet, &errorMessage](const T& i)
+ {
+ // Ignore unused for release builds
+ boost::ignore_unused(errorMessage);
+
+ // Check if the item is valid
+ BOOST_ASSERT_MSG(i, errorMessage.c_str());
+
+ // Check if a duplicate has been found
+ BOOST_ASSERT_MSG(duplicateSet.find(i) == duplicateSet.end(), errorMessage.c_str());
+
+ duplicateSet.insert(i);
+ });
+}
+
+} // anonymous namespace
+
+SubGraph::SubGraph(Graph& graph)
+ : m_InputSlots{}
+ , m_OutputSlots{}
+ , m_Layers(graph.begin(), graph.end())
+ , m_ParentGraph(&graph)
+{
+ CheckSubGraph();
+}
+
+SubGraph::SubGraph(Graph* parentGraph, InputSlots&& inputs, OutputSlots&& outputs, Layers&& layers)
+ : m_InputSlots{inputs}
+ , m_OutputSlots{outputs}
+ , m_Layers{layers}
+ , m_ParentGraph(parentGraph)
{
+ CheckSubGraph();
}
-SubGraph::SubGraph(InputSlots && inputs,
- OutputSlots && outputs,
- Layers && layers)
-: m_InputSlots{inputs}
-, m_OutputSlots{outputs}
-, m_Layers{layers}
+SubGraph::SubGraph(const SubGraph& referenceSubGraph, InputSlots&& inputs, OutputSlots&& outputs, Layers&& layers)
+ : m_InputSlots{inputs}
+ , m_OutputSlots{outputs}
+ , m_Layers{layers}
+ , m_ParentGraph(referenceSubGraph.m_ParentGraph)
{
+ CheckSubGraph();
+}
+
+SubGraph::SubGraph(const SubGraph& subGraph)
+ : m_InputSlots(subGraph.m_InputSlots.begin(), subGraph.m_InputSlots.end())
+ , m_OutputSlots(subGraph.m_OutputSlots.begin(), subGraph.m_OutputSlots.end())
+ , m_Layers(subGraph.m_Layers.begin(), subGraph.m_Layers.end())
+ , m_ParentGraph(subGraph.m_ParentGraph)
+{
+ CheckSubGraph();
+}
+
+SubGraph::SubGraph(SubGraph&& subGraph)
+ : m_InputSlots(std::move(subGraph.m_InputSlots))
+ , m_OutputSlots(std::move(subGraph.m_OutputSlots))
+ , m_Layers(std::move(subGraph.m_Layers))
+ , m_ParentGraph(std::exchange(subGraph.m_ParentGraph, nullptr))
+{
+ CheckSubGraph();
+}
+
+SubGraph::SubGraph(const SubGraph& referenceSubGraph, IConnectableLayer* layer)
+ : m_InputSlots{}
+ , m_OutputSlots{}
+ , m_Layers{boost::polymorphic_downcast<Layer*>(layer)}
+ , m_ParentGraph(referenceSubGraph.m_ParentGraph)
+{
+ unsigned int numInputSlots = layer->GetNumInputSlots();
+ m_InputSlots.resize(numInputSlots);
+ for (unsigned int i = 0; i < numInputSlots; i++)
+ {
+ m_InputSlots.at(i) = boost::polymorphic_downcast<InputSlot*>(&(layer->GetInputSlot(i)));
+ }
+
+ unsigned int numOutputSlots = layer->GetNumOutputSlots();
+ m_OutputSlots.resize(numOutputSlots);
+ for (unsigned int i = 0; i < numOutputSlots; i++)
+ {
+ m_OutputSlots.at(i) = boost::polymorphic_downcast<OutputSlot*>(&(layer->GetOutputSlot(i)));
+ }
+
+ CheckSubGraph();
+}
+
+void SubGraph::CheckSubGraph()
+{
+ // Check that the sub-graph has a valid parent graph
+ BOOST_ASSERT_MSG(m_ParentGraph, "Sub-graphs must have a parent graph");
+
+ // Check for invalid or duplicate input slots
+ AssertIfNullsOrDuplicates(m_InputSlots, "Sub-graphs cannot contain null or duplicate input slots");
+
+ // Check for invalid or duplicate output slots
+ AssertIfNullsOrDuplicates(m_OutputSlots, "Sub-graphs cannot contain null or duplicate output slots");
+
+ // Check for invalid or duplicate layers
+ AssertIfNullsOrDuplicates(m_Layers, "Sub-graphs cannot contain null or duplicate layers");
+
+ // Check that all the layers of the sub-graph belong to the parent graph
+ std::for_each(m_Layers.begin(), m_Layers.end(), [&](const Layer* l)
+ {
+ BOOST_ASSERT_MSG(std::find(m_ParentGraph->begin(), m_ParentGraph->end(), l) != m_ParentGraph->end(),
+ "Sub-graph layer is not a member of the parent graph");
+ });
+}
+
+void SubGraph::Update(Graph &graph)
+{
+ m_InputSlots.clear();
+ m_OutputSlots.clear();
+ m_Layers.assign(graph.begin(), graph.end());
+ m_ParentGraph = &graph;
+
+ CheckSubGraph();
}
-const SubGraph::InputSlots & SubGraph::GetInputSlots() const
+const SubGraph::InputSlots& SubGraph::GetInputSlots() const
{
return m_InputSlots;
}
-const SubGraph::OutputSlots & SubGraph::GetOutputSlots() const
+const SubGraph::OutputSlots& SubGraph::GetOutputSlots() const
{
return m_OutputSlots;
}
@@ -74,27 +188,27 @@ SubGraph::Layers::iterator SubGraph::begin()
return m_Layers.begin();
}
-SubGraph::Layers::iterator SubGraph::end()
+SubGraph::Iterator SubGraph::end()
{
return m_Layers.end();
}
-SubGraph::Layers::const_iterator SubGraph::begin() const
+SubGraph::ConstIterator SubGraph::begin() const
{
return m_Layers.begin();
}
-SubGraph::Layers::const_iterator SubGraph::end() const
+SubGraph::ConstIterator SubGraph::end() const
{
return m_Layers.end();
}
-SubGraph::Layers::const_iterator SubGraph::cbegin() const
+SubGraph::ConstIterator SubGraph::cbegin() const
{
return begin();
}
-SubGraph::Layers::const_iterator SubGraph::cend() const
+SubGraph::ConstIterator SubGraph::cend() const
{
return end();
}