aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Graph.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Graph.hpp')
-rw-r--r--src/armnn/Graph.hpp38
1 files changed, 30 insertions, 8 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp
index 88d2002112..062d727fd1 100644
--- a/src/armnn/Graph.hpp
+++ b/src/armnn/Graph.hpp
@@ -235,18 +235,40 @@ class Graph::LayerInGraphBase : public LayerT
protected:
template <typename... Args>
LayerInGraphBase(Graph& graph, Iterator insertBefore, Args&&... args)
- : LayerT(std::forward<Args>(args)...), m_Graph(graph)
+ : LayerT(std::forward<Args>(args)...), m_Graph(&graph)
{
- m_Graph.m_PosInGraphMap.emplace(this, m_Graph.m_Layers.emplace(insertBefore, this));
+ Insert(*m_Graph, insertBefore);
}
~LayerInGraphBase()
{
- const size_t numErased = m_Graph.m_PosInGraphMap.erase(this);
+ Remove(*m_Graph);
+ }
+
+ void Reparent(Graph& destGraph, Iterator insertBefore) override
+ {
+ Insert(destGraph, insertBefore);
+ Remove(*m_Graph);
+ m_Graph->m_Layers.erase(m_Graph->GetPosInGraph(*this));
+
+ m_Graph = &destGraph;
+ }
+
+private:
+ void Insert(Graph& graph, Iterator insertBefore)
+ {
+ graph.m_PosInGraphMap.emplace(this, graph.m_Layers.emplace(insertBefore, this));
+ }
+
+ void Remove(Graph& graph)
+ {
+ const size_t numErased = graph.m_PosInGraphMap.erase(this);
boost::ignore_unused(numErased);
BOOST_ASSERT(numErased == 1);
}
- Graph& m_Graph;
+protected:
+ Graph* m_Graph;
+
};
/// Input/Output layers specialize this template.
@@ -284,7 +306,7 @@ public:
std::next(graph.begin(), IteratorDifference(graph.GetNumInputs())),
std::forward<Args>(args)...)
{
- const bool isNewId = m_Graph.m_InputIds.emplace(GetBindingId()).second;
+ const bool isNewId = m_Graph->m_InputIds.emplace(GetBindingId()).second;
if (!isNewId)
{
throw InvalidArgumentException("A layer already exists with the specified id");
@@ -298,7 +320,7 @@ public:
}
~LayerInGraph() override
{
- const size_t numErased = m_Graph.m_InputIds.erase(GetBindingId());
+ const size_t numErased = m_Graph->m_InputIds.erase(GetBindingId());
boost::ignore_unused(numErased);
BOOST_ASSERT(numErased == 1);
}
@@ -316,7 +338,7 @@ public:
graph.end(),
std::forward<Args>(args)...)
{
- const bool isNewId = m_Graph.m_OutputIds.emplace(GetBindingId()).second;
+ const bool isNewId = m_Graph->m_OutputIds.emplace(GetBindingId()).second;
if (!isNewId)
{
throw InvalidArgumentException("A layer already exists with the specified id");
@@ -324,7 +346,7 @@ public:
}
~LayerInGraph() override
{
- const size_t numErased = m_Graph.m_OutputIds.erase(GetBindingId());
+ const size_t numErased = m_Graph->m_OutputIds.erase(GetBindingId());
boost::ignore_unused(numErased);
BOOST_ASSERT(numErased == 1);
}