aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/RefWorkloadFactory.cpp
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2020-06-22 15:58:32 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2020-06-30 13:00:52 +0000
commitcbd2c230b7ce5f26e2ccccf36b7ad450f6e1ad09 (patch)
treeeb6e5393726be21213e72f26a676b7c3809fc995 /src/backends/reference/RefWorkloadFactory.cpp
parent532a29d12d72f54549d8b71edd485c17af65698a (diff)
downloadarmnn-cbd2c230b7ce5f26e2ccccf36b7ad450f6e1ad09.tar.gz
IVGCVSW-5007 Implement an Int32 reference Elementwise workload
Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: I6592169b74ac4294bc09647879aec0718c641f91
Diffstat (limited to 'src/backends/reference/RefWorkloadFactory.cpp')
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp54
1 files changed, 48 insertions, 6 deletions
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 643684c5b0..dcdabe17ff 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -141,7 +141,14 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateActivation(const Activation
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<RefAdditionWorkload>(descriptor, info);
+ if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
+ {
+ return std::make_unique<RefAdditionWorkload<int32_t>>(descriptor, info);
+ }
+ else
+ {
+ return std::make_unique<RefAdditionWorkload<float>>(descriptor, info);
+ }
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor,
@@ -279,7 +286,14 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDetectionPostProcess(
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<RefDivisionWorkload>(descriptor, info);
+ if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
+ {
+ return std::make_unique<RefDivisionWorkload<int32_t>>(descriptor, info);
+ }
+ else
+ {
+ return std::make_unique<RefDivisionWorkload<float>>(descriptor, info);
+ }
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& descriptor,
@@ -387,7 +401,14 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLstm(const LstmQueueDescrip
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<RefMaximumWorkload>(descriptor, info);
+ if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
+ {
+ return std::make_unique<RefMaximumWorkload<int32_t>>(descriptor, info);
+ }
+ else
+ {
+ return std::make_unique<RefMaximumWorkload<float>>(descriptor, info);
+ }
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
@@ -425,13 +446,27 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMerger(const MergerQueueDes
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<RefMinimumWorkload>(descriptor, info);
+ if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
+ {
+ return std::make_unique<RefMinimumWorkload<int32_t>>(descriptor, info);
+ }
+ else
+ {
+ return std::make_unique<RefMinimumWorkload<float>>(descriptor, info);
+ }
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<RefMultiplicationWorkload>(descriptor, info);
+ if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
+ {
+ return std::make_unique<RefMultiplicationWorkload<int32_t>>(descriptor, info);
+ }
+ else
+ {
+ return std::make_unique<RefMultiplicationWorkload<float>>(descriptor, info);
+ }
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
@@ -593,7 +628,14 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateStridedSlice(const StridedS
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<RefSubtractionWorkload>(descriptor, info);
+ if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
+ {
+ return std::make_unique<RefSubtractionWorkload<int32_t>>(descriptor, info);
+ }
+ else
+ {
+ return std::make_unique<RefSubtractionWorkload<float>>(descriptor, info);
+ }
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& descriptor,