From d134093a271b60e248942af9757e8236e8f41ac1 Mon Sep 17 00:00:00 2001 From: Nikhil Raj Date: Thu, 18 Oct 2018 14:27:50 +0100 Subject: IVGCVSW-2023 CL and Neon implementation of BatchNorm with NHWC Change-Id: I962e986607e5d045cd97b9eaeaea2f5ae624db35 --- src/armnn/test/CreateWorkload.hpp | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) (limited to 'src/armnn') diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp index db1773a0ce..01c5e9f689 100644 --- a/src/armnn/test/CreateWorkload.hpp +++ b/src/armnn/test/CreateWorkload.hpp @@ -133,8 +133,20 @@ std::unique_ptr CreateArithmeticWorkloadTest(armnn::IWorkloadFacto template std::unique_ptr CreateBatchNormalizationWorkloadTest( - armnn::IWorkloadFactory& factory, armnn::Graph& graph, DataLayout dataLayout = DataLayout::NCHW) + armnn::IWorkloadFactory& factory, armnn::Graph& graph, DataLayoutIndexed dataLayout = DataLayout::NCHW) { + + TensorShape tensorShape; + switch (dataLayout.GetDataLayout()) + { + case DataLayout::NHWC: + tensorShape = { 2, 4, 4, 3 }; + break; + case DataLayout::NCHW: + default: + tensorShape = { 2, 3, 4, 4 }; + } + // Creates the layer we're testing. BatchNormalizationDescriptor layerDesc; layerDesc.m_Eps = 0.05f; @@ -156,25 +168,23 @@ std::unique_ptr CreateBatchNormalizationWorkl Layer* const input = graph.AddLayer(0, "input"); Layer* const output = graph.AddLayer(0, "output"); - TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{ 2, 3, 1, 1 } : TensorShape{ 2, 1, 1, 3 }; - TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{ 2, 3, 1, 1 } : TensorShape{ 2, 1, 1, 3 }; - - // Connects up. - Connect(input, layer, TensorInfo(inputShape, DataType)); - Connect(layer, output, TensorInfo(outputShape, DataType)); + //Connects up. + armnn::TensorInfo tensorInfo(tensorShape, DataType); + Connect(input, layer, tensorInfo); + Connect(layer, output, tensorInfo); CreateTensorHandles(graph, factory); // Makes the workload and checks it. auto workload = MakeAndCheckWorkload(*layer, graph, factory); BatchNormalizationQueueDescriptor queueDescriptor = workload->GetData(); BOOST_TEST(queueDescriptor.m_Parameters.m_Eps == 0.05f); - BOOST_TEST((queueDescriptor.m_Parameters.m_DataLayout == dataLayout)); BOOST_TEST(queueDescriptor.m_Inputs.size() == 1); BOOST_TEST(queueDescriptor.m_Outputs.size() == 1); BOOST_TEST((queueDescriptor.m_Mean->GetTensorInfo() == TensorInfo({3}, DataType))); BOOST_TEST((queueDescriptor.m_Variance->GetTensorInfo() == TensorInfo({3}, DataType))); BOOST_TEST((queueDescriptor.m_Gamma->GetTensorInfo() == TensorInfo({3}, DataType))); BOOST_TEST((queueDescriptor.m_Beta->GetTensorInfo() == TensorInfo({3}, DataType))); + BOOST_TEST((queueDescriptor.m_Parameters.m_DataLayout.GetDataLayout() == dataLayout)); // Returns so we can do extra, backend-specific tests. return workload; -- cgit v1.2.1