From 5200afa2742ad9cd1cda7fbce8604794c0616818 Mon Sep 17 00:00:00 2001 From: David Monahan Date: Fri, 10 May 2019 11:52:14 +0100 Subject: IVGCVSW-3034 Updates to SubstituteSubGraph and ReplaceSubgraphConnections to support Graphs instead of SubGraphViews * Added layer iteration function to SubgraphView similar to the Graph's one * Updated SubstituteSubgraph to reparent the layers to the calling graph Signed-off-by: David Monahan Change-Id: Ib2f8e70decca4a59c53ceb127e07ef5a430d1005 --- src/armnn/Graph.cpp | 11 +++++++++++ src/armnn/Graph.hpp | 6 +++--- src/armnn/SubgraphView.hpp | 11 +++++++++++ 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/armnn/Graph.cpp b/src/armnn/Graph.cpp index 8c2b232ead..31ca55cb9d 100644 --- a/src/armnn/Graph.cpp +++ b/src/armnn/Graph.cpp @@ -308,6 +308,17 @@ void Graph::SubstituteSubgraph(SubgraphView& subgraph, IConnectableLayer* substi void Graph::SubstituteSubgraph(SubgraphView& subgraph, const SubgraphView& substituteSubgraph) { + // Look through each layer in the new subgraph and add any that are not already a member of this graph + substituteSubgraph.ForEachLayer([this](Layer* layer) + { + if (std::find(std::begin(m_Layers), std::end(m_Layers), layer) == std::end(m_Layers)) + { + layer->Reparent(*this, m_Layers.end()); + m_LayersInOrder = false; + } + }); + + TopologicalSort(); ReplaceSubgraphConnections(subgraph, substituteSubgraph); EraseSubgraphLayers(subgraph); } diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp index c5b1b045a7..47e0e3b317 100644 --- a/src/armnn/Graph.hpp +++ b/src/armnn/Graph.hpp @@ -36,7 +36,7 @@ public: } template - void ForEachLayerInGraph(Func func) + void ForEachLayer(Func func) const { for (auto it = m_Layers.begin(); it != m_Layers.end(); ) { @@ -110,7 +110,7 @@ public: m_LayersInOrder = std::move(other.m_LayersInOrder); m_Views = std::move(other.m_Views); - other.ForEachLayerInGraph([this](Layer* otherLayer) + other.ForEachLayer([this](Layer* otherLayer) { otherLayer->Reparent(*this, m_Layers.end()); }); @@ -123,7 +123,7 @@ public: ~Graph() { - ForEachLayerInGraph([](Layer* layer) + ForEachLayer([](Layer* layer) { delete layer; }); diff --git a/src/armnn/SubgraphView.hpp b/src/armnn/SubgraphView.hpp index d86f1c1c93..f29e0a18ae 100644 --- a/src/armnn/SubgraphView.hpp +++ b/src/armnn/SubgraphView.hpp @@ -23,6 +23,17 @@ namespace armnn class SubgraphView final { public: + template + void ForEachLayer(Func func) const + { + for (auto it = m_Layers.begin(); it != m_Layers.end(); ) + { + auto next = std::next(it); + func(*it); + it = next; + } + } + using SubgraphViewPtr = std::unique_ptr; using InputSlots = std::vector; using OutputSlots = std::vector; -- cgit v1.2.1