diff options
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r-- | src/backends/reference/workloads/CMakeLists.txt | 4 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefNormalizationWorkload.cpp (renamed from src/backends/reference/workloads/RefNormalizationFloat32Workload.cpp) | 106 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefNormalizationWorkload.hpp (renamed from src/backends/reference/workloads/RefNormalizationFloat32Workload.hpp) | 8 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefWorkloads.hpp | 2 |
4 files changed, 71 insertions, 49 deletions
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 82502c513c..9d5c4442fc 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -76,8 +76,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefLstmWorkload.hpp RefConcatWorkload.cpp RefConcatWorkload.hpp - RefNormalizationFloat32Workload.cpp - RefNormalizationFloat32Workload.hpp + RefNormalizationWorkload.cpp + RefNormalizationWorkload.hpp RefPadWorkload.cpp RefPadWorkload.hpp RefPermuteWorkload.cpp diff --git a/src/backends/reference/workloads/RefNormalizationFloat32Workload.cpp b/src/backends/reference/workloads/RefNormalizationWorkload.cpp index 3a2f2b9658..8ff2d9cf92 100644 --- a/src/backends/reference/workloads/RefNormalizationFloat32Workload.cpp +++ b/src/backends/reference/workloads/RefNormalizationWorkload.cpp @@ -3,31 +3,34 @@ // SPDX-License-Identifier: MIT // -#include "RefNormalizationFloat32Workload.hpp" +#include "RefNormalizationWorkload.hpp" #include "RefWorkloadUtils.hpp" -#include "TensorBufferArrayView.hpp" - -#include "Profiling.hpp" +#include "Decoders.hpp" +#include "Encoders.hpp" #include <armnn/Tensor.hpp> +#include <DataLayoutIndexed.hpp> +#include <Profiling.hpp> + #include <boost/log/trivial.hpp> #include <boost/numeric/conversion/cast.hpp> +using namespace armnn; using namespace armnnUtils; -namespace armnn +namespace { // Helper function to compute "Within" normalization using Krichevsky 2012: Local Brightness Normalization. -static void NormalizeWithinUingLbr(const float* inputData, - float* outputData, - const TensorShape& tensorShape, - uint32_t norm_size, - float alpha, - float beta, - float kappa) +void NormalizeWithinUingLbr(Decoder<float>& inputData, + Encoder<float>& outputData, + const TensorShape& tensorShape, + uint32_t norm_size, + float alpha, + float beta, + float kappa) { const unsigned int batchSize = tensorShape[0]; const unsigned int depth = tensorShape[1]; @@ -62,21 +65,24 @@ static void NormalizeWithinUingLbr(const float* inputData, continue; } - float inval = inputData[n * cols * rows * depth + - c * cols * rows + - boost::numeric_cast<unsigned int>(j) * cols + - boost::numeric_cast<unsigned int>(i)]; + unsigned int inputIndex = n * cols * rows * depth + + c * cols * rows + + boost::numeric_cast<unsigned int>(j) * cols + + boost::numeric_cast<unsigned int>(i); + inputData[inputIndex]; + float inval = inputData.Get(); accumulated_scale += inval*inval; } } - outputData[n * cols * rows * depth + - c * cols * rows + - h * cols + - w] = inputData[n * cols * rows * depth + - c * cols * rows + - h * cols + - w] / (powf((kappa + (accumulated_scale * alpha)), beta)); + + unsigned int index = n * cols * rows * depth + + c * cols * rows + + h * cols + + w; + inputData[index]; + outputData[index]; + outputData.Set(inputData.Get() / (powf((kappa + (accumulated_scale * alpha)), beta))); } } } @@ -84,8 +90,8 @@ static void NormalizeWithinUingLbr(const float* inputData, } // Helper function to compute "Across" normalization using Krichevsky 2012: Local Brightness Normalization. -void NormalizeAcrossUingLbr(const float* inputData, - float* outputData, +void NormalizeAcrossUingLbr(Decoder<float>& inputData, + Encoder<float>& outputData, const TensorShape& tensorShape, uint32_t norm_size, float alpha, @@ -93,13 +99,6 @@ void NormalizeAcrossUingLbr(const float* inputData, float kappa, DataLayout dataLayout) { - TensorBufferArrayView<const float> input(tensorShape, - inputData, - dataLayout); - TensorBufferArrayView<float> output(tensorShape, - outputData, - dataLayout); - DataLayoutIndexed dataLayoutIndexed(dataLayout); const unsigned int batchSize = tensorShape[0]; @@ -127,7 +126,14 @@ void NormalizeAcrossUingLbr(const float* inputData, continue; } - float inval = input.Get(n, boost::numeric_cast<unsigned int>(k), h, w); + unsigned inputIndex = dataLayoutIndexed.GetIndex(tensorShape, + n, + boost::numeric_cast<unsigned int>(k), + h, + w); + + inputData[inputIndex]; + float inval = inputData.Get(); accumulated_scale += inval * inval; } @@ -135,28 +141,42 @@ void NormalizeAcrossUingLbr(const float* inputData, float scale = kappa + (accumulated_scale * alpha); scale = powf(scale, -beta); - output.Get(n, c, h, w) = scale * input.Get(n, c, h, w); + unsigned index = dataLayoutIndexed.GetIndex(tensorShape, n, c, h, w); + + inputData[index]; + outputData[index]; + outputData.Set(scale * inputData.Get()); } } } } } -void RefNormalizationFloat32Workload::Execute() const +} // Anonymous namespace + +namespace armnn +{ + +RefNormalizationWorkload::RefNormalizationWorkload(const NormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload(descriptor, info) +{} + +void RefNormalizationWorkload::Execute() const { - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefNormalizationFloat32Workload_Execute"); + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefNormalizationWorkload_Execute"); const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); - float* outputData = GetOutputTensorDataFloat(0, m_Data); - const float* inputData = GetInputTensorDataFloat(0, m_Data); + auto inputDecoder = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map()); + auto outputEncoder = MakeEncoder<float>(inputInfo, m_Data.m_Outputs[0]->Map()); if (NormalizationAlgorithmMethod::LocalBrightness == m_Data.m_Parameters.m_NormMethodType) { if (NormalizationAlgorithmChannel::Within == m_Data.m_Parameters.m_NormChannelType) { - NormalizeWithinUingLbr(inputData, - outputData, + NormalizeWithinUingLbr(*inputDecoder, + *outputEncoder, inputInfo.GetShape(), m_Data.m_Parameters.m_NormSize, m_Data.m_Parameters.m_Alpha, @@ -165,8 +185,8 @@ void RefNormalizationFloat32Workload::Execute() const } else if (NormalizationAlgorithmChannel::Across == m_Data.m_Parameters.m_NormChannelType) { - NormalizeAcrossUingLbr(inputData, - outputData, + NormalizeAcrossUingLbr(*inputDecoder, + *outputEncoder, inputInfo.GetShape(), m_Data.m_Parameters.m_NormSize, m_Data.m_Parameters.m_Alpha, @@ -187,4 +207,4 @@ void RefNormalizationFloat32Workload::Execute() const } } -} //namespace armnn +} // namespace armnn diff --git a/src/backends/reference/workloads/RefNormalizationFloat32Workload.hpp b/src/backends/reference/workloads/RefNormalizationWorkload.hpp index 9dff187bd4..6d33c8afb2 100644 --- a/src/backends/reference/workloads/RefNormalizationFloat32Workload.hpp +++ b/src/backends/reference/workloads/RefNormalizationWorkload.hpp @@ -11,11 +11,13 @@ namespace armnn { -class RefNormalizationFloat32Workload : public Float32Workload<NormalizationQueueDescriptor> +class RefNormalizationWorkload : public BaseWorkload<NormalizationQueueDescriptor> { public: - using Float32Workload<NormalizationQueueDescriptor>::Float32Workload; + explicit RefNormalizationWorkload(const NormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info); + virtual void Execute() const override; }; -} //namespace armnn +} // namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index ce1e688dcf..96f98ee7a8 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -30,7 +30,7 @@ #include "RefSoftmaxWorkload.hpp" #include "RefResizeBilinearFloat32Workload.hpp" #include "ResizeBilinear.hpp" -#include "RefNormalizationFloat32Workload.hpp" +#include "RefNormalizationWorkload.hpp" #include "RefDetectionPostProcessWorkload.hpp" #include "BatchNormImpl.hpp" #include "Activation.hpp" |