diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-05-09 19:06:22 +0100 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-05-10 11:04:10 +0100 |
commit | f3d102114a6f837f40400c4de50915abc488f3a5 (patch) | |
tree | f524cdd3cd40f5d8df321a935988ed80c9f4a66b /src/armnn | |
parent | 724e48013142562b7f09c9c819f57c314c4ee3d4 (diff) | |
download | armnn-f3d102114a6f837f40400c4de50915abc488f3a5.tar.gz |
IVGCVSW-3030 Added move operators to the Graph class
* Updated the LayerInGraph class to properly support
the new Reparent operation
* Improved the Graph class destruction process to take into
account eventual reparent layer operations
* Added new ForEachLayerInGraph utility function to safely
loop through all the layers in the graph
Change-Id: Ie67cbdee0c3c8625662ebfa00f860ae0d2fac59c
Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/Graph.hpp | 68 | ||||
-rw-r--r-- | src/armnn/Optimizer.cpp | 4 |
2 files changed, 51 insertions, 21 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp index 5c71ccef2b..c5b1b045a7 100644 --- a/src/armnn/Graph.hpp +++ b/src/armnn/Graph.hpp @@ -35,6 +35,17 @@ public: return boost::polymorphic_downcast<LayerType*>(layer); } + template <typename Func> + void ForEachLayerInGraph(Func func) + { + for (auto it = m_Layers.begin(); it != m_Layers.end(); ) + { + auto next = std::next(it); + func(*it); + it = next; + } + } + using LayerList = std::list<Layer*>; using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally. using IteratorDifference = Iterator::difference_type; @@ -87,15 +98,35 @@ public: Graph& operator=(const Graph& other) = delete; - Graph(Graph&&) = default; - Graph& operator=(Graph&&) = default; + Graph(Graph&& other) + { + *this = std::move(other); + } + + Graph& operator=(Graph&& other) + { + m_InputIds = std::move(other.m_InputIds); + m_OutputIds = std::move(other.m_OutputIds); + m_LayersInOrder = std::move(other.m_LayersInOrder); + m_Views = std::move(other.m_Views); + + other.ForEachLayerInGraph([this](Layer* otherLayer) + { + otherLayer->Reparent(*this, m_Layers.end()); + }); + + BOOST_ASSERT(other.m_PosInGraphMap.empty()); + BOOST_ASSERT(other.m_Layers.empty()); + + return *this; + } ~Graph() { - for (auto&& layer : m_Layers) + ForEachLayerInGraph([](Layer* layer) { delete layer; - } + }); } Status Print() const; @@ -115,15 +146,13 @@ public: template <typename LayerT, typename... Args> LayerT* InsertNewLayer(OutputSlot& insertAfter, Args&&... args); - /// Deletes the layer at the specified position and returns an iterator pointing - /// to the next element after the one being deleted. - Iterator EraseLayer(Iterator pos); + /// Deletes the layer at the specified position. + void EraseLayer(Iterator pos); - /// Deletes the layer and returns an iterator pointing to the next layer in the graph - /// (next in the list, after the one being deleted). Sets @a layer to nullptr on return. + /// Deletes the layer. Sets @a layer to nullptr on return. /// Templated to support pointers to any layer type. template <typename LayerT> - Iterator EraseLayer(LayerT*& layer); + void EraseLayer(LayerT*& layer); /// Returns iterator pointing to the beginning of the list. Lowercase for range-based for loops. Iterator begin() { return m_Layers.begin(); } @@ -179,6 +208,9 @@ public: m_Views[notifyOnEvent].remove(observable); } + /// Gets the position of a layer in the graph. + Iterator GetPosInGraph(Layer& layer); + private: template <typename LayerT> class LayerInGraphBase; @@ -204,9 +236,6 @@ private: return it; } - /// Gets the position of a layer in the graph. - Iterator GetPosInGraph(Layer& layer); - void NotifyObservables(GraphEvent event, Layer* graphState) { // Iterate over all observables observing this event @@ -251,7 +280,6 @@ protected: { Insert(destGraph, insertBefore); Remove(*m_Graph); - m_Graph->m_Layers.erase(m_Graph->GetPosInGraph(*this)); m_Graph = &destGraph; } @@ -264,6 +292,9 @@ private: void Remove(Graph& graph) { + auto layerIt = graph.GetPosInGraph(*this); + graph.m_Layers.erase(layerIt); + const size_t numErased = graph.m_PosInGraphMap.erase(this); boost::ignore_unused(numErased); BOOST_ASSERT(numErased == 1); @@ -271,7 +302,6 @@ private: protected: Graph* m_Graph; - }; /// Input/Output layers specialize this template. @@ -408,21 +438,19 @@ inline LayerT* Graph::InsertNewLayer(OutputSlot& insertAfter, Args&&... args) return layer; } -inline Graph::Iterator Graph::EraseLayer(Iterator pos) +inline void Graph::EraseLayer(Iterator pos) { NotifyObservables(GraphEvent::LayerErased, *pos); delete *pos; - return m_Layers.erase(pos); } template <typename LayerT> -inline Graph::Iterator Graph::EraseLayer(LayerT*& layer) +inline void Graph::EraseLayer(LayerT*& layer) { BOOST_ASSERT(layer != nullptr); - Iterator next = EraseLayer(GetPosInGraph(*layer)); + EraseLayer(GetPosInGraph(*layer)); layer = nullptr; - return next; } } // namespace armnn diff --git a/src/armnn/Optimizer.cpp b/src/armnn/Optimizer.cpp index 5e50c01c09..4d0aae8491 100644 --- a/src/armnn/Optimizer.cpp +++ b/src/armnn/Optimizer.cpp @@ -32,7 +32,9 @@ void Optimizer::Pass(Graph& graph, const Optimizations& optimizations) if ((*it)->IsOutputUnconnected()) { - it = graph.EraseLayer(it); + auto next = std::next(graph.GetPosInGraph(**it)); + graph.EraseLayer(it); + it = next; graphNeedsSorting = true; } |