aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorColm Donelan <Colm.Donelan@arm.com>2019-10-23 14:18:56 +0100
committerColm Donelan <Colm.Donelan@arm.com>2019-10-24 10:07:13 +0100
commita87698211b6aaab38424865d200534d96f55dcf2 (patch)
tree7234cf42b4784f1388055961e3aac51049dd86a0
parent6e0d962f1c18996c131817ef9420ed0c2c2a3126 (diff)
downloadarmnn-a87698211b6aaab38424865d200534d96f55dcf2.tar.gz
IVGCVSW-4011 Add Unit tests for StandInLayer
* Added network level unit tests with different number of inputs and outputs. Signed-off-by: Colm Donelan <Colm.Donelan@arm.com> Change-Id: I251296ca98a34f459181fed32343e7c579938eab
-rw-r--r--src/armnn/test/NetworkTests.cpp66
-rw-r--r--src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp1
2 files changed, 67 insertions, 0 deletions
diff --git a/src/armnn/test/NetworkTests.cpp b/src/armnn/test/NetworkTests.cpp
index 14b67a1f4a..d8b4e17a3c 100644
--- a/src/armnn/test/NetworkTests.cpp
+++ b/src/armnn/test/NetworkTests.cpp
@@ -473,4 +473,70 @@ BOOST_AUTO_TEST_CASE(Network_AddMerge)
BOOST_TEST(testMerge.m_Visited == true);
}
+BOOST_AUTO_TEST_CASE(StandInLayerNetworkTest)
+{
+ // Create a simple network with a StandIn some place in it.
+ armnn::Network net;
+ auto input = net.AddInputLayer(0);
+
+ // Add some valid layer.
+ auto floor = net.AddFloorLayer("Floor");
+
+ // Add a standin layer
+ armnn::StandInDescriptor standInDescriptor;
+ standInDescriptor.m_NumInputs = 1;
+ standInDescriptor.m_NumOutputs = 1;
+ auto standIn = net.AddStandInLayer(standInDescriptor, "StandIn");
+
+ // Finally the output.
+ auto output = net.AddOutputLayer(0);
+
+ // Connect up the layers
+ input->GetOutputSlot(0).Connect(floor->GetInputSlot(0));
+
+ floor->GetOutputSlot(0).Connect(standIn->GetInputSlot(0));
+
+ standIn->GetOutputSlot(0).Connect(output->GetInputSlot(0));
+
+ // Check that the layer is there.
+ BOOST_TEST(GraphHasNamedLayer(net.GetGraph(), "StandIn"));
+ // Check that it is connected as expected.
+ BOOST_TEST(input->GetOutputSlot(0).GetConnection(0) == &floor->GetInputSlot(0));
+ BOOST_TEST(floor->GetOutputSlot(0).GetConnection(0) == &standIn->GetInputSlot(0));
+ BOOST_TEST(standIn->GetOutputSlot(0).GetConnection(0) == &output->GetInputSlot(0));
+}
+
+BOOST_AUTO_TEST_CASE(StandInLayerSingleInputMultipleOutputsNetworkTest)
+{
+ // Another test with one input and two outputs on the StandIn layer.
+ armnn::Network net;
+
+ // Create the input.
+ auto input = net.AddInputLayer(0);
+
+ // Add a standin layer
+ armnn::StandInDescriptor standInDescriptor;
+ standInDescriptor.m_NumInputs = 1;
+ standInDescriptor.m_NumOutputs = 2;
+ auto standIn = net.AddStandInLayer(standInDescriptor, "StandIn");
+
+ // Add two outputs.
+ auto output0 = net.AddOutputLayer(0);
+ auto output1 = net.AddOutputLayer(1);
+
+ // Connect up the layers
+ input->GetOutputSlot(0).Connect(standIn->GetInputSlot(0));
+
+ // Connect the two outputs of the Standin to the two outputs.
+ standIn->GetOutputSlot(0).Connect(output0->GetInputSlot(0));
+ standIn->GetOutputSlot(1).Connect(output1->GetInputSlot(0));
+
+ // Check that the layer is there.
+ BOOST_TEST(GraphHasNamedLayer(net.GetGraph(), "StandIn"));
+ // Check that it is connected as expected.
+ BOOST_TEST(input->GetOutputSlot(0).GetConnection(0) == &standIn->GetInputSlot(0));
+ BOOST_TEST(standIn->GetOutputSlot(0).GetConnection(0) == &output0->GetInputSlot(0));
+ BOOST_TEST(standIn->GetOutputSlot(1).GetConnection(0) == &output1->GetInputSlot(0));
+}
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
index b1f7f57075..9f4efa91f8 100644
--- a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
+++ b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
@@ -65,4 +65,5 @@ DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(SpaceToBatchNd)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(SpaceToDepth)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(Splitter)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(Stack)
+DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(StandIn)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(StridedSlice)