diff options
author | Matthew Bentham <matthew.bentham@arm.com> | 2022-12-21 09:10:04 +0000 |
---|---|---|
committer | Francis Murtagh <francis.murtagh@arm.com> | 2023-01-09 15:23:21 +0000 |
commit | 4512b921b3533d7476c941cf61edcf57418b17d4 (patch) | |
tree | 487093dc9c56ecb7576814cd253508ebc0747649 /src/armnn | |
parent | 94916a5c06065bca0b232106bd4ae68f9986b7b0 (diff) | |
download | armnn-4512b921b3533d7476c941cf61edcf57418b17d4.tar.gz |
IVGCVSW-7418: Can't call SubstituteSubgraph on working copy of subgraph in Optimize
* Add unit test for WorkingCopy of SubgraphView with Inputs and Outputs
* Added check to ensure InputSlot is connected before trying to disconnect
Signed-off-by: Matthew Bentham <matthew.bentham@arm.com>
Change-Id: I261d55e38c94687a9de64cdee726a7c7442ed537
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/Graph.cpp | 18 | ||||
-rw-r--r-- | src/armnn/test/SubgraphViewTests.cpp | 29 |
2 files changed, 41 insertions, 6 deletions
diff --git a/src/armnn/Graph.cpp b/src/armnn/Graph.cpp index b5769f75f3..e5d123830c 100644 --- a/src/armnn/Graph.cpp +++ b/src/armnn/Graph.cpp @@ -497,13 +497,19 @@ void Graph::ReplaceSubgraphConnections(const SubgraphView& subgraph, const Subgr IInputSlot* subgraphInputSlot = subgraphInputSlots.at(inputSlotIdx); ARMNN_ASSERT(subgraphInputSlot); - IOutputSlot* connectedOutputSlot = subgraphInputSlot->GetConnection(); - ARMNN_ASSERT(connectedOutputSlot); - connectedOutputSlot->Disconnect(*subgraphInputSlot); + // Only disconnect if the InputSlot has a connection, this might not be the case when + // dealing with working copies of SubgraphViews + // Note: we don't need this check for OutputSlot as it iterates over a vector of valid connections + if (subgraphInputSlot->GetConnection()) + { + IOutputSlot* connectedOutputSlot = subgraphInputSlot->GetConnection(); + ARMNN_ASSERT(connectedOutputSlot); + connectedOutputSlot->Disconnect(*subgraphInputSlot); - IInputSlot* substituteInputSlot = substituteSubgraphInputSlots.at(inputSlotIdx); - ARMNN_ASSERT(substituteInputSlot); - connectedOutputSlot->Connect(*substituteInputSlot); + IInputSlot* substituteInputSlot = substituteSubgraphInputSlots.at(inputSlotIdx); + ARMNN_ASSERT(substituteInputSlot); + connectedOutputSlot->Connect(*substituteInputSlot); + } } // Step 2: process output slots diff --git a/src/armnn/test/SubgraphViewTests.cpp b/src/armnn/test/SubgraphViewTests.cpp index 4ce67b0fec..9bb5e69bbb 100644 --- a/src/armnn/test/SubgraphViewTests.cpp +++ b/src/armnn/test/SubgraphViewTests.cpp @@ -2063,6 +2063,35 @@ TEST_CASE("SubgraphViewWorkingCopySubstituteSubgraph") CHECK_THROWS_AS(workingCopy.GetWorkingCopy(), Exception); } +TEST_CASE("SubgraphViewPartialWorkingCopySubstituteSubgraph") +{ + Graph graph; + + auto input = graph.AddLayer<InputLayer>(0, "Input"); + auto activation = graph.AddLayer<ActivationLayer>(ActivationDescriptor{}, "Activation"); + auto output = graph.AddLayer<OutputLayer>(1, "Output"); + + input->GetOutputSlot(0).Connect(activation->GetInputSlot(0)); + activation->GetOutputSlot(0).Connect(output->GetInputSlot(0)); + + //Add in out of order + auto view = CreateSubgraphViewFrom({activation}, + {&activation->GetInputSlot(0)}, + {&activation->GetOutputSlot(0)}); + + auto workingCopy = view->GetWorkingCopy(); + + // First (and only) layer in the subgraph is the Activation + CHECK(std::string((*workingCopy.beginIConnectable())->GetName()) == "Activation"); + + // Substitute the "Activation" layer for an equivalent layer + auto activation2 = graph.AddLayer<ActivationLayer>(ActivationDescriptor{}, "Activation2"); + SubgraphView pattern(*workingCopy.beginIConnectable()); + workingCopy.SubstituteSubgraph(pattern, activation2); + + CHECK(std::string((*workingCopy.beginIConnectable())->GetName()) == "Activation2"); +} + TEST_CASE("SubgraphViewWorkingCopyOptimizationViews") { Graph graph; |