aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp85
-rw-r--r--src/backends/reference/RefLayerSupport.hpp9
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp18
-rw-r--r--src/backends/reference/RefWorkloadFactory.hpp5
-rw-r--r--src/backends/reference/backend.mk5
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp35
-rw-r--r--src/backends/reference/test/RefEndToEndTests.cpp34
-rw-r--r--src/backends/reference/workloads/Abs.cpp23
-rw-r--r--src/backends/reference/workloads/Abs.hpp23
-rw-r--r--src/backends/reference/workloads/Broadcast.cpp21
-rw-r--r--src/backends/reference/workloads/Broadcast.hpp35
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt10
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.cpp58
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.hpp26
-rw-r--r--src/backends/reference/workloads/Exp.hpp22
-rw-r--r--src/backends/reference/workloads/RefAbsWorkload.cpp37
-rw-r--r--src/backends/reference/workloads/RefAbsWorkload.hpp21
-rw-r--r--src/backends/reference/workloads/RefComparisonWorkload.cpp12
-rw-r--r--src/backends/reference/workloads/RefElementwiseUnaryWorkload.cpp95
-rw-r--r--src/backends/reference/workloads/RefElementwiseUnaryWorkload.hpp33
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.cpp12
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefRsqrtWorkload.cpp37
-rw-r--r--src/backends/reference/workloads/RefRsqrtWorkload.hpp21
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp4
-rw-r--r--src/backends/reference/workloads/Rsqrt.cpp25
-rw-r--r--src/backends/reference/workloads/Rsqrt.hpp23
-rw-r--r--src/backends/reference/workloads/Sqrt.hpp22
28 files changed, 430 insertions, 325 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 26a61d45d5..491081dbac 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -70,28 +70,10 @@ std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- bool supported = true;
- std::array<DataType,4> supportedTypes =
- {
- DataType::Float32,
- DataType::Float16,
- DataType::QAsymmU8,
- DataType::QSymmS16
- };
-
- supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
- "Reference abs: input type not supported");
-
- supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
- "Reference abs: output type not supported");
-
- supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
- "Reference abs: input and output types not matching");
-
- supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
- "Reference abs: input and output shapes have different number of total elements");
-
- return supported;
+ return IsElementwiseUnarySupported(input,
+ output,
+ ElementwiseUnaryDescriptor(UnaryOperation::Abs),
+ reasonIfUnsupported);
}
bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
@@ -714,6 +696,39 @@ bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
return supported;
}
+bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const ElementwiseUnaryDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ boost::ignore_unused(descriptor);
+
+ std::array<DataType, 4> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QAsymmU8,
+ DataType::QSymmS16
+ };
+
+ bool supported = true;
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference elementwise unary: input type not supported");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference elementwise unary: output type not supported");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+ "Reference elementwise unary: input and output types not matching");
+
+ supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
+ "Reference elementwise unary: input and output shapes"
+ "have different number of total elements");
+
+ return supported;
+}
+
bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
@@ -1499,28 +1514,10 @@ bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- bool supported = true;
- std::array<DataType,4> supportedTypes =
- {
- DataType::Float32,
- DataType::Float16,
- DataType::QAsymmU8,
- DataType::QSymmS16
- };
-
- supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
- "Reference rsqrt: input type not supported");
-
- supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
- "Reference rsqrt: output type not supported");
-
- supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
- "Reference rsqrt: input and output types not matching");
-
- supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
- "Reference Rsqrt: input and output shapes have different number of total elements");
-
- return supported;
+ return IsElementwiseUnarySupported(input,
+ output,
+ ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
+ reasonIfUnsupported);
}
bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index a7d6303d86..123c2643df 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -12,6 +12,7 @@ namespace armnn
class RefLayerSupport : public LayerSupportBase
{
public:
+ ARMNN_DEPRECATED_MSG("Use IsElementwiseUnarySupported instead")
bool IsAbsSupported(const TensorInfo& input,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
@@ -117,6 +118,11 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsElementwiseUnarySupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const ElementwiseUnaryDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
bool IsEqualSupported(const TensorInfo& input0,
const TensorInfo& input1,
@@ -247,7 +253,8 @@ public:
const TensorInfo& output,
const ResizeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
-
+
+ ARMNN_DEPRECATED_MSG("Use IsElementwiseUnarySupported instead")
bool IsRsqrtSupported(const TensorInfo& input,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 2db47d35c2..e7a9c19fc7 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -98,7 +98,11 @@ std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const Tens
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateAbs(const AbsQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<RefAbsWorkload>(descriptor, info);
+ boost::ignore_unused(descriptor);
+ ElementwiseUnaryQueueDescriptor elementwiseUnaryDescriptor;
+ elementwiseUnaryDescriptor.m_Parameters.m_Operation = UnaryOperation::Abs;
+
+ return CreateElementwiseUnary(elementwiseUnaryDescriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
@@ -221,6 +225,12 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDivision(const DivisionQueu
return std::make_unique<RefDivisionWorkload>(descriptor, info);
}
+std::unique_ptr<IWorkload> RefWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return std::make_unique<RefElementwiseUnaryWorkload>(descriptor, info);
+}
+
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
@@ -463,7 +473,11 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateResizeBilinear(const Resize
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<RefRsqrtWorkload>(descriptor, info);
+ boost::ignore_unused(descriptor);
+ ElementwiseUnaryQueueDescriptor elementwiseUnaryDescriptor;
+ elementwiseUnaryDescriptor.m_Parameters.m_Operation = UnaryOperation::Rsqrt;
+
+ return CreateElementwiseUnary(elementwiseUnaryDescriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSlice(const SliceQueueDescriptor& descriptor,
diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp
index 80393c3f3a..b5b9b0faf0 100644
--- a/src/backends/reference/RefWorkloadFactory.hpp
+++ b/src/backends/reference/RefWorkloadFactory.hpp
@@ -59,6 +59,7 @@ public:
DataLayout dataLayout,
const bool IsMemoryManaged = true) const override;
+ ARMNN_DEPRECATED_MSG("Use CreateElementwiseUnary instead")
std::unique_ptr<IWorkload> CreateAbs(const AbsQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
@@ -113,6 +114,9 @@ public:
std::unique_ptr<IWorkload> CreateDivision(const DivisionQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
+
ARMNN_DEPRECATED_MSG("Use CreateComparison instead")
std::unique_ptr<IWorkload> CreateEqual(const EqualQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
@@ -204,6 +208,7 @@ public:
std::unique_ptr<IWorkload> CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ ARMNN_DEPRECATED_MSG("Use CreateElementwiseUnary instead")
std::unique_ptr<IWorkload> CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 5f9af59e74..412dc9438c 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -21,7 +21,6 @@ BACKEND_SOURCES := \
RefWorkloadFactory.cpp \
RefRegistryInitializer.cpp \
RefTensorHandleFactory.cpp \
- workloads/Abs.cpp \
workloads/Activation.cpp \
workloads/ArgMinMax.cpp \
workloads/BatchNormImpl.cpp \
@@ -43,7 +42,6 @@ BACKEND_SOURCES := \
workloads/Pad.cpp \
workloads/Pooling2d.cpp \
workloads/PreluImpl.cpp \
- workloads/RefAbsWorkload.cpp \
workloads/RefActivationWorkload.cpp \
workloads/RefArgMinMaxWorkload.cpp \
workloads/RefBatchNormalizationWorkload.cpp \
@@ -60,6 +58,7 @@ BACKEND_SOURCES := \
workloads/RefDequantizeWorkload.cpp \
workloads/RefDetectionPostProcessWorkload.cpp \
workloads/RefElementwiseWorkload.cpp \
+ workloads/RefElementwiseUnaryWorkload.cpp \
workloads/RefFakeQuantizationFloat32Workload.cpp \
workloads/RefFloorWorkload.cpp \
workloads/RefFullyConnectedWorkload.cpp \
@@ -78,7 +77,6 @@ BACKEND_SOURCES := \
workloads/RefReshapeWorkload.cpp \
workloads/RefResizeBilinearWorkload.cpp \
workloads/RefResizeWorkload.cpp \
- workloads/RefRsqrtWorkload.cpp \
workloads/RefSliceWorkload.cpp \
workloads/RefSoftmaxWorkload.cpp \
workloads/RefSpaceToBatchNdWorkload.cpp \
@@ -88,7 +86,6 @@ BACKEND_SOURCES := \
workloads/RefSplitterWorkload.cpp \
workloads/RefTransposeConvolution2dWorkload.cpp \
workloads/Resize.cpp \
- workloads/Rsqrt.cpp \
workloads/Slice.cpp \
workloads/SpaceToBatchNd.cpp \
workloads/SpaceToDepth.cpp \
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 23a8e9b9e9..b83d205970 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -717,41 +717,6 @@ BOOST_AUTO_TEST_CASE(CreateResizeBilinearFloat32Nhwc)
RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
}
-template <typename RsqrtWorkloadType, armnn::DataType DataType>
-static void RefCreateRsqrtTest()
-{
- Graph graph;
- RefWorkloadFactory factory = GetFactory();
-
- auto workload = CreateRsqrtWorkloadTest<RsqrtWorkloadType, DataType>(factory, graph);
-
- // Checks that outputs are as we expect them (see definition of CreateRsqrtWorkloadTest).
- CheckInputOutput(std::move(workload),
- TensorInfo({ 1, 1 }, DataType),
- TensorInfo({ 1, 1 }, DataType));
-
-}
-
-BOOST_AUTO_TEST_CASE(CreateRsqrtFloat32)
-{
- RefCreateRsqrtTest<RefRsqrtWorkload, armnn::DataType::Float32>();
-}
-
-BOOST_AUTO_TEST_CASE(CreateRsqrtFloat16)
-{
- RefCreateRsqrtTest<RefRsqrtWorkload, armnn::DataType::Float16>();
-}
-
-BOOST_AUTO_TEST_CASE(CreateRsqrtUint8)
-{
- RefCreateRsqrtTest<RefRsqrtWorkload, armnn::DataType::QAsymmU8>();
-}
-
-BOOST_AUTO_TEST_CASE(CreateRsqrtQsymm16)
-{
- RefCreateRsqrtTest<RefRsqrtWorkload, armnn::DataType::QSymmS16>();
-}
-
template <typename BatchToSpaceNdWorkloadType, armnn::DataType DataType>
static void RefCreateBatchToSpaceNdTest()
{
diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp
index 75eccdee88..54a68810f6 100644
--- a/src/backends/reference/test/RefEndToEndTests.cpp
+++ b/src/backends/reference/test/RefEndToEndTests.cpp
@@ -5,7 +5,6 @@
#include <backendsCommon/test/EndToEndTestImpl.hpp>
-#include <backendsCommon/test/AbsEndToEndTestImpl.hpp>
#include <backendsCommon/test/ArgMinMaxEndToEndTestImpl.hpp>
#include <backendsCommon/test/BatchToSpaceNdEndToEndTestImpl.hpp>
#include <backendsCommon/test/ComparisonEndToEndTestImpl.hpp>
@@ -13,6 +12,7 @@
#include <backendsCommon/test/DepthToSpaceEndToEndTestImpl.hpp>
#include <backendsCommon/test/DequantizeEndToEndTestImpl.hpp>
#include <backendsCommon/test/DetectionPostProcessEndToEndTestImpl.hpp>
+#include <backendsCommon/test/ElementwiseUnaryEndToEndTestImpl.hpp>
#include <backendsCommon/test/GatherEndToEndTestImpl.hpp>
#include <backendsCommon/test/InstanceNormalizationEndToEndTestImpl.hpp>
#include <backendsCommon/test/LogSoftmaxEndToEndTestImpl.hpp>
@@ -32,17 +32,43 @@ std::vector<armnn::BackendId> defaultBackends = {armnn::Compute::CpuRef};
// Abs
BOOST_AUTO_TEST_CASE(RefAbsEndToEndTestFloat32)
{
- AbsEndToEnd<armnn::DataType::Float32>(defaultBackends);
+ std::vector<float> expectedOutput =
+ {
+ 1.f, 1.f, 1.f, 1.f, 5.f, 5.f, 5.f, 5.f,
+ 3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 4.f
+ };
+
+ ElementwiseUnarySimpleEndToEnd<armnn::DataType::Float32>(defaultBackends,
+ UnaryOperation::Abs,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(RefAbsEndToEndTestUint8)
{
- AbsEndToEnd<armnn::DataType::QAsymmU8>(defaultBackends);
+ // Note the expected output will be implicitly quantized by the below test function
+ std::vector<float> expectedOutput =
+ {
+ 1.f, 1.f, 1.f, 1.f, 5.f, 5.f, 5.f, 5.f,
+ 3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 4.f
+ };
+
+ ElementwiseUnarySimpleEndToEnd<armnn::DataType::QAsymmU8>(defaultBackends,
+ UnaryOperation::Abs,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(RefAbsEndToEndTestInt16)
{
- AbsEndToEnd<armnn::DataType::QSymmS16>(defaultBackends);
+ // Note the expected output will be implicitly quantized by the below test function
+ std::vector<float> expectedOutput =
+ {
+ 1.f, 1.f, 1.f, 1.f, 5.f, 5.f, 5.f, 5.f,
+ 3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 4.f
+ };
+
+ ElementwiseUnarySimpleEndToEnd<armnn::DataType::QSymmS16>(defaultBackends,
+ UnaryOperation::Abs,
+ expectedOutput);
}
// Constant
diff --git a/src/backends/reference/workloads/Abs.cpp b/src/backends/reference/workloads/Abs.cpp
deleted file mode 100644
index 6a6a79ca56..0000000000
--- a/src/backends/reference/workloads/Abs.cpp
+++ /dev/null
@@ -1,23 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "Abs.hpp"
-
-namespace armnn
-{
-
-void Abs(Decoder<float>& in,
- Encoder<float>& out,
- const TensorInfo& tensorInfo)
-{
- for (unsigned int i = 0u; i < tensorInfo.GetNumElements(); ++i)
- {
- out[i];
- in[i];
- out.Set(std::abs(in.Get()));
- }
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/Abs.hpp b/src/backends/reference/workloads/Abs.hpp
index b1165d2d93..b05f2e3367 100644
--- a/src/backends/reference/workloads/Abs.hpp
+++ b/src/backends/reference/workloads/Abs.hpp
@@ -1,19 +1,22 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2019 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
-#include "BaseIterator.hpp"
-#include <armnn/Tensor.hpp>
-#include <armnn/Types.hpp>
+#pragma once
+
+#include <iostream>
namespace armnn
{
-
-/// Performs the absolute function elementwise
-/// on the inputs to give the outputs.
-void Abs(Decoder<float>& in,
- Encoder<float>& out,
- const TensorInfo& tensorInfo);
+ template<typename T>
+struct abs : public std::unary_function<T, T>
+ {
+ T
+ operator () (const T& inputData) const
+ {
+ return std::abs(inputData);
+ }
+ };
} //namespace armnn
diff --git a/src/backends/reference/workloads/Broadcast.cpp b/src/backends/reference/workloads/Broadcast.cpp
index 8421a0a7ed..24af0fc4b1 100644
--- a/src/backends/reference/workloads/Broadcast.cpp
+++ b/src/backends/reference/workloads/Broadcast.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2019 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -30,4 +30,23 @@ BroadcastLoop::BroadcastLoop(const TensorShape& inShape0, const TensorShape& inS
}
}
+BroadcastLoop::BroadcastLoop(const TensorShape& inShape, const TensorShape& outShape)
+: m_DimData(outShape.GetNumDimensions())
+{
+ const unsigned int numDims = GetNumDimensions();
+
+ unsigned int sIn = 1;
+ unsigned int sOut = 1;
+
+ for (unsigned int j = numDims - 1, k = 0; k < numDims ; k++, j--)
+ {
+ m_DimData[j].m_DimSize = outShape[j];
+ m_DimData[j].m_Stride1 = (inShape[j] > 1) ? sIn : 0;
+ m_DimData[j].m_StrideOut = sOut;
+
+ sIn *= inShape[j];
+ sOut *= outShape[j];
+ }
+}
+
} // namespace armnn
diff --git a/src/backends/reference/workloads/Broadcast.hpp b/src/backends/reference/workloads/Broadcast.hpp
index 5bf6be8939..a3d944ae75 100644
--- a/src/backends/reference/workloads/Broadcast.hpp
+++ b/src/backends/reference/workloads/Broadcast.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2019 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -15,6 +15,8 @@ struct BroadcastLoop
{
BroadcastLoop(const TensorShape& inShape0, const TensorShape& inShape1, const TensorShape& outShape);
+ BroadcastLoop(const TensorShape& inShape, const TensorShape& outShape);
+
unsigned int GetNumDimensions()
{
return static_cast<unsigned int>(m_DimData.size());
@@ -56,6 +58,37 @@ struct BroadcastLoop
outData -= outDataMovement;
}
+ template <typename Func, typename DecoderOp, typename EncoderOp>
+ void Unroll(Func operationFunc,
+ unsigned int dimension,
+ DecoderOp& inData,
+ EncoderOp& outData)
+ {
+ if (dimension >= GetNumDimensions())
+ {
+ outData.Set(operationFunc(inData.Get()));
+ return;
+ }
+
+ unsigned int inDataMovement = 0;
+ unsigned int outDataMovement = 0;
+
+ for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++)
+ {
+ Unroll(operationFunc, dimension + 1, inData, outData);
+
+ inData += m_DimData[dimension].m_Stride1;
+ outData += m_DimData[dimension].m_StrideOut;
+
+ inDataMovement += m_DimData[dimension].m_Stride1;
+ outDataMovement += m_DimData[dimension].m_StrideOut;
+ }
+
+ // move iterator back to the start
+ inData -= inDataMovement;
+ outData -= outDataMovement;
+ }
+
private:
// Struct to hold the dimension data.
struct BroadcastDimensionData
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index dbbdd89fd4..6795204d59 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -4,7 +4,6 @@
#
list(APPEND armnnRefBackendWorkloads_sources
- Abs.cpp
Abs.hpp
ArgMinMax.cpp
ArgMinMax.hpp
@@ -33,6 +32,7 @@ list(APPEND armnnRefBackendWorkloads_sources
ElementwiseFunction.cpp
ElementwiseFunction.hpp
Encoders.hpp
+ Exp.hpp
FullyConnected.cpp
FullyConnected.hpp
Gather.cpp
@@ -55,8 +55,6 @@ list(APPEND armnnRefBackendWorkloads_sources
Pooling2d.hpp
PreluImpl.cpp
PreluImpl.hpp
- RefAbsWorkload.cpp
- RefAbsWorkload.hpp
RefActivationWorkload.cpp
RefActivationWorkload.hpp
RefArgMinMaxWorkload.cpp
@@ -89,6 +87,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefDequantizeWorkload.hpp
RefDetectionPostProcessWorkload.cpp
RefDetectionPostProcessWorkload.hpp
+ RefElementwiseUnaryWorkload.cpp
+ RefElementwiseUnaryWorkload.hpp
RefFakeQuantizationFloat32Workload.cpp
RefFakeQuantizationFloat32Workload.hpp
RefFloorWorkload.cpp
@@ -125,8 +125,6 @@ list(APPEND armnnRefBackendWorkloads_sources
RefResizeBilinearWorkload.hpp
RefResizeWorkload.cpp
RefResizeWorkload.hpp
- RefRsqrtWorkload.cpp
- RefRsqrtWorkload.hpp
RefSliceWorkload.cpp
RefSliceWorkload.hpp
RefSoftmaxWorkload.cpp
@@ -147,7 +145,6 @@ list(APPEND armnnRefBackendWorkloads_sources
RefWorkloadUtils.hpp
Resize.cpp
Resize.hpp
- Rsqrt.cpp
Rsqrt.hpp
Slice.cpp
Slice.hpp
@@ -159,6 +156,7 @@ list(APPEND armnnRefBackendWorkloads_sources
SpaceToDepth.cpp
Splitter.hpp
Splitter.cpp
+ Sqrt.hpp
Stack.cpp
Stack.hpp
StridedSlice.hpp
diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp
index 888037f9a6..5687cf5861 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.cpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.cpp
@@ -7,36 +7,56 @@
#include "Broadcast.hpp"
#include <functional>
#include "Minimum.hpp"
-
#include "Maximum.hpp"
+#include "Abs.hpp"
+#include "Exp.hpp"
+#include "Rsqrt.hpp"
+#include "Sqrt.hpp"
+
namespace armnn
{
template <typename Functor>
-ElementwiseFunction<Functor>::ElementwiseFunction(const TensorShape& inShape0,
- const TensorShape& inShape1,
- const TensorShape& outShape,
- armnn::Decoder<InType>& inData0,
- armnn::Decoder<InType>& inData1,
- armnn::Encoder<OutType>& outData)
+ElementwiseBinaryFunction<Functor>::ElementwiseBinaryFunction(const TensorShape& inShape0,
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ Decoder<InType>& inData0,
+ Decoder<InType>& inData1,
+ Encoder<OutType>& outData)
{
BroadcastLoop(inShape0, inShape1, outShape).Unroll(Functor(), 0, inData0, inData1, outData);
}
+template <typename Functor>
+ElementwiseUnaryFunction<Functor>::ElementwiseUnaryFunction(const TensorShape& inShape,
+ const TensorShape& outShape,
+ Decoder<InType>& inData,
+ Encoder<OutType>& outData)
+{
+ BroadcastLoop(inShape, outShape).Unroll(Functor(), 0, inData, outData);
+}
+
} //namespace armnn
-template struct armnn::ElementwiseFunction<std::plus<float>>;
-template struct armnn::ElementwiseFunction<std::minus<float>>;
-template struct armnn::ElementwiseFunction<std::multiplies<float>>;
-template struct armnn::ElementwiseFunction<std::divides<float>>;
-template struct armnn::ElementwiseFunction<armnn::maximum<float>>;
-template struct armnn::ElementwiseFunction<armnn::minimum<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::plus<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::minus<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::multiplies<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::divides<float>>;
+template struct armnn::ElementwiseBinaryFunction<armnn::maximum<float>>;
+template struct armnn::ElementwiseBinaryFunction<armnn::minimum<float>>;
// Comparison
-template struct armnn::ElementwiseFunction<std::equal_to<float>>;
-template struct armnn::ElementwiseFunction<std::greater<float>>;
-template struct armnn::ElementwiseFunction<std::greater_equal<float>>;
-template struct armnn::ElementwiseFunction<std::less<float>>;
-template struct armnn::ElementwiseFunction<std::less_equal<float>>;
-template struct armnn::ElementwiseFunction<std::not_equal_to<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::equal_to<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::greater<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::greater_equal<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::less<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::less_equal<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::not_equal_to<float>>;
+
+// Unary
+template struct armnn::ElementwiseUnaryFunction<armnn::abs<float>>;
+template struct armnn::ElementwiseUnaryFunction<armnn::exp<float>>;
+template struct armnn::ElementwiseUnaryFunction<std::negate<float>>;
+template struct armnn::ElementwiseUnaryFunction<armnn::rsqrt<float>>;
+template struct armnn::ElementwiseUnaryFunction<armnn::sqrt<float>>;
diff --git a/src/backends/reference/workloads/ElementwiseFunction.hpp b/src/backends/reference/workloads/ElementwiseFunction.hpp
index fd1fab0690..8259ba5ac7 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.hpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.hpp
@@ -12,17 +12,29 @@ namespace armnn
{
template <typename Functor>
-struct ElementwiseFunction
+struct ElementwiseBinaryFunction
{
using OutType = typename Functor::result_type;
using InType = typename Functor::first_argument_type;
- ElementwiseFunction(const TensorShape& inShape0,
- const TensorShape& inShape1,
- const TensorShape& outShape,
- armnn::Decoder<InType>& inData0,
- armnn::Decoder<InType>& inData1,
- armnn::Encoder<OutType>& outData);
+ ElementwiseBinaryFunction(const TensorShape& inShape0,
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ Decoder<InType>& inData0,
+ Decoder<InType>& inData1,
+ Encoder<OutType>& outData);
+};
+
+template <typename Functor>
+struct ElementwiseUnaryFunction
+{
+ using OutType = typename Functor::result_type;
+ using InType = typename Functor::argument_type;
+
+ ElementwiseUnaryFunction(const TensorShape& inShape,
+ const TensorShape& outShape,
+ Decoder<InType>& inData,
+ Encoder<OutType>& outData);
};
} //namespace armnn
diff --git a/src/backends/reference/workloads/Exp.hpp b/src/backends/reference/workloads/Exp.hpp
new file mode 100644
index 0000000000..1a046728ba
--- /dev/null
+++ b/src/backends/reference/workloads/Exp.hpp
@@ -0,0 +1,22 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <iostream>
+
+namespace armnn
+{
+ template<typename T>
+struct exp : public std::unary_function<T, T>
+ {
+ T
+ operator () (const T& inputData) const
+ {
+ return std::exp(inputData);
+ }
+ };
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefAbsWorkload.cpp b/src/backends/reference/workloads/RefAbsWorkload.cpp
deleted file mode 100644
index 5c1f8c0c69..0000000000
--- a/src/backends/reference/workloads/RefAbsWorkload.cpp
+++ /dev/null
@@ -1,37 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefAbsWorkload.hpp"
-
-#include "Abs.hpp"
-#include "Decoders.hpp"
-#include "Encoders.hpp"
-#include "RefWorkloadUtils.hpp"
-
-#include <Profiling.hpp>
-
-namespace armnn
-{
-
-void RefAbsWorkload::Execute() const
-{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefAbsWorkload_Execute");
-
- const TensorInfo& inputTensorInfo = GetTensorInfo(m_Data.m_Inputs[0]);
-
- std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputTensorInfo, m_Data.m_Inputs[0]->Map());
- Decoder<float>& decoder = *decoderPtr;
-
- const TensorInfo& outputTensorInfo = GetTensorInfo(m_Data.m_Outputs[0]);
-
- std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputTensorInfo, m_Data.m_Outputs[0]->Map());
- Encoder<float>& encoder = *encoderPtr;
-
- Abs(decoder,
- encoder,
- inputTensorInfo);
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefAbsWorkload.hpp b/src/backends/reference/workloads/RefAbsWorkload.hpp
deleted file mode 100644
index 68105556d5..0000000000
--- a/src/backends/reference/workloads/RefAbsWorkload.hpp
+++ /dev/null
@@ -1,21 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <backendsCommon/Workload.hpp>
-#include <backendsCommon/WorkloadData.hpp>
-
-namespace armnn
-{
-
-class RefAbsWorkload : public BaseWorkload<AbsQueueDescriptor>
-{
-public:
- using BaseWorkload<AbsQueueDescriptor>::BaseWorkload;
- virtual void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp
index 60446226be..52ad9a2879 100644
--- a/src/backends/reference/workloads/RefComparisonWorkload.cpp
+++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp
@@ -52,12 +52,12 @@ void RefComparisonWorkload::Execute() const
m_Input1->Reset(m_Data.m_Inputs[1]->Map());
m_Output->Reset(m_Data.m_Outputs[0]->Map());
- using EqualFunction = ElementwiseFunction<std::equal_to<InType>>;
- using GreaterFunction = ElementwiseFunction<std::greater<InType>>;
- using GreaterOrEqualFunction = ElementwiseFunction<std::greater_equal<InType>>;
- using LessFunction = ElementwiseFunction<std::less<InType>>;
- using LessOrEqualFunction = ElementwiseFunction<std::less_equal<InType>>;
- using NotEqualFunction = ElementwiseFunction<std::not_equal_to<InType>>;
+ using EqualFunction = ElementwiseBinaryFunction<std::equal_to<InType>>;
+ using GreaterFunction = ElementwiseBinaryFunction<std::greater<InType>>;
+ using GreaterOrEqualFunction = ElementwiseBinaryFunction<std::greater_equal<InType>>;
+ using LessFunction = ElementwiseBinaryFunction<std::less<InType>>;
+ using LessOrEqualFunction = ElementwiseBinaryFunction<std::less_equal<InType>>;
+ using NotEqualFunction = ElementwiseBinaryFunction<std::not_equal_to<InType>>;
switch (m_Data.m_Parameters.m_Operation)
{
diff --git a/src/backends/reference/workloads/RefElementwiseUnaryWorkload.cpp b/src/backends/reference/workloads/RefElementwiseUnaryWorkload.cpp
new file mode 100644
index 0000000000..4fbb0d123f
--- /dev/null
+++ b/src/backends/reference/workloads/RefElementwiseUnaryWorkload.cpp
@@ -0,0 +1,95 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefElementwiseUnaryWorkload.hpp"
+
+#include "Decoders.hpp"
+#include "ElementwiseFunction.hpp"
+#include "Encoders.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "Abs.hpp"
+#include "Exp.hpp"
+#include "Rsqrt.hpp"
+#include "Sqrt.hpp"
+
+#include <Profiling.hpp>
+
+#include <armnn/TypesUtils.hpp>
+
+#include <functional>
+
+namespace armnn
+{
+
+RefElementwiseUnaryWorkload::RefElementwiseUnaryWorkload(const ElementwiseUnaryQueueDescriptor& desc,
+ const WorkloadInfo& info)
+ : BaseWorkload<ElementwiseUnaryQueueDescriptor>(desc, info)
+{}
+
+void RefElementwiseUnaryWorkload::PostAllocationConfigure()
+{
+ const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ m_Input = MakeDecoder<InType>(inputInfo);
+
+ m_Output = MakeEncoder<OutType>(outputInfo);
+}
+
+void RefElementwiseUnaryWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefElementwiseUnaryWorkload_Execute");
+
+ const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ const TensorShape& inShape = inputInfo.GetShape();
+ const TensorShape& outShape = outputInfo.GetShape();
+
+ m_Input->Reset(m_Data.m_Inputs[0]->Map());
+ m_Output->Reset(m_Data.m_Outputs[0]->Map());
+
+ using AbsFunction = ElementwiseUnaryFunction<abs<InType>>;
+ using ExpFunction = ElementwiseUnaryFunction<exp<InType>>;
+ using NegFunction = ElementwiseUnaryFunction<std::negate<InType>>;
+ using RsqrtFunction = ElementwiseUnaryFunction<rsqrt<InType>>;
+ using SqrtFunction = ElementwiseUnaryFunction<sqrt<InType>>;
+
+ switch (m_Data.m_Parameters.m_Operation)
+ {
+ case UnaryOperation::Abs:
+ {
+ AbsFunction(inShape, outShape, *m_Input, *m_Output);
+ break;
+ }
+ case UnaryOperation::Exp:
+ {
+ ExpFunction(inShape, outShape, *m_Input, *m_Output);
+ break;
+ }
+ case UnaryOperation::Neg:
+ {
+ NegFunction(inShape, outShape, *m_Input, *m_Output);
+ break;
+ }
+ case UnaryOperation::Rsqrt:
+ {
+ RsqrtFunction(inShape, outShape, *m_Input, *m_Output);
+ break;
+ }
+ case UnaryOperation::Sqrt:
+ {
+ SqrtFunction(inShape, outShape, *m_Input, *m_Output);
+ break;
+ }
+ default:
+ {
+ throw InvalidArgumentException(std::string("Unsupported unary operation ") +
+ GetUnaryOperationAsCString(m_Data.m_Parameters.m_Operation), CHECK_LOCATION());
+ }
+ }
+}
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefElementwiseUnaryWorkload.hpp b/src/backends/reference/workloads/RefElementwiseUnaryWorkload.hpp
new file mode 100644
index 0000000000..efb2865ebd
--- /dev/null
+++ b/src/backends/reference/workloads/RefElementwiseUnaryWorkload.hpp
@@ -0,0 +1,33 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "BaseIterator.hpp"
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+
+namespace armnn
+{
+
+class RefElementwiseUnaryWorkload : public BaseWorkload<ElementwiseUnaryQueueDescriptor>
+{
+public:
+ using BaseWorkload<ElementwiseUnaryQueueDescriptor>::m_Data;
+
+ RefElementwiseUnaryWorkload(const ElementwiseUnaryQueueDescriptor& descriptor, const WorkloadInfo& info);
+ void PostAllocationConfigure() override;
+ void Execute() const override;
+
+private:
+ using InType = float;
+ using OutType = float;
+
+ std::unique_ptr<Decoder<InType>> m_Input;
+ std::unique_ptr<Encoder<OutType>> m_Output;
+};
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
index 7e02f032ef..18bf0a7ad9 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
@@ -53,12 +53,12 @@ void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::Execute() c
m_Input1->Reset(m_Data.m_Inputs[1]->Map());
m_Output->Reset(m_Data.m_Outputs[0]->Map());
- ElementwiseFunction<Functor>(inShape0,
- inShape1,
- outShape,
- *m_Input0,
- *m_Input1,
- *m_Output);
+ ElementwiseBinaryFunction<Functor>(inShape0,
+ inShape1,
+ outShape,
+ *m_Input0,
+ *m_Input1,
+ *m_Output);
}
} //namespace armnn
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
index ee0d80b172..264ddce2de 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
@@ -21,8 +21,8 @@ template <typename Functor, typename ParentDescriptor, typename armnn::StringMap
class RefElementwiseWorkload : public BaseWorkload<ParentDescriptor>
{
public:
- using InType = typename ElementwiseFunction<Functor>::InType;
- using OutType = typename ElementwiseFunction<Functor>::OutType;
+ using InType = typename ElementwiseBinaryFunction<Functor>::InType;
+ using OutType = typename ElementwiseBinaryFunction<Functor>::OutType;
using BaseWorkload<ParentDescriptor>::m_Data;
RefElementwiseWorkload(const ParentDescriptor& descriptor, const WorkloadInfo& info);
diff --git a/src/backends/reference/workloads/RefRsqrtWorkload.cpp b/src/backends/reference/workloads/RefRsqrtWorkload.cpp
deleted file mode 100644
index fd6b9a3549..0000000000
--- a/src/backends/reference/workloads/RefRsqrtWorkload.cpp
+++ /dev/null
@@ -1,37 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefRsqrtWorkload.hpp"
-
-#include "Decoders.hpp"
-#include "Encoders.hpp"
-#include "RefWorkloadUtils.hpp"
-#include "Rsqrt.hpp"
-
-#include <Profiling.hpp>
-
-namespace armnn
-{
-
-void RefRsqrtWorkload::Execute() const
-{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefRsqrtWorkload_Execute");
-
- const TensorInfo& inputTensorInfo = GetTensorInfo(m_Data.m_Inputs[0]);
-
- std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputTensorInfo, m_Data.m_Inputs[0]->Map());
- Decoder<float>& decoder = *decoderPtr;
-
- const TensorInfo& outputTensorInfo = GetTensorInfo(m_Data.m_Outputs[0]);
-
- std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputTensorInfo, m_Data.m_Outputs[0]->Map());
- Encoder<float>& encoder = *encoderPtr;
-
- Rsqrt(decoder,
- encoder,
- GetTensorInfo(m_Data.m_Inputs[0]));
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefRsqrtWorkload.hpp b/src/backends/reference/workloads/RefRsqrtWorkload.hpp
deleted file mode 100644
index 6c8ad5bc60..0000000000
--- a/src/backends/reference/workloads/RefRsqrtWorkload.hpp
+++ /dev/null
@@ -1,21 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <backendsCommon/Workload.hpp>
-#include <backendsCommon/WorkloadData.hpp>
-
-namespace armnn
-{
-
-class RefRsqrtWorkload : public BaseWorkload<RsqrtQueueDescriptor>
-{
-public:
- using BaseWorkload<RsqrtQueueDescriptor>::BaseWorkload;
- virtual void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 1f9ad4a19a..7034b67aa5 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -5,7 +5,6 @@
#pragma once
-#include "Abs.hpp"
#include "Activation.hpp"
#include "ArgMinMax.hpp"
#include "BatchNormImpl.hpp"
@@ -15,7 +14,6 @@
#include "FullyConnected.hpp"
#include "Gather.hpp"
#include "Pooling2d.hpp"
-#include "RefAbsWorkload.hpp"
#include "RefActivationWorkload.hpp"
#include "RefArgMinMaxWorkload.hpp"
#include "RefBatchNormalizationWorkload.hpp"
@@ -33,6 +31,7 @@
#include "RefDetectionPostProcessWorkload.hpp"
#include "RefDequantizeWorkload.hpp"
#include "RefElementwiseWorkload.hpp"
+#include "RefElementwiseUnaryWorkload.hpp"
#include "RefFullyConnectedWorkload.hpp"
#include "RefFloorWorkload.hpp"
#include "RefFakeQuantizationFloat32Workload.hpp"
@@ -51,7 +50,6 @@
#include "RefReshapeWorkload.hpp"
#include "RefResizeBilinearWorkload.hpp"
#include "RefResizeWorkload.hpp"
-#include "RefRsqrtWorkload.hpp"
#include "RefSliceWorkload.hpp"
#include "RefSplitterWorkload.hpp"
#include "RefSoftmaxWorkload.hpp"
diff --git a/src/backends/reference/workloads/Rsqrt.cpp b/src/backends/reference/workloads/Rsqrt.cpp
deleted file mode 100644
index 5abc2c8f7b..0000000000
--- a/src/backends/reference/workloads/Rsqrt.cpp
+++ /dev/null
@@ -1,25 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "Rsqrt.hpp"
-
-#include <cmath>
-
-namespace armnn
-{
-
-void Rsqrt(Decoder<float>& in,
- Encoder<float>& out,
- const TensorInfo& tensorInfo)
-{
- for (unsigned int i = 0; i < tensorInfo.GetNumElements(); ++i)
- {
- out[i];
- in[i];
- out.Set(1.f / sqrtf(in.Get()));
- }
-}
-
-} //namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/Rsqrt.hpp b/src/backends/reference/workloads/Rsqrt.hpp
index ffc6b18d13..47ebcf36f6 100644
--- a/src/backends/reference/workloads/Rsqrt.hpp
+++ b/src/backends/reference/workloads/Rsqrt.hpp
@@ -1,19 +1,22 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2019 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
-#include "BaseIterator.hpp"
-#include <armnn/Tensor.hpp>
-#include <armnn/Types.hpp>
+#pragma once
+
+#include <iostream>
namespace armnn
{
-
-/// Performs the reciprocal squareroot function elementwise
-/// on the inputs to give the outputs.
-void Rsqrt(Decoder<float>& in,
- Encoder<float>& out,
- const TensorInfo& tensorInfo);
+ template<typename T>
+struct rsqrt : public std::unary_function<T, T>
+ {
+ T
+ operator () (const T& inputData) const
+ {
+ return 1 / std::sqrt(inputData);
+ }
+ };
} //namespace armnn
diff --git a/src/backends/reference/workloads/Sqrt.hpp b/src/backends/reference/workloads/Sqrt.hpp
new file mode 100644
index 0000000000..e4ff6a4829
--- /dev/null
+++ b/src/backends/reference/workloads/Sqrt.hpp
@@ -0,0 +1,22 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <iostream>
+
+namespace armnn
+{
+ template<typename T>
+struct sqrt : public std::unary_function<T, T>
+ {
+ T
+ operator () (const T& inputData) const
+ {
+ return std::sqrt(inputData);
+ }
+ };
+
+} //namespace armnn