diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-01-24 14:06:23 +0000 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-01-30 14:03:28 +0000 |
commit | adddddb6cbcb777d92a8c464c9ad0cb9aecc76a3 (patch) | |
tree | b15de32bf9f8612f66e1ae23d2f8009e80e7d0e6 /src/armnn/Graph.hpp | |
parent | d089b74bebbcc8518fb0f4eacb7e6569ae170199 (diff) | |
download | armnn-adddddb6cbcb777d92a8c464c9ad0cb9aecc76a3.tar.gz |
IVGCVSW-2458 Refactor the Optimize function (Network.cpp) so that
subgraphs are optimized by the backends
* Added a new method OptimizeSubGraph to the backend interface
* Refactored the Optimize function so that the backend-specific
optimization is performed by the backend itself (through the new
OptimizeSubGraph interface method)
* Added a new ApplyBackendOptimizations function to apply the new
changes
* Added some new convenient constructors to the SubGraph class
* Added AddLayer method and a pointer to the parent graph to the
SubGraph class
* Updated the sub-graph unit tests to match the changes
* Added SelectSubGraphs and ReplaceSubGraphConnections overloads
that work with sub-graphs
* Removed unused code and minor refactoring where necessary
Change-Id: I46181794c6a9e3b10558944f804e06a8f693a6d0
Diffstat (limited to 'src/armnn/Graph.hpp')
-rw-r--r-- | src/armnn/Graph.hpp | 33 |
1 files changed, 18 insertions, 15 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp index 8f93f56b4a..8046977411 100644 --- a/src/armnn/Graph.hpp +++ b/src/armnn/Graph.hpp @@ -6,7 +6,6 @@ #include "LayersFwd.hpp" #include "IGraphObservable.hpp" -#include "SubGraph.hpp" #include <armnn/Types.hpp> #include <armnn/TensorFwd.hpp> @@ -25,21 +24,23 @@ namespace armnn { +class SubGraph; + class Graph { public: - template <typename CVLayerT> - static CVLayerT* PtrCast(Layer* const layer) + template <typename LayerType> + static LayerType* PtrCast(Layer* const layer) { - return boost::polymorphic_downcast<CVLayerT*>(layer); + return boost::polymorphic_downcast<LayerType*>(layer); } - using LayersList = std::list<Layer*>; - using Iterator = LayersList::const_iterator; // Const so pointers in the list can't be modified externally. - using ConstIterator = boost::transform_iterator<decltype(&PtrCast<const Layer>), Iterator>; + using LayerList = std::list<Layer*>; + using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally. using IteratorDifference = Iterator::difference_type; - using ConstIteratorInputs = boost::transform_iterator<decltype(&PtrCast<const InputLayer>), Iterator>; + using ConstIterator = boost::transform_iterator<decltype(&PtrCast<const Layer>), Iterator>; + using ConstIteratorInputs = boost::transform_iterator<decltype(&PtrCast<const InputLayer>), Iterator>; using ConstIteratorOutputs = boost::transform_iterator<decltype(&PtrCast<const OutputLayer>), Iterator>; /// Wrapper class returned by Graph::GetInputLayers() @@ -49,13 +50,13 @@ public: ConstIteratorInputs begin() const { - return { m_Graph.m_Layers.begin(), &PtrCast<const InputLayer> }; + return { m_Graph.m_Layers.begin(), &(PtrCast<const InputLayer>) }; } ConstIteratorInputs end() const { return { std::next(m_Graph.m_Layers.begin(), static_cast<IteratorDifference>(m_Graph.GetNumInputs())), - &PtrCast<const InputLayer> }; + &(PtrCast<const InputLayer>) }; } const Graph& m_Graph; @@ -69,12 +70,12 @@ public: ConstIteratorOutputs begin() const { return { std::prev(m_Graph.m_Layers.end(), static_cast<IteratorDifference>(m_Graph.GetNumOutputs())), - &PtrCast<const OutputLayer> }; + &(PtrCast<const OutputLayer>) }; } ConstIteratorOutputs end() const { - return { m_Graph.m_Layers.end(), &PtrCast<const OutputLayer> }; + return { m_Graph.m_Layers.end(), &(PtrCast<const OutputLayer>) }; } const Graph& m_Graph; @@ -127,9 +128,9 @@ public: Iterator end() { return m_Layers.end(); } /// Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops. - ConstIterator begin() const { return {m_Layers.begin(), &PtrCast<const Layer>}; } + ConstIterator begin() const { return {m_Layers.begin(), &(PtrCast<const Layer>)}; } /// Returns const iterator pointing to the end of the list. Lowercase for range-based for loops. - ConstIterator end() const { return {m_Layers.end(), &PtrCast<const Layer>}; } + ConstIterator end() const { return {m_Layers.end(), &(PtrCast<const Layer>)}; } /// Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops. ConstIterator cbegin() const { return begin(); } @@ -161,6 +162,7 @@ public: void AddCopyLayers(); void SubstituteSubGraph(std::unique_ptr<SubGraph> subGraph, IConnectableLayer* substituteLayer); + void SubstituteSubGraph(std::unique_ptr<SubGraph> subGraph, const SubGraph& substituteSubGraph); void InferTensorInfos(); @@ -214,10 +216,11 @@ private: std::unordered_map<const Layer*, Iterator> m_PosInGraphMap; void ReplaceSubGraphConnections(const SubGraph& subGraph, IConnectableLayer* substituteLayer); + void ReplaceSubGraphConnections(const SubGraph& subGraph, const SubGraph& substituteSubGraph); void EraseSubGraphLayers(const SubGraph &subGraph); /// Mutable to allow sorting on const object. - mutable LayersList m_Layers; + mutable LayerList m_Layers; mutable bool m_LayersInOrder; std::map<const GraphEvent, std::list<IGraphObservable*>> m_Views; |