aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Graph.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Graph.hpp')
-rw-r--r--src/armnn/Graph.hpp87
1 files changed, 70 insertions, 17 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp
index 06b6fd32ae..fd81e51b7b 100644
--- a/src/armnn/Graph.hpp
+++ b/src/armnn/Graph.hpp
@@ -5,6 +5,7 @@
#pragma once
#include "LayersFwd.hpp"
+#include "IGraphObservable.hpp"
#include <armnn/Types.hpp>
#include <armnn/TensorFwd.hpp>
@@ -12,6 +13,7 @@
#include <armnn/Exceptions.hpp>
#include <list>
+#include <map>
#include <unordered_map>
#include <unordered_set>
#include <vector>
@@ -21,6 +23,7 @@
namespace armnn
{
+
class Graph
{
public:
@@ -31,7 +34,7 @@ public:
}
using LayersList = std::list<Layer*>;
- using Iterator = LayersList::const_iterator; // const so pointers in the list can't be modified externally
+ using Iterator = LayersList::const_iterator; // Const so pointers in the list can't be modified externally.
using ConstIterator = boost::transform_iterator<decltype(&PtrCast<const Layer>), Iterator>;
using IteratorDifference = Iterator::difference_type;
@@ -94,7 +97,7 @@ public:
Status SerializeToDot(std::ostream& stream);
- /// Adds a new layer of type LaterType to the graph constructed with the arguments passed.
+ /// Adds a new layer, of type LayerType, to the graph constructed with the arguments passed.
template <typename LayerT, typename... Args>
LayerT* AddLayer(Args&&... args);
@@ -103,6 +106,10 @@ public:
template <typename LayerT, typename... Args>
LayerT* InsertNewLayer(InputSlot& insertBefore, Args&&... args);
+ /// Inserts a new layer between insertAfter and the input slot(s) currently connected to it
+ template <typename LayerT, typename... Args>
+ LayerT* InsertNewLayer(OutputSlot& insertAfter, Args&&... args);
+
/// Deletes the layer at the specified position and returns an iterator pointing
/// to the next element after the one being deleted.
Iterator EraseLayer(Iterator pos);
@@ -113,22 +120,22 @@ public:
template <typename LayerT>
Iterator EraseLayer(LayerT*& layer);
- /// Return iterator pointing to begin of list. Lowercase for range-based for loops.
+ /// Returns iterator pointing to the beginning of the list. Lowercase for range-based for loops.
Iterator begin() { return m_Layers.begin(); }
- /// Return iterator pointing to end of list. Lowercase for range-based for loops.
+ /// Returns iterator pointing to the end of the list. Lowercase for range-based for loops.
Iterator end() { return m_Layers.end(); }
- /// Return const iterator pointing to begin of list. Lowercase for range-based for loops.
+ /// Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops.
ConstIterator begin() const { return {m_Layers.begin(), &PtrCast<const Layer>}; }
- /// Return const iterator pointing to end of list. Lowercase for range-based for loops.
+ /// Returns const iterator pointing to the end of the list. Lowercase for range-based for loops.
ConstIterator end() const { return {m_Layers.end(), &PtrCast<const Layer>}; }
- /// Return const iterator pointing to begin of list. Lowercase for range-based for loops.
+ /// Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops.
ConstIterator cbegin() const { return begin(); }
- /// Return const iterator pointing to end of list. Lowercase for range-based for loops.
+ /// Returns const iterator pointing to the end of the list. Lowercase for range-based for loops.
ConstIterator cend() const { return end(); }
- /// Sort layers in topological order and return this.
+ /// Sorts layers in topological order and return this.
Graph& TopologicalSort() { const_cast<const Graph*>(this)->TopologicalSort(); return *this; }
const Graph& TopologicalSort() const;
@@ -136,16 +143,16 @@ public:
size_t GetNumOutputs() const { return m_OutputIds.size(); }
/// Returns a wrapper object with begin(), end() methods to iterate over the input layers
- /// in a range-based for loop
+ /// in a range-based for loop.
InputLayersAccessor GetInputLayers() const { return InputLayersAccessor(*this); }
/// Returns a wrapper object with begin(), end() methods to iterate over the output layers
- /// in a range-based for loop
+ /// in a range-based for loop.
OutputLayersAccessor GetOutputLayers() const { return OutputLayersAccessor(*this); }
size_t GetNumLayers() const { return m_Layers.size(); }
- /// Allocate memory for all tensors under output tensor handers of each layer
+ /// Allocates memory for all tensors under output tensor handers of each layer.
Status AllocateDynamicBuffers();
/// Modifies the graph in-place, removing edges connecting layers using different compute devices,
@@ -154,6 +161,14 @@ public:
void InferTensorInfos();
+ void AttachObservable(IGraphObservable* const observable, GraphEvent notifyOnEvent) {
+ m_Views[notifyOnEvent].emplace_back(observable);
+ }
+
+ void DetachObservable(IGraphObservable* const observable, GraphEvent notifyOnEvent) {
+ m_Views[notifyOnEvent].remove(observable);
+ }
+
private:
template <typename LayerT>
class LayerInGraphBase;
@@ -179,9 +194,18 @@ private:
return it;
}
- /// Get the position of a layer in the graph.
+ /// Gets the position of a layer in the graph.
Iterator GetPosInGraph(Layer& layer);
+ void NotifyObservables(GraphEvent event, Layer* graphState)
+ {
+ // Iterate over all observables observing this event
+ for (auto& observable : m_Views[event])
+ {
+ observable->Update(graphState);
+ }
+ }
+
std::unordered_set<LayerBindingId> m_InputIds;
std::unordered_set<LayerBindingId> m_OutputIds;
std::unordered_map<const Layer*, Iterator> m_PosInGraphMap;
@@ -189,9 +213,11 @@ private:
/// Mutable to allow sorting on const object.
mutable LayersList m_Layers;
mutable bool m_LayersInOrder;
+
+ std::map<const GraphEvent, std::list<IGraphObservable*>> m_Views;
};
-/// Common base class for layers in the graph
+/// Common base class for layers in the graph.
template <typename LayerT>
class Graph::LayerInGraphBase : public LayerT
{
@@ -212,7 +238,7 @@ protected:
Graph& m_Graph;
};
-/// Input/Output layers specialize this template
+/// Input/Output layers specialize this template.
template <typename LayerT>
class Graph::LayerInGraph final : public LayerInGraphBase<LayerT>
{
@@ -305,24 +331,51 @@ inline LayerT* Graph::AddLayer(Args&&... args)
{
m_LayersInOrder = m_LayersInOrder &&
((LayerEnumOf<LayerT>() == LayerType::Input) || (LayerEnumOf<LayerT>() == LayerType::Output));
- return new LayerInGraph<LayerT>(*this, std::forward<Args>(args)...);
+ LayerT* const layer = new LayerInGraph<LayerT>(*this, std::forward<Args>(args)...);
+
+ NotifyObservables(GraphEvent::LayerAdded, layer);
+
+ return layer;
}
template <typename LayerT, typename... Args>
inline LayerT* Graph::InsertNewLayer(InputSlot& insertBefore, Args&&... args)
{
- // Insert after the parent if any, or before the child otherwise, so topological order is kept.
+ // Insert after the parent if any, or before the child otherwise, so the topological order is kept.
OutputSlot* parentOut = insertBefore.GetConnectedOutputSlot();
const Iterator pos = (parentOut != nullptr)
? std::next(GetPosInGraph(parentOut->GetOwningLayer()))
: GetPosInGraph(insertBefore.GetOwningLayer());
LayerT* const layer = new LayerInGraph<LayerT>(*this, pos, std::forward<Args>(args)...);
insertBefore.Insert(*layer);
+
+ NotifyObservables(GraphEvent::LayerAdded, layer);
+
+ return layer;
+}
+
+template <typename LayerT, typename... Args>
+inline LayerT* Graph::InsertNewLayer(OutputSlot& insertAfter, Args&&... args)
+{
+ Layer& owningLayer = insertAfter.GetOwningLayer();
+
+ const Iterator pos = std::next(GetPosInGraph(owningLayer));
+ LayerT* const layer = new LayerInGraph<LayerT>(*this, pos, std::forward<Args>(args)...);
+
+ BOOST_ASSERT(layer->GetNumInputSlots() == 1);
+
+ insertAfter.MoveAllConnections(layer->GetOutputSlot());
+ insertAfter.Connect(layer->GetInputSlot(0));
+
+ NotifyObservables(GraphEvent::LayerAdded, layer);
+
return layer;
}
inline Graph::Iterator Graph::EraseLayer(Iterator pos)
{
+ NotifyObservables(GraphEvent::LayerErased, *pos);
+
delete *pos;
return m_Layers.erase(pos);
}