aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt4
-rw-r--r--src/backends/reference/workloads/RefNormalizationWorkload.cpp (renamed from src/backends/reference/workloads/RefNormalizationFloat32Workload.cpp)106
-rw-r--r--src/backends/reference/workloads/RefNormalizationWorkload.hpp (renamed from src/backends/reference/workloads/RefNormalizationFloat32Workload.hpp)8
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp2
4 files changed, 71 insertions, 49 deletions
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 82502c513c..9d5c4442fc 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -76,8 +76,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefLstmWorkload.hpp
RefConcatWorkload.cpp
RefConcatWorkload.hpp
- RefNormalizationFloat32Workload.cpp
- RefNormalizationFloat32Workload.hpp
+ RefNormalizationWorkload.cpp
+ RefNormalizationWorkload.hpp
RefPadWorkload.cpp
RefPadWorkload.hpp
RefPermuteWorkload.cpp
diff --git a/src/backends/reference/workloads/RefNormalizationFloat32Workload.cpp b/src/backends/reference/workloads/RefNormalizationWorkload.cpp
index 3a2f2b9658..8ff2d9cf92 100644
--- a/src/backends/reference/workloads/RefNormalizationFloat32Workload.cpp
+++ b/src/backends/reference/workloads/RefNormalizationWorkload.cpp
@@ -3,31 +3,34 @@
// SPDX-License-Identifier: MIT
//
-#include "RefNormalizationFloat32Workload.hpp"
+#include "RefNormalizationWorkload.hpp"
#include "RefWorkloadUtils.hpp"
-#include "TensorBufferArrayView.hpp"
-
-#include "Profiling.hpp"
+#include "Decoders.hpp"
+#include "Encoders.hpp"
#include <armnn/Tensor.hpp>
+#include <DataLayoutIndexed.hpp>
+#include <Profiling.hpp>
+
#include <boost/log/trivial.hpp>
#include <boost/numeric/conversion/cast.hpp>
+using namespace armnn;
using namespace armnnUtils;
-namespace armnn
+namespace
{
// Helper function to compute "Within" normalization using Krichevsky 2012: Local Brightness Normalization.
-static void NormalizeWithinUingLbr(const float* inputData,
- float* outputData,
- const TensorShape& tensorShape,
- uint32_t norm_size,
- float alpha,
- float beta,
- float kappa)
+void NormalizeWithinUingLbr(Decoder<float>& inputData,
+ Encoder<float>& outputData,
+ const TensorShape& tensorShape,
+ uint32_t norm_size,
+ float alpha,
+ float beta,
+ float kappa)
{
const unsigned int batchSize = tensorShape[0];
const unsigned int depth = tensorShape[1];
@@ -62,21 +65,24 @@ static void NormalizeWithinUingLbr(const float* inputData,
continue;
}
- float inval = inputData[n * cols * rows * depth +
- c * cols * rows +
- boost::numeric_cast<unsigned int>(j) * cols +
- boost::numeric_cast<unsigned int>(i)];
+ unsigned int inputIndex = n * cols * rows * depth +
+ c * cols * rows +
+ boost::numeric_cast<unsigned int>(j) * cols +
+ boost::numeric_cast<unsigned int>(i);
+ inputData[inputIndex];
+ float inval = inputData.Get();
accumulated_scale += inval*inval;
}
}
- outputData[n * cols * rows * depth +
- c * cols * rows +
- h * cols +
- w] = inputData[n * cols * rows * depth +
- c * cols * rows +
- h * cols +
- w] / (powf((kappa + (accumulated_scale * alpha)), beta));
+
+ unsigned int index = n * cols * rows * depth +
+ c * cols * rows +
+ h * cols +
+ w;
+ inputData[index];
+ outputData[index];
+ outputData.Set(inputData.Get() / (powf((kappa + (accumulated_scale * alpha)), beta)));
}
}
}
@@ -84,8 +90,8 @@ static void NormalizeWithinUingLbr(const float* inputData,
}
// Helper function to compute "Across" normalization using Krichevsky 2012: Local Brightness Normalization.
-void NormalizeAcrossUingLbr(const float* inputData,
- float* outputData,
+void NormalizeAcrossUingLbr(Decoder<float>& inputData,
+ Encoder<float>& outputData,
const TensorShape& tensorShape,
uint32_t norm_size,
float alpha,
@@ -93,13 +99,6 @@ void NormalizeAcrossUingLbr(const float* inputData,
float kappa,
DataLayout dataLayout)
{
- TensorBufferArrayView<const float> input(tensorShape,
- inputData,
- dataLayout);
- TensorBufferArrayView<float> output(tensorShape,
- outputData,
- dataLayout);
-
DataLayoutIndexed dataLayoutIndexed(dataLayout);
const unsigned int batchSize = tensorShape[0];
@@ -127,7 +126,14 @@ void NormalizeAcrossUingLbr(const float* inputData,
continue;
}
- float inval = input.Get(n, boost::numeric_cast<unsigned int>(k), h, w);
+ unsigned inputIndex = dataLayoutIndexed.GetIndex(tensorShape,
+ n,
+ boost::numeric_cast<unsigned int>(k),
+ h,
+ w);
+
+ inputData[inputIndex];
+ float inval = inputData.Get();
accumulated_scale += inval * inval;
}
@@ -135,28 +141,42 @@ void NormalizeAcrossUingLbr(const float* inputData,
float scale = kappa + (accumulated_scale * alpha);
scale = powf(scale, -beta);
- output.Get(n, c, h, w) = scale * input.Get(n, c, h, w);
+ unsigned index = dataLayoutIndexed.GetIndex(tensorShape, n, c, h, w);
+
+ inputData[index];
+ outputData[index];
+ outputData.Set(scale * inputData.Get());
}
}
}
}
}
-void RefNormalizationFloat32Workload::Execute() const
+} // Anonymous namespace
+
+namespace armnn
+{
+
+RefNormalizationWorkload::RefNormalizationWorkload(const NormalizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info)
+ : BaseWorkload(descriptor, info)
+{}
+
+void RefNormalizationWorkload::Execute() const
{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefNormalizationFloat32Workload_Execute");
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefNormalizationWorkload_Execute");
const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
- float* outputData = GetOutputTensorDataFloat(0, m_Data);
- const float* inputData = GetInputTensorDataFloat(0, m_Data);
+ auto inputDecoder = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map());
+ auto outputEncoder = MakeEncoder<float>(inputInfo, m_Data.m_Outputs[0]->Map());
if (NormalizationAlgorithmMethod::LocalBrightness == m_Data.m_Parameters.m_NormMethodType)
{
if (NormalizationAlgorithmChannel::Within == m_Data.m_Parameters.m_NormChannelType)
{
- NormalizeWithinUingLbr(inputData,
- outputData,
+ NormalizeWithinUingLbr(*inputDecoder,
+ *outputEncoder,
inputInfo.GetShape(),
m_Data.m_Parameters.m_NormSize,
m_Data.m_Parameters.m_Alpha,
@@ -165,8 +185,8 @@ void RefNormalizationFloat32Workload::Execute() const
}
else if (NormalizationAlgorithmChannel::Across == m_Data.m_Parameters.m_NormChannelType)
{
- NormalizeAcrossUingLbr(inputData,
- outputData,
+ NormalizeAcrossUingLbr(*inputDecoder,
+ *outputEncoder,
inputInfo.GetShape(),
m_Data.m_Parameters.m_NormSize,
m_Data.m_Parameters.m_Alpha,
@@ -187,4 +207,4 @@ void RefNormalizationFloat32Workload::Execute() const
}
}
-} //namespace armnn
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefNormalizationFloat32Workload.hpp b/src/backends/reference/workloads/RefNormalizationWorkload.hpp
index 9dff187bd4..6d33c8afb2 100644
--- a/src/backends/reference/workloads/RefNormalizationFloat32Workload.hpp
+++ b/src/backends/reference/workloads/RefNormalizationWorkload.hpp
@@ -11,11 +11,13 @@
namespace armnn
{
-class RefNormalizationFloat32Workload : public Float32Workload<NormalizationQueueDescriptor>
+class RefNormalizationWorkload : public BaseWorkload<NormalizationQueueDescriptor>
{
public:
- using Float32Workload<NormalizationQueueDescriptor>::Float32Workload;
+ explicit RefNormalizationWorkload(const NormalizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info);
+
virtual void Execute() const override;
};
-} //namespace armnn
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index ce1e688dcf..96f98ee7a8 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -30,7 +30,7 @@
#include "RefSoftmaxWorkload.hpp"
#include "RefResizeBilinearFloat32Workload.hpp"
#include "ResizeBilinear.hpp"
-#include "RefNormalizationFloat32Workload.hpp"
+#include "RefNormalizationWorkload.hpp"
#include "RefDetectionPostProcessWorkload.hpp"
#include "BatchNormImpl.hpp"
#include "Activation.hpp"