aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Graph.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Graph.hpp')
-rw-r--r--src/armnn/Graph.hpp33
1 files changed, 18 insertions, 15 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp
index 8f93f56b4a..8046977411 100644
--- a/src/armnn/Graph.hpp
+++ b/src/armnn/Graph.hpp
@@ -6,7 +6,6 @@
#include "LayersFwd.hpp"
#include "IGraphObservable.hpp"
-#include "SubGraph.hpp"
#include <armnn/Types.hpp>
#include <armnn/TensorFwd.hpp>
@@ -25,21 +24,23 @@
namespace armnn
{
+class SubGraph;
+
class Graph
{
public:
- template <typename CVLayerT>
- static CVLayerT* PtrCast(Layer* const layer)
+ template <typename LayerType>
+ static LayerType* PtrCast(Layer* const layer)
{
- return boost::polymorphic_downcast<CVLayerT*>(layer);
+ return boost::polymorphic_downcast<LayerType*>(layer);
}
- using LayersList = std::list<Layer*>;
- using Iterator = LayersList::const_iterator; // Const so pointers in the list can't be modified externally.
- using ConstIterator = boost::transform_iterator<decltype(&PtrCast<const Layer>), Iterator>;
+ using LayerList = std::list<Layer*>;
+ using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally.
using IteratorDifference = Iterator::difference_type;
- using ConstIteratorInputs = boost::transform_iterator<decltype(&PtrCast<const InputLayer>), Iterator>;
+ using ConstIterator = boost::transform_iterator<decltype(&PtrCast<const Layer>), Iterator>;
+ using ConstIteratorInputs = boost::transform_iterator<decltype(&PtrCast<const InputLayer>), Iterator>;
using ConstIteratorOutputs = boost::transform_iterator<decltype(&PtrCast<const OutputLayer>), Iterator>;
/// Wrapper class returned by Graph::GetInputLayers()
@@ -49,13 +50,13 @@ public:
ConstIteratorInputs begin() const
{
- return { m_Graph.m_Layers.begin(), &PtrCast<const InputLayer> };
+ return { m_Graph.m_Layers.begin(), &(PtrCast<const InputLayer>) };
}
ConstIteratorInputs end() const
{
return { std::next(m_Graph.m_Layers.begin(), static_cast<IteratorDifference>(m_Graph.GetNumInputs())),
- &PtrCast<const InputLayer> };
+ &(PtrCast<const InputLayer>) };
}
const Graph& m_Graph;
@@ -69,12 +70,12 @@ public:
ConstIteratorOutputs begin() const
{
return { std::prev(m_Graph.m_Layers.end(), static_cast<IteratorDifference>(m_Graph.GetNumOutputs())),
- &PtrCast<const OutputLayer> };
+ &(PtrCast<const OutputLayer>) };
}
ConstIteratorOutputs end() const
{
- return { m_Graph.m_Layers.end(), &PtrCast<const OutputLayer> };
+ return { m_Graph.m_Layers.end(), &(PtrCast<const OutputLayer>) };
}
const Graph& m_Graph;
@@ -127,9 +128,9 @@ public:
Iterator end() { return m_Layers.end(); }
/// Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops.
- ConstIterator begin() const { return {m_Layers.begin(), &PtrCast<const Layer>}; }
+ ConstIterator begin() const { return {m_Layers.begin(), &(PtrCast<const Layer>)}; }
/// Returns const iterator pointing to the end of the list. Lowercase for range-based for loops.
- ConstIterator end() const { return {m_Layers.end(), &PtrCast<const Layer>}; }
+ ConstIterator end() const { return {m_Layers.end(), &(PtrCast<const Layer>)}; }
/// Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops.
ConstIterator cbegin() const { return begin(); }
@@ -161,6 +162,7 @@ public:
void AddCopyLayers();
void SubstituteSubGraph(std::unique_ptr<SubGraph> subGraph, IConnectableLayer* substituteLayer);
+ void SubstituteSubGraph(std::unique_ptr<SubGraph> subGraph, const SubGraph& substituteSubGraph);
void InferTensorInfos();
@@ -214,10 +216,11 @@ private:
std::unordered_map<const Layer*, Iterator> m_PosInGraphMap;
void ReplaceSubGraphConnections(const SubGraph& subGraph, IConnectableLayer* substituteLayer);
+ void ReplaceSubGraphConnections(const SubGraph& subGraph, const SubGraph& substituteSubGraph);
void EraseSubGraphLayers(const SubGraph &subGraph);
/// Mutable to allow sorting on const object.
- mutable LayersList m_Layers;
+ mutable LayerList m_Layers;
mutable bool m_LayersInOrder;
std::map<const GraphEvent, std::list<IGraphObservable*>> m_Views;