From 57f13d5905a1fbdc89b53d68b2bdc6b753b9e8d5 Mon Sep 17 00:00:00 2001 From: Francis Murtagh Date: Mon, 24 Jun 2019 14:24:36 +0100 Subject: IVGCVSW-3334 Refactor BatchToSpace tests to be generic * Generify and reformat test BatchToSpace tests * Add missing RefCreateWorkload test Change-Id: I08af018c07ee41df5b9d1e578d99bc03f38090ac Signed-off-by: Francis Murtagh --- src/armnn/test/CreateWorkload.hpp | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) (limited to 'src/armnn/test') diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index 47af4a89b5..00257ea090 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -914,6 +914,35 @@ std::unique_ptr CreateRsqrtWorkloadTest(armnn::IWorkloadFactory& return workload; } +template +std::unique_ptr CreateBatchToSpaceNdWorkloadTest(armnn::IWorkloadFactory& factory, + armnn::Graph& graph) +{ + BatchToSpaceNdDescriptor desc; + Layer* const layer = graph.AddLayer(desc, "batchToSpace"); + + // Creates extra layers. + Layer* const input = graph.AddLayer(0, "input"); + Layer* const output = graph.AddLayer(0, "output"); + + // Connects up. + armnn::TensorInfo tensorInfo({1, 1}, DataType); + + Connect(input, layer, tensorInfo); + Connect(layer, output, tensorInfo); + + CreateTensorHandles(graph, factory); + + // Makes the workload and checks it. + auto workload = MakeAndCheckWorkload(*layer, graph, factory); + + BatchToSpaceNdQueueDescriptor queueDescriptor = workload->GetData(); + BOOST_TEST(queueDescriptor.m_Inputs.size() == 1); + BOOST_TEST(queueDescriptor.m_Outputs.size() == 1); + + return workload; +} + template std::unique_ptr CreateL2NormalizationWorkloadTest(armnn::IWorkloadFactory& factory, armnn::Graph& graph, DataLayout dataLayout = DataLayout::NCHW) -- cgit v1.2.1