aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornikraj01 <nikhil.raj@arm.com>2019-06-06 10:31:27 +0100
committerNikhil Raj Arm <nikhil.raj@arm.com>2019-06-06 16:06:10 +0000
commit99a663140294afd2a4ea91ccc61b7266f735b46a (patch)
treebf6984e3734f0d5182a7fa510b80e3600e640e4f
parent0434df6030eee9bb6842b02a8b28598cfe8f3460 (diff)
downloadarmnn-99a663140294afd2a4ea91ccc61b7266f735b46a.tar.gz
IVGCVSW-3211 Refactor reference Rsqrt workload
Change-Id: Ia413c6b5352dbb3390e7d84e837a542c24ae8813 Signed-off-by: nikraj01 <nikhil.raj@arm.com>
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp26
-rw-r--r--src/backends/reference/backend.mk2
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp2
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt4
-rw-r--r--src/backends/reference/workloads/RefRsqrtFloat32Workload.cpp25
-rw-r--r--src/backends/reference/workloads/RefRsqrtWorkload.cpp37
-rw-r--r--src/backends/reference/workloads/RefRsqrtWorkload.hpp (renamed from src/backends/reference/workloads/RefRsqrtFloat32Workload.hpp)4
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp2
-rw-r--r--src/backends/reference/workloads/Rsqrt.cpp10
-rw-r--r--src/backends/reference/workloads/Rsqrt.hpp5
10 files changed, 78 insertions, 39 deletions
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 319a620d2b..1ef88a090e 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -43,6 +43,22 @@ bool IsFloat16(const WorkloadInfo& info)
return false;
}
+bool IsUint8(const WorkloadInfo& info)
+{
+ auto checkUint8 = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == DataType::QuantisedAsymm8;};
+ auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkUint8);
+ if (it != std::end(info.m_InputTensorInfos))
+ {
+ return true;
+ }
+ it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkUint8);
+ if (it != std::end(info.m_OutputTensorInfos))
+ {
+ return true;
+ }
+ return false;
+}
+
RefWorkloadFactory::RefWorkloadFactory()
{
}
@@ -382,7 +398,15 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDebug(const DebugQueueDescr
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<RefRsqrtFloat32Workload, NullWorkload>(descriptor, info);
+ if (IsFloat16(info))
+ {
+ return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ }
+ else if(IsUint8(info))
+ {
+ return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ }
+ return std::make_unique<RefRsqrtWorkload>(descriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGather(const armnn::GatherQueueDescriptor& descriptor,
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 2822c305f5..0d2b65d433 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -55,7 +55,7 @@ BACKEND_SOURCES := \
workloads/RefReshapeWorkload.cpp \
workloads/RefResizeBilinearFloat32Workload.cpp \
workloads/RefResizeBilinearUint8Workload.cpp \
- workloads/RefRsqrtFloat32Workload.cpp \
+ workloads/RefRsqrtWorkload.cpp \
workloads/RefSoftmaxWorkload.cpp \
workloads/RefSpaceToBatchNdWorkload.cpp \
workloads/RefStridedSliceWorkload.cpp \
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 82a4120d9a..5139888e39 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -674,7 +674,7 @@ static void RefCreateRsqrtTest()
BOOST_AUTO_TEST_CASE(CreateRsqrtFloat32)
{
- RefCreateRsqrtTest<RefRsqrtFloat32Workload, armnn::DataType::Float32>();
+ RefCreateRsqrtTest<RefRsqrtWorkload, armnn::DataType::Float32>();
}
template <typename L2NormalizationWorkloadType, armnn::DataType DataType>
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 9d5c4442fc..4d11447280 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -92,8 +92,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefResizeBilinearFloat32Workload.hpp
RefResizeBilinearUint8Workload.cpp
RefResizeBilinearUint8Workload.hpp
- RefRsqrtFloat32Workload.cpp
- RefRsqrtFloat32Workload.hpp
+ RefRsqrtWorkload.cpp
+ RefRsqrtWorkload.hpp
RefSoftmaxWorkload.cpp
RefSoftmaxWorkload.hpp
RefSpaceToBatchNdWorkload.cpp
diff --git a/src/backends/reference/workloads/RefRsqrtFloat32Workload.cpp b/src/backends/reference/workloads/RefRsqrtFloat32Workload.cpp
deleted file mode 100644
index c08dbf0cab..0000000000
--- a/src/backends/reference/workloads/RefRsqrtFloat32Workload.cpp
+++ /dev/null
@@ -1,25 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefRsqrtFloat32Workload.hpp"
-
-#include "RefWorkloadUtils.hpp"
-#include "Rsqrt.hpp"
-
-#include <Profiling.hpp>
-
-namespace armnn
-{
-
-void RefRsqrtFloat32Workload::Execute() const
-{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefRsqrtFloat32Workload_Execute");
-
- Rsqrt(GetInputTensorDataFloat(0, m_Data),
- GetOutputTensorDataFloat(0, m_Data),
- GetTensorInfo(m_Data.m_Inputs[0]));
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefRsqrtWorkload.cpp b/src/backends/reference/workloads/RefRsqrtWorkload.cpp
new file mode 100644
index 0000000000..fd6b9a3549
--- /dev/null
+++ b/src/backends/reference/workloads/RefRsqrtWorkload.cpp
@@ -0,0 +1,37 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefRsqrtWorkload.hpp"
+
+#include "Decoders.hpp"
+#include "Encoders.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "Rsqrt.hpp"
+
+#include <Profiling.hpp>
+
+namespace armnn
+{
+
+void RefRsqrtWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefRsqrtWorkload_Execute");
+
+ const TensorInfo& inputTensorInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+
+ std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputTensorInfo, m_Data.m_Inputs[0]->Map());
+ Decoder<float>& decoder = *decoderPtr;
+
+ const TensorInfo& outputTensorInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputTensorInfo, m_Data.m_Outputs[0]->Map());
+ Encoder<float>& encoder = *encoderPtr;
+
+ Rsqrt(decoder,
+ encoder,
+ GetTensorInfo(m_Data.m_Inputs[0]));
+}
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefRsqrtFloat32Workload.hpp b/src/backends/reference/workloads/RefRsqrtWorkload.hpp
index 9d1b4505fe..6c8ad5bc60 100644
--- a/src/backends/reference/workloads/RefRsqrtFloat32Workload.hpp
+++ b/src/backends/reference/workloads/RefRsqrtWorkload.hpp
@@ -11,10 +11,10 @@
namespace armnn
{
-class RefRsqrtFloat32Workload : public Float32Workload<RsqrtQueueDescriptor>
+class RefRsqrtWorkload : public BaseWorkload<RsqrtQueueDescriptor>
{
public:
- using Float32Workload<RsqrtQueueDescriptor>::Float32Workload;
+ using BaseWorkload<RsqrtQueueDescriptor>::BaseWorkload;
virtual void Execute() const override;
};
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 96f98ee7a8..53f7aa2efb 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -49,7 +49,7 @@
#include "RefBatchToSpaceNdUint8Workload.hpp"
#include "RefBatchToSpaceNdFloat32Workload.hpp"
#include "RefDebugWorkload.hpp"
-#include "RefRsqrtFloat32Workload.hpp"
+#include "RefRsqrtWorkload.hpp"
#include "RefDequantizeWorkload.hpp"
#include "RefQuantizeWorkload.hpp"
#include "RefReshapeWorkload.hpp"
diff --git a/src/backends/reference/workloads/Rsqrt.cpp b/src/backends/reference/workloads/Rsqrt.cpp
index cee38fc1f1..5abc2c8f7b 100644
--- a/src/backends/reference/workloads/Rsqrt.cpp
+++ b/src/backends/reference/workloads/Rsqrt.cpp
@@ -10,13 +10,15 @@
namespace armnn
{
-void Rsqrt(const float* in,
- float* out,
+void Rsqrt(Decoder<float>& in,
+ Encoder<float>& out,
const TensorInfo& tensorInfo)
{
- for (size_t i = 0; i < tensorInfo.GetNumElements(); i++)
+ for (unsigned int i = 0; i < tensorInfo.GetNumElements(); ++i)
{
- out[i] = 1.f / sqrtf(in[i]);
+ out[i];
+ in[i];
+ out.Set(1.f / sqrtf(in.Get()));
}
}
diff --git a/src/backends/reference/workloads/Rsqrt.hpp b/src/backends/reference/workloads/Rsqrt.hpp
index 35cacede66..ffc6b18d13 100644
--- a/src/backends/reference/workloads/Rsqrt.hpp
+++ b/src/backends/reference/workloads/Rsqrt.hpp
@@ -3,6 +3,7 @@
// SPDX-License-Identifier: MIT
//
+#include "BaseIterator.hpp"
#include <armnn/Tensor.hpp>
#include <armnn/Types.hpp>
@@ -11,8 +12,8 @@ namespace armnn
/// Performs the reciprocal squareroot function elementwise
/// on the inputs to give the outputs.
-void Rsqrt(const float* in,
- float* out,
+void Rsqrt(Decoder<float>& in,
+ Encoder<float>& out,
const TensorInfo& tensorInfo);
} //namespace armnn