aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/SubgraphView.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/SubgraphView.cpp')
-rw-r--r--src/armnn/SubgraphView.cpp47
1 files changed, 33 insertions, 14 deletions
diff --git a/src/armnn/SubgraphView.cpp b/src/armnn/SubgraphView.cpp
index fef6390bf2..e16c5c7b3a 100644
--- a/src/armnn/SubgraphView.cpp
+++ b/src/armnn/SubgraphView.cpp
@@ -11,6 +11,7 @@
#include <armnn/utility/NumericCast.hpp>
#include <armnn/utility/PolymorphicDowncast.hpp>
+#include <fmt/format.h>
#include <utility>
namespace armnn
@@ -539,37 +540,33 @@ void SubgraphView::UpdateSubgraphViewSlotPointers(SubgraphView& patternSubgraph,
{
std::vector<IInputSlot*>::iterator inputSlotPosition;
// search for and erase any InputSlots that appear in the WorkingCopy that match those in the PatternSubgraph
- for (auto slot : patternSubgraph.GetIInputSlots())
+ for (unsigned long idx = 0; idx < patternSubgraph.GetIInputSlots().size(); idx++)
{
+ IInputSlot *slot = patternSubgraph.GetIInputSlots()[idx];
inputSlotPosition = std::find(m_IInputSlots.begin(), m_IInputSlots.end(), slot);
if (inputSlotPosition != m_IInputSlots.end())
{
m_IInputSlots.erase(inputSlotPosition);
+
+ // while here, with correct position, add in replacement InputSlot from the substituteSubgraph
+ m_IInputSlots.insert(inputSlotPosition, substituteSubgraph.GetIInputSlots()[idx]);
}
}
std::vector<IOutputSlot*>::iterator outputSlotPosition;
// search for and erase any OutputSlots that appear in the WorkingCopy that match those in the PatternSubgraph
- for (auto slot : patternSubgraph.GetIOutputSlots())
+ for (unsigned long idx = 0; idx < patternSubgraph.GetIOutputSlots().size(); idx++)
{
+ IOutputSlot *slot = patternSubgraph.GetIOutputSlots()[idx];
outputSlotPosition = std::find(m_IOutputSlots.begin(), m_IOutputSlots.end(), slot);
if (outputSlotPosition != m_IOutputSlots.end())
{
m_IOutputSlots.erase(outputSlotPosition);
+
+ // while here, with correct position, add in replacement OutputSlot from the substituteSubgraph
+ m_IOutputSlots.insert(outputSlotPosition, substituteSubgraph.GetIOutputSlots()[idx]);
}
}
-
- // append InputSlots from the SubstituteSubgraph to the WorkingCopy m_IInputSlots vector variable
- // at the position in the vector where PatternSubgraph InputSlots were last removed
- m_IInputSlots.insert(inputSlotPosition,
- std::make_move_iterator(substituteSubgraph.m_IInputSlots.begin()),
- std::make_move_iterator(substituteSubgraph.m_IInputSlots.end()));
-
- // append OutputSlots from the SubstituteSubgraph to the WorkingCopy m_IOutputSlots vector variable
- // at the position in the vector where PatternSubgraph OutputSlots were last removed
- m_IOutputSlots.insert(outputSlotPosition,
- std::make_move_iterator(substituteSubgraph.m_OutputSlots.begin()),
- std::make_move_iterator(substituteSubgraph.m_OutputSlots.end()));
}
void SubgraphView::SubstituteSubgraph(SubgraphView& patternSubgraph, const SubgraphView& substituteSubgraph)
@@ -580,6 +577,28 @@ void SubgraphView::SubstituteSubgraph(SubgraphView& patternSubgraph, const Subgr
"Call this function on SubgraphView returned from SubgraphView::GetWorkingCopy()");
}
+ auto numPatternInputs = patternSubgraph.GetIInputSlots().size();
+ auto numSubInputs = substituteSubgraph.GetIInputSlots().size();
+ if (numPatternInputs != numSubInputs)
+ {
+ throw armnn::InvalidArgumentException(
+ fmt::format("Number of InputSlots on substitute SubgraphView ({}) must equal the number of"
+ " InputSlots on pattern SubgraphView ({})",
+ numSubInputs,
+ numPatternInputs));
+ }
+
+ auto numPatternOutputs = patternSubgraph.GetIOutputSlots().size();
+ auto numSubOutputs = substituteSubgraph.GetIOutputSlots().size();
+ if (numPatternOutputs != numSubOutputs)
+ {
+ throw armnn::InvalidArgumentException(
+ fmt::format("Number of OutputSlots on substitute SubgraphView ({}) must equal the number of"
+ " OutputSlots on pattern SubgraphView ({})",
+ numSubOutputs,
+ numPatternOutputs));
+ }
+
// Add substitute layer to the Main graph i.e. graph in p_WorkingCopyImpl
auto workingCopyGraph = &p_WorkingCopyImpl->m_Graph;
substituteSubgraph.ForEachIConnectableLayer([workingCopyGraph](IConnectableLayer* iConnectableLayer)