From 69482271d3e02af950d2d0f1947ae6c3eeed537b Mon Sep 17 00:00:00 2001 From: James Conroy Date: Fri, 19 Oct 2018 10:41:35 +0100 Subject: IVGCVSW-2024: Support NHWC for Pooling2D CpuRef * Adds implementation to plumb DataLayout parameter for Pooling2D on CpuRef. * Adds unit tests to execute Pooling2D on CpuRef using NHWC data layout. * Refactors original tests to use DataLayoutIndexed and removes duplicate code. Change-Id: Ife7e0861a886cf58a2042e5be20e5b27af4528c9 --- .../reference/test/RefCreateWorkloadTests.cpp | 50 ++++++++++++++++------ 1 file changed, 37 insertions(+), 13 deletions(-) (limited to 'src/backends/reference/test/RefCreateWorkloadTests.cpp') diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index 8bad5497a2..d9322709b2 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -308,27 +308,51 @@ BOOST_AUTO_TEST_CASE(CreateRefNormalizationNhwcWorkload) } template -static void RefCreatePooling2dWorkloadTest() +static void RefCreatePooling2dWorkloadTest(DataLayout dataLayout) { Graph graph; RefWorkloadFactory factory; - auto workload = CreatePooling2dWorkloadTest(factory, graph); + auto workload = CreatePooling2dWorkloadTest(factory, graph, dataLayout); + + TensorShape inputShape; + TensorShape outputShape; + + switch (dataLayout) + { + case DataLayout::NHWC: + inputShape = { 3, 5, 5, 2 }; + outputShape = { 3, 2, 4, 2 }; + break; + case DataLayout::NCHW: + default: + inputShape = { 3, 2, 5, 5 }; + outputShape = { 3, 2, 2, 4 }; + } // Checks that outputs and inputs are as we expect them (see definition of CreatePooling2dWorkloadTest). - CheckInputOutput( - std::move(workload), - TensorInfo({3, 2, 5, 5}, DataType), - TensorInfo({3, 2, 2, 4}, DataType)); + CheckInputOutput(std::move(workload), + TensorInfo(inputShape, DataType), + TensorInfo(outputShape, DataType)); } BOOST_AUTO_TEST_CASE(CreatePooling2dFloat32Workload) { - RefCreatePooling2dWorkloadTest(); + RefCreatePooling2dWorkloadTest(DataLayout::NCHW); +} + +BOOST_AUTO_TEST_CASE(CreatePooling2dFloat32NhwcWorkload) +{ + RefCreatePooling2dWorkloadTest(DataLayout::NHWC); } BOOST_AUTO_TEST_CASE(CreatePooling2dUint8Workload) { - RefCreatePooling2dWorkloadTest(); + RefCreatePooling2dWorkloadTest(DataLayout::NCHW); +} + +BOOST_AUTO_TEST_CASE(CreatePooling2dUint8NhwcWorkload) +{ + RefCreatePooling2dWorkloadTest(DataLayout::NHWC); } template @@ -496,16 +520,16 @@ static void RefCreateResizeBilinearTest(DataLayout dataLayout) inputShape = { 2, 4, 4, 3 }; outputShape = { 2, 2, 2, 3 }; break; - default: // NCHW + case DataLayout::NCHW: + default: inputShape = { 2, 3, 4, 4 }; outputShape = { 2, 3, 2, 2 }; } // Checks that outputs and inputs are as we expect them (see definition of CreateResizeBilinearWorkloadTest). - CheckInputOutput( - std::move(workload), - TensorInfo(inputShape, DataType), - TensorInfo(outputShape, DataType)); + CheckInputOutput(std::move(workload), + TensorInfo(inputShape, DataType), + TensorInfo(outputShape, DataType)); } BOOST_AUTO_TEST_CASE(CreateResizeBilinearFloat32) -- cgit v1.2.1