diff options
Diffstat (limited to 'src/armnn/SubgraphView.cpp')
-rw-r--r-- | src/armnn/SubgraphView.cpp | 41 |
1 files changed, 39 insertions, 2 deletions
diff --git a/src/armnn/SubgraphView.cpp b/src/armnn/SubgraphView.cpp index 804ff731fb..b48529c523 100644 --- a/src/armnn/SubgraphView.cpp +++ b/src/armnn/SubgraphView.cpp @@ -525,6 +525,44 @@ void SubgraphView::SubstituteSubgraph(SubgraphView& subgraph, IConnectableLayer* SubstituteSubgraph(subgraph, substituteSubgraph); } +void SubgraphView::UpdateSubgraphViewSlotPointers(SubgraphView& patternSubgraph, + const SubgraphView& substituteSubgraph) +{ + std::vector<IInputSlot*>::iterator inputSlotPosition; + // search for and erase any InputSlots that appear in the WorkingCopy that match those in the PatternSubgraph + for (auto slot : patternSubgraph.GetIInputSlots()) + { + inputSlotPosition = std::find(m_IInputSlots.begin(), m_IInputSlots.end(), slot); + if (inputSlotPosition != m_IInputSlots.end()) + { + m_IInputSlots.erase(inputSlotPosition); + } + } + + std::vector<IOutputSlot*>::iterator outputSlotPosition; + // search for and erase any OutputSlots that appear in the WorkingCopy that match those in the PatternSubgraph + for (auto slot : patternSubgraph.GetIOutputSlots()) + { + outputSlotPosition = std::find(m_IOutputSlots.begin(), m_IOutputSlots.end(), slot); + if (outputSlotPosition != m_IOutputSlots.end()) + { + m_IOutputSlots.erase(outputSlotPosition); + } + } + + // append InputSlots from the SubstituteSubgraph to the WorkingCopy m_IInputSlots vector variable + // at the position in the vector where PatternSubgraph InputSlots were last removed + m_IInputSlots.insert(inputSlotPosition, + std::make_move_iterator(substituteSubgraph.m_IInputSlots.begin()), + std::make_move_iterator(substituteSubgraph.m_IInputSlots.end())); + + // append OutputSlots from the SubstituteSubgraph to the WorkingCopy m_IOutputSlots vector variable + // at the position in the vector where PatternSubgraph OutputSlots were last removed + m_IOutputSlots.insert(outputSlotPosition, + std::make_move_iterator(substituteSubgraph.m_OutputSlots.begin()), + std::make_move_iterator(substituteSubgraph.m_OutputSlots.end())); +} + void SubgraphView::SubstituteSubgraph(SubgraphView& patternSubgraph, const SubgraphView& substituteSubgraph) { if (!p_WorkingCopyImpl) @@ -556,8 +594,7 @@ void SubgraphView::SubstituteSubgraph(SubgraphView& patternSubgraph, const Subgr workingCopyGraph->ReplaceSubgraphConnections(patternSubgraph, substituteSubgraph); // Update input/outputSlot pointers - m_IInputSlots = std::move(substituteSubgraph.m_IInputSlots); - m_IOutputSlots = std::move(substituteSubgraph.m_IOutputSlots); + UpdateSubgraphViewSlotPointers(patternSubgraph, substituteSubgraph); // Delete the old layers. workingCopyGraph->EraseSubgraphLayers(patternSubgraph); |