From b63973ee1134336434a490fc9af8bba6cde79820 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Tue, 16 Oct 2018 16:23:33 +0100 Subject: IVGCVSW-2018 Support NHWC in the current ref implementation * Enabled the now supported ref layer tests * Re-enabled the failing test now that the bug has been fixed in ACL 1903a9976ae24f40cb2203364211ed62fcfbb985 * Added CreateWorkload test for ref L2Normalization NHWC * Refactoring the ref L2Normalization for clarity !armnn:153723 Change-Id: Id0067e49072b3e057ffe3ae3b70d928be6091c0f --- .../RefL2NormalizationFloat32Workload.cpp | 32 +++++++++++++--------- .../RefL2NormalizationFloat32Workload.hpp | 3 +- 2 files changed, 21 insertions(+), 14 deletions(-) (limited to 'src/backends/reference/workloads') diff --git a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp index 973c87b009..d21cfa947a 100644 --- a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp +++ b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp @@ -22,26 +22,32 @@ void RefL2NormalizationFloat32Workload::Execute() const const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); - TensorBufferArrayView input(inputInfo.GetShape(), GetInputTensorDataFloat(0, m_Data)); - TensorBufferArrayView output(outputInfo.GetShape(), GetOutputTensorDataFloat(0, m_Data)); + TensorBufferArrayView input(inputInfo.GetShape(), + GetInputTensorDataFloat(0, m_Data), + m_Data.m_Parameters.m_DataLayout); + TensorBufferArrayView output(outputInfo.GetShape(), + GetOutputTensorDataFloat(0, m_Data), + m_Data.m_Parameters.m_DataLayout); - const unsigned int batchSize = inputInfo.GetShape()[0]; - const unsigned int depth = inputInfo.GetShape()[1]; - const unsigned int rows = inputInfo.GetShape()[2]; - const unsigned int cols = inputInfo.GetShape()[3]; + DataLayoutIndexed dataLayout(m_Data.m_Parameters.m_DataLayout); - for (unsigned int n = 0; n < batchSize; ++n) + const unsigned int batches = inputInfo.GetShape()[0]; + const unsigned int channels = inputInfo.GetShape()[dataLayout.GetChannelsIndex()]; + const unsigned int height = inputInfo.GetShape()[dataLayout.GetHeightIndex()]; + const unsigned int width = inputInfo.GetShape()[dataLayout.GetWidthIndex()]; + + for (unsigned int n = 0; n < batches; ++n) { - for (unsigned int d = 0; d < depth; ++d) + for (unsigned int c = 0; c < channels; ++c) { - for (unsigned int h = 0; h < rows; ++h) + for (unsigned int h = 0; h < height; ++h) { - for (unsigned int w = 0; w < cols; ++w) + for (unsigned int w = 0; w < width; ++w) { float reduction = 0.0; - for (unsigned int c = 0; c < depth; ++c) + for (unsigned int d = 0; d < channels; ++d) { - const float value = input.Get(n, c, h, w); + const float value = input.Get(n, d, h, w); reduction += value * value; } @@ -51,7 +57,7 @@ void RefL2NormalizationFloat32Workload::Execute() const // backend. // - The reference semantics for this operator do not include this parameter. const float scale = 1.0f / sqrtf(reduction); - output.Get(n, d, h, w) = input.Get(n, d, h, w) * scale; + output.Get(n, c, h, w) = input.Get(n, c, h, w) * scale; } } } diff --git a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp index 67055a9c37..b2e37954f5 100644 --- a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp +++ b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp @@ -15,7 +15,8 @@ class RefL2NormalizationFloat32Workload : public Float32Workload::Float32Workload; - virtual void Execute() const override; + + void Execute() const override; }; } //namespace armnn -- cgit v1.2.1