aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/OptimizationViewsTests.cpp
diff options
context:
space:
mode:
authorCathal Corbett <cathal.corbett@arm.com>2021-12-15 17:12:59 +0000
committerCathal Corbett <cathal.corbett@arm.com>2021-12-23 13:21:22 +0000
commitcbfd718464b8ac41f0338ae6565d8213d24c0a2a (patch)
treef26da835108a0ed52ac0ffc8f7ebec64827b5033 /src/backends/backendsCommon/test/OptimizationViewsTests.cpp
parent81edc6217f76953c0be4c47f3d005cf48772ccb7 (diff)
downloadarmnn-cbfd718464b8ac41f0338ae6565d8213d24c0a2a.tar.gz
IVGCVSW-6632 OptimizationViews: has INetwork rather than Graph for holding layers
* Deprecate the GetGraph() function in OptimizationViews & remove/fix occurances where OptimizationViews.GetGraph() is called. * OptimizationViews has member INetworkPtr. * OptimizationViews has GetINetwork() method. * Unit test added to OptimizationViewsTests.cpp. Signed-off-by: Cathal Corbett <cathal.corbett@arm.com> Change-Id: Ifc1e53f1c34d786502279631942f0472f401038e
Diffstat (limited to 'src/backends/backendsCommon/test/OptimizationViewsTests.cpp')
-rw-r--r--src/backends/backendsCommon/test/OptimizationViewsTests.cpp71
1 files changed, 69 insertions, 2 deletions
diff --git a/src/backends/backendsCommon/test/OptimizationViewsTests.cpp b/src/backends/backendsCommon/test/OptimizationViewsTests.cpp
index bbae229927..c40c5131a8 100644
--- a/src/backends/backendsCommon/test/OptimizationViewsTests.cpp
+++ b/src/backends/backendsCommon/test/OptimizationViewsTests.cpp
@@ -55,7 +55,7 @@ TEST_CASE("OptimizedViewsSubgraphLayerCount")
{
OptimizationViews view;
// Construct a graph with 3 layers
- Graph& baseGraph = view.GetGraph();
+ Graph baseGraph;
Layer* const inputLayer = baseGraph.AddLayer<InputLayer>(0, "input");
@@ -119,11 +119,78 @@ TEST_CASE("OptimizedViewsSubgraphLayerCount")
CHECK(view.Validate(*originalSubgraph));
}
+
+TEST_CASE("OptimizedViewsSubgraphLayerCountUsingGetINetwork")
+{
+ OptimizationViews view;
+
+ IConnectableLayer* const inputLayer = view.GetINetwork()->AddInputLayer(0, "input");
+
+ DepthwiseConvolution2dDescriptor convDescriptor;
+ PreCompiledDescriptor substitutionLayerDescriptor(1, 1);
+ CompiledBlobPtr blobPtr;
+ BackendId backend = Compute::CpuRef;
+
+ Layer* convLayer1 = PolymorphicDowncast<Layer*>(
+ view.GetINetwork()->AddDepthwiseConvolution2dLayer(convDescriptor,
+ ConstTensor(),
+ Optional<ConstTensor>(),
+ "conv1"));
+
+ Layer* convLayer2 = PolymorphicDowncast<Layer*>(
+ view.GetINetwork()->AddDepthwiseConvolution2dLayer(convDescriptor,
+ ConstTensor(),
+ Optional<ConstTensor>(),
+ "conv2"));
+
+ IConnectableLayer* const outputLayer = view.GetINetwork()->AddOutputLayer(0, "output");
+
+ inputLayer->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(0));
+ convLayer1->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(0));
+ convLayer2->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+
+ // Subgraph for a failed layer
+ SubgraphViewSelector::SubgraphViewPtr failedSubgraph = CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}),
+ CreateOutputsFrom({convLayer1}),
+ {convLayer1});
+ // Subgraph for an untouched layer
+ SubgraphViewSelector::SubgraphViewPtr untouchedSubgraph = CreateSubgraphViewFrom(CreateInputsFrom({convLayer2}),
+ CreateOutputsFrom({convLayer2}),
+ {convLayer2});
+
+ // Create a Network containing a layer to substitute in
+ NetworkImpl net;
+ Layer* substitutionpreCompiledLayer = PolymorphicDowncast<Layer*>(
+ net.AddPrecompiledLayer(substitutionLayerDescriptor, blobPtr, backend));
+
+ // Subgraph for a substitution layer
+ SubgraphViewSelector::SubgraphViewPtr substitutionSubgraph =
+ CreateSubgraphViewFrom(CreateInputsFrom({substitutionpreCompiledLayer}),
+ CreateOutputsFrom({substitutionpreCompiledLayer}),
+ {substitutionpreCompiledLayer});
+
+ view.AddFailedSubgraph(SubgraphView(*failedSubgraph));
+ view.AddUntouchedSubgraph(SubgraphView(*untouchedSubgraph));
+
+ SubgraphViewSelector::SubgraphViewPtr baseSubgraph = CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}),
+ CreateOutputsFrom({convLayer2}),
+ {substitutionpreCompiledLayer});
+ view.AddSubstitution({*baseSubgraph, *substitutionSubgraph});
+
+ // Construct original subgraph to compare against
+ SubgraphViewSelector::SubgraphViewPtr originalSubgraph =
+ CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}),
+ CreateOutputsFrom({convLayer2}),
+ {convLayer1, convLayer2, substitutionpreCompiledLayer});
+
+ CHECK(view.Validate(*originalSubgraph));
+}
+
TEST_CASE("OptimizedViewsSubgraphLayerCountFailValidate")
{
OptimizationViews view;
// Construct a graph with 3 layers
- Graph& baseGraph = view.GetGraph();
+ Graph baseGraph;
Layer* const inputLayer = baseGraph.AddLayer<InputLayer>(0, "input");