From b48e68674e600d68ca7059736d930ada6a3b4969 Mon Sep 17 00:00:00 2001 From: Nina Drozd Date: Tue, 9 Oct 2018 12:09:56 +0100 Subject: IVGCVSW-1982 - add create workload test for 2D Pooling (NHWC data layout) Change-Id: Ief0c91ba9abc2578944860ddbd3c19e2bad465bd --- src/armnn/test/CreateWorkload.hpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'src/armnn') diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index f2c8b5a20a..b63e95d4cb 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -162,7 +162,6 @@ std::unique_ptr CreateBatchNormalizationWorkl // Makes the workload and checks it. auto workload = MakeAndCheckWorkload(*layer, graph, factory); - BatchNormalizationQueueDescriptor queueDescriptor = workload->GetData(); BOOST_TEST(queueDescriptor.m_Parameters.m_Eps == 0.05f); BOOST_TEST(queueDescriptor.m_Inputs.size() == 1); @@ -532,7 +531,8 @@ std::unique_ptr CreateNormalizationWorkloadTest(armnn::IW template std::unique_ptr CreatePooling2dWorkloadTest(armnn::IWorkloadFactory& factory, - armnn::Graph& graph) + armnn::Graph& graph, + DataLayout dataLayout = DataLayout::NCHW) { // Creates the layer we're testing. Pooling2dDescriptor layerDesc; @@ -546,6 +546,7 @@ std::unique_ptr CreatePooling2dWorkloadTest(armnn::IWorkloadF layerDesc.m_StrideX = 2; layerDesc.m_StrideY = 3; layerDesc.m_OutputShapeRounding = OutputShapeRounding::Floor; + layerDesc.m_DataLayout = dataLayout; Pooling2dLayer* const layer = graph.AddLayer(layerDesc, "layer"); @@ -553,9 +554,12 @@ std::unique_ptr CreatePooling2dWorkloadTest(armnn::IWorkloadF Layer* const input = graph.AddLayer(0, "input"); Layer* const output = graph.AddLayer(0, "output"); + TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{3, 2, 5, 5} : TensorShape{3, 5, 5, 2}; + TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{3, 2, 2, 4} : TensorShape{3, 2, 4, 2}; + // Connect up - Connect(input, layer, TensorInfo({3, 2, 5, 5}, DataType)); - Connect(layer, output, TensorInfo({3, 2, 2, 4}, DataType)); + Connect(input, layer, TensorInfo(inputShape, DataType)); + Connect(layer, output, TensorInfo(outputShape, DataType)); CreateTensorHandles(graph, factory); // Make the workload and checks it -- cgit v1.2.1