From e4dfd6ead59e17828f8814f0ecc5fa67f0c72868 Mon Sep 17 00:00:00 2001 From: Nikhil Raj Date: Thu, 18 Oct 2018 10:11:04 +0100 Subject: IVGCVSW-1865 - Support NHWC for Convolution2D (CpuRef) * Updated the ConvImpl.hpp to use DataLayoutIndex * Enabled unit test for CpuRef * Update CreateWorkload Tests for ref with NHWC Change-Id: Id309b7ef677489d63dcb5e09bd48ab9624b5ebfb --- .../reference/test/RefCreateWorkloadTests.cpp | 27 +++++++++++++++++----- src/backends/reference/test/RefLayerTests.cpp | 2 ++ src/backends/reference/workloads/ConvImpl.hpp | 21 ++++++++++------- 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index dc0348dc10..236267c177 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -177,17 +177,32 @@ BOOST_AUTO_TEST_CASE(CreateConvertFp32ToFp16Float16Workload) std::move(workload), TensorInfo({1, 3, 2, 3}, DataType::Float32), TensorInfo({1, 3, 2, 3}, DataType::Float16)); } -BOOST_AUTO_TEST_CASE(CreateConvolution2dWorkload) +static void RefCreateConvolution2dWorkloadTest(DataLayout dataLayout = DataLayout::NCHW) { - Graph graph; + Graph graph; RefWorkloadFactory factory; - auto workload = CreateConvolution2dWorkloadTest(factory, graph); + auto workload = CreateConvolution2dWorkloadTest + (factory, graph, dataLayout); + + std::initializer_list inputShape = (dataLayout == DataLayout::NCHW) ? + std::initializer_list({2, 3, 8, 16}) : std::initializer_list({2, 8, 16, 3}); + std::initializer_list outputShape = (dataLayout == DataLayout::NCHW) ? + std::initializer_list({2, 2, 2, 10}) : std::initializer_list({2, 2, 10, 2}); // Checks that outputs and inputs are as we expect them (see definition of CreateConvolution2dWorkloadTest). CheckInputOutput(std::move(workload), - TensorInfo({2, 3, 8, 16}, DataType::Float32), - TensorInfo({2, 2, 2, 10}, DataType::Float32)); + TensorInfo(inputShape, DataType::Float32), + TensorInfo(outputShape, DataType::Float32)); +} + +BOOST_AUTO_TEST_CASE(CreateConvolution2dFloatNchwWorkload) +{ + RefCreateConvolution2dWorkloadTest(DataLayout::NCHW); +} + +BOOST_AUTO_TEST_CASE(CreateConvolution2dFloatNhwcWorkload) +{ + RefCreateConvolution2dWorkloadTest(DataLayout::NHWC); } template diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 21371611bb..259739ba55 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -36,6 +36,8 @@ ARMNN_AUTO_TEST_CASE(SimpleConvolution2dAsymmetricPaddingLargerThanHalfKernelSiz Convolution2dAsymmetricPaddingLargerThanHalfKernelSizeTest) ARMNN_AUTO_TEST_CASE(SimpleConvolution2dAsymmetricPadding, Convolution2dAsymmetricPaddingTest) +ARMNN_AUTO_TEST_CASE(SimpleConvolution2dSquareNhwc, SimpleConvolution2d3x3NhwcTest, false) + // Depthwise Convolution ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2d, DepthwiseConvolution2dTest, true) ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2dUint8, DepthwiseConvolution2dUint8Test, true) diff --git a/src/backends/reference/workloads/ConvImpl.hpp b/src/backends/reference/workloads/ConvImpl.hpp index 4c9ab2a644..60a3622c55 100644 --- a/src/backends/reference/workloads/ConvImpl.hpp +++ b/src/backends/reference/workloads/ConvImpl.hpp @@ -63,21 +63,26 @@ static void ConvImpl(ConvData data, throw InvalidArgumentException("Bias is enabled but the bias data is invalid"); } - const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]); + const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]); const TensorInfo& outputInfo0 = GetTensorInfo(data.m_Outputs[0]); + const DataLayoutIndexed dataLayoutIndexed(data.m_Parameters.m_DataLayout); + const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex(); + const unsigned int heightIndex = dataLayoutIndexed.GetHeightIndex(); + const unsigned int widthIndex = dataLayoutIndexed.GetWidthIndex(); + unsigned int depthMult = depthwise ? filterInfo.GetShape()[0] : 1; - unsigned int channelsInput = filterInfo.GetShape()[1]; + unsigned int channelsInput = filterInfo.GetShape()[channelsIndex]; unsigned int channelsOutput = depthwise ? channelsInput * depthMult : filterInfo.GetShape()[0]; unsigned int batchSize = outputInfo0.GetShape()[0]; - unsigned int heightOutput = outputInfo0.GetShape()[2]; - unsigned int widthOutput = outputInfo0.GetShape()[3]; - unsigned int heightInput = inputInfo0.GetShape()[2]; - unsigned int widthInput = inputInfo0.GetShape()[3]; + unsigned int heightOutput = outputInfo0.GetShape()[heightIndex]; + unsigned int widthOutput = outputInfo0.GetShape()[widthIndex]; + unsigned int heightInput = inputInfo0.GetShape()[heightIndex]; + unsigned int widthInput = inputInfo0.GetShape()[widthIndex]; - unsigned int heightFilter = filterInfo.GetShape()[2]; - unsigned int widthFilter = filterInfo.GetShape()[3]; + unsigned int heightFilter = filterInfo.GetShape()[heightIndex]; + unsigned int widthFilter = filterInfo.GetShape()[widthIndex]; unsigned int paddingTop = data.m_Parameters.m_PadTop; unsigned int paddingLeft = data.m_Parameters.m_PadLeft; -- cgit v1.2.1