From 87231bea67d95021f4e3674ae9a9f751dc3fd94b Mon Sep 17 00:00:00 2001 From: Francis Murtagh Date: Tue, 24 Jan 2023 15:49:01 +0000 Subject: IVGCVSW-7155: Fix Slot replacement during UpdateSubgraphViewSlotPointers * Only update boundary slots on actual subgraphview * Previously all slots from replacement subgraph added even if internal Signed-off-by: Francis Murtagh Signed-off-by: Matthew Bentham Change-Id: Ic9ef9fc41ad248838d1c019dd0368378c3119648 --- src/armnn/SubgraphView.cpp | 47 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 14 deletions(-) (limited to 'src/armnn/SubgraphView.cpp') diff --git a/src/armnn/SubgraphView.cpp b/src/armnn/SubgraphView.cpp index fef6390bf2..e16c5c7b3a 100644 --- a/src/armnn/SubgraphView.cpp +++ b/src/armnn/SubgraphView.cpp @@ -11,6 +11,7 @@ #include #include +#include #include namespace armnn @@ -539,37 +540,33 @@ void SubgraphView::UpdateSubgraphViewSlotPointers(SubgraphView& patternSubgraph, { std::vector::iterator inputSlotPosition; // search for and erase any InputSlots that appear in the WorkingCopy that match those in the PatternSubgraph - for (auto slot : patternSubgraph.GetIInputSlots()) + for (unsigned long idx = 0; idx < patternSubgraph.GetIInputSlots().size(); idx++) { + IInputSlot *slot = patternSubgraph.GetIInputSlots()[idx]; inputSlotPosition = std::find(m_IInputSlots.begin(), m_IInputSlots.end(), slot); if (inputSlotPosition != m_IInputSlots.end()) { m_IInputSlots.erase(inputSlotPosition); + + // while here, with correct position, add in replacement InputSlot from the substituteSubgraph + m_IInputSlots.insert(inputSlotPosition, substituteSubgraph.GetIInputSlots()[idx]); } } std::vector::iterator outputSlotPosition; // search for and erase any OutputSlots that appear in the WorkingCopy that match those in the PatternSubgraph - for (auto slot : patternSubgraph.GetIOutputSlots()) + for (unsigned long idx = 0; idx < patternSubgraph.GetIOutputSlots().size(); idx++) { + IOutputSlot *slot = patternSubgraph.GetIOutputSlots()[idx]; outputSlotPosition = std::find(m_IOutputSlots.begin(), m_IOutputSlots.end(), slot); if (outputSlotPosition != m_IOutputSlots.end()) { m_IOutputSlots.erase(outputSlotPosition); + + // while here, with correct position, add in replacement OutputSlot from the substituteSubgraph + m_IOutputSlots.insert(outputSlotPosition, substituteSubgraph.GetIOutputSlots()[idx]); } } - - // append InputSlots from the SubstituteSubgraph to the WorkingCopy m_IInputSlots vector variable - // at the position in the vector where PatternSubgraph InputSlots were last removed - m_IInputSlots.insert(inputSlotPosition, - std::make_move_iterator(substituteSubgraph.m_IInputSlots.begin()), - std::make_move_iterator(substituteSubgraph.m_IInputSlots.end())); - - // append OutputSlots from the SubstituteSubgraph to the WorkingCopy m_IOutputSlots vector variable - // at the position in the vector where PatternSubgraph OutputSlots were last removed - m_IOutputSlots.insert(outputSlotPosition, - std::make_move_iterator(substituteSubgraph.m_OutputSlots.begin()), - std::make_move_iterator(substituteSubgraph.m_OutputSlots.end())); } void SubgraphView::SubstituteSubgraph(SubgraphView& patternSubgraph, const SubgraphView& substituteSubgraph) @@ -580,6 +577,28 @@ void SubgraphView::SubstituteSubgraph(SubgraphView& patternSubgraph, const Subgr "Call this function on SubgraphView returned from SubgraphView::GetWorkingCopy()"); } + auto numPatternInputs = patternSubgraph.GetIInputSlots().size(); + auto numSubInputs = substituteSubgraph.GetIInputSlots().size(); + if (numPatternInputs != numSubInputs) + { + throw armnn::InvalidArgumentException( + fmt::format("Number of InputSlots on substitute SubgraphView ({}) must equal the number of" + " InputSlots on pattern SubgraphView ({})", + numSubInputs, + numPatternInputs)); + } + + auto numPatternOutputs = patternSubgraph.GetIOutputSlots().size(); + auto numSubOutputs = substituteSubgraph.GetIOutputSlots().size(); + if (numPatternOutputs != numSubOutputs) + { + throw armnn::InvalidArgumentException( + fmt::format("Number of OutputSlots on substitute SubgraphView ({}) must equal the number of" + " OutputSlots on pattern SubgraphView ({})", + numSubOutputs, + numPatternOutputs)); + } + // Add substitute layer to the Main graph i.e. graph in p_WorkingCopyImpl auto workingCopyGraph = &p_WorkingCopyImpl->m_Graph; substituteSubgraph.ForEachIConnectableLayer([workingCopyGraph](IConnectableLayer* iConnectableLayer) -- cgit v1.2.1