aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/SubgraphView.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/SubgraphView.cpp')
-rw-r--r--src/armnn/SubgraphView.cpp41
1 files changed, 39 insertions, 2 deletions
diff --git a/src/armnn/SubgraphView.cpp b/src/armnn/SubgraphView.cpp
index 804ff731fb..b48529c523 100644
--- a/src/armnn/SubgraphView.cpp
+++ b/src/armnn/SubgraphView.cpp
@@ -525,6 +525,44 @@ void SubgraphView::SubstituteSubgraph(SubgraphView& subgraph, IConnectableLayer*
SubstituteSubgraph(subgraph, substituteSubgraph);
}
+void SubgraphView::UpdateSubgraphViewSlotPointers(SubgraphView& patternSubgraph,
+ const SubgraphView& substituteSubgraph)
+{
+ std::vector<IInputSlot*>::iterator inputSlotPosition;
+ // search for and erase any InputSlots that appear in the WorkingCopy that match those in the PatternSubgraph
+ for (auto slot : patternSubgraph.GetIInputSlots())
+ {
+ inputSlotPosition = std::find(m_IInputSlots.begin(), m_IInputSlots.end(), slot);
+ if (inputSlotPosition != m_IInputSlots.end())
+ {
+ m_IInputSlots.erase(inputSlotPosition);
+ }
+ }
+
+ std::vector<IOutputSlot*>::iterator outputSlotPosition;
+ // search for and erase any OutputSlots that appear in the WorkingCopy that match those in the PatternSubgraph
+ for (auto slot : patternSubgraph.GetIOutputSlots())
+ {
+ outputSlotPosition = std::find(m_IOutputSlots.begin(), m_IOutputSlots.end(), slot);
+ if (outputSlotPosition != m_IOutputSlots.end())
+ {
+ m_IOutputSlots.erase(outputSlotPosition);
+ }
+ }
+
+ // append InputSlots from the SubstituteSubgraph to the WorkingCopy m_IInputSlots vector variable
+ // at the position in the vector where PatternSubgraph InputSlots were last removed
+ m_IInputSlots.insert(inputSlotPosition,
+ std::make_move_iterator(substituteSubgraph.m_IInputSlots.begin()),
+ std::make_move_iterator(substituteSubgraph.m_IInputSlots.end()));
+
+ // append OutputSlots from the SubstituteSubgraph to the WorkingCopy m_IOutputSlots vector variable
+ // at the position in the vector where PatternSubgraph OutputSlots were last removed
+ m_IOutputSlots.insert(outputSlotPosition,
+ std::make_move_iterator(substituteSubgraph.m_OutputSlots.begin()),
+ std::make_move_iterator(substituteSubgraph.m_OutputSlots.end()));
+}
+
void SubgraphView::SubstituteSubgraph(SubgraphView& patternSubgraph, const SubgraphView& substituteSubgraph)
{
if (!p_WorkingCopyImpl)
@@ -556,8 +594,7 @@ void SubgraphView::SubstituteSubgraph(SubgraphView& patternSubgraph, const Subgr
workingCopyGraph->ReplaceSubgraphConnections(patternSubgraph, substituteSubgraph);
// Update input/outputSlot pointers
- m_IInputSlots = std::move(substituteSubgraph.m_IInputSlots);
- m_IOutputSlots = std::move(substituteSubgraph.m_IOutputSlots);
+ UpdateSubgraphViewSlotPointers(patternSubgraph, substituteSubgraph);
// Delete the old layers.
workingCopyGraph->EraseSubgraphLayers(patternSubgraph);