aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-05-09 19:06:22 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-05-10 11:04:10 +0100
commitf3d102114a6f837f40400c4de50915abc488f3a5 (patch)
treef524cdd3cd40f5d8df321a935988ed80c9f4a66b
parent724e48013142562b7f09c9c819f57c314c4ee3d4 (diff)
downloadarmnn-f3d102114a6f837f40400c4de50915abc488f3a5.tar.gz
IVGCVSW-3030 Added move operators to the Graph class
* Updated the LayerInGraph class to properly support the new Reparent operation * Improved the Graph class destruction process to take into account eventual reparent layer operations * Added new ForEachLayerInGraph utility function to safely loop through all the layers in the graph Change-Id: Ie67cbdee0c3c8625662ebfa00f860ae0d2fac59c Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
-rw-r--r--src/armnn/Graph.hpp68
-rw-r--r--src/armnn/Optimizer.cpp4
2 files changed, 51 insertions, 21 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp
index 5c71ccef2b..c5b1b045a7 100644
--- a/src/armnn/Graph.hpp
+++ b/src/armnn/Graph.hpp
@@ -35,6 +35,17 @@ public:
return boost::polymorphic_downcast<LayerType*>(layer);
}
+ template <typename Func>
+ void ForEachLayerInGraph(Func func)
+ {
+ for (auto it = m_Layers.begin(); it != m_Layers.end(); )
+ {
+ auto next = std::next(it);
+ func(*it);
+ it = next;
+ }
+ }
+
using LayerList = std::list<Layer*>;
using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
using IteratorDifference = Iterator::difference_type;
@@ -87,15 +98,35 @@ public:
Graph& operator=(const Graph& other) = delete;
- Graph(Graph&&) = default;
- Graph& operator=(Graph&&) = default;
+ Graph(Graph&& other)
+ {
+ *this = std::move(other);
+ }
+
+ Graph& operator=(Graph&& other)
+ {
+ m_InputIds = std::move(other.m_InputIds);
+ m_OutputIds = std::move(other.m_OutputIds);
+ m_LayersInOrder = std::move(other.m_LayersInOrder);
+ m_Views = std::move(other.m_Views);
+
+ other.ForEachLayerInGraph([this](Layer* otherLayer)
+ {
+ otherLayer->Reparent(*this, m_Layers.end());
+ });
+
+ BOOST_ASSERT(other.m_PosInGraphMap.empty());
+ BOOST_ASSERT(other.m_Layers.empty());
+
+ return *this;
+ }
~Graph()
{
- for (auto&& layer : m_Layers)
+ ForEachLayerInGraph([](Layer* layer)
{
delete layer;
- }
+ });
}
Status Print() const;
@@ -115,15 +146,13 @@ public:
template <typename LayerT, typename... Args>
LayerT* InsertNewLayer(OutputSlot& insertAfter, Args&&... args);
- /// Deletes the layer at the specified position and returns an iterator pointing
- /// to the next element after the one being deleted.
- Iterator EraseLayer(Iterator pos);
+ /// Deletes the layer at the specified position.
+ void EraseLayer(Iterator pos);
- /// Deletes the layer and returns an iterator pointing to the next layer in the graph
- /// (next in the list, after the one being deleted). Sets @a layer to nullptr on return.
+ /// Deletes the layer. Sets @a layer to nullptr on return.
/// Templated to support pointers to any layer type.
template <typename LayerT>
- Iterator EraseLayer(LayerT*& layer);
+ void EraseLayer(LayerT*& layer);
/// Returns iterator pointing to the beginning of the list. Lowercase for range-based for loops.
Iterator begin() { return m_Layers.begin(); }
@@ -179,6 +208,9 @@ public:
m_Views[notifyOnEvent].remove(observable);
}
+ /// Gets the position of a layer in the graph.
+ Iterator GetPosInGraph(Layer& layer);
+
private:
template <typename LayerT>
class LayerInGraphBase;
@@ -204,9 +236,6 @@ private:
return it;
}
- /// Gets the position of a layer in the graph.
- Iterator GetPosInGraph(Layer& layer);
-
void NotifyObservables(GraphEvent event, Layer* graphState)
{
// Iterate over all observables observing this event
@@ -251,7 +280,6 @@ protected:
{
Insert(destGraph, insertBefore);
Remove(*m_Graph);
- m_Graph->m_Layers.erase(m_Graph->GetPosInGraph(*this));
m_Graph = &destGraph;
}
@@ -264,6 +292,9 @@ private:
void Remove(Graph& graph)
{
+ auto layerIt = graph.GetPosInGraph(*this);
+ graph.m_Layers.erase(layerIt);
+
const size_t numErased = graph.m_PosInGraphMap.erase(this);
boost::ignore_unused(numErased);
BOOST_ASSERT(numErased == 1);
@@ -271,7 +302,6 @@ private:
protected:
Graph* m_Graph;
-
};
/// Input/Output layers specialize this template.
@@ -408,21 +438,19 @@ inline LayerT* Graph::InsertNewLayer(OutputSlot& insertAfter, Args&&... args)
return layer;
}
-inline Graph::Iterator Graph::EraseLayer(Iterator pos)
+inline void Graph::EraseLayer(Iterator pos)
{
NotifyObservables(GraphEvent::LayerErased, *pos);
delete *pos;
- return m_Layers.erase(pos);
}
template <typename LayerT>
-inline Graph::Iterator Graph::EraseLayer(LayerT*& layer)
+inline void Graph::EraseLayer(LayerT*& layer)
{
BOOST_ASSERT(layer != nullptr);
- Iterator next = EraseLayer(GetPosInGraph(*layer));
+ EraseLayer(GetPosInGraph(*layer));
layer = nullptr;
- return next;
}
} // namespace armnn
diff --git a/src/armnn/Optimizer.cpp b/src/armnn/Optimizer.cpp
index 5e50c01c09..4d0aae8491 100644
--- a/src/armnn/Optimizer.cpp
+++ b/src/armnn/Optimizer.cpp
@@ -32,7 +32,9 @@ void Optimizer::Pass(Graph& graph, const Optimizations& optimizations)
if ((*it)->IsOutputUnconnected())
{
- it = graph.EraseLayer(it);
+ auto next = std::next(graph.GetPosInGraph(**it));
+ graph.EraseLayer(it);
+ it = next;
graphNeedsSorting = true;
}