diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2018-10-16 16:23:33 +0100 |
---|---|---|
committer | Matthew Bentham <matthew.bentham@arm.com> | 2018-10-22 16:57:54 +0100 |
commit | b63973ee1134336434a490fc9af8bba6cde79820 (patch) | |
tree | 1304b693044697454bc10cd52b7a4746444b5feb /src/backends/reference/workloads | |
parent | 177d8d26925a58a579943e010d28d1ceaa033d64 (diff) | |
download | armnn-b63973ee1134336434a490fc9af8bba6cde79820.tar.gz |
IVGCVSW-2018 Support NHWC in the current ref implementation
* Enabled the now supported ref layer tests
* Re-enabled the failing test now that the bug has been fixed in
ACL 1903a9976ae24f40cb2203364211ed62fcfbb985
* Added CreateWorkload test for ref L2Normalization NHWC
* Refactoring the ref L2Normalization for clarity
!armnn:153723
Change-Id: Id0067e49072b3e057ffe3ae3b70d928be6091c0f
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r-- | src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp | 32 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp | 3 |
2 files changed, 21 insertions, 14 deletions
diff --git a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp index 973c87b009..d21cfa947a 100644 --- a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp +++ b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp @@ -22,26 +22,32 @@ void RefL2NormalizationFloat32Workload::Execute() const const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); - TensorBufferArrayView<const float> input(inputInfo.GetShape(), GetInputTensorDataFloat(0, m_Data)); - TensorBufferArrayView<float> output(outputInfo.GetShape(), GetOutputTensorDataFloat(0, m_Data)); + TensorBufferArrayView<const float> input(inputInfo.GetShape(), + GetInputTensorDataFloat(0, m_Data), + m_Data.m_Parameters.m_DataLayout); + TensorBufferArrayView<float> output(outputInfo.GetShape(), + GetOutputTensorDataFloat(0, m_Data), + m_Data.m_Parameters.m_DataLayout); - const unsigned int batchSize = inputInfo.GetShape()[0]; - const unsigned int depth = inputInfo.GetShape()[1]; - const unsigned int rows = inputInfo.GetShape()[2]; - const unsigned int cols = inputInfo.GetShape()[3]; + DataLayoutIndexed dataLayout(m_Data.m_Parameters.m_DataLayout); - for (unsigned int n = 0; n < batchSize; ++n) + const unsigned int batches = inputInfo.GetShape()[0]; + const unsigned int channels = inputInfo.GetShape()[dataLayout.GetChannelsIndex()]; + const unsigned int height = inputInfo.GetShape()[dataLayout.GetHeightIndex()]; + const unsigned int width = inputInfo.GetShape()[dataLayout.GetWidthIndex()]; + + for (unsigned int n = 0; n < batches; ++n) { - for (unsigned int d = 0; d < depth; ++d) + for (unsigned int c = 0; c < channels; ++c) { - for (unsigned int h = 0; h < rows; ++h) + for (unsigned int h = 0; h < height; ++h) { - for (unsigned int w = 0; w < cols; ++w) + for (unsigned int w = 0; w < width; ++w) { float reduction = 0.0; - for (unsigned int c = 0; c < depth; ++c) + for (unsigned int d = 0; d < channels; ++d) { - const float value = input.Get(n, c, h, w); + const float value = input.Get(n, d, h, w); reduction += value * value; } @@ -51,7 +57,7 @@ void RefL2NormalizationFloat32Workload::Execute() const // backend. // - The reference semantics for this operator do not include this parameter. const float scale = 1.0f / sqrtf(reduction); - output.Get(n, d, h, w) = input.Get(n, d, h, w) * scale; + output.Get(n, c, h, w) = input.Get(n, c, h, w) * scale; } } } diff --git a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp index 67055a9c37..b2e37954f5 100644 --- a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp +++ b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp @@ -15,7 +15,8 @@ class RefL2NormalizationFloat32Workload : public Float32Workload<L2Normalization { public: using Float32Workload<L2NormalizationQueueDescriptor>::Float32Workload; - virtual void Execute() const override; + + void Execute() const override; }; } //namespace armnn |