From cbd2c230b7ce5f26e2ccccf36b7ad450f6e1ad09 Mon Sep 17 00:00:00 2001 From: Finn Williams Date: Mon, 22 Jun 2020 15:58:32 +0100 Subject: IVGCVSW-5007 Implement an Int32 reference Elementwise workload Signed-off-by: Finn Williams Change-Id: I6592169b74ac4294bc09647879aec0718c641f91 --- src/backends/reference/RefWorkloadFactory.cpp | 54 ++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 6 deletions(-) (limited to 'src/backends/reference/RefWorkloadFactory.cpp') diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 643684c5b0..dcdabe17ff 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -141,7 +141,14 @@ std::unique_ptr RefWorkloadFactory::CreateActivation(const Activation std::unique_ptr RefWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32) + { + return std::make_unique>(descriptor, info); + } + else + { + return std::make_unique>(descriptor, info); + } } std::unique_ptr RefWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor, @@ -279,7 +286,14 @@ std::unique_ptr RefWorkloadFactory::CreateDetectionPostProcess( std::unique_ptr RefWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32) + { + return std::make_unique>(descriptor, info); + } + else + { + return std::make_unique>(descriptor, info); + } } std::unique_ptr RefWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& descriptor, @@ -387,7 +401,14 @@ std::unique_ptr RefWorkloadFactory::CreateLstm(const LstmQueueDescrip std::unique_ptr RefWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32) + { + return std::make_unique>(descriptor, info); + } + else + { + return std::make_unique>(descriptor, info); + } } std::unique_ptr RefWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor, @@ -425,13 +446,27 @@ std::unique_ptr RefWorkloadFactory::CreateMerger(const MergerQueueDes std::unique_ptr RefWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32) + { + return std::make_unique>(descriptor, info); + } + else + { + return std::make_unique>(descriptor, info); + } } std::unique_ptr RefWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32) + { + return std::make_unique>(descriptor, info); + } + else + { + return std::make_unique>(descriptor, info); + } } std::unique_ptr RefWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor, @@ -593,7 +628,14 @@ std::unique_ptr RefWorkloadFactory::CreateStridedSlice(const StridedS std::unique_ptr RefWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32) + { + return std::make_unique>(descriptor, info); + } + else + { + return std::make_unique>(descriptor, info); + } } std::unique_ptr RefWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& descriptor, -- cgit v1.2.1