aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt4
-rw-r--r--src/backends/reference/workloads/RefRsqrtFloat32Workload.cpp25
-rw-r--r--src/backends/reference/workloads/RefRsqrtWorkload.cpp37
-rw-r--r--src/backends/reference/workloads/RefRsqrtWorkload.hpp (renamed from src/backends/reference/workloads/RefRsqrtFloat32Workload.hpp)4
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp2
-rw-r--r--src/backends/reference/workloads/Rsqrt.cpp10
-rw-r--r--src/backends/reference/workloads/Rsqrt.hpp5
7 files changed, 51 insertions, 36 deletions
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 9d5c4442fc..4d11447280 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -92,8 +92,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefResizeBilinearFloat32Workload.hpp
RefResizeBilinearUint8Workload.cpp
RefResizeBilinearUint8Workload.hpp
- RefRsqrtFloat32Workload.cpp
- RefRsqrtFloat32Workload.hpp
+ RefRsqrtWorkload.cpp
+ RefRsqrtWorkload.hpp
RefSoftmaxWorkload.cpp
RefSoftmaxWorkload.hpp
RefSpaceToBatchNdWorkload.cpp
diff --git a/src/backends/reference/workloads/RefRsqrtFloat32Workload.cpp b/src/backends/reference/workloads/RefRsqrtFloat32Workload.cpp
deleted file mode 100644
index c08dbf0cab..0000000000
--- a/src/backends/reference/workloads/RefRsqrtFloat32Workload.cpp
+++ /dev/null
@@ -1,25 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefRsqrtFloat32Workload.hpp"
-
-#include "RefWorkloadUtils.hpp"
-#include "Rsqrt.hpp"
-
-#include <Profiling.hpp>
-
-namespace armnn
-{
-
-void RefRsqrtFloat32Workload::Execute() const
-{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefRsqrtFloat32Workload_Execute");
-
- Rsqrt(GetInputTensorDataFloat(0, m_Data),
- GetOutputTensorDataFloat(0, m_Data),
- GetTensorInfo(m_Data.m_Inputs[0]));
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefRsqrtWorkload.cpp b/src/backends/reference/workloads/RefRsqrtWorkload.cpp
new file mode 100644
index 0000000000..fd6b9a3549
--- /dev/null
+++ b/src/backends/reference/workloads/RefRsqrtWorkload.cpp
@@ -0,0 +1,37 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefRsqrtWorkload.hpp"
+
+#include "Decoders.hpp"
+#include "Encoders.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "Rsqrt.hpp"
+
+#include <Profiling.hpp>
+
+namespace armnn
+{
+
+void RefRsqrtWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefRsqrtWorkload_Execute");
+
+ const TensorInfo& inputTensorInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+
+ std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputTensorInfo, m_Data.m_Inputs[0]->Map());
+ Decoder<float>& decoder = *decoderPtr;
+
+ const TensorInfo& outputTensorInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputTensorInfo, m_Data.m_Outputs[0]->Map());
+ Encoder<float>& encoder = *encoderPtr;
+
+ Rsqrt(decoder,
+ encoder,
+ GetTensorInfo(m_Data.m_Inputs[0]));
+}
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefRsqrtFloat32Workload.hpp b/src/backends/reference/workloads/RefRsqrtWorkload.hpp
index 9d1b4505fe..6c8ad5bc60 100644
--- a/src/backends/reference/workloads/RefRsqrtFloat32Workload.hpp
+++ b/src/backends/reference/workloads/RefRsqrtWorkload.hpp
@@ -11,10 +11,10 @@
namespace armnn
{
-class RefRsqrtFloat32Workload : public Float32Workload<RsqrtQueueDescriptor>
+class RefRsqrtWorkload : public BaseWorkload<RsqrtQueueDescriptor>
{
public:
- using Float32Workload<RsqrtQueueDescriptor>::Float32Workload;
+ using BaseWorkload<RsqrtQueueDescriptor>::BaseWorkload;
virtual void Execute() const override;
};
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 96f98ee7a8..53f7aa2efb 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -49,7 +49,7 @@
#include "RefBatchToSpaceNdUint8Workload.hpp"
#include "RefBatchToSpaceNdFloat32Workload.hpp"
#include "RefDebugWorkload.hpp"
-#include "RefRsqrtFloat32Workload.hpp"
+#include "RefRsqrtWorkload.hpp"
#include "RefDequantizeWorkload.hpp"
#include "RefQuantizeWorkload.hpp"
#include "RefReshapeWorkload.hpp"
diff --git a/src/backends/reference/workloads/Rsqrt.cpp b/src/backends/reference/workloads/Rsqrt.cpp
index cee38fc1f1..5abc2c8f7b 100644
--- a/src/backends/reference/workloads/Rsqrt.cpp
+++ b/src/backends/reference/workloads/Rsqrt.cpp
@@ -10,13 +10,15 @@
namespace armnn
{
-void Rsqrt(const float* in,
- float* out,
+void Rsqrt(Decoder<float>& in,
+ Encoder<float>& out,
const TensorInfo& tensorInfo)
{
- for (size_t i = 0; i < tensorInfo.GetNumElements(); i++)
+ for (unsigned int i = 0; i < tensorInfo.GetNumElements(); ++i)
{
- out[i] = 1.f / sqrtf(in[i]);
+ out[i];
+ in[i];
+ out.Set(1.f / sqrtf(in.Get()));
}
}
diff --git a/src/backends/reference/workloads/Rsqrt.hpp b/src/backends/reference/workloads/Rsqrt.hpp
index 35cacede66..ffc6b18d13 100644
--- a/src/backends/reference/workloads/Rsqrt.hpp
+++ b/src/backends/reference/workloads/Rsqrt.hpp
@@ -3,6 +3,7 @@
// SPDX-License-Identifier: MIT
//
+#include "BaseIterator.hpp"
#include <armnn/Tensor.hpp>
#include <armnn/Types.hpp>
@@ -11,8 +12,8 @@ namespace armnn
/// Performs the reciprocal squareroot function elementwise
/// on the inputs to give the outputs.
-void Rsqrt(const float* in,
- float* out,
+void Rsqrt(Decoder<float>& in,
+ Encoder<float>& out,
const TensorInfo& tensorInfo);
} //namespace armnn