diff options
Diffstat (limited to 'src/armnn/test')
-rw-r--r-- | src/armnn/test/CreateWorkload.hpp | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index 834aa0e620..4181db26d0 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -1376,4 +1376,51 @@ std::unique_ptr<SpaceToDepthWorkload> CreateSpaceToDepthWorkloadTest(armnn::IWor return workload; } +template <typename StackWorkload> +std::unique_ptr<StackWorkload> CreateStackWorkloadTest(armnn::IWorkloadFactory& factory, + armnn::Graph& graph, + const armnn::TensorShape& inputShape, + const armnn::TensorShape& outputShape, + unsigned int axis, + unsigned int numInputs, + armnn::DataType dataType) +{ + armnn::TensorInfo inputTensorInfo(inputShape, dataType); + armnn::TensorInfo outputTensorInfo(outputShape, dataType); + + // Constructs the Stack layer. + armnn::StackDescriptor descriptor(axis, numInputs, inputShape); + Layer* const stackLayer = graph.AddLayer<StackLayer>(descriptor, "stack"); + BOOST_CHECK(stackLayer != nullptr); + + // Constructs layer inputs and output. + std::vector<Layer*> inputs; + for (unsigned int i=0; i<numInputs; ++i) + { + inputs.push_back(graph.AddLayer<InputLayer>( + static_cast<int>(i), + ("input" + std::to_string(i)).c_str() + )); + BOOST_CHECK(inputs[i] != nullptr); + } + Layer* const output = graph.AddLayer<OutputLayer>(0, "output"); + BOOST_CHECK(output != nullptr); + + // Adds connections. + for (unsigned int i=0; i<numInputs; ++i) + { + Connect(inputs[i], stackLayer, inputTensorInfo, 0, i); + } + Connect(stackLayer, output, outputTensorInfo, 0, 0); + + CreateTensorHandles(graph, factory); + + auto stackWorkload = MakeAndCheckWorkload<StackWorkload>(*stackLayer, graph, factory); + StackQueueDescriptor queueDescriptor = stackWorkload->GetData(); + BOOST_TEST(queueDescriptor.m_Inputs.size() == numInputs); + BOOST_TEST(queueDescriptor.m_Outputs.size() == 1); + + return stackWorkload; +} + } // Anonymous namespace |