From d73d14fd77fe1405a33b3ecf3c56e1ac65647ff7 Mon Sep 17 00:00:00 2001 From: Ferran Balaguer Date: Mon, 10 Jun 2019 10:29:54 +0100 Subject: IVGCVSW-3229 Refactor L2Normalization workload to support multiple data types Signed-off-by: Ferran Balaguer Change-Id: I848056aad4b172d432664633eea000843d85a85d --- src/backends/reference/workloads/CMakeLists.txt | 4 +- .../RefL2NormalizationFloat32Workload.cpp | 69 -------------------- .../RefL2NormalizationFloat32Workload.hpp | 22 ------- .../workloads/RefL2NormalizationWorkload.cpp | 75 ++++++++++++++++++++++ .../workloads/RefL2NormalizationWorkload.hpp | 23 +++++++ src/backends/reference/workloads/RefWorkloads.hpp | 2 +- 6 files changed, 101 insertions(+), 94 deletions(-) delete mode 100644 src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp delete mode 100644 src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp create mode 100644 src/backends/reference/workloads/RefL2NormalizationWorkload.cpp create mode 100644 src/backends/reference/workloads/RefL2NormalizationWorkload.hpp (limited to 'src/backends/reference/workloads') diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 4d11447280..41a553482d 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -70,8 +70,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefFullyConnectedWorkload.hpp RefGatherWorkload.cpp RefGatherWorkload.hpp - RefL2NormalizationFloat32Workload.cpp - RefL2NormalizationFloat32Workload.hpp + RefL2NormalizationWorkload.cpp + RefL2NormalizationWorkload.hpp RefLstmWorkload.cpp RefLstmWorkload.hpp RefConcatWorkload.cpp diff --git a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp deleted file mode 100644 index bc82739f6e..0000000000 --- a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp +++ /dev/null @@ -1,69 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "RefL2NormalizationFloat32Workload.hpp" - -#include "RefWorkloadUtils.hpp" -#include "TensorBufferArrayView.hpp" - -#include "Profiling.hpp" - -#include - -using namespace armnnUtils; - -namespace armnn -{ - -void RefL2NormalizationFloat32Workload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefL2NormalizationFloat32Workload_Execute"); - - const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); - const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); - - TensorBufferArrayView input(inputInfo.GetShape(), - GetInputTensorDataFloat(0, m_Data), - m_Data.m_Parameters.m_DataLayout); - TensorBufferArrayView output(outputInfo.GetShape(), - GetOutputTensorDataFloat(0, m_Data), - m_Data.m_Parameters.m_DataLayout); - - DataLayoutIndexed dataLayout(m_Data.m_Parameters.m_DataLayout); - - const unsigned int batches = inputInfo.GetShape()[0]; - const unsigned int channels = inputInfo.GetShape()[dataLayout.GetChannelsIndex()]; - const unsigned int height = inputInfo.GetShape()[dataLayout.GetHeightIndex()]; - const unsigned int width = inputInfo.GetShape()[dataLayout.GetWidthIndex()]; - - for (unsigned int n = 0; n < batches; ++n) - { - for (unsigned int c = 0; c < channels; ++c) - { - for (unsigned int h = 0; h < height; ++h) - { - for (unsigned int w = 0; w < width; ++w) - { - float reduction = 0.0; - for (unsigned int d = 0; d < channels; ++d) - { - const float value = input.Get(n, d, h, w); - reduction += value * value; - } - - // Using std::max(reduction, epsilon) below would prevent against division by 0. - // However, at the time of writing: - // - This is not supported by the ACL functions used to implement L2Normalization in the CL - // backend. - // - The reference semantics for this operator do not include this parameter. - const float scale = 1.0f / sqrtf(reduction); - output.Get(n, c, h, w) = input.Get(n, c, h, w) * scale; - } - } - } - } -} - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp deleted file mode 100644 index 50ece0e905..0000000000 --- a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include -#include - -namespace armnn -{ - -class RefL2NormalizationFloat32Workload : public Float32Workload -{ -public: - using Float32Workload::Float32Workload; - - void Execute() const override; -}; - -} //namespace armnn diff --git a/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp b/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp new file mode 100644 index 0000000000..ce5699ef0b --- /dev/null +++ b/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp @@ -0,0 +1,75 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefL2NormalizationWorkload.hpp" + +#include "RefWorkloadUtils.hpp" +#include "Decoders.hpp" +#include "Encoders.hpp" +#include "DataLayoutIndexed.hpp" + + +#include "Profiling.hpp" + +#include + +using namespace armnnUtils; + +namespace armnn +{ +RefL2NormalizationWorkload::RefL2NormalizationWorkload( + const L2NormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload(descriptor, info) {} + + void RefL2NormalizationWorkload::Execute() const + { + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefL2NormalizationWorkload_Execute"); + + const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + + auto inputDecoder = MakeDecoder(inputInfo, m_Data.m_Inputs[0]->Map()); + auto outputEncoder = MakeEncoder(outputInfo, m_Data.m_Outputs[0]->Map()); + + DataLayoutIndexed dataLayout(m_Data.m_Parameters.m_DataLayout); + + const unsigned int batches = inputInfo.GetShape()[0]; + const unsigned int channels = inputInfo.GetShape()[dataLayout.GetChannelsIndex()]; + const unsigned int height = inputInfo.GetShape()[dataLayout.GetHeightIndex()]; + const unsigned int width = inputInfo.GetShape()[dataLayout.GetWidthIndex()]; + + for (unsigned int n = 0; n < batches; ++n) + { + for (unsigned int c = 0; c < channels; ++c) + { + for (unsigned int h = 0; h < height; ++h) + { + for (unsigned int w = 0; w < width; ++w) + { + float reduction = 0.0; + for (unsigned int d = 0; d < channels; ++d) + { + unsigned int inputIndex = dataLayout.GetIndex(inputInfo.GetShape(), n, d, h, w); + + (*inputDecoder)[inputIndex]; + const float value = inputDecoder->Get(); + reduction += value * value; + } + + unsigned int index = dataLayout.GetIndex(inputInfo.GetShape(), n, c, h, w); + + const float scale = 1.0f / sqrtf(reduction); + + (*inputDecoder)[index]; + (*outputEncoder)[index]; + outputEncoder->Set(inputDecoder->Get() * scale); + } + } + } + } + } + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp b/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp new file mode 100644 index 0000000000..4beedc9992 --- /dev/null +++ b/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp @@ -0,0 +1,23 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include + +namespace armnn +{ + +class RefL2NormalizationWorkload : public BaseWorkload +{ +public: + explicit RefL2NormalizationWorkload(const L2NormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info); + + void Execute() const override; +}; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 53f7aa2efb..1a2dec402e 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -12,7 +12,7 @@ #include "RefConvolution2dWorkload.hpp" #include "RefSplitterWorkload.hpp" #include "RefResizeBilinearUint8Workload.hpp" -#include "RefL2NormalizationFloat32Workload.hpp" +#include "RefL2NormalizationWorkload.hpp" #include "RefActivationWorkload.hpp" #include "RefPooling2dWorkload.hpp" #include "RefWorkloadUtils.hpp" -- cgit v1.2.1