diff options
Diffstat (limited to 'src/armnn/Graph.hpp')
-rw-r--r-- | src/armnn/Graph.hpp | 87 |
1 files changed, 70 insertions, 17 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp index 06b6fd32ae..fd81e51b7b 100644 --- a/src/armnn/Graph.hpp +++ b/src/armnn/Graph.hpp @@ -5,6 +5,7 @@ #pragma once #include "LayersFwd.hpp" +#include "IGraphObservable.hpp" #include <armnn/Types.hpp> #include <armnn/TensorFwd.hpp> @@ -12,6 +13,7 @@ #include <armnn/Exceptions.hpp> #include <list> +#include <map> #include <unordered_map> #include <unordered_set> #include <vector> @@ -21,6 +23,7 @@ namespace armnn { + class Graph { public: @@ -31,7 +34,7 @@ public: } using LayersList = std::list<Layer*>; - using Iterator = LayersList::const_iterator; // const so pointers in the list can't be modified externally + using Iterator = LayersList::const_iterator; // Const so pointers in the list can't be modified externally. using ConstIterator = boost::transform_iterator<decltype(&PtrCast<const Layer>), Iterator>; using IteratorDifference = Iterator::difference_type; @@ -94,7 +97,7 @@ public: Status SerializeToDot(std::ostream& stream); - /// Adds a new layer of type LaterType to the graph constructed with the arguments passed. + /// Adds a new layer, of type LayerType, to the graph constructed with the arguments passed. template <typename LayerT, typename... Args> LayerT* AddLayer(Args&&... args); @@ -103,6 +106,10 @@ public: template <typename LayerT, typename... Args> LayerT* InsertNewLayer(InputSlot& insertBefore, Args&&... args); + /// Inserts a new layer between insertAfter and the input slot(s) currently connected to it + template <typename LayerT, typename... Args> + LayerT* InsertNewLayer(OutputSlot& insertAfter, Args&&... args); + /// Deletes the layer at the specified position and returns an iterator pointing /// to the next element after the one being deleted. Iterator EraseLayer(Iterator pos); @@ -113,22 +120,22 @@ public: template <typename LayerT> Iterator EraseLayer(LayerT*& layer); - /// Return iterator pointing to begin of list. Lowercase for range-based for loops. + /// Returns iterator pointing to the beginning of the list. Lowercase for range-based for loops. Iterator begin() { return m_Layers.begin(); } - /// Return iterator pointing to end of list. Lowercase for range-based for loops. + /// Returns iterator pointing to the end of the list. Lowercase for range-based for loops. Iterator end() { return m_Layers.end(); } - /// Return const iterator pointing to begin of list. Lowercase for range-based for loops. + /// Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops. ConstIterator begin() const { return {m_Layers.begin(), &PtrCast<const Layer>}; } - /// Return const iterator pointing to end of list. Lowercase for range-based for loops. + /// Returns const iterator pointing to the end of the list. Lowercase for range-based for loops. ConstIterator end() const { return {m_Layers.end(), &PtrCast<const Layer>}; } - /// Return const iterator pointing to begin of list. Lowercase for range-based for loops. + /// Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops. ConstIterator cbegin() const { return begin(); } - /// Return const iterator pointing to end of list. Lowercase for range-based for loops. + /// Returns const iterator pointing to the end of the list. Lowercase for range-based for loops. ConstIterator cend() const { return end(); } - /// Sort layers in topological order and return this. + /// Sorts layers in topological order and return this. Graph& TopologicalSort() { const_cast<const Graph*>(this)->TopologicalSort(); return *this; } const Graph& TopologicalSort() const; @@ -136,16 +143,16 @@ public: size_t GetNumOutputs() const { return m_OutputIds.size(); } /// Returns a wrapper object with begin(), end() methods to iterate over the input layers - /// in a range-based for loop + /// in a range-based for loop. InputLayersAccessor GetInputLayers() const { return InputLayersAccessor(*this); } /// Returns a wrapper object with begin(), end() methods to iterate over the output layers - /// in a range-based for loop + /// in a range-based for loop. OutputLayersAccessor GetOutputLayers() const { return OutputLayersAccessor(*this); } size_t GetNumLayers() const { return m_Layers.size(); } - /// Allocate memory for all tensors under output tensor handers of each layer + /// Allocates memory for all tensors under output tensor handers of each layer. Status AllocateDynamicBuffers(); /// Modifies the graph in-place, removing edges connecting layers using different compute devices, @@ -154,6 +161,14 @@ public: void InferTensorInfos(); + void AttachObservable(IGraphObservable* const observable, GraphEvent notifyOnEvent) { + m_Views[notifyOnEvent].emplace_back(observable); + } + + void DetachObservable(IGraphObservable* const observable, GraphEvent notifyOnEvent) { + m_Views[notifyOnEvent].remove(observable); + } + private: template <typename LayerT> class LayerInGraphBase; @@ -179,9 +194,18 @@ private: return it; } - /// Get the position of a layer in the graph. + /// Gets the position of a layer in the graph. Iterator GetPosInGraph(Layer& layer); + void NotifyObservables(GraphEvent event, Layer* graphState) + { + // Iterate over all observables observing this event + for (auto& observable : m_Views[event]) + { + observable->Update(graphState); + } + } + std::unordered_set<LayerBindingId> m_InputIds; std::unordered_set<LayerBindingId> m_OutputIds; std::unordered_map<const Layer*, Iterator> m_PosInGraphMap; @@ -189,9 +213,11 @@ private: /// Mutable to allow sorting on const object. mutable LayersList m_Layers; mutable bool m_LayersInOrder; + + std::map<const GraphEvent, std::list<IGraphObservable*>> m_Views; }; -/// Common base class for layers in the graph +/// Common base class for layers in the graph. template <typename LayerT> class Graph::LayerInGraphBase : public LayerT { @@ -212,7 +238,7 @@ protected: Graph& m_Graph; }; -/// Input/Output layers specialize this template +/// Input/Output layers specialize this template. template <typename LayerT> class Graph::LayerInGraph final : public LayerInGraphBase<LayerT> { @@ -305,24 +331,51 @@ inline LayerT* Graph::AddLayer(Args&&... args) { m_LayersInOrder = m_LayersInOrder && ((LayerEnumOf<LayerT>() == LayerType::Input) || (LayerEnumOf<LayerT>() == LayerType::Output)); - return new LayerInGraph<LayerT>(*this, std::forward<Args>(args)...); + LayerT* const layer = new LayerInGraph<LayerT>(*this, std::forward<Args>(args)...); + + NotifyObservables(GraphEvent::LayerAdded, layer); + + return layer; } template <typename LayerT, typename... Args> inline LayerT* Graph::InsertNewLayer(InputSlot& insertBefore, Args&&... args) { - // Insert after the parent if any, or before the child otherwise, so topological order is kept. + // Insert after the parent if any, or before the child otherwise, so the topological order is kept. OutputSlot* parentOut = insertBefore.GetConnectedOutputSlot(); const Iterator pos = (parentOut != nullptr) ? std::next(GetPosInGraph(parentOut->GetOwningLayer())) : GetPosInGraph(insertBefore.GetOwningLayer()); LayerT* const layer = new LayerInGraph<LayerT>(*this, pos, std::forward<Args>(args)...); insertBefore.Insert(*layer); + + NotifyObservables(GraphEvent::LayerAdded, layer); + + return layer; +} + +template <typename LayerT, typename... Args> +inline LayerT* Graph::InsertNewLayer(OutputSlot& insertAfter, Args&&... args) +{ + Layer& owningLayer = insertAfter.GetOwningLayer(); + + const Iterator pos = std::next(GetPosInGraph(owningLayer)); + LayerT* const layer = new LayerInGraph<LayerT>(*this, pos, std::forward<Args>(args)...); + + BOOST_ASSERT(layer->GetNumInputSlots() == 1); + + insertAfter.MoveAllConnections(layer->GetOutputSlot()); + insertAfter.Connect(layer->GetInputSlot(0)); + + NotifyObservables(GraphEvent::LayerAdded, layer); + return layer; } inline Graph::Iterator Graph::EraseLayer(Iterator pos) { + NotifyObservables(GraphEvent::LayerErased, *pos); + delete *pos; return m_Layers.erase(pos); } |