aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2018-10-16 16:23:33 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-22 16:57:54 +0100
commitb63973ee1134336434a490fc9af8bba6cde79820 (patch)
tree1304b693044697454bc10cd52b7a4746444b5feb
parent177d8d26925a58a579943e010d28d1ceaa033d64 (diff)
downloadarmnn-b63973ee1134336434a490fc9af8bba6cde79820.tar.gz
IVGCVSW-2018 Support NHWC in the current ref implementation
* Enabled the now supported ref layer tests * Re-enabled the failing test now that the bug has been fixed in ACL 1903a9976ae24f40cb2203364211ed62fcfbb985 * Added CreateWorkload test for ref L2Normalization NHWC * Refactoring the ref L2Normalization for clarity !armnn:153723 Change-Id: Id0067e49072b3e057ffe3ae3b70d928be6091c0f
-rwxr-xr-xsrc/backends/cl/test/ClLayerTests.cpp3
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp38
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp9
-rw-r--r--src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp32
-rw-r--r--src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp3
5 files changed, 57 insertions, 28 deletions
diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp
index 62ce2cb18f..3b1603c13c 100755
--- a/src/backends/cl/test/ClLayerTests.cpp
+++ b/src/backends/cl/test/ClLayerTests.cpp
@@ -188,8 +188,7 @@ ARMNN_AUTO_TEST_CASE(L2Normalization2d, L2Normalization2dTest)
ARMNN_AUTO_TEST_CASE(L2Normalization3d, L2Normalization3dTest)
ARMNN_AUTO_TEST_CASE(L2Normalization4d, L2Normalization4dTest)
-// NOTE: The following test hits a bug in ACL that makes it fail, keep it disabled until a patch is available in ACL
-//ARMNN_AUTO_TEST_CASE(L2Normalization1dNhwc, L2Normalization1dNhwcTest)
+ARMNN_AUTO_TEST_CASE(L2Normalization1dNhwc, L2Normalization1dNhwcTest)
ARMNN_AUTO_TEST_CASE(L2Normalization2dNhwc, L2Normalization2dNhwcTest)
ARMNN_AUTO_TEST_CASE(L2Normalization3dNhwc, L2Normalization3dNhwcTest)
ARMNN_AUTO_TEST_CASE(L2Normalization4dNhwc, L2Normalization4dNhwcTest)
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index a8901d2cc5..dc0348dc10 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -449,18 +449,42 @@ BOOST_AUTO_TEST_CASE(CreateResizeBilinearFloat32Nhwc)
RefCreateResizeBilinearTest<RefResizeBilinearFloat32Workload, armnn::DataType::Float32>(DataLayout::NHWC);
}
-BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32)
+template <typename L2NormalizationWorkloadType, armnn::DataType DataType>
+static void RefCreateL2NormalizationTest(DataLayout dataLayout)
{
Graph graph;
RefWorkloadFactory factory;
- auto workload = CreateL2NormalizationWorkloadTest<RefL2NormalizationFloat32Workload, armnn::DataType::Float32>
- (factory, graph);
+ auto workload =
+ CreateL2NormalizationWorkloadTest<L2NormalizationWorkloadType, DataType>(factory, graph, dataLayout);
+
+ TensorShape inputShape;
+ TensorShape outputShape;
+
+ switch (dataLayout)
+ {
+ case DataLayout::NHWC:
+ inputShape = { 5, 50, 67, 20 };
+ outputShape = { 5, 50, 67, 20 };
+ break;
+ case DataLayout::NCHW:
+ default:
+ inputShape = { 5, 20, 50, 67 };
+ outputShape = { 5, 20, 50, 67 };
+ break;
+ }
// Checks that outputs and inputs are as we expect them (see definition of CreateL2NormalizationWorkloadTest).
- CheckInputOutput(
- std::move(workload),
- TensorInfo({ 5, 20, 50, 67 }, armnn::DataType::Float32),
- TensorInfo({ 5, 20, 50, 67 }, armnn::DataType::Float32));
+ CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
+}
+
+BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32)
+{
+ RefCreateL2NormalizationTest<RefL2NormalizationFloat32Workload, armnn::DataType::Float32>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32Nhwc)
+{
+ RefCreateL2NormalizationTest<RefL2NormalizationFloat32Workload, armnn::DataType::Float32>(DataLayout::NHWC);
}
template <typename ReshapeWorkloadType, armnn::DataType DataType>
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 797051ee18..2815e342c0 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -211,11 +211,10 @@ ARMNN_AUTO_TEST_CASE(Pad2d, Pad2dTest)
ARMNN_AUTO_TEST_CASE(Pad3d, Pad3dTest)
ARMNN_AUTO_TEST_CASE(Pad4d, Pad4dTest)
-// NOTE: These tests are disabled until NHWC is supported by the reference L2Normalization implementation.
-//ARMNN_AUTO_TEST_CASE(L2Normalization1dNhwc, L2Normalization1dNhwcTest);
-//ARMNN_AUTO_TEST_CASE(L2Normalization2dNhwc, L2Normalization2dNhwcTest);
-//ARMNN_AUTO_TEST_CASE(L2Normalization3dNhwc, L2Normalization3dNhwcTest);
-//ARMNN_AUTO_TEST_CASE(L2Normalization4dNhwc, L2Normalization4dNhwcTest);
+ARMNN_AUTO_TEST_CASE(L2Normalization1dNhwc, L2Normalization1dNhwcTest)
+ARMNN_AUTO_TEST_CASE(L2Normalization2dNhwc, L2Normalization2dNhwcTest)
+ARMNN_AUTO_TEST_CASE(L2Normalization3dNhwc, L2Normalization3dNhwcTest)
+ARMNN_AUTO_TEST_CASE(L2Normalization4dNhwc, L2Normalization4dNhwcTest)
// Constant
ARMNN_AUTO_TEST_CASE(Constant, ConstantTest)
diff --git a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp
index 973c87b009..d21cfa947a 100644
--- a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp
+++ b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp
@@ -22,26 +22,32 @@ void RefL2NormalizationFloat32Workload::Execute() const
const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
- TensorBufferArrayView<const float> input(inputInfo.GetShape(), GetInputTensorDataFloat(0, m_Data));
- TensorBufferArrayView<float> output(outputInfo.GetShape(), GetOutputTensorDataFloat(0, m_Data));
+ TensorBufferArrayView<const float> input(inputInfo.GetShape(),
+ GetInputTensorDataFloat(0, m_Data),
+ m_Data.m_Parameters.m_DataLayout);
+ TensorBufferArrayView<float> output(outputInfo.GetShape(),
+ GetOutputTensorDataFloat(0, m_Data),
+ m_Data.m_Parameters.m_DataLayout);
- const unsigned int batchSize = inputInfo.GetShape()[0];
- const unsigned int depth = inputInfo.GetShape()[1];
- const unsigned int rows = inputInfo.GetShape()[2];
- const unsigned int cols = inputInfo.GetShape()[3];
+ DataLayoutIndexed dataLayout(m_Data.m_Parameters.m_DataLayout);
- for (unsigned int n = 0; n < batchSize; ++n)
+ const unsigned int batches = inputInfo.GetShape()[0];
+ const unsigned int channels = inputInfo.GetShape()[dataLayout.GetChannelsIndex()];
+ const unsigned int height = inputInfo.GetShape()[dataLayout.GetHeightIndex()];
+ const unsigned int width = inputInfo.GetShape()[dataLayout.GetWidthIndex()];
+
+ for (unsigned int n = 0; n < batches; ++n)
{
- for (unsigned int d = 0; d < depth; ++d)
+ for (unsigned int c = 0; c < channels; ++c)
{
- for (unsigned int h = 0; h < rows; ++h)
+ for (unsigned int h = 0; h < height; ++h)
{
- for (unsigned int w = 0; w < cols; ++w)
+ for (unsigned int w = 0; w < width; ++w)
{
float reduction = 0.0;
- for (unsigned int c = 0; c < depth; ++c)
+ for (unsigned int d = 0; d < channels; ++d)
{
- const float value = input.Get(n, c, h, w);
+ const float value = input.Get(n, d, h, w);
reduction += value * value;
}
@@ -51,7 +57,7 @@ void RefL2NormalizationFloat32Workload::Execute() const
// backend.
// - The reference semantics for this operator do not include this parameter.
const float scale = 1.0f / sqrtf(reduction);
- output.Get(n, d, h, w) = input.Get(n, d, h, w) * scale;
+ output.Get(n, c, h, w) = input.Get(n, c, h, w) * scale;
}
}
}
diff --git a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp
index 67055a9c37..b2e37954f5 100644
--- a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp
+++ b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp
@@ -15,7 +15,8 @@ class RefL2NormalizationFloat32Workload : public Float32Workload<L2Normalization
{
public:
using Float32Workload<L2NormalizationQueueDescriptor>::Float32Workload;
- virtual void Execute() const override;
+
+ void Execute() const override;
};
} //namespace armnn