From 8106b7cc67b0fe2c91eaf2ff129200529ae8c31f Mon Sep 17 00:00:00 2001 From: Derek Lamberti Date: Tue, 7 May 2019 21:33:30 +0100 Subject: IVGCVSW-3031 Reparent layer to new graph Change-Id: Ic4423b8d21d794f44ddae291853e0e3b89d11bc0 Signed-off-by: Derek Lamberti --- src/armnn/Graph.hpp | 38 ++++++++++++++++++----- src/armnn/Layer.hpp | 1 + src/backends/backendsCommon/IBackendInternal.hpp | 2 +- src/backends/backendsCommon/OptimizationViews.hpp | 4 +++ 4 files changed, 36 insertions(+), 9 deletions(-) diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp index 88d2002112..062d727fd1 100644 --- a/src/armnn/Graph.hpp +++ b/src/armnn/Graph.hpp @@ -235,18 +235,40 @@ class Graph::LayerInGraphBase : public LayerT protected: template LayerInGraphBase(Graph& graph, Iterator insertBefore, Args&&... args) - : LayerT(std::forward(args)...), m_Graph(graph) + : LayerT(std::forward(args)...), m_Graph(&graph) { - m_Graph.m_PosInGraphMap.emplace(this, m_Graph.m_Layers.emplace(insertBefore, this)); + Insert(*m_Graph, insertBefore); } ~LayerInGraphBase() { - const size_t numErased = m_Graph.m_PosInGraphMap.erase(this); + Remove(*m_Graph); + } + + void Reparent(Graph& destGraph, Iterator insertBefore) override + { + Insert(destGraph, insertBefore); + Remove(*m_Graph); + m_Graph->m_Layers.erase(m_Graph->GetPosInGraph(*this)); + + m_Graph = &destGraph; + } + +private: + void Insert(Graph& graph, Iterator insertBefore) + { + graph.m_PosInGraphMap.emplace(this, graph.m_Layers.emplace(insertBefore, this)); + } + + void Remove(Graph& graph) + { + const size_t numErased = graph.m_PosInGraphMap.erase(this); boost::ignore_unused(numErased); BOOST_ASSERT(numErased == 1); } - Graph& m_Graph; +protected: + Graph* m_Graph; + }; /// Input/Output layers specialize this template. @@ -284,7 +306,7 @@ public: std::next(graph.begin(), IteratorDifference(graph.GetNumInputs())), std::forward(args)...) { - const bool isNewId = m_Graph.m_InputIds.emplace(GetBindingId()).second; + const bool isNewId = m_Graph->m_InputIds.emplace(GetBindingId()).second; if (!isNewId) { throw InvalidArgumentException("A layer already exists with the specified id"); @@ -298,7 +320,7 @@ public: } ~LayerInGraph() override { - const size_t numErased = m_Graph.m_InputIds.erase(GetBindingId()); + const size_t numErased = m_Graph->m_InputIds.erase(GetBindingId()); boost::ignore_unused(numErased); BOOST_ASSERT(numErased == 1); } @@ -316,7 +338,7 @@ public: graph.end(), std::forward(args)...) { - const bool isNewId = m_Graph.m_OutputIds.emplace(GetBindingId()).second; + const bool isNewId = m_Graph->m_OutputIds.emplace(GetBindingId()).second; if (!isNewId) { throw InvalidArgumentException("A layer already exists with the specified id"); @@ -324,7 +346,7 @@ public: } ~LayerInGraph() override { - const size_t numErased = m_Graph.m_OutputIds.erase(GetBindingId()); + const size_t numErased = m_Graph->m_OutputIds.erase(GetBindingId()); boost::ignore_unused(numErased); BOOST_ASSERT(numErased == 1); } diff --git a/src/armnn/Layer.hpp b/src/armnn/Layer.hpp index 507b37bf95..cbb1771668 100644 --- a/src/armnn/Layer.hpp +++ b/src/armnn/Layer.hpp @@ -298,6 +298,7 @@ public: const std::list& GetRelatedLayerNames() { return m_RelatedLayerNames; } + virtual void Reparent(Graph& dest, std::list::const_iterator iterator) = 0; protected: // Graph needs access to the virtual destructor. friend class Graph; diff --git a/src/backends/backendsCommon/IBackendInternal.hpp b/src/backends/backendsCommon/IBackendInternal.hpp index f49a210988..5316f68009 100644 --- a/src/backends/backendsCommon/IBackendInternal.hpp +++ b/src/backends/backendsCommon/IBackendInternal.hpp @@ -65,7 +65,7 @@ public: // Default implementation of OptimizeSubgraphView for backward compatibility with old API. // Override this method with a custom optimization implementation. - virtual OptimizationViews OptimizeSubgraphView(const SubgraphView& subgraph) + virtual OptimizationViews OptimizeSubgraphView(const SubgraphView& subgraph) const { bool attempted=false; SubgraphViewUniquePtr optSubgraph = OptimizeSubgraphView(subgraph, attempted); diff --git a/src/backends/backendsCommon/OptimizationViews.hpp b/src/backends/backendsCommon/OptimizationViews.hpp index cf7051d887..e1b59ed633 100644 --- a/src/backends/backendsCommon/OptimizationViews.hpp +++ b/src/backends/backendsCommon/OptimizationViews.hpp @@ -45,9 +45,13 @@ public: Subgraphs GetUntouchedSubgraphs() const { return m_UntouchedSubgraphs; } bool Validate(const SubgraphView& originalSubgraph) const; + Graph& GetGraph() { return m_Graph; }; + private: Substitutions m_SuccesfulOptimizations; ///< Proposed substitutions from successful optimizations Subgraphs m_FailedOptimizations; ///< Subgraphs from the original subgraph which cannot be supported Subgraphs m_UntouchedSubgraphs; ///< Subgraphs from the original subgraph which remain unmodified + + Graph m_Graph; }; } //namespace armnn \ No newline at end of file -- cgit v1.2.1