diff options
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/test/CreateWorkload.hpp | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index 47af4a89b5..00257ea090 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -914,6 +914,35 @@ std::unique_ptr<RsqrtWorkload> CreateRsqrtWorkloadTest(armnn::IWorkloadFactory& return workload; } +template <typename BatchToSpaceNdWorkload, armnn::DataType DataType> +std::unique_ptr<BatchToSpaceNdWorkload> CreateBatchToSpaceNdWorkloadTest(armnn::IWorkloadFactory& factory, + armnn::Graph& graph) +{ + BatchToSpaceNdDescriptor desc; + Layer* const layer = graph.AddLayer<BatchToSpaceNdLayer>(desc, "batchToSpace"); + + // Creates extra layers. + Layer* const input = graph.AddLayer<InputLayer>(0, "input"); + Layer* const output = graph.AddLayer<OutputLayer>(0, "output"); + + // Connects up. + armnn::TensorInfo tensorInfo({1, 1}, DataType); + + Connect(input, layer, tensorInfo); + Connect(layer, output, tensorInfo); + + CreateTensorHandles(graph, factory); + + // Makes the workload and checks it. + auto workload = MakeAndCheckWorkload<BatchToSpaceNdWorkload>(*layer, graph, factory); + + BatchToSpaceNdQueueDescriptor queueDescriptor = workload->GetData(); + BOOST_TEST(queueDescriptor.m_Inputs.size() == 1); + BOOST_TEST(queueDescriptor.m_Outputs.size() == 1); + + return workload; +} + template <typename L2NormalizationWorkload, armnn::DataType DataType> std::unique_ptr<L2NormalizationWorkload> CreateL2NormalizationWorkloadTest(armnn::IWorkloadFactory& factory, armnn::Graph& graph, DataLayout dataLayout = DataLayout::NCHW) |