diff options
Diffstat (limited to 'src/armnn/test/SubgraphViewTests.cpp')
-rw-r--r-- | src/armnn/test/SubgraphViewTests.cpp | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/src/armnn/test/SubgraphViewTests.cpp b/src/armnn/test/SubgraphViewTests.cpp index d270787968..4e509be78b 100644 --- a/src/armnn/test/SubgraphViewTests.cpp +++ b/src/armnn/test/SubgraphViewTests.cpp @@ -168,6 +168,91 @@ TEST_CASE("SingleInputSingleOutput") CHECK_EQ(preCompiledLayer->GetOutputSlot(0).GetConnection(0), subgraphOutputConn); } +TEST_CASE("SingleInputSingleOutputAddPrecompiledLayerSubstituteSubgraph1") +{ + // Construct graph. + Graph graph; + + Layer* const inputLayer = graph.AddLayer<InputLayer>(0, "input"); + + Convolution2dDescriptor convDescriptor; + Layer* const convLayer1 = graph.AddLayer<Convolution2dLayer>(convDescriptor, "conv1"); + Layer* const convLayer2 = graph.AddLayer<Convolution2dLayer>(convDescriptor, "conv2"); + + Layer* const outputLayer = graph.AddLayer<OutputLayer>(0, "output"); + + inputLayer->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(0)); + convLayer1->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(0)); + convLayer2->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + // Construct sub-graph + SubgraphViewSelector::SubgraphViewPtr subgraph = CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}), + CreateOutputsFrom({convLayer2}), + {}); + + // Save sub-graph connections for comparison after substitution + IOutputSlot* subgraphInputConn = subgraph->GetInputSlot(0)->GetConnection(); + IInputSlot* subgraphOutputConn = subgraph->GetOutputSlot(0)->GetConnection(0); + + PreCompiledDescriptor preCompiledDescriptor(1, 1); + CompiledBlobPtr compiledBlobPtr; + BackendId backend = Compute::CpuRef; + + // Construct dummy pre-compiled layer + INetworkPtr network = INetwork::Create(); + IConnectableLayer* preCompiledLayer = network->AddPrecompiledLayer(preCompiledDescriptor, compiledBlobPtr, backend); + + // Substitute sub-graph with pre-compiled layer + graph.SubstituteSubgraph(*subgraph, preCompiledLayer); + + // Check that connections are correct after substitution + CHECK_EQ(preCompiledLayer->GetInputSlot(0).GetConnection(), subgraphInputConn); + CHECK_EQ(preCompiledLayer->GetOutputSlot(0).GetConnection(0), subgraphOutputConn); +} + +TEST_CASE("SingleInputSingleOutputAddPrecompiledLayerSubstituteSubgraph2") +{ + // Construct graph. + Graph graph; + + Layer* const inputLayer = graph.AddLayer<InputLayer>(0, "input"); + + Convolution2dDescriptor convDescriptor; + Layer* const convLayer1 = graph.AddLayer<Convolution2dLayer>(convDescriptor, "conv1"); + Layer* const convLayer2 = graph.AddLayer<Convolution2dLayer>(convDescriptor, "conv2"); + + Layer* const outputLayer = graph.AddLayer<OutputLayer>(0, "output"); + + inputLayer->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(0)); + convLayer1->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(0)); + convLayer2->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + // Construct sub-graph + SubgraphViewSelector::SubgraphViewPtr subgraph = CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}), + CreateOutputsFrom({convLayer2}), + {}); + + // Save sub-graph connections for comparison after substitution + IOutputSlot* subgraphInputConn = subgraph->GetInputSlot(0)->GetConnection(); + IInputSlot* subgraphOutputConn = subgraph->GetOutputSlot(0)->GetConnection(0); + + PreCompiledDescriptor preCompiledDescriptor(1, 1); + CompiledBlobPtr compiledBlobPtr; + BackendId backend = Compute::CpuRef; + + // Construct dummy pre-compiled layer + INetworkPtr network = INetwork::Create(); + IConnectableLayer* preCompiledLayer = network->AddPrecompiledLayer(preCompiledDescriptor, compiledBlobPtr, backend); + SubgraphView substituteSubgraph(preCompiledLayer); + + // Substitute sub-graph with pre-compiled layer + graph.SubstituteSubgraph(*subgraph, substituteSubgraph); + + // Check that connections are correct after substitution + CHECK_EQ(preCompiledLayer->GetInputSlot(0).GetConnection(), subgraphInputConn); + CHECK_EQ(preCompiledLayer->GetOutputSlot(0).GetConnection(0), subgraphOutputConn); +} + TEST_CASE("SingleInputSingleOutputSubstituteGraph") { // Construct graph |