aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/SubGraph.hpp
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-01-24 14:06:23 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-01-30 14:03:28 +0000
commitadddddb6cbcb777d92a8c464c9ad0cb9aecc76a3 (patch)
treeb15de32bf9f8612f66e1ae23d2f8009e80e7d0e6 /src/armnn/SubGraph.hpp
parentd089b74bebbcc8518fb0f4eacb7e6569ae170199 (diff)
downloadarmnn-adddddb6cbcb777d92a8c464c9ad0cb9aecc76a3.tar.gz
IVGCVSW-2458 Refactor the Optimize function (Network.cpp) so that
subgraphs are optimized by the backends * Added a new method OptimizeSubGraph to the backend interface * Refactored the Optimize function so that the backend-specific optimization is performed by the backend itself (through the new OptimizeSubGraph interface method) * Added a new ApplyBackendOptimizations function to apply the new changes * Added some new convenient constructors to the SubGraph class * Added AddLayer method and a pointer to the parent graph to the SubGraph class * Updated the sub-graph unit tests to match the changes * Added SelectSubGraphs and ReplaceSubGraphConnections overloads that work with sub-graphs * Removed unused code and minor refactoring where necessary Change-Id: I46181794c6a9e3b10558944f804e06a8f693a6d0
Diffstat (limited to 'src/armnn/SubGraph.hpp')
-rw-r--r--src/armnn/SubGraph.hpp72
1 files changed, 57 insertions, 15 deletions
diff --git a/src/armnn/SubGraph.hpp b/src/armnn/SubGraph.hpp
index d22377daff..81166f1285 100644
--- a/src/armnn/SubGraph.hpp
+++ b/src/armnn/SubGraph.hpp
@@ -6,6 +6,7 @@
#pragma once
#include "Layer.hpp"
+#include "Graph.hpp"
#include <vector>
#include <list>
@@ -22,18 +23,46 @@ namespace armnn
class SubGraph final
{
public:
- using InputSlots = std::vector<InputSlot *>;
- using OutputSlots = std::vector<OutputSlot *>;
+ using SubGraphPtr = std::unique_ptr<SubGraph>;
+ using InputSlots = std::vector<InputSlot*>;
+ using OutputSlots = std::vector<OutputSlot*>;
using Layers = std::list<Layer*>;
+ using Iterator = Layers::iterator;
+ using ConstIterator = Layers::const_iterator;
- SubGraph();
- SubGraph(InputSlots && inputs,
- OutputSlots && outputs,
- Layers && layers);
+ /// Empty subgraphs are not allowed, they must at least have a parent graph.
+ SubGraph() = delete;
- const InputSlots & GetInputSlots() const;
- const OutputSlots & GetOutputSlots() const;
- const Layers & GetLayers() const;
+ /// Constructs a sub-graph from the entire given graph.
+ SubGraph(Graph& graph);
+
+ /// Constructs a sub-graph with the given arguments and binds it to the specified parent graph.
+ SubGraph(Graph* parentGraph, InputSlots&& inputs, OutputSlots&& outputs, Layers&& layers);
+
+ /// Constructs a sub-graph with the given arguments and uses the specified sub-graph to get a reference
+ /// to the parent graph.
+ SubGraph(const SubGraph& referenceSubGraph, InputSlots&& inputs, OutputSlots&& outputs, Layers&& layers);
+
+ /// Copy-constructor.
+ SubGraph(const SubGraph& subGraph);
+
+ /// Move-constructor.
+ SubGraph(SubGraph&& subGraph);
+
+ /// Constructs a sub-graph with only the given layer and uses the specified sub-graph to get a reference
+ /// to the parent graph.
+ SubGraph(const SubGraph& referenceSubGraph, IConnectableLayer* layer);
+
+ /// Updates this sub-graph with the contents of the whole given graph.
+ void Update(Graph& graph);
+
+ /// Adds a new layer, of type LayerType, to the graph this sub-graph is a view of.
+ template <typename LayerT, typename... Args>
+ LayerT* AddLayer(Args&&... args) const;
+
+ const InputSlots& GetInputSlots() const;
+ const OutputSlots& GetOutputSlots() const;
+ const Layers& GetLayers() const;
const InputSlot* GetInputSlot(unsigned int index) const;
InputSlot* GetInputSlot(unsigned int index);
@@ -44,19 +73,32 @@ public:
unsigned int GetNumInputSlots() const;
unsigned int GetNumOutputSlots() const;
- Layers::iterator begin();
- Layers::iterator end();
+ Iterator begin();
+ Iterator end();
- Layers::const_iterator begin() const;
- Layers::const_iterator end() const;
+ ConstIterator begin() const;
+ ConstIterator end() const;
- Layers::const_iterator cbegin() const;
- Layers::const_iterator cend() const;
+ ConstIterator cbegin() const;
+ ConstIterator cend() const;
private:
+ void CheckSubGraph();
+
InputSlots m_InputSlots;
OutputSlots m_OutputSlots;
Layers m_Layers;
+
+ /// Pointer to the graph this sub-graph is a view of.
+ Graph* m_ParentGraph;
};
+template <typename LayerT, typename... Args>
+LayerT* SubGraph::AddLayer(Args&&... args) const
+{
+ BOOST_ASSERT(m_ParentGraph);
+
+ return m_ParentGraph->AddLayer<LayerT>(args...);
+}
+
} // namespace armnn