aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r--src/backends/reference/workloads/Abs.cpp23
-rw-r--r--src/backends/reference/workloads/Abs.hpp23
-rw-r--r--src/backends/reference/workloads/Broadcast.cpp21
-rw-r--r--src/backends/reference/workloads/Broadcast.hpp35
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt10
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.cpp58
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.hpp26
-rw-r--r--src/backends/reference/workloads/Exp.hpp22
-rw-r--r--src/backends/reference/workloads/RefAbsWorkload.cpp37
-rw-r--r--src/backends/reference/workloads/RefAbsWorkload.hpp21
-rw-r--r--src/backends/reference/workloads/RefComparisonWorkload.cpp12
-rw-r--r--src/backends/reference/workloads/RefElementwiseUnaryWorkload.cpp95
-rw-r--r--src/backends/reference/workloads/RefElementwiseUnaryWorkload.hpp33
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.cpp12
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefRsqrtWorkload.cpp37
-rw-r--r--src/backends/reference/workloads/RefRsqrtWorkload.hpp21
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp4
-rw-r--r--src/backends/reference/workloads/Rsqrt.cpp25
-rw-r--r--src/backends/reference/workloads/Rsqrt.hpp23
-rw-r--r--src/backends/reference/workloads/Sqrt.hpp22
21 files changed, 329 insertions, 235 deletions
diff --git a/src/backends/reference/workloads/Abs.cpp b/src/backends/reference/workloads/Abs.cpp
deleted file mode 100644
index 6a6a79ca56..0000000000
--- a/src/backends/reference/workloads/Abs.cpp
+++ /dev/null
@@ -1,23 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "Abs.hpp"
-
-namespace armnn
-{
-
-void Abs(Decoder<float>& in,
- Encoder<float>& out,
- const TensorInfo& tensorInfo)
-{
- for (unsigned int i = 0u; i < tensorInfo.GetNumElements(); ++i)
- {
- out[i];
- in[i];
- out.Set(std::abs(in.Get()));
- }
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/Abs.hpp b/src/backends/reference/workloads/Abs.hpp
index b1165d2d93..b05f2e3367 100644
--- a/src/backends/reference/workloads/Abs.hpp
+++ b/src/backends/reference/workloads/Abs.hpp
@@ -1,19 +1,22 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2019 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
-#include "BaseIterator.hpp"
-#include <armnn/Tensor.hpp>
-#include <armnn/Types.hpp>
+#pragma once
+
+#include <iostream>
namespace armnn
{
-
-/// Performs the absolute function elementwise
-/// on the inputs to give the outputs.
-void Abs(Decoder<float>& in,
- Encoder<float>& out,
- const TensorInfo& tensorInfo);
+ template<typename T>
+struct abs : public std::unary_function<T, T>
+ {
+ T
+ operator () (const T& inputData) const
+ {
+ return std::abs(inputData);
+ }
+ };
} //namespace armnn
diff --git a/src/backends/reference/workloads/Broadcast.cpp b/src/backends/reference/workloads/Broadcast.cpp
index 8421a0a7ed..24af0fc4b1 100644
--- a/src/backends/reference/workloads/Broadcast.cpp
+++ b/src/backends/reference/workloads/Broadcast.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2019 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -30,4 +30,23 @@ BroadcastLoop::BroadcastLoop(const TensorShape& inShape0, const TensorShape& inS
}
}
+BroadcastLoop::BroadcastLoop(const TensorShape& inShape, const TensorShape& outShape)
+: m_DimData(outShape.GetNumDimensions())
+{
+ const unsigned int numDims = GetNumDimensions();
+
+ unsigned int sIn = 1;
+ unsigned int sOut = 1;
+
+ for (unsigned int j = numDims - 1, k = 0; k < numDims ; k++, j--)
+ {
+ m_DimData[j].m_DimSize = outShape[j];
+ m_DimData[j].m_Stride1 = (inShape[j] > 1) ? sIn : 0;
+ m_DimData[j].m_StrideOut = sOut;
+
+ sIn *= inShape[j];
+ sOut *= outShape[j];
+ }
+}
+
} // namespace armnn
diff --git a/src/backends/reference/workloads/Broadcast.hpp b/src/backends/reference/workloads/Broadcast.hpp
index 5bf6be8939..a3d944ae75 100644
--- a/src/backends/reference/workloads/Broadcast.hpp
+++ b/src/backends/reference/workloads/Broadcast.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2019 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -15,6 +15,8 @@ struct BroadcastLoop
{
BroadcastLoop(const TensorShape& inShape0, const TensorShape& inShape1, const TensorShape& outShape);
+ BroadcastLoop(const TensorShape& inShape, const TensorShape& outShape);
+
unsigned int GetNumDimensions()
{
return static_cast<unsigned int>(m_DimData.size());
@@ -56,6 +58,37 @@ struct BroadcastLoop
outData -= outDataMovement;
}
+ template <typename Func, typename DecoderOp, typename EncoderOp>
+ void Unroll(Func operationFunc,
+ unsigned int dimension,
+ DecoderOp& inData,
+ EncoderOp& outData)
+ {
+ if (dimension >= GetNumDimensions())
+ {
+ outData.Set(operationFunc(inData.Get()));
+ return;
+ }
+
+ unsigned int inDataMovement = 0;
+ unsigned int outDataMovement = 0;
+
+ for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++)
+ {
+ Unroll(operationFunc, dimension + 1, inData, outData);
+
+ inData += m_DimData[dimension].m_Stride1;
+ outData += m_DimData[dimension].m_StrideOut;
+
+ inDataMovement += m_DimData[dimension].m_Stride1;
+ outDataMovement += m_DimData[dimension].m_StrideOut;
+ }
+
+ // move iterator back to the start
+ inData -= inDataMovement;
+ outData -= outDataMovement;
+ }
+
private:
// Struct to hold the dimension data.
struct BroadcastDimensionData
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index dbbdd89fd4..6795204d59 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -4,7 +4,6 @@
#
list(APPEND armnnRefBackendWorkloads_sources
- Abs.cpp
Abs.hpp
ArgMinMax.cpp
ArgMinMax.hpp
@@ -33,6 +32,7 @@ list(APPEND armnnRefBackendWorkloads_sources
ElementwiseFunction.cpp
ElementwiseFunction.hpp
Encoders.hpp
+ Exp.hpp
FullyConnected.cpp
FullyConnected.hpp
Gather.cpp
@@ -55,8 +55,6 @@ list(APPEND armnnRefBackendWorkloads_sources
Pooling2d.hpp
PreluImpl.cpp
PreluImpl.hpp
- RefAbsWorkload.cpp
- RefAbsWorkload.hpp
RefActivationWorkload.cpp
RefActivationWorkload.hpp
RefArgMinMaxWorkload.cpp
@@ -89,6 +87,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefDequantizeWorkload.hpp
RefDetectionPostProcessWorkload.cpp
RefDetectionPostProcessWorkload.hpp
+ RefElementwiseUnaryWorkload.cpp
+ RefElementwiseUnaryWorkload.hpp
RefFakeQuantizationFloat32Workload.cpp
RefFakeQuantizationFloat32Workload.hpp
RefFloorWorkload.cpp
@@ -125,8 +125,6 @@ list(APPEND armnnRefBackendWorkloads_sources
RefResizeBilinearWorkload.hpp
RefResizeWorkload.cpp
RefResizeWorkload.hpp
- RefRsqrtWorkload.cpp
- RefRsqrtWorkload.hpp
RefSliceWorkload.cpp
RefSliceWorkload.hpp
RefSoftmaxWorkload.cpp
@@ -147,7 +145,6 @@ list(APPEND armnnRefBackendWorkloads_sources
RefWorkloadUtils.hpp
Resize.cpp
Resize.hpp
- Rsqrt.cpp
Rsqrt.hpp
Slice.cpp
Slice.hpp
@@ -159,6 +156,7 @@ list(APPEND armnnRefBackendWorkloads_sources
SpaceToDepth.cpp
Splitter.hpp
Splitter.cpp
+ Sqrt.hpp
Stack.cpp
Stack.hpp
StridedSlice.hpp
diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp
index 888037f9a6..5687cf5861 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.cpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.cpp
@@ -7,36 +7,56 @@
#include "Broadcast.hpp"
#include <functional>
#include "Minimum.hpp"
-
#include "Maximum.hpp"
+#include "Abs.hpp"
+#include "Exp.hpp"
+#include "Rsqrt.hpp"
+#include "Sqrt.hpp"
+
namespace armnn
{
template <typename Functor>
-ElementwiseFunction<Functor>::ElementwiseFunction(const TensorShape& inShape0,
- const TensorShape& inShape1,
- const TensorShape& outShape,
- armnn::Decoder<InType>& inData0,
- armnn::Decoder<InType>& inData1,
- armnn::Encoder<OutType>& outData)
+ElementwiseBinaryFunction<Functor>::ElementwiseBinaryFunction(const TensorShape& inShape0,
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ Decoder<InType>& inData0,
+ Decoder<InType>& inData1,
+ Encoder<OutType>& outData)
{
BroadcastLoop(inShape0, inShape1, outShape).Unroll(Functor(), 0, inData0, inData1, outData);
}
+template <typename Functor>
+ElementwiseUnaryFunction<Functor>::ElementwiseUnaryFunction(const TensorShape& inShape,
+ const TensorShape& outShape,
+ Decoder<InType>& inData,
+ Encoder<OutType>& outData)
+{
+ BroadcastLoop(inShape, outShape).Unroll(Functor(), 0, inData, outData);
+}
+
} //namespace armnn
-template struct armnn::ElementwiseFunction<std::plus<float>>;
-template struct armnn::ElementwiseFunction<std::minus<float>>;
-template struct armnn::ElementwiseFunction<std::multiplies<float>>;
-template struct armnn::ElementwiseFunction<std::divides<float>>;
-template struct armnn::ElementwiseFunction<armnn::maximum<float>>;
-template struct armnn::ElementwiseFunction<armnn::minimum<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::plus<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::minus<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::multiplies<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::divides<float>>;
+template struct armnn::ElementwiseBinaryFunction<armnn::maximum<float>>;
+template struct armnn::ElementwiseBinaryFunction<armnn::minimum<float>>;
// Comparison
-template struct armnn::ElementwiseFunction<std::equal_to<float>>;
-template struct armnn::ElementwiseFunction<std::greater<float>>;
-template struct armnn::ElementwiseFunction<std::greater_equal<float>>;
-template struct armnn::ElementwiseFunction<std::less<float>>;
-template struct armnn::ElementwiseFunction<std::less_equal<float>>;
-template struct armnn::ElementwiseFunction<std::not_equal_to<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::equal_to<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::greater<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::greater_equal<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::less<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::less_equal<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::not_equal_to<float>>;
+
+// Unary
+template struct armnn::ElementwiseUnaryFunction<armnn::abs<float>>;
+template struct armnn::ElementwiseUnaryFunction<armnn::exp<float>>;
+template struct armnn::ElementwiseUnaryFunction<std::negate<float>>;
+template struct armnn::ElementwiseUnaryFunction<armnn::rsqrt<float>>;
+template struct armnn::ElementwiseUnaryFunction<armnn::sqrt<float>>;
diff --git a/src/backends/reference/workloads/ElementwiseFunction.hpp b/src/backends/reference/workloads/ElementwiseFunction.hpp
index fd1fab0690..8259ba5ac7 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.hpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.hpp
@@ -12,17 +12,29 @@ namespace armnn
{
template <typename Functor>
-struct ElementwiseFunction
+struct ElementwiseBinaryFunction
{
using OutType = typename Functor::result_type;
using InType = typename Functor::first_argument_type;
- ElementwiseFunction(const TensorShape& inShape0,
- const TensorShape& inShape1,
- const TensorShape& outShape,
- armnn::Decoder<InType>& inData0,
- armnn::Decoder<InType>& inData1,
- armnn::Encoder<OutType>& outData);
+ ElementwiseBinaryFunction(const TensorShape& inShape0,
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ Decoder<InType>& inData0,
+ Decoder<InType>& inData1,
+ Encoder<OutType>& outData);
+};
+
+template <typename Functor>
+struct ElementwiseUnaryFunction
+{
+ using OutType = typename Functor::result_type;
+ using InType = typename Functor::argument_type;
+
+ ElementwiseUnaryFunction(const TensorShape& inShape,
+ const TensorShape& outShape,
+ Decoder<InType>& inData,
+ Encoder<OutType>& outData);
};
} //namespace armnn
diff --git a/src/backends/reference/workloads/Exp.hpp b/src/backends/reference/workloads/Exp.hpp
new file mode 100644
index 0000000000..1a046728ba
--- /dev/null
+++ b/src/backends/reference/workloads/Exp.hpp
@@ -0,0 +1,22 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <iostream>
+
+namespace armnn
+{
+ template<typename T>
+struct exp : public std::unary_function<T, T>
+ {
+ T
+ operator () (const T& inputData) const
+ {
+ return std::exp(inputData);
+ }
+ };
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefAbsWorkload.cpp b/src/backends/reference/workloads/RefAbsWorkload.cpp
deleted file mode 100644
index 5c1f8c0c69..0000000000
--- a/src/backends/reference/workloads/RefAbsWorkload.cpp
+++ /dev/null
@@ -1,37 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefAbsWorkload.hpp"
-
-#include "Abs.hpp"
-#include "Decoders.hpp"
-#include "Encoders.hpp"
-#include "RefWorkloadUtils.hpp"
-
-#include <Profiling.hpp>
-
-namespace armnn
-{
-
-void RefAbsWorkload::Execute() const
-{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefAbsWorkload_Execute");
-
- const TensorInfo& inputTensorInfo = GetTensorInfo(m_Data.m_Inputs[0]);
-
- std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputTensorInfo, m_Data.m_Inputs[0]->Map());
- Decoder<float>& decoder = *decoderPtr;
-
- const TensorInfo& outputTensorInfo = GetTensorInfo(m_Data.m_Outputs[0]);
-
- std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputTensorInfo, m_Data.m_Outputs[0]->Map());
- Encoder<float>& encoder = *encoderPtr;
-
- Abs(decoder,
- encoder,
- inputTensorInfo);
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefAbsWorkload.hpp b/src/backends/reference/workloads/RefAbsWorkload.hpp
deleted file mode 100644
index 68105556d5..0000000000
--- a/src/backends/reference/workloads/RefAbsWorkload.hpp
+++ /dev/null
@@ -1,21 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <backendsCommon/Workload.hpp>
-#include <backendsCommon/WorkloadData.hpp>
-
-namespace armnn
-{
-
-class RefAbsWorkload : public BaseWorkload<AbsQueueDescriptor>
-{
-public:
- using BaseWorkload<AbsQueueDescriptor>::BaseWorkload;
- virtual void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp
index 60446226be..52ad9a2879 100644
--- a/src/backends/reference/workloads/RefComparisonWorkload.cpp
+++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp
@@ -52,12 +52,12 @@ void RefComparisonWorkload::Execute() const
m_Input1->Reset(m_Data.m_Inputs[1]->Map());
m_Output->Reset(m_Data.m_Outputs[0]->Map());
- using EqualFunction = ElementwiseFunction<std::equal_to<InType>>;
- using GreaterFunction = ElementwiseFunction<std::greater<InType>>;
- using GreaterOrEqualFunction = ElementwiseFunction<std::greater_equal<InType>>;
- using LessFunction = ElementwiseFunction<std::less<InType>>;
- using LessOrEqualFunction = ElementwiseFunction<std::less_equal<InType>>;
- using NotEqualFunction = ElementwiseFunction<std::not_equal_to<InType>>;
+ using EqualFunction = ElementwiseBinaryFunction<std::equal_to<InType>>;
+ using GreaterFunction = ElementwiseBinaryFunction<std::greater<InType>>;
+ using GreaterOrEqualFunction = ElementwiseBinaryFunction<std::greater_equal<InType>>;
+ using LessFunction = ElementwiseBinaryFunction<std::less<InType>>;
+ using LessOrEqualFunction = ElementwiseBinaryFunction<std::less_equal<InType>>;
+ using NotEqualFunction = ElementwiseBinaryFunction<std::not_equal_to<InType>>;
switch (m_Data.m_Parameters.m_Operation)
{
diff --git a/src/backends/reference/workloads/RefElementwiseUnaryWorkload.cpp b/src/backends/reference/workloads/RefElementwiseUnaryWorkload.cpp
new file mode 100644
index 0000000000..4fbb0d123f
--- /dev/null
+++ b/src/backends/reference/workloads/RefElementwiseUnaryWorkload.cpp
@@ -0,0 +1,95 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefElementwiseUnaryWorkload.hpp"
+
+#include "Decoders.hpp"
+#include "ElementwiseFunction.hpp"
+#include "Encoders.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "Abs.hpp"
+#include "Exp.hpp"
+#include "Rsqrt.hpp"
+#include "Sqrt.hpp"
+
+#include <Profiling.hpp>
+
+#include <armnn/TypesUtils.hpp>
+
+#include <functional>
+
+namespace armnn
+{
+
+RefElementwiseUnaryWorkload::RefElementwiseUnaryWorkload(const ElementwiseUnaryQueueDescriptor& desc,
+ const WorkloadInfo& info)
+ : BaseWorkload<ElementwiseUnaryQueueDescriptor>(desc, info)
+{}
+
+void RefElementwiseUnaryWorkload::PostAllocationConfigure()
+{
+ const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ m_Input = MakeDecoder<InType>(inputInfo);
+
+ m_Output = MakeEncoder<OutType>(outputInfo);
+}
+
+void RefElementwiseUnaryWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefElementwiseUnaryWorkload_Execute");
+
+ const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ const TensorShape& inShape = inputInfo.GetShape();
+ const TensorShape& outShape = outputInfo.GetShape();
+
+ m_Input->Reset(m_Data.m_Inputs[0]->Map());
+ m_Output->Reset(m_Data.m_Outputs[0]->Map());
+
+ using AbsFunction = ElementwiseUnaryFunction<abs<InType>>;
+ using ExpFunction = ElementwiseUnaryFunction<exp<InType>>;
+ using NegFunction = ElementwiseUnaryFunction<std::negate<InType>>;
+ using RsqrtFunction = ElementwiseUnaryFunction<rsqrt<InType>>;
+ using SqrtFunction = ElementwiseUnaryFunction<sqrt<InType>>;
+
+ switch (m_Data.m_Parameters.m_Operation)
+ {
+ case UnaryOperation::Abs:
+ {
+ AbsFunction(inShape, outShape, *m_Input, *m_Output);
+ break;
+ }
+ case UnaryOperation::Exp:
+ {
+ ExpFunction(inShape, outShape, *m_Input, *m_Output);
+ break;
+ }
+ case UnaryOperation::Neg:
+ {
+ NegFunction(inShape, outShape, *m_Input, *m_Output);
+ break;
+ }
+ case UnaryOperation::Rsqrt:
+ {
+ RsqrtFunction(inShape, outShape, *m_Input, *m_Output);
+ break;
+ }
+ case UnaryOperation::Sqrt:
+ {
+ SqrtFunction(inShape, outShape, *m_Input, *m_Output);
+ break;
+ }
+ default:
+ {
+ throw InvalidArgumentException(std::string("Unsupported unary operation ") +
+ GetUnaryOperationAsCString(m_Data.m_Parameters.m_Operation), CHECK_LOCATION());
+ }
+ }
+}
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefElementwiseUnaryWorkload.hpp b/src/backends/reference/workloads/RefElementwiseUnaryWorkload.hpp
new file mode 100644
index 0000000000..efb2865ebd
--- /dev/null
+++ b/src/backends/reference/workloads/RefElementwiseUnaryWorkload.hpp
@@ -0,0 +1,33 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "BaseIterator.hpp"
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+
+namespace armnn
+{
+
+class RefElementwiseUnaryWorkload : public BaseWorkload<ElementwiseUnaryQueueDescriptor>
+{
+public:
+ using BaseWorkload<ElementwiseUnaryQueueDescriptor>::m_Data;
+
+ RefElementwiseUnaryWorkload(const ElementwiseUnaryQueueDescriptor& descriptor, const WorkloadInfo& info);
+ void PostAllocationConfigure() override;
+ void Execute() const override;
+
+private:
+ using InType = float;
+ using OutType = float;
+
+ std::unique_ptr<Decoder<InType>> m_Input;
+ std::unique_ptr<Encoder<OutType>> m_Output;
+};
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
index 7e02f032ef..18bf0a7ad9 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
@@ -53,12 +53,12 @@ void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::Execute() c
m_Input1->Reset(m_Data.m_Inputs[1]->Map());
m_Output->Reset(m_Data.m_Outputs[0]->Map());
- ElementwiseFunction<Functor>(inShape0,
- inShape1,
- outShape,
- *m_Input0,
- *m_Input1,
- *m_Output);
+ ElementwiseBinaryFunction<Functor>(inShape0,
+ inShape1,
+ outShape,
+ *m_Input0,
+ *m_Input1,
+ *m_Output);
}
} //namespace armnn
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
index ee0d80b172..264ddce2de 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
@@ -21,8 +21,8 @@ template <typename Functor, typename ParentDescriptor, typename armnn::StringMap
class RefElementwiseWorkload : public BaseWorkload<ParentDescriptor>
{
public:
- using InType = typename ElementwiseFunction<Functor>::InType;
- using OutType = typename ElementwiseFunction<Functor>::OutType;
+ using InType = typename ElementwiseBinaryFunction<Functor>::InType;
+ using OutType = typename ElementwiseBinaryFunction<Functor>::OutType;
using BaseWorkload<ParentDescriptor>::m_Data;
RefElementwiseWorkload(const ParentDescriptor& descriptor, const WorkloadInfo& info);
diff --git a/src/backends/reference/workloads/RefRsqrtWorkload.cpp b/src/backends/reference/workloads/RefRsqrtWorkload.cpp
deleted file mode 100644
index fd6b9a3549..0000000000
--- a/src/backends/reference/workloads/RefRsqrtWorkload.cpp
+++ /dev/null
@@ -1,37 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefRsqrtWorkload.hpp"
-
-#include "Decoders.hpp"
-#include "Encoders.hpp"
-#include "RefWorkloadUtils.hpp"
-#include "Rsqrt.hpp"
-
-#include <Profiling.hpp>
-
-namespace armnn
-{
-
-void RefRsqrtWorkload::Execute() const
-{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefRsqrtWorkload_Execute");
-
- const TensorInfo& inputTensorInfo = GetTensorInfo(m_Data.m_Inputs[0]);
-
- std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputTensorInfo, m_Data.m_Inputs[0]->Map());
- Decoder<float>& decoder = *decoderPtr;
-
- const TensorInfo& outputTensorInfo = GetTensorInfo(m_Data.m_Outputs[0]);
-
- std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputTensorInfo, m_Data.m_Outputs[0]->Map());
- Encoder<float>& encoder = *encoderPtr;
-
- Rsqrt(decoder,
- encoder,
- GetTensorInfo(m_Data.m_Inputs[0]));
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefRsqrtWorkload.hpp b/src/backends/reference/workloads/RefRsqrtWorkload.hpp
deleted file mode 100644
index 6c8ad5bc60..0000000000
--- a/src/backends/reference/workloads/RefRsqrtWorkload.hpp
+++ /dev/null
@@ -1,21 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <backendsCommon/Workload.hpp>
-#include <backendsCommon/WorkloadData.hpp>
-
-namespace armnn
-{
-
-class RefRsqrtWorkload : public BaseWorkload<RsqrtQueueDescriptor>
-{
-public:
- using BaseWorkload<RsqrtQueueDescriptor>::BaseWorkload;
- virtual void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 1f9ad4a19a..7034b67aa5 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -5,7 +5,6 @@
#pragma once
-#include "Abs.hpp"
#include "Activation.hpp"
#include "ArgMinMax.hpp"
#include "BatchNormImpl.hpp"
@@ -15,7 +14,6 @@
#include "FullyConnected.hpp"
#include "Gather.hpp"
#include "Pooling2d.hpp"
-#include "RefAbsWorkload.hpp"
#include "RefActivationWorkload.hpp"
#include "RefArgMinMaxWorkload.hpp"
#include "RefBatchNormalizationWorkload.hpp"
@@ -33,6 +31,7 @@
#include "RefDetectionPostProcessWorkload.hpp"
#include "RefDequantizeWorkload.hpp"
#include "RefElementwiseWorkload.hpp"
+#include "RefElementwiseUnaryWorkload.hpp"
#include "RefFullyConnectedWorkload.hpp"
#include "RefFloorWorkload.hpp"
#include "RefFakeQuantizationFloat32Workload.hpp"
@@ -51,7 +50,6 @@
#include "RefReshapeWorkload.hpp"
#include "RefResizeBilinearWorkload.hpp"
#include "RefResizeWorkload.hpp"
-#include "RefRsqrtWorkload.hpp"
#include "RefSliceWorkload.hpp"
#include "RefSplitterWorkload.hpp"
#include "RefSoftmaxWorkload.hpp"
diff --git a/src/backends/reference/workloads/Rsqrt.cpp b/src/backends/reference/workloads/Rsqrt.cpp
deleted file mode 100644
index 5abc2c8f7b..0000000000
--- a/src/backends/reference/workloads/Rsqrt.cpp
+++ /dev/null
@@ -1,25 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "Rsqrt.hpp"
-
-#include <cmath>
-
-namespace armnn
-{
-
-void Rsqrt(Decoder<float>& in,
- Encoder<float>& out,
- const TensorInfo& tensorInfo)
-{
- for (unsigned int i = 0; i < tensorInfo.GetNumElements(); ++i)
- {
- out[i];
- in[i];
- out.Set(1.f / sqrtf(in.Get()));
- }
-}
-
-} //namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/Rsqrt.hpp b/src/backends/reference/workloads/Rsqrt.hpp
index ffc6b18d13..47ebcf36f6 100644
--- a/src/backends/reference/workloads/Rsqrt.hpp
+++ b/src/backends/reference/workloads/Rsqrt.hpp
@@ -1,19 +1,22 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2019 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
-#include "BaseIterator.hpp"
-#include <armnn/Tensor.hpp>
-#include <armnn/Types.hpp>
+#pragma once
+
+#include <iostream>
namespace armnn
{
-
-/// Performs the reciprocal squareroot function elementwise
-/// on the inputs to give the outputs.
-void Rsqrt(Decoder<float>& in,
- Encoder<float>& out,
- const TensorInfo& tensorInfo);
+ template<typename T>
+struct rsqrt : public std::unary_function<T, T>
+ {
+ T
+ operator () (const T& inputData) const
+ {
+ return 1 / std::sqrt(inputData);
+ }
+ };
} //namespace armnn
diff --git a/src/backends/reference/workloads/Sqrt.hpp b/src/backends/reference/workloads/Sqrt.hpp
new file mode 100644
index 0000000000..e4ff6a4829
--- /dev/null
+++ b/src/backends/reference/workloads/Sqrt.hpp
@@ -0,0 +1,22 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <iostream>
+
+namespace armnn
+{
+ template<typename T>
+struct sqrt : public std::unary_function<T, T>
+ {
+ T
+ operator () (const T& inputData) const
+ {
+ return std::sqrt(inputData);
+ }
+ };
+
+} //namespace armnn