aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDerek Lamberti <derek.lamberti@arm.com>2019-05-07 21:33:30 +0100
committerDerek Lamberti <derek.lamberti@arm.com>2019-05-09 10:41:54 +0100
commit8106b7cc67b0fe2c91eaf2ff129200529ae8c31f (patch)
tree961a0d49e96fc9688233d9f9a4166f57069c60d0
parentf92dfced4498f12b9315c0fa377ba7be8998b607 (diff)
downloadarmnn-8106b7cc67b0fe2c91eaf2ff129200529ae8c31f.tar.gz
IVGCVSW-3031 Reparent layer to new graph
Change-Id: Ic4423b8d21d794f44ddae291853e0e3b89d11bc0 Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
-rw-r--r--src/armnn/Graph.hpp38
-rw-r--r--src/armnn/Layer.hpp1
-rw-r--r--src/backends/backendsCommon/IBackendInternal.hpp2
-rw-r--r--src/backends/backendsCommon/OptimizationViews.hpp4
4 files changed, 36 insertions, 9 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp
index 88d2002112..062d727fd1 100644
--- a/src/armnn/Graph.hpp
+++ b/src/armnn/Graph.hpp
@@ -235,18 +235,40 @@ class Graph::LayerInGraphBase : public LayerT
protected:
template <typename... Args>
LayerInGraphBase(Graph& graph, Iterator insertBefore, Args&&... args)
- : LayerT(std::forward<Args>(args)...), m_Graph(graph)
+ : LayerT(std::forward<Args>(args)...), m_Graph(&graph)
{
- m_Graph.m_PosInGraphMap.emplace(this, m_Graph.m_Layers.emplace(insertBefore, this));
+ Insert(*m_Graph, insertBefore);
}
~LayerInGraphBase()
{
- const size_t numErased = m_Graph.m_PosInGraphMap.erase(this);
+ Remove(*m_Graph);
+ }
+
+ void Reparent(Graph& destGraph, Iterator insertBefore) override
+ {
+ Insert(destGraph, insertBefore);
+ Remove(*m_Graph);
+ m_Graph->m_Layers.erase(m_Graph->GetPosInGraph(*this));
+
+ m_Graph = &destGraph;
+ }
+
+private:
+ void Insert(Graph& graph, Iterator insertBefore)
+ {
+ graph.m_PosInGraphMap.emplace(this, graph.m_Layers.emplace(insertBefore, this));
+ }
+
+ void Remove(Graph& graph)
+ {
+ const size_t numErased = graph.m_PosInGraphMap.erase(this);
boost::ignore_unused(numErased);
BOOST_ASSERT(numErased == 1);
}
- Graph& m_Graph;
+protected:
+ Graph* m_Graph;
+
};
/// Input/Output layers specialize this template.
@@ -284,7 +306,7 @@ public:
std::next(graph.begin(), IteratorDifference(graph.GetNumInputs())),
std::forward<Args>(args)...)
{
- const bool isNewId = m_Graph.m_InputIds.emplace(GetBindingId()).second;
+ const bool isNewId = m_Graph->m_InputIds.emplace(GetBindingId()).second;
if (!isNewId)
{
throw InvalidArgumentException("A layer already exists with the specified id");
@@ -298,7 +320,7 @@ public:
}
~LayerInGraph() override
{
- const size_t numErased = m_Graph.m_InputIds.erase(GetBindingId());
+ const size_t numErased = m_Graph->m_InputIds.erase(GetBindingId());
boost::ignore_unused(numErased);
BOOST_ASSERT(numErased == 1);
}
@@ -316,7 +338,7 @@ public:
graph.end(),
std::forward<Args>(args)...)
{
- const bool isNewId = m_Graph.m_OutputIds.emplace(GetBindingId()).second;
+ const bool isNewId = m_Graph->m_OutputIds.emplace(GetBindingId()).second;
if (!isNewId)
{
throw InvalidArgumentException("A layer already exists with the specified id");
@@ -324,7 +346,7 @@ public:
}
~LayerInGraph() override
{
- const size_t numErased = m_Graph.m_OutputIds.erase(GetBindingId());
+ const size_t numErased = m_Graph->m_OutputIds.erase(GetBindingId());
boost::ignore_unused(numErased);
BOOST_ASSERT(numErased == 1);
}
diff --git a/src/armnn/Layer.hpp b/src/armnn/Layer.hpp
index 507b37bf95..cbb1771668 100644
--- a/src/armnn/Layer.hpp
+++ b/src/armnn/Layer.hpp
@@ -298,6 +298,7 @@ public:
const std::list<std::string>& GetRelatedLayerNames() { return m_RelatedLayerNames; }
+ virtual void Reparent(Graph& dest, std::list<Layer*>::const_iterator iterator) = 0;
protected:
// Graph needs access to the virtual destructor.
friend class Graph;
diff --git a/src/backends/backendsCommon/IBackendInternal.hpp b/src/backends/backendsCommon/IBackendInternal.hpp
index f49a210988..5316f68009 100644
--- a/src/backends/backendsCommon/IBackendInternal.hpp
+++ b/src/backends/backendsCommon/IBackendInternal.hpp
@@ -65,7 +65,7 @@ public:
// Default implementation of OptimizeSubgraphView for backward compatibility with old API.
// Override this method with a custom optimization implementation.
- virtual OptimizationViews OptimizeSubgraphView(const SubgraphView& subgraph)
+ virtual OptimizationViews OptimizeSubgraphView(const SubgraphView& subgraph) const
{
bool attempted=false;
SubgraphViewUniquePtr optSubgraph = OptimizeSubgraphView(subgraph, attempted);
diff --git a/src/backends/backendsCommon/OptimizationViews.hpp b/src/backends/backendsCommon/OptimizationViews.hpp
index cf7051d887..e1b59ed633 100644
--- a/src/backends/backendsCommon/OptimizationViews.hpp
+++ b/src/backends/backendsCommon/OptimizationViews.hpp
@@ -45,9 +45,13 @@ public:
Subgraphs GetUntouchedSubgraphs() const { return m_UntouchedSubgraphs; }
bool Validate(const SubgraphView& originalSubgraph) const;
+ Graph& GetGraph() { return m_Graph; };
+
private:
Substitutions m_SuccesfulOptimizations; ///< Proposed substitutions from successful optimizations
Subgraphs m_FailedOptimizations; ///< Subgraphs from the original subgraph which cannot be supported
Subgraphs m_UntouchedSubgraphs; ///< Subgraphs from the original subgraph which remain unmodified
+
+ Graph m_Graph;
};
} //namespace armnn \ No newline at end of file