aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2020-06-22 15:58:32 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2020-06-30 13:00:52 +0000
commitcbd2c230b7ce5f26e2ccccf36b7ad450f6e1ad09 (patch)
treeeb6e5393726be21213e72f26a676b7c3809fc995 /src
parent532a29d12d72f54549d8b71edd485c17af65698a (diff)
downloadarmnn-cbd2c230b7ce5f26e2ccccf36b7ad450f6e1ad09.tar.gz
IVGCVSW-5007 Implement an Int32 reference Elementwise workload
Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: I6592169b74ac4294bc09647879aec0718c641f91
Diffstat (limited to 'src')
-rw-r--r--src/backends/backendsCommon/test/WorkloadDataValidation.cpp14
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp54
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp60
-rw-r--r--src/backends/reference/workloads/BaseIterator.hpp35
-rw-r--r--src/backends/reference/workloads/Decoders.hpp18
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.cpp7
-rw-r--r--src/backends/reference/workloads/Encoders.hpp18
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.cpp24
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.hpp18
9 files changed, 215 insertions, 33 deletions
diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
index b3987c0a74..2eb4a06f29 100644
--- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
+++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
@@ -336,17 +336,17 @@ BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputNumbers)
AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr);
// Too few inputs.
- BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefAdditionWorkload<>(invalidData, invalidInfo), armnn::InvalidArgumentException);
AddInputToWorkload(invalidData, invalidInfo, input2TensorInfo, nullptr);
// Correct.
- BOOST_CHECK_NO_THROW(RefAdditionWorkload(invalidData, invalidInfo));
+ BOOST_CHECK_NO_THROW(RefAdditionWorkload<>(invalidData, invalidInfo));
AddInputToWorkload(invalidData, invalidInfo, input3TensorInfo, nullptr);
// Too many inputs.
- BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefAdditionWorkload<>(invalidData, invalidInfo), armnn::InvalidArgumentException);
}
BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputShapes)
@@ -371,7 +371,7 @@ BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputShapes)
AddInputToWorkload(invalidData, invalidInfo, input2TensorInfo, nullptr);
AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr);
- BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefAdditionWorkload<>(invalidData, invalidInfo), armnn::InvalidArgumentException);
}
// Output size not compatible with input sizes.
@@ -388,7 +388,7 @@ BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputShapes)
AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr);
// Output differs.
- BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefAdditionWorkload<>(invalidData, invalidInfo), armnn::InvalidArgumentException);
}
}
@@ -423,7 +423,7 @@ BOOST_AUTO_TEST_CASE(MultiplicationQueueDescriptor_Validate_InputTensorDimension
AddInputToWorkload(invalidData, invalidInfo, input0TensorInfo, nullptr);
AddInputToWorkload(invalidData, invalidInfo, input1TensorInfo, nullptr);
- BOOST_CHECK_THROW(RefMultiplicationWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefMultiplicationWorkload<>(invalidData, invalidInfo), armnn::InvalidArgumentException);
}
// Checks dimension consistency for input and output tensors.
@@ -448,7 +448,7 @@ BOOST_AUTO_TEST_CASE(MultiplicationQueueDescriptor_Validate_InputTensorDimension
AddInputToWorkload(invalidData, invalidInfo, input0TensorInfo, nullptr);
AddInputToWorkload(invalidData, invalidInfo, input1TensorInfo, nullptr);
- BOOST_CHECK_THROW(RefMultiplicationWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefMultiplicationWorkload<>(invalidData, invalidInfo), armnn::InvalidArgumentException);
}
}
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 643684c5b0..dcdabe17ff 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -141,7 +141,14 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateActivation(const Activation
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<RefAdditionWorkload>(descriptor, info);
+ if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
+ {
+ return std::make_unique<RefAdditionWorkload<int32_t>>(descriptor, info);
+ }
+ else
+ {
+ return std::make_unique<RefAdditionWorkload<float>>(descriptor, info);
+ }
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor,
@@ -279,7 +286,14 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDetectionPostProcess(
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<RefDivisionWorkload>(descriptor, info);
+ if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
+ {
+ return std::make_unique<RefDivisionWorkload<int32_t>>(descriptor, info);
+ }
+ else
+ {
+ return std::make_unique<RefDivisionWorkload<float>>(descriptor, info);
+ }
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& descriptor,
@@ -387,7 +401,14 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLstm(const LstmQueueDescrip
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<RefMaximumWorkload>(descriptor, info);
+ if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
+ {
+ return std::make_unique<RefMaximumWorkload<int32_t>>(descriptor, info);
+ }
+ else
+ {
+ return std::make_unique<RefMaximumWorkload<float>>(descriptor, info);
+ }
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
@@ -425,13 +446,27 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMerger(const MergerQueueDes
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<RefMinimumWorkload>(descriptor, info);
+ if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
+ {
+ return std::make_unique<RefMinimumWorkload<int32_t>>(descriptor, info);
+ }
+ else
+ {
+ return std::make_unique<RefMinimumWorkload<float>>(descriptor, info);
+ }
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<RefMultiplicationWorkload>(descriptor, info);
+ if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
+ {
+ return std::make_unique<RefMultiplicationWorkload<int32_t>>(descriptor, info);
+ }
+ else
+ {
+ return std::make_unique<RefMultiplicationWorkload<float>>(descriptor, info);
+ }
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
@@ -593,7 +628,14 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateStridedSlice(const StridedS
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<RefSubtractionWorkload>(descriptor, info);
+ if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
+ {
+ return std::make_unique<RefSubtractionWorkload<int32_t>>(descriptor, info);
+ }
+ else
+ {
+ return std::make_unique<RefSubtractionWorkload<float>>(descriptor, info);
+ }
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& descriptor,
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 9c08909e95..b1e49e6ff3 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -91,7 +91,7 @@ static void RefCreateElementwiseWorkloadTest()
BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload)
{
- RefCreateElementwiseWorkloadTest<RefAdditionWorkload,
+ RefCreateElementwiseWorkloadTest<RefAdditionWorkload<>,
AdditionQueueDescriptor,
AdditionLayer,
armnn::DataType::Float32>();
@@ -99,7 +99,7 @@ BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload)
BOOST_AUTO_TEST_CASE(CreateAdditionUint8Workload)
{
- RefCreateElementwiseWorkloadTest<RefAdditionWorkload,
+ RefCreateElementwiseWorkloadTest<RefAdditionWorkload<>,
AdditionQueueDescriptor,
AdditionLayer,
armnn::DataType::QAsymmU8>();
@@ -107,15 +107,23 @@ BOOST_AUTO_TEST_CASE(CreateAdditionUint8Workload)
BOOST_AUTO_TEST_CASE(CreateAdditionInt16Workload)
{
- RefCreateElementwiseWorkloadTest<RefAdditionWorkload,
+ RefCreateElementwiseWorkloadTest<RefAdditionWorkload<>,
AdditionQueueDescriptor,
AdditionLayer,
armnn::DataType::QSymmS16>();
}
+BOOST_AUTO_TEST_CASE(CreateAdditionInt32Workload)
+{
+ RefCreateElementwiseWorkloadTest<RefAdditionWorkload<int32_t>,
+ AdditionQueueDescriptor,
+ AdditionLayer,
+ armnn::DataType::Signed32>();
+}
+
BOOST_AUTO_TEST_CASE(CreateSubtractionFloat32Workload)
{
- RefCreateElementwiseWorkloadTest<RefSubtractionWorkload,
+ RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
SubtractionQueueDescriptor,
SubtractionLayer,
armnn::DataType::Float32>();
@@ -123,7 +131,7 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionFloat32Workload)
BOOST_AUTO_TEST_CASE(CreateSubtractionFloat16Workload)
{
- RefCreateElementwiseWorkloadTest<RefSubtractionWorkload,
+ RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
SubtractionQueueDescriptor,
SubtractionLayer,
armnn::DataType::Float16>();
@@ -131,7 +139,7 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionFloat16Workload)
BOOST_AUTO_TEST_CASE(CreateSubtractionUint8Workload)
{
- RefCreateElementwiseWorkloadTest<RefSubtractionWorkload,
+ RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
SubtractionQueueDescriptor,
SubtractionLayer,
armnn::DataType::QAsymmU8>();
@@ -139,15 +147,23 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionUint8Workload)
BOOST_AUTO_TEST_CASE(CreateSubtractionInt16Workload)
{
- RefCreateElementwiseWorkloadTest<RefSubtractionWorkload,
+ RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
SubtractionQueueDescriptor,
SubtractionLayer,
armnn::DataType::QSymmS16>();
}
+BOOST_AUTO_TEST_CASE(CreateSubtractionInt32Workload)
+{
+ RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<int32_t>,
+ SubtractionQueueDescriptor,
+ SubtractionLayer,
+ armnn::DataType::Signed32>();
+}
+
BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkload)
{
- RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload,
+ RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<>,
MultiplicationQueueDescriptor,
MultiplicationLayer,
armnn::DataType::Float32>();
@@ -155,7 +171,7 @@ BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkload)
BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8Workload)
{
- RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload,
+ RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<>,
MultiplicationQueueDescriptor,
MultiplicationLayer,
armnn::DataType::QAsymmU8>();
@@ -163,15 +179,23 @@ BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8Workload)
BOOST_AUTO_TEST_CASE(CreateMultiplicationInt16Workload)
{
- RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload,
+ RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<>,
MultiplicationQueueDescriptor,
MultiplicationLayer,
armnn::DataType::QSymmS16>();
}
+BOOST_AUTO_TEST_CASE(CreateMultiplicationInt32Workload)
+{
+ RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<int32_t>,
+ MultiplicationQueueDescriptor,
+ MultiplicationLayer,
+ armnn::DataType::Signed32>();
+}
+
BOOST_AUTO_TEST_CASE(CreateDivisionFloat32Workload)
{
- RefCreateElementwiseWorkloadTest<RefDivisionWorkload,
+ RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
DivisionQueueDescriptor,
DivisionLayer,
armnn::DataType::Float32>();
@@ -179,7 +203,7 @@ BOOST_AUTO_TEST_CASE(CreateDivisionFloat32Workload)
BOOST_AUTO_TEST_CASE(CreateDivisionFloat16Workload)
{
- RefCreateElementwiseWorkloadTest<RefDivisionWorkload,
+ RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
DivisionQueueDescriptor,
DivisionLayer,
armnn::DataType::Float16>();
@@ -187,7 +211,7 @@ BOOST_AUTO_TEST_CASE(CreateDivisionFloat16Workload)
BOOST_AUTO_TEST_CASE(CreateDivisionUint8Workload)
{
- RefCreateElementwiseWorkloadTest<RefDivisionWorkload,
+ RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
DivisionQueueDescriptor,
DivisionLayer,
armnn::DataType::QAsymmU8>();
@@ -195,12 +219,20 @@ BOOST_AUTO_TEST_CASE(CreateDivisionUint8Workload)
BOOST_AUTO_TEST_CASE(CreateDivisionInt16Workload)
{
- RefCreateElementwiseWorkloadTest<RefDivisionWorkload,
+ RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
DivisionQueueDescriptor,
DivisionLayer,
armnn::DataType::QSymmS16>();
}
+BOOST_AUTO_TEST_CASE(CreateDivisionInt32Workload)
+{
+ RefCreateElementwiseWorkloadTest<RefDivisionWorkload<int32_t>,
+ DivisionQueueDescriptor,
+ DivisionLayer,
+ armnn::DataType::Signed32>();
+}
+
template <typename BatchNormalizationWorkloadType, armnn::DataType DataType>
static void RefCreateBatchNormalizationWorkloadTest(DataLayout dataLayout)
{
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
index be20644ab7..1f4f2da717 100644
--- a/src/backends/reference/workloads/BaseIterator.hpp
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -274,6 +274,21 @@ public:
}
};
+class Int32ToInt32tDecoder : public TypedIterator<const int32_t, Decoder<int32_t>>
+{
+public:
+ Int32ToInt32tDecoder(const int32_t* data)
+ : TypedIterator(data){}
+
+ Int32ToInt32tDecoder()
+ : Int32ToInt32tDecoder(nullptr) {}
+
+ int32_t Get() const override
+ {
+ return *m_Iterator;
+ }
+};
+
class BooleanDecoder : public TypedIterator<const uint8_t, Decoder<float>>
{
public:
@@ -470,6 +485,26 @@ public:
}
};
+class Int32ToInt32tEncoder : public TypedIterator<int32_t, Encoder<int32_t>>
+{
+public:
+ Int32ToInt32tEncoder(int32_t* data)
+ : TypedIterator(data){}
+
+ Int32ToInt32tEncoder()
+ : Int32ToInt32tEncoder(nullptr) {}
+
+ void Set(int32_t right) override
+ {
+ *m_Iterator = right;
+ }
+
+ int32_t Get() const override
+ {
+ return *m_Iterator;
+ }
+};
+
class BooleanEncoder : public TypedIterator<uint8_t, Encoder<bool>>
{
public:
diff --git a/src/backends/reference/workloads/Decoders.hpp b/src/backends/reference/workloads/Decoders.hpp
index deb3b1f4b2..08e0140fad 100644
--- a/src/backends/reference/workloads/Decoders.hpp
+++ b/src/backends/reference/workloads/Decoders.hpp
@@ -149,4 +149,22 @@ inline std::unique_ptr<Decoder<float>> MakeDecoder(const TensorInfo& info, const
return nullptr;
}
+template<>
+inline std::unique_ptr<Decoder<int32_t>> MakeDecoder(const TensorInfo& info, const void* data)
+{
+ switch(info.GetDataType())
+ {
+ case DataType::Signed32:
+ {
+ return std::make_unique<Int32ToInt32tDecoder>(static_cast<const int32_t*>(data));
+ }
+ default:
+ {
+ ARMNN_ASSERT_MSG(false, "Unsupported Data Type!");
+ break;
+ }
+ }
+ return nullptr;
+}
+
} //namespace armnn
diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp
index 5687cf5861..afae188bd6 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.cpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.cpp
@@ -46,6 +46,13 @@ template struct armnn::ElementwiseBinaryFunction<std::divides<float>>;
template struct armnn::ElementwiseBinaryFunction<armnn::maximum<float>>;
template struct armnn::ElementwiseBinaryFunction<armnn::minimum<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::plus<int32_t>>;
+template struct armnn::ElementwiseBinaryFunction<std::minus<int32_t>>;
+template struct armnn::ElementwiseBinaryFunction<std::multiplies<int32_t>>;
+template struct armnn::ElementwiseBinaryFunction<std::divides<int32_t>>;
+template struct armnn::ElementwiseBinaryFunction<armnn::maximum<int32_t>>;
+template struct armnn::ElementwiseBinaryFunction<armnn::minimum<int32_t>>;
+
// Comparison
template struct armnn::ElementwiseBinaryFunction<std::equal_to<float>>;
template struct armnn::ElementwiseBinaryFunction<std::greater<float>>;
diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp
index c0524a7719..a2d565ec4a 100644
--- a/src/backends/reference/workloads/Encoders.hpp
+++ b/src/backends/reference/workloads/Encoders.hpp
@@ -114,4 +114,22 @@ inline std::unique_ptr<Encoder<bool>> MakeEncoder(const TensorInfo& info, void*
return nullptr;
}
+template<>
+inline std::unique_ptr<Encoder<int32_t>> MakeEncoder(const TensorInfo& info, void* data)
+{
+ switch(info.GetDataType())
+ {
+ case DataType::Signed32:
+ {
+ return std::make_unique<Int32ToInt32tEncoder>(static_cast<int32_t*>(data));
+ }
+ default:
+ {
+ ARMNN_ASSERT_MSG(false, "Unsupported Data Type!");
+ break;
+ }
+ }
+ return nullptr;
+}
+
} //namespace armnn
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
index 18bf0a7ad9..60acbd6252 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
@@ -67,22 +67,46 @@ template class armnn::RefElementwiseWorkload<std::plus<float>,
armnn::AdditionQueueDescriptor,
armnn::StringMapping::RefAdditionWorkload_Execute>;
+template class armnn::RefElementwiseWorkload<std::plus<int32_t>,
+ armnn::AdditionQueueDescriptor,
+ armnn::StringMapping::RefAdditionWorkload_Execute>;
+
template class armnn::RefElementwiseWorkload<std::minus<float>,
armnn::SubtractionQueueDescriptor,
armnn::StringMapping::RefSubtractionWorkload_Execute>;
+template class armnn::RefElementwiseWorkload<std::minus<int32_t>,
+ armnn::SubtractionQueueDescriptor,
+ armnn::StringMapping::RefSubtractionWorkload_Execute>;
+
template class armnn::RefElementwiseWorkload<std::multiplies<float>,
armnn::MultiplicationQueueDescriptor,
armnn::StringMapping::RefMultiplicationWorkload_Execute>;
+template class armnn::RefElementwiseWorkload<std::multiplies<int32_t>,
+ armnn::MultiplicationQueueDescriptor,
+ armnn::StringMapping::RefMultiplicationWorkload_Execute>;
+
template class armnn::RefElementwiseWorkload<std::divides<float>,
armnn::DivisionQueueDescriptor,
armnn::StringMapping::RefDivisionWorkload_Execute>;
+template class armnn::RefElementwiseWorkload<std::divides<int32_t>,
+ armnn::DivisionQueueDescriptor,
+ armnn::StringMapping::RefDivisionWorkload_Execute>;
+
template class armnn::RefElementwiseWorkload<armnn::maximum<float>,
armnn::MaximumQueueDescriptor,
armnn::StringMapping::RefMaximumWorkload_Execute>;
+template class armnn::RefElementwiseWorkload<armnn::maximum<int32_t>,
+ armnn::MaximumQueueDescriptor,
+ armnn::StringMapping::RefMaximumWorkload_Execute>;
+
template class armnn::RefElementwiseWorkload<armnn::minimum<float>,
armnn::MinimumQueueDescriptor,
armnn::StringMapping::RefMinimumWorkload_Execute>;
+
+template class armnn::RefElementwiseWorkload<armnn::minimum<int32_t>,
+ armnn::MinimumQueueDescriptor,
+ armnn::StringMapping::RefMinimumWorkload_Execute>;
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
index 264ddce2de..03683b1a06 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
@@ -35,33 +35,39 @@ private:
std::unique_ptr<Encoder<OutType>> m_Output;
};
+template <typename DataType = float>
using RefAdditionWorkload =
- RefElementwiseWorkload<std::plus<float>,
+ RefElementwiseWorkload<std::plus<DataType>,
AdditionQueueDescriptor,
StringMapping::RefAdditionWorkload_Execute>;
+template <typename DataType = float>
using RefSubtractionWorkload =
- RefElementwiseWorkload<std::minus<float>,
+ RefElementwiseWorkload<std::minus<DataType>,
SubtractionQueueDescriptor,
StringMapping::RefSubtractionWorkload_Execute>;
+template <typename DataType = float>
using RefMultiplicationWorkload =
- RefElementwiseWorkload<std::multiplies<float>,
+ RefElementwiseWorkload<std::multiplies<DataType>,
MultiplicationQueueDescriptor,
StringMapping::RefMultiplicationWorkload_Execute>;
+template <typename DataType = float>
using RefDivisionWorkload =
- RefElementwiseWorkload<std::divides<float>,
+ RefElementwiseWorkload<std::divides<DataType>,
DivisionQueueDescriptor,
StringMapping::RefDivisionWorkload_Execute>;
+template <typename DataType = float>
using RefMaximumWorkload =
- RefElementwiseWorkload<armnn::maximum<float>,
+ RefElementwiseWorkload<armnn::maximum<DataType>,
MaximumQueueDescriptor,
StringMapping::RefMaximumWorkload_Execute>;
+template <typename DataType = float>
using RefMinimumWorkload =
- RefElementwiseWorkload<armnn::minimum<float>,
+ RefElementwiseWorkload<armnn::minimum<DataType>,
MinimumQueueDescriptor,
StringMapping::RefMinimumWorkload_Execute>;