diff options
author | Derek Lamberti <derek.lamberti@arm.com> | 2019-05-07 21:33:30 +0100 |
---|---|---|
committer | Derek Lamberti <derek.lamberti@arm.com> | 2019-05-09 10:41:54 +0100 |
commit | 8106b7cc67b0fe2c91eaf2ff129200529ae8c31f (patch) | |
tree | 961a0d49e96fc9688233d9f9a4166f57069c60d0 | |
parent | f92dfced4498f12b9315c0fa377ba7be8998b607 (diff) | |
download | armnn-8106b7cc67b0fe2c91eaf2ff129200529ae8c31f.tar.gz |
IVGCVSW-3031 Reparent layer to new graph
Change-Id: Ic4423b8d21d794f44ddae291853e0e3b89d11bc0
Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
-rw-r--r-- | src/armnn/Graph.hpp | 38 | ||||
-rw-r--r-- | src/armnn/Layer.hpp | 1 | ||||
-rw-r--r-- | src/backends/backendsCommon/IBackendInternal.hpp | 2 | ||||
-rw-r--r-- | src/backends/backendsCommon/OptimizationViews.hpp | 4 |
4 files changed, 36 insertions, 9 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp index 88d2002112..062d727fd1 100644 --- a/src/armnn/Graph.hpp +++ b/src/armnn/Graph.hpp @@ -235,18 +235,40 @@ class Graph::LayerInGraphBase : public LayerT protected: template <typename... Args> LayerInGraphBase(Graph& graph, Iterator insertBefore, Args&&... args) - : LayerT(std::forward<Args>(args)...), m_Graph(graph) + : LayerT(std::forward<Args>(args)...), m_Graph(&graph) { - m_Graph.m_PosInGraphMap.emplace(this, m_Graph.m_Layers.emplace(insertBefore, this)); + Insert(*m_Graph, insertBefore); } ~LayerInGraphBase() { - const size_t numErased = m_Graph.m_PosInGraphMap.erase(this); + Remove(*m_Graph); + } + + void Reparent(Graph& destGraph, Iterator insertBefore) override + { + Insert(destGraph, insertBefore); + Remove(*m_Graph); + m_Graph->m_Layers.erase(m_Graph->GetPosInGraph(*this)); + + m_Graph = &destGraph; + } + +private: + void Insert(Graph& graph, Iterator insertBefore) + { + graph.m_PosInGraphMap.emplace(this, graph.m_Layers.emplace(insertBefore, this)); + } + + void Remove(Graph& graph) + { + const size_t numErased = graph.m_PosInGraphMap.erase(this); boost::ignore_unused(numErased); BOOST_ASSERT(numErased == 1); } - Graph& m_Graph; +protected: + Graph* m_Graph; + }; /// Input/Output layers specialize this template. @@ -284,7 +306,7 @@ public: std::next(graph.begin(), IteratorDifference(graph.GetNumInputs())), std::forward<Args>(args)...) { - const bool isNewId = m_Graph.m_InputIds.emplace(GetBindingId()).second; + const bool isNewId = m_Graph->m_InputIds.emplace(GetBindingId()).second; if (!isNewId) { throw InvalidArgumentException("A layer already exists with the specified id"); @@ -298,7 +320,7 @@ public: } ~LayerInGraph() override { - const size_t numErased = m_Graph.m_InputIds.erase(GetBindingId()); + const size_t numErased = m_Graph->m_InputIds.erase(GetBindingId()); boost::ignore_unused(numErased); BOOST_ASSERT(numErased == 1); } @@ -316,7 +338,7 @@ public: graph.end(), std::forward<Args>(args)...) { - const bool isNewId = m_Graph.m_OutputIds.emplace(GetBindingId()).second; + const bool isNewId = m_Graph->m_OutputIds.emplace(GetBindingId()).second; if (!isNewId) { throw InvalidArgumentException("A layer already exists with the specified id"); @@ -324,7 +346,7 @@ public: } ~LayerInGraph() override { - const size_t numErased = m_Graph.m_OutputIds.erase(GetBindingId()); + const size_t numErased = m_Graph->m_OutputIds.erase(GetBindingId()); boost::ignore_unused(numErased); BOOST_ASSERT(numErased == 1); } diff --git a/src/armnn/Layer.hpp b/src/armnn/Layer.hpp index 507b37bf95..cbb1771668 100644 --- a/src/armnn/Layer.hpp +++ b/src/armnn/Layer.hpp @@ -298,6 +298,7 @@ public: const std::list<std::string>& GetRelatedLayerNames() { return m_RelatedLayerNames; } + virtual void Reparent(Graph& dest, std::list<Layer*>::const_iterator iterator) = 0; protected: // Graph needs access to the virtual destructor. friend class Graph; diff --git a/src/backends/backendsCommon/IBackendInternal.hpp b/src/backends/backendsCommon/IBackendInternal.hpp index f49a210988..5316f68009 100644 --- a/src/backends/backendsCommon/IBackendInternal.hpp +++ b/src/backends/backendsCommon/IBackendInternal.hpp @@ -65,7 +65,7 @@ public: // Default implementation of OptimizeSubgraphView for backward compatibility with old API. // Override this method with a custom optimization implementation. - virtual OptimizationViews OptimizeSubgraphView(const SubgraphView& subgraph) + virtual OptimizationViews OptimizeSubgraphView(const SubgraphView& subgraph) const { bool attempted=false; SubgraphViewUniquePtr optSubgraph = OptimizeSubgraphView(subgraph, attempted); diff --git a/src/backends/backendsCommon/OptimizationViews.hpp b/src/backends/backendsCommon/OptimizationViews.hpp index cf7051d887..e1b59ed633 100644 --- a/src/backends/backendsCommon/OptimizationViews.hpp +++ b/src/backends/backendsCommon/OptimizationViews.hpp @@ -45,9 +45,13 @@ public: Subgraphs GetUntouchedSubgraphs() const { return m_UntouchedSubgraphs; } bool Validate(const SubgraphView& originalSubgraph) const; + Graph& GetGraph() { return m_Graph; }; + private: Substitutions m_SuccesfulOptimizations; ///< Proposed substitutions from successful optimizations Subgraphs m_FailedOptimizations; ///< Subgraphs from the original subgraph which cannot be supported Subgraphs m_UntouchedSubgraphs; ///< Subgraphs from the original subgraph which remain unmodified + + Graph m_Graph; }; } //namespace armnn
\ No newline at end of file |