diff options
Diffstat (limited to 'src/armnn/SubGraph.hpp')
-rw-r--r-- | src/armnn/SubGraph.hpp | 72 |
1 files changed, 57 insertions, 15 deletions
diff --git a/src/armnn/SubGraph.hpp b/src/armnn/SubGraph.hpp index d22377daff..81166f1285 100644 --- a/src/armnn/SubGraph.hpp +++ b/src/armnn/SubGraph.hpp @@ -6,6 +6,7 @@ #pragma once #include "Layer.hpp" +#include "Graph.hpp" #include <vector> #include <list> @@ -22,18 +23,46 @@ namespace armnn class SubGraph final { public: - using InputSlots = std::vector<InputSlot *>; - using OutputSlots = std::vector<OutputSlot *>; + using SubGraphPtr = std::unique_ptr<SubGraph>; + using InputSlots = std::vector<InputSlot*>; + using OutputSlots = std::vector<OutputSlot*>; using Layers = std::list<Layer*>; + using Iterator = Layers::iterator; + using ConstIterator = Layers::const_iterator; - SubGraph(); - SubGraph(InputSlots && inputs, - OutputSlots && outputs, - Layers && layers); + /// Empty subgraphs are not allowed, they must at least have a parent graph. + SubGraph() = delete; - const InputSlots & GetInputSlots() const; - const OutputSlots & GetOutputSlots() const; - const Layers & GetLayers() const; + /// Constructs a sub-graph from the entire given graph. + SubGraph(Graph& graph); + + /// Constructs a sub-graph with the given arguments and binds it to the specified parent graph. + SubGraph(Graph* parentGraph, InputSlots&& inputs, OutputSlots&& outputs, Layers&& layers); + + /// Constructs a sub-graph with the given arguments and uses the specified sub-graph to get a reference + /// to the parent graph. + SubGraph(const SubGraph& referenceSubGraph, InputSlots&& inputs, OutputSlots&& outputs, Layers&& layers); + + /// Copy-constructor. + SubGraph(const SubGraph& subGraph); + + /// Move-constructor. + SubGraph(SubGraph&& subGraph); + + /// Constructs a sub-graph with only the given layer and uses the specified sub-graph to get a reference + /// to the parent graph. + SubGraph(const SubGraph& referenceSubGraph, IConnectableLayer* layer); + + /// Updates this sub-graph with the contents of the whole given graph. + void Update(Graph& graph); + + /// Adds a new layer, of type LayerType, to the graph this sub-graph is a view of. + template <typename LayerT, typename... Args> + LayerT* AddLayer(Args&&... args) const; + + const InputSlots& GetInputSlots() const; + const OutputSlots& GetOutputSlots() const; + const Layers& GetLayers() const; const InputSlot* GetInputSlot(unsigned int index) const; InputSlot* GetInputSlot(unsigned int index); @@ -44,19 +73,32 @@ public: unsigned int GetNumInputSlots() const; unsigned int GetNumOutputSlots() const; - Layers::iterator begin(); - Layers::iterator end(); + Iterator begin(); + Iterator end(); - Layers::const_iterator begin() const; - Layers::const_iterator end() const; + ConstIterator begin() const; + ConstIterator end() const; - Layers::const_iterator cbegin() const; - Layers::const_iterator cend() const; + ConstIterator cbegin() const; + ConstIterator cend() const; private: + void CheckSubGraph(); + InputSlots m_InputSlots; OutputSlots m_OutputSlots; Layers m_Layers; + + /// Pointer to the graph this sub-graph is a view of. + Graph* m_ParentGraph; }; +template <typename LayerT, typename... Args> +LayerT* SubGraph::AddLayer(Args&&... args) const +{ + BOOST_ASSERT(m_ParentGraph); + + return m_ParentGraph->AddLayer<LayerT>(args...); +} + } // namespace armnn |