aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Graph.hpp
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-01-24 14:06:23 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-01-30 14:03:28 +0000
commitadddddb6cbcb777d92a8c464c9ad0cb9aecc76a3 (patch)
treeb15de32bf9f8612f66e1ae23d2f8009e80e7d0e6 /src/armnn/Graph.hpp
parentd089b74bebbcc8518fb0f4eacb7e6569ae170199 (diff)
downloadarmnn-adddddb6cbcb777d92a8c464c9ad0cb9aecc76a3.tar.gz
IVGCVSW-2458 Refactor the Optimize function (Network.cpp) so that
subgraphs are optimized by the backends * Added a new method OptimizeSubGraph to the backend interface * Refactored the Optimize function so that the backend-specific optimization is performed by the backend itself (through the new OptimizeSubGraph interface method) * Added a new ApplyBackendOptimizations function to apply the new changes * Added some new convenient constructors to the SubGraph class * Added AddLayer method and a pointer to the parent graph to the SubGraph class * Updated the sub-graph unit tests to match the changes * Added SelectSubGraphs and ReplaceSubGraphConnections overloads that work with sub-graphs * Removed unused code and minor refactoring where necessary Change-Id: I46181794c6a9e3b10558944f804e06a8f693a6d0
Diffstat (limited to 'src/armnn/Graph.hpp')
-rw-r--r--src/armnn/Graph.hpp33
1 files changed, 18 insertions, 15 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp
index 8f93f56b4a..8046977411 100644
--- a/src/armnn/Graph.hpp
+++ b/src/armnn/Graph.hpp
@@ -6,7 +6,6 @@
#include "LayersFwd.hpp"
#include "IGraphObservable.hpp"
-#include "SubGraph.hpp"
#include <armnn/Types.hpp>
#include <armnn/TensorFwd.hpp>
@@ -25,21 +24,23 @@
namespace armnn
{
+class SubGraph;
+
class Graph
{
public:
- template <typename CVLayerT>
- static CVLayerT* PtrCast(Layer* const layer)
+ template <typename LayerType>
+ static LayerType* PtrCast(Layer* const layer)
{
- return boost::polymorphic_downcast<CVLayerT*>(layer);
+ return boost::polymorphic_downcast<LayerType*>(layer);
}
- using LayersList = std::list<Layer*>;
- using Iterator = LayersList::const_iterator; // Const so pointers in the list can't be modified externally.
- using ConstIterator = boost::transform_iterator<decltype(&PtrCast<const Layer>), Iterator>;
+ using LayerList = std::list<Layer*>;
+ using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
using IteratorDifference = Iterator::difference_type;
- using ConstIteratorInputs = boost::transform_iterator<decltype(&PtrCast<const InputLayer>), Iterator>;
+ using ConstIterator = boost::transform_iterator<decltype(&PtrCast<const Layer>), Iterator>;
+ using ConstIteratorInputs = boost::transform_iterator<decltype(&PtrCast<const InputLayer>), Iterator>;
using ConstIteratorOutputs = boost::transform_iterator<decltype(&PtrCast<const OutputLayer>), Iterator>;
/// Wrapper class returned by Graph::GetInputLayers()
@@ -49,13 +50,13 @@ public:
ConstIteratorInputs begin() const
{
- return { m_Graph.m_Layers.begin(), &PtrCast<const InputLayer> };
+ return { m_Graph.m_Layers.begin(), &(PtrCast<const InputLayer>) };
}
ConstIteratorInputs end() const
{
return { std::next(m_Graph.m_Layers.begin(), static_cast<IteratorDifference>(m_Graph.GetNumInputs())),
- &PtrCast<const InputLayer> };
+ &(PtrCast<const InputLayer>) };
}
const Graph& m_Graph;
@@ -69,12 +70,12 @@ public:
ConstIteratorOutputs begin() const
{
return { std::prev(m_Graph.m_Layers.end(), static_cast<IteratorDifference>(m_Graph.GetNumOutputs())),
- &PtrCast<const OutputLayer> };
+ &(PtrCast<const OutputLayer>) };
}
ConstIteratorOutputs end() const
{
- return { m_Graph.m_Layers.end(), &PtrCast<const OutputLayer> };
+ return { m_Graph.m_Layers.end(), &(PtrCast<const OutputLayer>) };
}
const Graph& m_Graph;
@@ -127,9 +128,9 @@ public:
Iterator end() { return m_Layers.end(); }
/// Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops.
- ConstIterator begin() const { return {m_Layers.begin(), &PtrCast<const Layer>}; }
+ ConstIterator begin() const { return {m_Layers.begin(), &(PtrCast<const Layer>)}; }
/// Returns const iterator pointing to the end of the list. Lowercase for range-based for loops.
- ConstIterator end() const { return {m_Layers.end(), &PtrCast<const Layer>}; }
+ ConstIterator end() const { return {m_Layers.end(), &(PtrCast<const Layer>)}; }
/// Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops.
ConstIterator cbegin() const { return begin(); }
@@ -161,6 +162,7 @@ public:
void AddCopyLayers();
void SubstituteSubGraph(std::unique_ptr<SubGraph> subGraph, IConnectableLayer* substituteLayer);
+ void SubstituteSubGraph(std::unique_ptr<SubGraph> subGraph, const SubGraph& substituteSubGraph);
void InferTensorInfos();
@@ -214,10 +216,11 @@ private:
std::unordered_map<const Layer*, Iterator> m_PosInGraphMap;
void ReplaceSubGraphConnections(const SubGraph& subGraph, IConnectableLayer* substituteLayer);
+ void ReplaceSubGraphConnections(const SubGraph& subGraph, const SubGraph& substituteSubGraph);
void EraseSubGraphLayers(const SubGraph &subGraph);
/// Mutable to allow sorting on const object.
- mutable LayersList m_Layers;
+ mutable LayerList m_Layers;
mutable bool m_LayersInOrder;
std::map<const GraphEvent, std::list<IGraphObservable*>> m_Views;