From 81e601c5a5ebf3de3dd6418942708158de50252a Mon Sep 17 00:00:00 2001 From: Matthew Jackson Date: Thu, 11 Jul 2019 12:07:09 +0100 Subject: IVGCVSW-3419 Add reference workload support for the new Stack layer * Added reference workload for the Stack layer * Added factory methods * Added validation support * Added unit tests Signed-off-by: Matthew Jackson Change-Id: Ib14b72c15f53a2a2ca152afc357ce2aa405ccc88 --- src/armnn/test/CreateWorkload.hpp | 47 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) (limited to 'src/armnn/test/CreateWorkload.hpp') diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index 834aa0e620..4181db26d0 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -1376,4 +1376,51 @@ std::unique_ptr CreateSpaceToDepthWorkloadTest(armnn::IWor return workload; } +template +std::unique_ptr CreateStackWorkloadTest(armnn::IWorkloadFactory& factory, + armnn::Graph& graph, + const armnn::TensorShape& inputShape, + const armnn::TensorShape& outputShape, + unsigned int axis, + unsigned int numInputs, + armnn::DataType dataType) +{ + armnn::TensorInfo inputTensorInfo(inputShape, dataType); + armnn::TensorInfo outputTensorInfo(outputShape, dataType); + + // Constructs the Stack layer. + armnn::StackDescriptor descriptor(axis, numInputs, inputShape); + Layer* const stackLayer = graph.AddLayer(descriptor, "stack"); + BOOST_CHECK(stackLayer != nullptr); + + // Constructs layer inputs and output. + std::vector inputs; + for (unsigned int i=0; i( + static_cast(i), + ("input" + std::to_string(i)).c_str() + )); + BOOST_CHECK(inputs[i] != nullptr); + } + Layer* const output = graph.AddLayer(0, "output"); + BOOST_CHECK(output != nullptr); + + // Adds connections. + for (unsigned int i=0; i(*stackLayer, graph, factory); + StackQueueDescriptor queueDescriptor = stackWorkload->GetData(); + BOOST_TEST(queueDescriptor.m_Inputs.size() == numInputs); + BOOST_TEST(queueDescriptor.m_Outputs.size() == 1); + + return stackWorkload; +} + } // Anonymous namespace -- cgit v1.2.1