diff options
Diffstat (limited to 'src/armnn/Graph.hpp')
-rw-r--r-- | src/armnn/Graph.hpp | 33 |
1 files changed, 18 insertions, 15 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp index 8f93f56b4a..8046977411 100644 --- a/src/armnn/Graph.hpp +++ b/src/armnn/Graph.hpp @@ -6,7 +6,6 @@ #include "LayersFwd.hpp" #include "IGraphObservable.hpp" -#include "SubGraph.hpp" #include <armnn/Types.hpp> #include <armnn/TensorFwd.hpp> @@ -25,21 +24,23 @@ namespace armnn { +class SubGraph; + class Graph { public: - template <typename CVLayerT> - static CVLayerT* PtrCast(Layer* const layer) + template <typename LayerType> + static LayerType* PtrCast(Layer* const layer) { - return boost::polymorphic_downcast<CVLayerT*>(layer); + return boost::polymorphic_downcast<LayerType*>(layer); } - using LayersList = std::list<Layer*>; - using Iterator = LayersList::const_iterator; // Const so pointers in the list can't be modified externally. - using ConstIterator = boost::transform_iterator<decltype(&PtrCast<const Layer>), Iterator>; + using LayerList = std::list<Layer*>; + using Iterator = LayerList::const_iterator; // Const so pointers in the list can't be modified externally. using IteratorDifference = Iterator::difference_type; - using ConstIteratorInputs = boost::transform_iterator<decltype(&PtrCast<const InputLayer>), Iterator>; + using ConstIterator = boost::transform_iterator<decltype(&PtrCast<const Layer>), Iterator>; + using ConstIteratorInputs = boost::transform_iterator<decltype(&PtrCast<const InputLayer>), Iterator>; using ConstIteratorOutputs = boost::transform_iterator<decltype(&PtrCast<const OutputLayer>), Iterator>; /// Wrapper class returned by Graph::GetInputLayers() @@ -49,13 +50,13 @@ public: ConstIteratorInputs begin() const { - return { m_Graph.m_Layers.begin(), &PtrCast<const InputLayer> }; + return { m_Graph.m_Layers.begin(), &(PtrCast<const InputLayer>) }; } ConstIteratorInputs end() const { return { std::next(m_Graph.m_Layers.begin(), static_cast<IteratorDifference>(m_Graph.GetNumInputs())), - &PtrCast<const InputLayer> }; + &(PtrCast<const InputLayer>) }; } const Graph& m_Graph; @@ -69,12 +70,12 @@ public: ConstIteratorOutputs begin() const { return { std::prev(m_Graph.m_Layers.end(), static_cast<IteratorDifference>(m_Graph.GetNumOutputs())), - &PtrCast<const OutputLayer> }; + &(PtrCast<const OutputLayer>) }; } ConstIteratorOutputs end() const { - return { m_Graph.m_Layers.end(), &PtrCast<const OutputLayer> }; + return { m_Graph.m_Layers.end(), &(PtrCast<const OutputLayer>) }; } const Graph& m_Graph; @@ -127,9 +128,9 @@ public: Iterator end() { return m_Layers.end(); } /// Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops. - ConstIterator begin() const { return {m_Layers.begin(), &PtrCast<const Layer>}; } + ConstIterator begin() const { return {m_Layers.begin(), &(PtrCast<const Layer>)}; } /// Returns const iterator pointing to the end of the list. Lowercase for range-based for loops. - ConstIterator end() const { return {m_Layers.end(), &PtrCast<const Layer>}; } + ConstIterator end() const { return {m_Layers.end(), &(PtrCast<const Layer>)}; } /// Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops. ConstIterator cbegin() const { return begin(); } @@ -161,6 +162,7 @@ public: void AddCopyLayers(); void SubstituteSubGraph(std::unique_ptr<SubGraph> subGraph, IConnectableLayer* substituteLayer); + void SubstituteSubGraph(std::unique_ptr<SubGraph> subGraph, const SubGraph& substituteSubGraph); void InferTensorInfos(); @@ -214,10 +216,11 @@ private: std::unordered_map<const Layer*, Iterator> m_PosInGraphMap; void ReplaceSubGraphConnections(const SubGraph& subGraph, IConnectableLayer* substituteLayer); + void ReplaceSubGraphConnections(const SubGraph& subGraph, const SubGraph& substituteSubGraph); void EraseSubGraphLayers(const SubGraph &subGraph); /// Mutable to allow sorting on const object. - mutable LayersList m_Layers; + mutable LayerList m_Layers; mutable bool m_LayersInOrder; std::map<const GraphEvent, std::list<IGraphObservable*>> m_Views; |