aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2019-04-03 17:48:18 +0100
committerSadik Armagan <sadik.armagan@arm.com>2019-04-08 15:48:28 +0000
commit2e6dc3a1c5d47825535db7993ba77eb1596ae99b (patch)
tree48e73fa1862d17534804d1699bedb76120e88c9f
parent0324f48e64edb99a5c8d819394545d97e0c2ae97 (diff)
downloadarmnn-2e6dc3a1c5d47825535db7993ba77eb1596ae99b.tar.gz
IVGCVSW-2861 Refactor the Reference Elementwise workload
* Refactor Reference Comparison workload * Removed templating based on the DataType * Implemented BaseIterator to do decode/encode Change-Id: I18f299f47ee23772f90152c1146b42f07465e105 Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Signed-off-by: Kevin May <kevin.may@arm.com>
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp103
-rw-r--r--src/backends/backendsCommon/test/WorkloadDataValidation.cpp14
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp16
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp16
-rw-r--r--src/backends/reference/workloads/BaseIterator.hpp155
-rw-r--r--src/backends/reference/workloads/Broadcast.hpp24
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt1
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.cpp34
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.hpp9
-rw-r--r--src/backends/reference/workloads/RefComparisonWorkload.cpp91
-rw-r--r--src/backends/reference/workloads/RefComparisonWorkload.hpp76
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.cpp115
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.hpp69
13 files changed, 438 insertions, 285 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index b850a65acf..1360ac5d0c 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -491,13 +491,29 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateNumInputs(workloadInfo, "AdditionQueueDescriptor", 2);
ValidateNumOutputs(workloadInfo, "AdditionQueueDescriptor", 1);
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "AdditionQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "AdditionQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "AdditionQueueDescriptor");
+
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
workloadInfo.m_OutputTensorInfos[0],
"AdditionQueueDescriptor",
"first input",
"second input");
-
}
//---------------------------------------------------------------
@@ -506,6 +522,23 @@ void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
ValidateNumInputs(workloadInfo, "MultiplicationQueueDescriptor", 2);
ValidateNumOutputs(workloadInfo, "MultiplicationQueueDescriptor", 1);
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "MultiplicationQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "MultiplicationQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "MultiplicationQueueDescriptor");
+
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
workloadInfo.m_OutputTensorInfos[0],
@@ -857,6 +890,23 @@ void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateNumInputs(workloadInfo, "DivisionQueueDescriptor", 2);
ValidateNumOutputs(workloadInfo, "DivisionQueueDescriptor", 1);
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "DivisionQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "DivisionQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "DivisionQueueDescriptor");
+
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
workloadInfo.m_OutputTensorInfos[0],
@@ -870,6 +920,23 @@ void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
ValidateNumInputs(workloadInfo, "SubtractionQueueDescriptor", 2);
ValidateNumOutputs(workloadInfo, "SubtractionQueueDescriptor", 1);
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "SubtractionQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "SubtractionQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "SubtractionQueueDescriptor");
+
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
workloadInfo.m_OutputTensorInfos[0],
@@ -883,6 +950,23 @@ void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateNumInputs(workloadInfo, "MaximumQueueDescriptor", 2);
ValidateNumOutputs(workloadInfo, "MaximumQueueDescriptor", 1);
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "MaximumQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "MaximumQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "MaximumQueueDescriptor");
+
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
workloadInfo.m_OutputTensorInfos[0],
@@ -1008,6 +1092,23 @@ void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateNumInputs(workloadInfo, "MinimumQueueDescriptor", 2);
ValidateNumOutputs(workloadInfo, "MinimumQueueDescriptor", 1);
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "MinimumQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "MinimumQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "MinimumQueueDescriptor");
+
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
workloadInfo.m_OutputTensorInfos[0],
diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
index 3664d56c28..d37cc74c66 100644
--- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
+++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
@@ -312,17 +312,17 @@ BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputNumbers)
AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr);
// Too few inputs.
- BOOST_CHECK_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
AddInputToWorkload(invalidData, invalidInfo, input2TensorInfo, nullptr);
// Correct.
- BOOST_CHECK_NO_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo));
+ BOOST_CHECK_NO_THROW(RefAdditionWorkload(invalidData, invalidInfo));
AddInputToWorkload(invalidData, invalidInfo, input3TensorInfo, nullptr);
// Too many inputs.
- BOOST_CHECK_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
}
BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputShapes)
@@ -347,7 +347,7 @@ BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputShapes)
AddInputToWorkload(invalidData, invalidInfo, input2TensorInfo, nullptr);
AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr);
- BOOST_CHECK_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
}
// Output size not compatible with input sizes.
@@ -364,7 +364,7 @@ BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputShapes)
AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr);
// Output differs.
- BOOST_CHECK_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
}
}
@@ -399,7 +399,7 @@ BOOST_AUTO_TEST_CASE(MultiplicationQueueDescriptor_Validate_InputTensorDimension
AddInputToWorkload(invalidData, invalidInfo, input0TensorInfo, nullptr);
AddInputToWorkload(invalidData, invalidInfo, input1TensorInfo, nullptr);
- BOOST_CHECK_THROW(RefMultiplicationFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefMultiplicationWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
}
// Checks dimension consistency for input and output tensors.
@@ -424,7 +424,7 @@ BOOST_AUTO_TEST_CASE(MultiplicationQueueDescriptor_Validate_InputTensorDimension
AddInputToWorkload(invalidData, invalidInfo, input0TensorInfo, nullptr);
AddInputToWorkload(invalidData, invalidInfo, input1TensorInfo, nullptr);
- BOOST_CHECK_THROW(RefMultiplicationFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefMultiplicationWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
}
}
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 619c14e007..8ea923d599 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -174,13 +174,13 @@ std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateNormalization(
std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<RefAdditionFloat32Workload, RefAdditionUint8Workload>(descriptor, info);
+ return std::make_unique<RefAdditionWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMultiplication(
const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info) const
{
- return MakeWorkload<RefMultiplicationFloat32Workload, RefMultiplicationUint8Workload>(descriptor, info);
+ return std::make_unique<RefMultiplicationWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateBatchNormalization(
@@ -266,19 +266,19 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp32ToFp16(
std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateDivision(
const DivisionQueueDescriptor& descriptor, const WorkloadInfo& info) const
{
- return MakeWorkload<RefDivisionFloat32Workload, RefDivisionUint8Workload>(descriptor, info);
+ return std::make_unique<RefDivisionWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateSubtraction(
const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info) const
{
- return MakeWorkload<RefSubtractionFloat32Workload, RefSubtractionUint8Workload>(descriptor, info);
+ return std::make_unique<RefSubtractionWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMaximum(
const MaximumQueueDescriptor& descriptor, const WorkloadInfo& info) const
{
- return MakeWorkload<RefMaximumFloat32Workload, RefMaximumUint8Workload>(descriptor, info);
+ return std::make_unique<RefMaximumWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMean(
@@ -290,7 +290,7 @@ std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMean(
std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMinimum(
const MinimumQueueDescriptor& descriptor, const WorkloadInfo& info) const
{
- return MakeWorkload<RefMinimumFloat32Workload, RefMinimumUint8Workload>(descriptor, info);
+ return std::make_unique<RefMinimumWorkload>(descriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
@@ -302,7 +302,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescripto
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<RefEqualFloat32Workload, RefEqualUint8Workload>(descriptor, info);
+ return std::make_unique<RefEqualWorkload>(descriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
@@ -320,7 +320,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateStridedSlice(const StridedS
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<RefGreaterFloat32Workload, RefGreaterUint8Workload>(descriptor, info);
+ return std::make_unique<RefGreaterWorkload>(descriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 8621122925..09b0246895 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -82,7 +82,7 @@ static void RefCreateElementwiseWorkloadTest()
BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload)
{
- RefCreateElementwiseWorkloadTest<RefAdditionFloat32Workload,
+ RefCreateElementwiseWorkloadTest<RefAdditionWorkload,
AdditionQueueDescriptor,
AdditionLayer,
armnn::DataType::Float32>();
@@ -90,7 +90,7 @@ BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload)
BOOST_AUTO_TEST_CASE(CreateAdditionUint8Workload)
{
- RefCreateElementwiseWorkloadTest<RefAdditionUint8Workload,
+ RefCreateElementwiseWorkloadTest<RefAdditionWorkload,
AdditionQueueDescriptor,
AdditionLayer,
armnn::DataType::QuantisedAsymm8>();
@@ -98,7 +98,7 @@ BOOST_AUTO_TEST_CASE(CreateAdditionUint8Workload)
BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload)
{
- RefCreateElementwiseWorkloadTest<RefSubtractionFloat32Workload,
+ RefCreateElementwiseWorkloadTest<RefSubtractionWorkload,
SubtractionQueueDescriptor,
SubtractionLayer,
armnn::DataType::Float32>();
@@ -106,7 +106,7 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload)
BOOST_AUTO_TEST_CASE(CreateSubtractionUint8Workload)
{
- RefCreateElementwiseWorkloadTest<RefSubtractionUint8Workload,
+ RefCreateElementwiseWorkloadTest<RefSubtractionWorkload,
SubtractionQueueDescriptor,
SubtractionLayer,
armnn::DataType::QuantisedAsymm8>();
@@ -114,7 +114,7 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionUint8Workload)
BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkload)
{
- RefCreateElementwiseWorkloadTest<RefMultiplicationFloat32Workload,
+ RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload,
MultiplicationQueueDescriptor,
MultiplicationLayer,
armnn::DataType::Float32>();
@@ -122,7 +122,7 @@ BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkload)
BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8Workload)
{
- RefCreateElementwiseWorkloadTest<RefMultiplicationUint8Workload,
+ RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload,
MultiplicationQueueDescriptor,
MultiplicationLayer,
armnn::DataType::QuantisedAsymm8>();
@@ -130,7 +130,7 @@ BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8Workload)
BOOST_AUTO_TEST_CASE(CreateDivisionFloatWorkload)
{
- RefCreateElementwiseWorkloadTest<RefDivisionFloat32Workload,
+ RefCreateElementwiseWorkloadTest<RefDivisionWorkload,
DivisionQueueDescriptor,
DivisionLayer,
armnn::DataType::Float32>();
@@ -138,7 +138,7 @@ BOOST_AUTO_TEST_CASE(CreateDivisionFloatWorkload)
BOOST_AUTO_TEST_CASE(CreateDivisionUint8Workload)
{
- RefCreateElementwiseWorkloadTest<RefDivisionUint8Workload,
+ RefCreateElementwiseWorkloadTest<RefDivisionWorkload,
DivisionQueueDescriptor,
DivisionLayer,
armnn::DataType::QuantisedAsymm8>();
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
new file mode 100644
index 0000000000..cfa8ce7e91
--- /dev/null
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -0,0 +1,155 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/ArmNN.hpp>
+#include <TypeUtils.hpp>
+
+namespace armnn
+{
+
+class BaseIterator
+{
+public:
+ BaseIterator() {}
+
+ virtual ~BaseIterator() {}
+
+ virtual BaseIterator& operator++() = 0;
+
+ virtual BaseIterator& operator+=(const unsigned int increment) = 0;
+
+ virtual BaseIterator& operator-=(const unsigned int increment) = 0;
+};
+
+class Decoder : public BaseIterator
+{
+public:
+ Decoder() : BaseIterator() {}
+
+ virtual ~Decoder() {}
+
+ virtual float Get() const = 0;
+};
+
+class Encoder : public BaseIterator
+{
+public:
+ Encoder() : BaseIterator() {}
+
+ virtual ~Encoder() {}
+
+ virtual void Set(const float& right) = 0;
+};
+
+class ComparisonEncoder : public BaseIterator
+{
+public:
+ ComparisonEncoder() : BaseIterator() {}
+
+ virtual ~ComparisonEncoder() {}
+
+ virtual void Set(bool right) = 0;
+};
+
+template<typename T, typename Base>
+class TypedIterator : public Base
+{
+public:
+ TypedIterator(T* data)
+ : m_Iterator(data)
+ {}
+
+ TypedIterator& operator++() override
+ {
+ ++m_Iterator;
+ return *this;
+ }
+
+ TypedIterator& operator+=(const unsigned int increment) override
+ {
+ m_Iterator += increment;
+ return *this;
+ }
+
+ TypedIterator& operator-=(const unsigned int increment) override
+ {
+ m_Iterator -= increment;
+ return *this;
+ }
+
+ T* m_Iterator;
+};
+
+class QASymm8Decoder : public TypedIterator<const uint8_t, Decoder>
+{
+public:
+ QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset)
+ : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
+
+ float Get() const override
+ {
+ return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
+ }
+
+private:
+ const float m_Scale;
+ const int32_t m_Offset;
+};
+
+class FloatDecoder : public TypedIterator<const float, Decoder>
+{
+public:
+ FloatDecoder(const float* data)
+ : TypedIterator(data) {}
+
+ float Get() const override
+ {
+ return *m_Iterator;
+ }
+};
+
+class FloatEncoder : public TypedIterator<float, Encoder>
+{
+public:
+ FloatEncoder(float* data)
+ : TypedIterator(data) {}
+
+ void Set(const float& right) override
+ {
+ *m_Iterator = right;
+ }
+};
+
+class QASymm8Encoder : public TypedIterator<uint8_t, Encoder>
+{
+public:
+ QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset)
+ : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
+
+ void Set(const float& right) override
+ {
+ *m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset);
+ }
+
+private:
+ const float m_Scale;
+ const int32_t m_Offset;
+};
+
+class BooleanEncoder : public TypedIterator<uint8_t, ComparisonEncoder>
+{
+public:
+ BooleanEncoder(uint8_t* data)
+ : TypedIterator(data) {}
+
+ void Set(bool right) override
+ {
+ *m_Iterator = right;
+ }
+};
+
+} //namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/Broadcast.hpp b/src/backends/reference/workloads/Broadcast.hpp
index e92ed0598d..5bf6be8939 100644
--- a/src/backends/reference/workloads/Broadcast.hpp
+++ b/src/backends/reference/workloads/Broadcast.hpp
@@ -3,6 +3,7 @@
// SPDX-License-Identifier: MIT
//
+#include "BaseIterator.hpp"
#include <armnn/Tensor.hpp>
#include <functional>
@@ -19,19 +20,23 @@ struct BroadcastLoop
return static_cast<unsigned int>(m_DimData.size());
}
- template <typename T0, typename T1, typename U, typename Func>
+ template <typename Func, typename DecoderOp, typename EncoderOp>
void Unroll(Func operationFunc,
unsigned int dimension,
- const T0* inData0,
- const T1* inData1,
- U* outData)
+ DecoderOp& inData0,
+ DecoderOp& inData1,
+ EncoderOp& outData)
{
if (dimension >= GetNumDimensions())
{
- *outData = operationFunc(*inData0, *inData1);
+ outData.Set(operationFunc(inData0.Get(), inData1.Get()));
return;
}
+ unsigned int inData0Movement = 0;
+ unsigned int inData1Movement = 0;
+ unsigned int outDataMovement = 0;
+
for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++)
{
Unroll(operationFunc, dimension + 1, inData0, inData1, outData);
@@ -39,7 +44,16 @@ struct BroadcastLoop
inData0 += m_DimData[dimension].m_Stride1;
inData1 += m_DimData[dimension].m_Stride2;
outData += m_DimData[dimension].m_StrideOut;
+
+ inData0Movement += m_DimData[dimension].m_Stride1;
+ inData1Movement += m_DimData[dimension].m_Stride2;
+ outDataMovement += m_DimData[dimension].m_StrideOut;
}
+
+ // move iterator back to the start
+ inData0 -= inData0Movement;
+ inData1 -= inData1Movement;
+ outData -= outDataMovement;
}
private:
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 4f5fbb554e..4ff2466e87 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -6,6 +6,7 @@
list(APPEND armnnRefBackendWorkloads_sources
Activation.cpp
Activation.hpp
+ BaseIterator.hpp
BatchNormImpl.hpp
BatchToSpaceNd.cpp
BatchToSpaceNd.hpp
diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp
index c8c25ef9e9..934a86217a 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.cpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.cpp
@@ -13,26 +13,26 @@
namespace armnn
{
-template <typename Functor, typename dataTypeInput, typename dataTypeOutput>
-ElementwiseFunction<Functor, dataTypeInput, dataTypeOutput>::ElementwiseFunction(const TensorShape& inShape0,
- const TensorShape& inShape1,
- const TensorShape& outShape,
- const dataTypeInput* inData0,
- const dataTypeInput* inData1,
- dataTypeOutput* outData)
+template <typename Functor, typename DecoderOp, typename EncoderOp>
+ElementwiseFunction<Functor, DecoderOp, EncoderOp>::ElementwiseFunction(const TensorShape& inShape0,
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ DecoderOp& inData0,
+ DecoderOp& inData1,
+ EncoderOp& outData)
{
BroadcastLoop(inShape0, inShape1, outShape).Unroll(Functor(), 0, inData0, inData1, outData);
}
} //namespace armnn
-template struct armnn::ElementwiseFunction<std::plus<float>, float, float>;
-template struct armnn::ElementwiseFunction<std::minus<float>, float, float>;
-template struct armnn::ElementwiseFunction<std::multiplies<float>, float, float>;
-template struct armnn::ElementwiseFunction<std::divides<float>, float, float>;
-template struct armnn::ElementwiseFunction<armnn::maximum<float>, float, float>;
-template struct armnn::ElementwiseFunction<armnn::minimum<float>, float, float>;
-template struct armnn::ElementwiseFunction<std::equal_to<float>, float ,uint8_t>;
-template struct armnn::ElementwiseFunction<std::equal_to<uint8_t>, uint8_t, uint8_t>;
-template struct armnn::ElementwiseFunction<std::greater<float>, float, uint8_t>;
-template struct armnn::ElementwiseFunction<std::greater<uint8_t>, uint8_t, uint8_t>;
+template struct armnn::ElementwiseFunction<std::plus<float>, armnn::Decoder, armnn::Encoder>;
+template struct armnn::ElementwiseFunction<std::minus<float>, armnn::Decoder, armnn::Encoder>;
+template struct armnn::ElementwiseFunction<std::multiplies<float>, armnn::Decoder, armnn::Encoder>;
+template struct armnn::ElementwiseFunction<std::divides<float>, armnn::Decoder, armnn::Encoder>;
+template struct armnn::ElementwiseFunction<armnn::maximum<float>, armnn::Decoder, armnn::Encoder>;
+template struct armnn::ElementwiseFunction<armnn::minimum<float>, armnn::Decoder, armnn::Encoder>;
+
+template struct armnn::ElementwiseFunction<std::equal_to<float>, armnn::Decoder, armnn::ComparisonEncoder>;
+template struct armnn::ElementwiseFunction<std::greater<float>, armnn::Decoder, armnn::ComparisonEncoder>;
+
diff --git a/src/backends/reference/workloads/ElementwiseFunction.hpp b/src/backends/reference/workloads/ElementwiseFunction.hpp
index 8099f3279a..9eb003d5f9 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.hpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.hpp
@@ -5,20 +5,21 @@
#pragma once
+#include "BaseIterator.hpp"
#include <armnn/Tensor.hpp>
namespace armnn
{
-template <typename Functor, typename dataTypeInput, typename dataTypeOutput>
+template <typename Functor, typename DecoderOp, typename EncoderOp>
struct ElementwiseFunction
{
ElementwiseFunction(const TensorShape& inShape0,
const TensorShape& inShape1,
const TensorShape& outShape,
- const dataTypeInput* inData0,
- const dataTypeInput* inData1,
- dataTypeOutput* outData);
+ DecoderOp& inData0,
+ DecoderOp& inData1,
+ EncoderOp& outData);
};
} //namespace armnn
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp
index fe517ff51a..bb8bb04ad3 100644
--- a/src/backends/reference/workloads/RefComparisonWorkload.cpp
+++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp
@@ -11,55 +11,66 @@
namespace armnn {
-template<typename ParentDescriptor, typename Functor>
-void RefFloat32ComparisonWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char* debugString) const
+template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
+void RefComparisonWorkload<Functor, ParentDescriptor, DebugString>::Execute() const
{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString);
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString));
+ const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
- auto data = BaseFloat32ComparisonWorkload<ParentDescriptor>::GetData();
- const TensorShape& inShape0 = GetTensorInfo(data.m_Inputs[0]).GetShape();
- const TensorShape& inShape1 = GetTensorInfo(data.m_Inputs[1]).GetShape();
- const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape();
+ const TensorShape& inShape0 = inputInfo0.GetShape();
+ const TensorShape& inShape1 = inputInfo1.GetShape();
+ const TensorShape& outShape = outputInfo.GetShape();
- const float* inData0 = GetInputTensorDataFloat(0, data);
- const float* inData1 = GetInputTensorDataFloat(1, data);
- uint8_t* outData = GetOutputTensorData<uint8_t>(0, data);
+ switch(inputInfo0.GetDataType())
+ {
+ case armnn::DataType::QuantisedAsymm8:
+ {
+ QASymm8Decoder decodeIterator0(GetInputTensorDataU8(0, m_Data),
+ inputInfo0.GetQuantizationScale(),
+ inputInfo0.GetQuantizationOffset());
- ElementwiseFunction<Functor, float, uint8_t>(inShape0,
- inShape1,
- outputShape,
- inData0,
- inData1,
- outData);
+ QASymm8Decoder decodeIterator1(GetInputTensorDataU8(1, m_Data),
+ inputInfo1.GetQuantizationScale(),
+ inputInfo1.GetQuantizationOffset());
-}
-
-template<typename ParentDescriptor, typename Functor>
-void RefUint8ComparisonWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char* debugString) const
-{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString);
-
- auto data = BaseUint8ComparisonWorkload<ParentDescriptor>::GetData();
- const TensorShape& inputInfo0 = GetTensorInfo(data.m_Inputs[0]).GetShape();
- const TensorShape& inputInfo1 = GetTensorInfo(data.m_Inputs[1]).GetShape();
- const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape();
+ BooleanEncoder encodeIterator0(GetOutputTensorDataU8(0, m_Data));
- const uint8_t* inData0 = GetInputTensorData<uint8_t>(0, data);
- const uint8_t* inData1 = GetInputTensorData<uint8_t>(1, data);
- uint8_t* outData = GetOutputTensorData<uint8_t>(0, data);
+ ElementwiseFunction<Functor, Decoder, ComparisonEncoder>(inShape0,
+ inShape1,
+ outShape,
+ decodeIterator0,
+ decodeIterator1,
+ encodeIterator0);
+ break;
+ }
+ case armnn::DataType::Float32:
+ {
+ FloatDecoder decodeIterator0(GetInputTensorDataFloat(0, m_Data));
+ FloatDecoder decodeIterator1(GetInputTensorDataFloat(1, m_Data));
+ BooleanEncoder encodeIterator0(GetOutputTensorDataU8(0, m_Data));
- ElementwiseFunction<Functor, uint8_t, uint8_t>(inputInfo0,
- inputInfo1,
- outputShape,
- inData0,
- inData1,
- outData);
+ ElementwiseFunction<Functor, Decoder, ComparisonEncoder>(inShape0,
+ inShape1,
+ outShape,
+ decodeIterator0,
+ decodeIterator1,
+ encodeIterator0);
+ break;
+ }
+ default:
+ BOOST_ASSERT_MSG(false, "RefComparisonWorkload: Not supported Data Type!");
+ break;
+ }
}
}
-template class armnn::RefFloat32ComparisonWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>;
-template class armnn::RefUint8ComparisonWorkload<armnn::EqualQueueDescriptor, std::equal_to<uint8_t>>;
+template class armnn::RefComparisonWorkload<std::equal_to<float>,
+ armnn::EqualQueueDescriptor,
+ armnn::StringMapping::RefEqualWorkload_Execute>;
-template class armnn::RefFloat32ComparisonWorkload<armnn::GreaterQueueDescriptor, std::greater<float>>;
-template class armnn::RefUint8ComparisonWorkload<armnn::GreaterQueueDescriptor, std::greater<uint8_t>>;
+template class armnn::RefComparisonWorkload<std::greater<float>,
+ armnn::GreaterQueueDescriptor,
+ armnn::StringMapping::RefGreaterWorkload_Execute>;
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.hpp b/src/backends/reference/workloads/RefComparisonWorkload.hpp
index 524d20625a..cfc2dcf2aa 100644
--- a/src/backends/reference/workloads/RefComparisonWorkload.hpp
+++ b/src/backends/reference/workloads/RefComparisonWorkload.hpp
@@ -13,80 +13,24 @@
namespace armnn
{
-template <typename Functor,
- typename armnn::DataType DataType,
- typename ParentDescriptor,
- typename armnn::StringMapping::Id DebugString>
-class RefComparisonWorkload
-{
- // Needs specialization. The default is empty on purpose.
-};
-
-template <typename ParentDescriptor, typename Functor>
-class RefFloat32ComparisonWorkload : public BaseFloat32ComparisonWorkload<ParentDescriptor>
-{
-public:
- using BaseFloat32ComparisonWorkload<ParentDescriptor>::BaseFloat32ComparisonWorkload;
- void ExecuteImpl(const char * debugString) const;
-};
-
-template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
-class RefComparisonWorkload<Functor, armnn::DataType::Float32, ParentDescriptor, DebugString>
- : public RefFloat32ComparisonWorkload<ParentDescriptor, Functor>
-{
-public:
- using RefFloat32ComparisonWorkload<ParentDescriptor, Functor>::RefFloat32ComparisonWorkload;
-
- virtual void Execute() const override
- {
- using Parent = RefFloat32ComparisonWorkload<ParentDescriptor, Functor>;
- Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString));
- }
-};
-
-template <typename ParentDescriptor, typename Functor>
-class RefUint8ComparisonWorkload : public BaseUint8ComparisonWorkload<ParentDescriptor>
-{
-public:
- using BaseUint8ComparisonWorkload<ParentDescriptor>::BaseUint8ComparisonWorkload;
- void ExecuteImpl(const char * debugString) const;
-};
-
template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
-class RefComparisonWorkload<Functor, armnn::DataType::QuantisedAsymm8, ParentDescriptor, DebugString>
- : public RefUint8ComparisonWorkload<ParentDescriptor, Functor>
+class RefComparisonWorkload : public BaseWorkload<ParentDescriptor>
{
public:
- using RefUint8ComparisonWorkload<ParentDescriptor, Functor>::RefUint8ComparisonWorkload;
+ using BaseWorkload<ParentDescriptor>::m_Data;
+ using BaseWorkload<ParentDescriptor>::BaseWorkload;
- virtual void Execute() const override
- {
- using Parent = RefUint8ComparisonWorkload<ParentDescriptor, Functor>;
- Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString));
- }
+ void Execute() const override;
};
-using RefEqualFloat32Workload =
+using RefEqualWorkload =
RefComparisonWorkload<std::equal_to<float>,
- DataType::Float32,
- EqualQueueDescriptor,
- StringMapping::RefEqualWorkload_Execute>;
+ EqualQueueDescriptor,
+ StringMapping::RefEqualWorkload_Execute>;
-using RefEqualUint8Workload =
- RefComparisonWorkload<std::equal_to<uint8_t>,
- DataType::QuantisedAsymm8,
- EqualQueueDescriptor,
- StringMapping::RefEqualWorkload_Execute>;
-using RefGreaterFloat32Workload =
+using RefGreaterWorkload =
RefComparisonWorkload<std::greater<float>,
- DataType::Float32,
- GreaterQueueDescriptor,
- StringMapping::RefGreaterWorkload_Execute>;
-
-using RefGreaterUint8Workload =
- RefComparisonWorkload<std::greater<uint8_t>,
- DataType::QuantisedAsymm8,
- GreaterQueueDescriptor,
- StringMapping::RefGreaterWorkload_Execute>;
+ GreaterQueueDescriptor,
+ StringMapping::RefGreaterWorkload_Execute>;
} // armnn
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
index 356d7a0c16..6e6e1d5f21 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
@@ -14,14 +14,10 @@
namespace armnn
{
-template <typename Functor,
- typename armnn::DataType DataType,
- typename ParentDescriptor,
- typename armnn::StringMapping::Id DebugString>
-void RefElementwiseWorkload<Functor, DataType, ParentDescriptor, DebugString>::Execute() const
+template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
+void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::Execute() const
{
ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString));
-
const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
@@ -30,32 +26,46 @@ void RefElementwiseWorkload<Functor, DataType, ParentDescriptor, DebugString>::E
const TensorShape& inShape1 = inputInfo1.GetShape();
const TensorShape& outShape = outputInfo.GetShape();
- switch(DataType)
+ switch(inputInfo0.GetDataType())
{
case armnn::DataType::QuantisedAsymm8:
{
- std::vector<float> results(outputInfo.GetNumElements());
- ElementwiseFunction<Functor, float, float>(inShape0,
- inShape1,
- outShape,
- Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo0).data(),
- Dequantize(GetInputTensorDataU8(1, m_Data), inputInfo1).data(),
- results.data());
- Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo);
+ QASymm8Decoder decodeIterator0(GetInputTensorDataU8(0, m_Data),
+ inputInfo0.GetQuantizationScale(),
+ inputInfo0.GetQuantizationOffset());
+
+ QASymm8Decoder decodeIterator1(GetInputTensorDataU8(1, m_Data),
+ inputInfo1.GetQuantizationScale(),
+ inputInfo1.GetQuantizationOffset());
+
+ QASymm8Encoder encodeIterator0(GetOutputTensorDataU8(0, m_Data),
+ outputInfo.GetQuantizationScale(),
+ outputInfo.GetQuantizationOffset());
+
+ ElementwiseFunction<Functor, Decoder, Encoder>(inShape0,
+ inShape1,
+ outShape,
+ decodeIterator0,
+ decodeIterator1,
+ encodeIterator0);
break;
}
case armnn::DataType::Float32:
{
- ElementwiseFunction<Functor, float, float>(inShape0,
- inShape1,
- outShape,
- GetInputTensorDataFloat(0, m_Data),
- GetInputTensorDataFloat(1, m_Data),
- GetOutputTensorDataFloat(0, m_Data));
+ FloatDecoder decodeIterator0(GetInputTensorDataFloat(0, m_Data));
+ FloatDecoder decodeIterator1(GetInputTensorDataFloat(1, m_Data));
+ FloatEncoder encodeIterator0(GetOutputTensorDataFloat(0, m_Data));
+
+ ElementwiseFunction<Functor, Decoder, Encoder>(inShape0,
+ inShape1,
+ outShape,
+ decodeIterator0,
+ decodeIterator1,
+ encodeIterator0);
break;
}
default:
- BOOST_ASSERT_MSG(false, "Unknown Data Type!");
+ BOOST_ASSERT_MSG(false, "RefElementwiseWorkload: Not supported Data Type!");
break;
}
}
@@ -63,62 +73,25 @@ void RefElementwiseWorkload<Functor, DataType, ParentDescriptor, DebugString>::E
}
template class armnn::RefElementwiseWorkload<std::plus<float>,
- armnn::DataType::Float32,
- armnn::AdditionQueueDescriptor,
- armnn::StringMapping::RefAdditionWorkload_Execute>;
-
-template class armnn::RefElementwiseWorkload<std::plus<float>,
- armnn::DataType::QuantisedAsymm8,
- armnn::AdditionQueueDescriptor,
- armnn::StringMapping::RefAdditionWorkload_Execute>;
+ armnn::AdditionQueueDescriptor,
+ armnn::StringMapping::RefAdditionWorkload_Execute>;
template class armnn::RefElementwiseWorkload<std::minus<float>,
- armnn::DataType::Float32,
- armnn::SubtractionQueueDescriptor,
- armnn::StringMapping::RefSubtractionWorkload_Execute>;
-
-template class armnn::RefElementwiseWorkload<std::minus<float>,
- armnn::DataType::QuantisedAsymm8,
- armnn::SubtractionQueueDescriptor,
- armnn::StringMapping::RefSubtractionWorkload_Execute>;
-
-template class armnn::RefElementwiseWorkload<std::multiplies<float>,
- armnn::DataType::Float32,
- armnn::MultiplicationQueueDescriptor,
- armnn::StringMapping::RefMultiplicationWorkload_Execute>;
+ armnn::SubtractionQueueDescriptor,
+ armnn::StringMapping::RefSubtractionWorkload_Execute>;
template class armnn::RefElementwiseWorkload<std::multiplies<float>,
- armnn::DataType::QuantisedAsymm8,
- armnn::MultiplicationQueueDescriptor,
- armnn::StringMapping::RefMultiplicationWorkload_Execute>;
+ armnn::MultiplicationQueueDescriptor,
+ armnn::StringMapping::RefMultiplicationWorkload_Execute>;
template class armnn::RefElementwiseWorkload<std::divides<float>,
- armnn::DataType::Float32,
- armnn::DivisionQueueDescriptor,
- armnn::StringMapping::RefDivisionWorkload_Execute>;
-
-template class armnn::RefElementwiseWorkload<std::divides<float>,
- armnn::DataType::QuantisedAsymm8,
- armnn::DivisionQueueDescriptor,
- armnn::StringMapping::RefDivisionWorkload_Execute>;
-
-template class armnn::RefElementwiseWorkload<armnn::maximum<float>,
- armnn::DataType::Float32,
- armnn::MaximumQueueDescriptor,
- armnn::StringMapping::RefMaximumWorkload_Execute>;
+ armnn::DivisionQueueDescriptor,
+ armnn::StringMapping::RefDivisionWorkload_Execute>;
template class armnn::RefElementwiseWorkload<armnn::maximum<float>,
- armnn::DataType::QuantisedAsymm8,
- armnn::MaximumQueueDescriptor,
- armnn::StringMapping::RefMaximumWorkload_Execute>;
-
-
-template class armnn::RefElementwiseWorkload<armnn::minimum<float>,
- armnn::DataType::Float32,
- armnn::MinimumQueueDescriptor,
- armnn::StringMapping::RefMinimumWorkload_Execute>;
+ armnn::MaximumQueueDescriptor,
+ armnn::StringMapping::RefMaximumWorkload_Execute>;
template class armnn::RefElementwiseWorkload<armnn::minimum<float>,
- armnn::DataType::QuantisedAsymm8,
- armnn::MinimumQueueDescriptor,
- armnn::StringMapping::RefMinimumWorkload_Execute>;
+ armnn::MinimumQueueDescriptor,
+ armnn::StringMapping::RefMinimumWorkload_Execute>; \ No newline at end of file
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
index 371904977a..81af19627e 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
@@ -15,90 +15,43 @@
namespace armnn
{
-template <typename Functor,
- typename armnn::DataType DataType,
- typename ParentDescriptor,
- typename armnn::StringMapping::Id DebugString>
-class RefElementwiseWorkload
- : public TypedWorkload<ParentDescriptor, DataType>
+template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
+class RefElementwiseWorkload : public BaseWorkload<ParentDescriptor>
{
public:
-
- using TypedWorkload<ParentDescriptor, DataType>::m_Data;
- using TypedWorkload<ParentDescriptor, DataType>::TypedWorkload;
+ using BaseWorkload<ParentDescriptor>::m_Data;
+ using BaseWorkload<ParentDescriptor>::BaseWorkload;
void Execute() const override;
};
-using RefAdditionFloat32Workload =
- RefElementwiseWorkload<std::plus<float>,
- DataType::Float32,
- AdditionQueueDescriptor,
- StringMapping::RefAdditionWorkload_Execute>;
-
-using RefAdditionUint8Workload =
+using RefAdditionWorkload =
RefElementwiseWorkload<std::plus<float>,
- DataType::QuantisedAsymm8,
AdditionQueueDescriptor,
StringMapping::RefAdditionWorkload_Execute>;
-using RefSubtractionFloat32Workload =
- RefElementwiseWorkload<std::minus<float>,
- DataType::Float32,
- SubtractionQueueDescriptor,
- StringMapping::RefSubtractionWorkload_Execute>;
-
-using RefSubtractionUint8Workload =
+using RefSubtractionWorkload =
RefElementwiseWorkload<std::minus<float>,
- DataType::QuantisedAsymm8,
SubtractionQueueDescriptor,
StringMapping::RefSubtractionWorkload_Execute>;
-using RefMultiplicationFloat32Workload =
- RefElementwiseWorkload<std::multiplies<float>,
- DataType::Float32,
- MultiplicationQueueDescriptor,
- StringMapping::RefMultiplicationWorkload_Execute>;
-
-using RefMultiplicationUint8Workload =
+using RefMultiplicationWorkload =
RefElementwiseWorkload<std::multiplies<float>,
- DataType::QuantisedAsymm8,
MultiplicationQueueDescriptor,
StringMapping::RefMultiplicationWorkload_Execute>;
-using RefDivisionFloat32Workload =
- RefElementwiseWorkload<std::divides<float>,
- DataType::Float32,
- DivisionQueueDescriptor,
- StringMapping::RefDivisionWorkload_Execute>;
-
-using RefDivisionUint8Workload =
+using RefDivisionWorkload =
RefElementwiseWorkload<std::divides<float>,
- DataType::QuantisedAsymm8,
DivisionQueueDescriptor,
StringMapping::RefDivisionWorkload_Execute>;
-using RefMaximumFloat32Workload =
- RefElementwiseWorkload<armnn::maximum<float>,
- DataType::Float32,
- MaximumQueueDescriptor,
- StringMapping::RefMaximumWorkload_Execute>;
-
-using RefMaximumUint8Workload =
+using RefMaximumWorkload =
RefElementwiseWorkload<armnn::maximum<float>,
- DataType::QuantisedAsymm8,
MaximumQueueDescriptor,
StringMapping::RefMaximumWorkload_Execute>;
-using RefMinimumFloat32Workload =
- RefElementwiseWorkload<minimum<float>,
- DataType::Float32,
- MinimumQueueDescriptor,
- StringMapping::RefMinimumWorkload_Execute>;
-
-using RefMinimumUint8Workload =
- RefElementwiseWorkload<minimum<float>,
- DataType::QuantisedAsymm8,
+using RefMinimumWorkload =
+ RefElementwiseWorkload<armnn::minimum<float>,
MinimumQueueDescriptor,
StringMapping::RefMinimumWorkload_Execute>;
} // armnn