diff options
Diffstat (limited to 'src/armnn/Graph.hpp')
-rw-r--r-- | src/armnn/Graph.hpp | 38 |
1 files changed, 30 insertions, 8 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp index 88d2002112..062d727fd1 100644 --- a/src/armnn/Graph.hpp +++ b/src/armnn/Graph.hpp @@ -235,18 +235,40 @@ class Graph::LayerInGraphBase : public LayerT protected: template <typename... Args> LayerInGraphBase(Graph& graph, Iterator insertBefore, Args&&... args) - : LayerT(std::forward<Args>(args)...), m_Graph(graph) + : LayerT(std::forward<Args>(args)...), m_Graph(&graph) { - m_Graph.m_PosInGraphMap.emplace(this, m_Graph.m_Layers.emplace(insertBefore, this)); + Insert(*m_Graph, insertBefore); } ~LayerInGraphBase() { - const size_t numErased = m_Graph.m_PosInGraphMap.erase(this); + Remove(*m_Graph); + } + + void Reparent(Graph& destGraph, Iterator insertBefore) override + { + Insert(destGraph, insertBefore); + Remove(*m_Graph); + m_Graph->m_Layers.erase(m_Graph->GetPosInGraph(*this)); + + m_Graph = &destGraph; + } + +private: + void Insert(Graph& graph, Iterator insertBefore) + { + graph.m_PosInGraphMap.emplace(this, graph.m_Layers.emplace(insertBefore, this)); + } + + void Remove(Graph& graph) + { + const size_t numErased = graph.m_PosInGraphMap.erase(this); boost::ignore_unused(numErased); BOOST_ASSERT(numErased == 1); } - Graph& m_Graph; +protected: + Graph* m_Graph; + }; /// Input/Output layers specialize this template. @@ -284,7 +306,7 @@ public: std::next(graph.begin(), IteratorDifference(graph.GetNumInputs())), std::forward<Args>(args)...) { - const bool isNewId = m_Graph.m_InputIds.emplace(GetBindingId()).second; + const bool isNewId = m_Graph->m_InputIds.emplace(GetBindingId()).second; if (!isNewId) { throw InvalidArgumentException("A layer already exists with the specified id"); @@ -298,7 +320,7 @@ public: } ~LayerInGraph() override { - const size_t numErased = m_Graph.m_InputIds.erase(GetBindingId()); + const size_t numErased = m_Graph->m_InputIds.erase(GetBindingId()); boost::ignore_unused(numErased); BOOST_ASSERT(numErased == 1); } @@ -316,7 +338,7 @@ public: graph.end(), std::forward<Args>(args)...) { - const bool isNewId = m_Graph.m_OutputIds.emplace(GetBindingId()).second; + const bool isNewId = m_Graph->m_OutputIds.emplace(GetBindingId()).second; if (!isNewId) { throw InvalidArgumentException("A layer already exists with the specified id"); @@ -324,7 +346,7 @@ public: } ~LayerInGraph() override { - const size_t numErased = m_Graph.m_OutputIds.erase(GetBindingId()); + const size_t numErased = m_Graph->m_OutputIds.erase(GetBindingId()); boost::ignore_unused(numErased); BOOST_ASSERT(numErased == 1); } |