aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp50
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp2
-rw-r--r--src/backends/reference/workloads/Pooling2d.cpp14
3 files changed, 48 insertions, 18 deletions
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 8bad5497a2..d9322709b2 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -308,27 +308,51 @@ BOOST_AUTO_TEST_CASE(CreateRefNormalizationNhwcWorkload)
}
template <typename Pooling2dWorkloadType, armnn::DataType DataType>
-static void RefCreatePooling2dWorkloadTest()
+static void RefCreatePooling2dWorkloadTest(DataLayout dataLayout)
{
Graph graph;
RefWorkloadFactory factory;
- auto workload = CreatePooling2dWorkloadTest<Pooling2dWorkloadType, DataType>(factory, graph);
+ auto workload = CreatePooling2dWorkloadTest<Pooling2dWorkloadType, DataType>(factory, graph, dataLayout);
+
+ TensorShape inputShape;
+ TensorShape outputShape;
+
+ switch (dataLayout)
+ {
+ case DataLayout::NHWC:
+ inputShape = { 3, 5, 5, 2 };
+ outputShape = { 3, 2, 4, 2 };
+ break;
+ case DataLayout::NCHW:
+ default:
+ inputShape = { 3, 2, 5, 5 };
+ outputShape = { 3, 2, 2, 4 };
+ }
// Checks that outputs and inputs are as we expect them (see definition of CreatePooling2dWorkloadTest).
- CheckInputOutput(
- std::move(workload),
- TensorInfo({3, 2, 5, 5}, DataType),
- TensorInfo({3, 2, 2, 4}, DataType));
+ CheckInputOutput(std::move(workload),
+ TensorInfo(inputShape, DataType),
+ TensorInfo(outputShape, DataType));
}
BOOST_AUTO_TEST_CASE(CreatePooling2dFloat32Workload)
{
- RefCreatePooling2dWorkloadTest<RefPooling2dFloat32Workload, armnn::DataType::Float32>();
+ RefCreatePooling2dWorkloadTest<RefPooling2dFloat32Workload, armnn::DataType::Float32>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreatePooling2dFloat32NhwcWorkload)
+{
+ RefCreatePooling2dWorkloadTest<RefPooling2dFloat32Workload, armnn::DataType::Float32>(DataLayout::NHWC);
}
BOOST_AUTO_TEST_CASE(CreatePooling2dUint8Workload)
{
- RefCreatePooling2dWorkloadTest<RefPooling2dUint8Workload, armnn::DataType::QuantisedAsymm8>();
+ RefCreatePooling2dWorkloadTest<RefPooling2dUint8Workload, armnn::DataType::QuantisedAsymm8>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreatePooling2dUint8NhwcWorkload)
+{
+ RefCreatePooling2dWorkloadTest<RefPooling2dUint8Workload, armnn::DataType::QuantisedAsymm8>(DataLayout::NHWC);
}
template <typename SoftmaxWorkloadType, armnn::DataType DataType>
@@ -496,16 +520,16 @@ static void RefCreateResizeBilinearTest(DataLayout dataLayout)
inputShape = { 2, 4, 4, 3 };
outputShape = { 2, 2, 2, 3 };
break;
- default: // NCHW
+ case DataLayout::NCHW:
+ default:
inputShape = { 2, 3, 4, 4 };
outputShape = { 2, 3, 2, 2 };
}
// Checks that outputs and inputs are as we expect them (see definition of CreateResizeBilinearWorkloadTest).
- CheckInputOutput(
- std::move(workload),
- TensorInfo(inputShape, DataType),
- TensorInfo(outputShape, DataType));
+ CheckInputOutput(std::move(workload),
+ TensorInfo(inputShape, DataType),
+ TensorInfo(outputShape, DataType));
}
BOOST_AUTO_TEST_CASE(CreateResizeBilinearFloat32)
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 9f044cdbaf..30f5b10b0e 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -80,7 +80,9 @@ ARMNN_AUTO_TEST_CASE(IgnorePaddingL2Pooling2dSize3, IgnorePaddingL2Pooling2dSize
ARMNN_AUTO_TEST_CASE(IgnorePaddingL2Pooling2dSize3Uint8, IgnorePaddingL2Pooling2dSize3Uint8Test)
ARMNN_AUTO_TEST_CASE(SimpleAveragePooling2d, SimpleAveragePooling2dTest)
+ARMNN_AUTO_TEST_CASE(SimpleAveragePooling2dNhwc, SimpleAveragePooling2dNhwcTest)
ARMNN_AUTO_TEST_CASE(SimpleAveragePooling2dUint8, SimpleAveragePooling2dUint8Test)
+ARMNN_AUTO_TEST_CASE(SimpleAveragePooling2dUint8Nhwc, SimpleAveragePooling2dUint8NhwcTest)
ARMNN_AUTO_TEST_CASE(IgnorePaddingAveragePooling2dSize3x2Stride2x2,
IgnorePaddingAveragePooling2dSize3x2Stride2x2Test, false)
ARMNN_AUTO_TEST_CASE(IgnorePaddingAveragePooling2dSize3x2Stride2x2NoPadding,
diff --git a/src/backends/reference/workloads/Pooling2d.cpp b/src/backends/reference/workloads/Pooling2d.cpp
index 5812a290e7..9890920113 100644
--- a/src/backends/reference/workloads/Pooling2d.cpp
+++ b/src/backends/reference/workloads/Pooling2d.cpp
@@ -143,12 +143,16 @@ void Pooling2d(const float* in,
const TensorInfo& outputInfo,
const Pooling2dDescriptor& params)
{
+ const unsigned int channelsIndex = params.m_DataLayout.GetChannelsIndex();
+ const unsigned int heightIndex = params.m_DataLayout.GetHeightIndex();
+ const unsigned int widthIndex = params.m_DataLayout.GetWidthIndex();
+
const int batchSize = boost::numeric_cast<int>(outputInfo.GetShape()[0]);
- const int channels = boost::numeric_cast<int>(outputInfo.GetShape()[1]);
- const int heightOutput = boost::numeric_cast<int>(outputInfo.GetShape()[2]);
- const int widthOutput = boost::numeric_cast<int>(outputInfo.GetShape()[3]);
- const int heightInput = boost::numeric_cast<int>(inputInfo.GetShape()[2]);
- const int widthInput = boost::numeric_cast<int>(inputInfo.GetShape()[3]);
+ const int channels = boost::numeric_cast<int>(outputInfo.GetShape()[channelsIndex]);
+ const int heightOutput = boost::numeric_cast<int>(outputInfo.GetShape()[heightIndex]);
+ const int widthOutput = boost::numeric_cast<int>(outputInfo.GetShape()[widthIndex]);
+ const int heightInput = boost::numeric_cast<int>(inputInfo.GetShape()[heightIndex]);
+ const int widthInput = boost::numeric_cast<int>(inputInfo.GetShape()[widthIndex]);
const int padLeft = boost::numeric_cast<int>(params.m_PadLeft);
const int padRight = boost::numeric_cast<int>(params.m_PadRight);
const int padTop = boost::numeric_cast<int>(params.m_PadTop);