aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/armnn/Graph.cpp11
-rw-r--r--src/armnn/Graph.hpp6
-rw-r--r--src/armnn/SubgraphView.hpp11
3 files changed, 25 insertions, 3 deletions
diff --git a/src/armnn/Graph.cpp b/src/armnn/Graph.cpp
index 8c2b232ead..31ca55cb9d 100644
--- a/src/armnn/Graph.cpp
+++ b/src/armnn/Graph.cpp
@@ -308,6 +308,17 @@ void Graph::SubstituteSubgraph(SubgraphView& subgraph, IConnectableLayer* substi
void Graph::SubstituteSubgraph(SubgraphView& subgraph, const SubgraphView& substituteSubgraph)
{
+ // Look through each layer in the new subgraph and add any that are not already a member of this graph
+ substituteSubgraph.ForEachLayer([this](Layer* layer)
+ {
+ if (std::find(std::begin(m_Layers), std::end(m_Layers), layer) == std::end(m_Layers))
+ {
+ layer->Reparent(*this, m_Layers.end());
+ m_LayersInOrder = false;
+ }
+ });
+
+ TopologicalSort();
ReplaceSubgraphConnections(subgraph, substituteSubgraph);
EraseSubgraphLayers(subgraph);
}
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp
index c5b1b045a7..47e0e3b317 100644
--- a/src/armnn/Graph.hpp
+++ b/src/armnn/Graph.hpp
@@ -36,7 +36,7 @@ public:
}
template <typename Func>
- void ForEachLayerInGraph(Func func)
+ void ForEachLayer(Func func) const
{
for (auto it = m_Layers.begin(); it != m_Layers.end(); )
{
@@ -110,7 +110,7 @@ public:
m_LayersInOrder = std::move(other.m_LayersInOrder);
m_Views = std::move(other.m_Views);
- other.ForEachLayerInGraph([this](Layer* otherLayer)
+ other.ForEachLayer([this](Layer* otherLayer)
{
otherLayer->Reparent(*this, m_Layers.end());
});
@@ -123,7 +123,7 @@ public:
~Graph()
{
- ForEachLayerInGraph([](Layer* layer)
+ ForEachLayer([](Layer* layer)
{
delete layer;
});
diff --git a/src/armnn/SubgraphView.hpp b/src/armnn/SubgraphView.hpp
index d86f1c1c93..f29e0a18ae 100644
--- a/src/armnn/SubgraphView.hpp
+++ b/src/armnn/SubgraphView.hpp
@@ -23,6 +23,17 @@ namespace armnn
class SubgraphView final
{
public:
+ template <typename Func>
+ void ForEachLayer(Func func) const
+ {
+ for (auto it = m_Layers.begin(); it != m_Layers.end(); )
+ {
+ auto next = std::next(it);
+ func(*it);
+ it = next;
+ }
+ }
+
using SubgraphViewPtr = std::unique_ptr<SubgraphView>;
using InputSlots = std::vector<InputSlot*>;
using OutputSlots = std::vector<OutputSlot*>;