diff options
author | kevmay01 <kevin.may@arm.com> | 2019-01-24 14:05:09 +0000 |
---|---|---|
committer | kevmay01 <kevin.may@arm.com> | 2019-01-24 14:05:09 +0000 |
commit | 2b4d88e34ac1f965417fd236fd4786f26bae2042 (patch) | |
tree | 4518b52c6a22e33c4b467588a2843c9d5f1a9ee6 /src/backends/reference | |
parent | 94412aff782472be54dce4328e2ecee0225b3e97 (diff) | |
download | armnn-2b4d88e34ac1f965417fd236fd4786f26bae2042.tar.gz |
IVGCVSW-2503 Refactor RefElementwiseWorkload around Equal and Greater
* Remove Equal and Greater from RefElementwiseWorkload
* Create RefComparisonWorkload and add Equal and Greater
* Update ElementwiseFunction for different input/output types
* Update TfParser to create Equal/Greater with Boolean output
* Update relevant tests to check for Boolean comparison
Change-Id: I299b7f2121769c960ac0c6139764a5f3c89c9c32
Diffstat (limited to 'src/backends/reference')
12 files changed, 252 insertions, 89 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 45f108c2f8..78e44bd6a3 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -35,6 +35,7 @@ bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported, floatFuncPtr, uint8FuncPtr, &FalseFunc<Params...>, + &FalseFunc<Params...>, std::forward<Params>(params)...); } @@ -111,7 +112,8 @@ bool RefLayerSupport::IsConstantSupported(const TensorInfo& output, &FalseFunc<>, &TrueFunc<>, &TrueFunc<>, - &TrueFunc<>); + &TrueFunc<>, + &FalseFunc<>); } bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input, @@ -123,13 +125,15 @@ bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input, &TrueFunc<>, &FalseInputFuncF32<>, &FalseFuncU8<>, - &FalseFuncI32<>) && + &FalseFuncI32<>, + &FalseFuncU8<>) && IsSupportedForDataTypeGeneric(reasonIfUnsupported, output.GetDataType(), &FalseOutputFuncF16<>, &TrueFunc<>, &FalseFuncU8<>, - &FalseFuncI32<>)); + &FalseFuncI32<>, + &FalseFuncU8<>)); } bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input, @@ -141,13 +145,15 @@ bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input, &FalseInputFuncF16<>, &TrueFunc<>, &FalseFuncU8<>, - &FalseFuncI32<>) && + &FalseFuncI32<>, + &FalseFuncU8<>) && IsSupportedForDataTypeGeneric(reasonIfUnsupported, output.GetDataType(), &TrueFunc<>, &FalseOutputFuncF32<>, &FalseFuncU8<>, - &FalseFuncI32<>)); + &FalseFuncI32<>, + &FalseFuncU8<>)); } bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input, @@ -415,10 +421,13 @@ bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input, Optional<std::string &> reasonIfUnsupported) const { ignore_unused(output); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); + return IsSupportedForDataTypeGeneric(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>, + &TrueFunc<>, + &FalseFuncI32<>, + &TrueFunc<>); } bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0, @@ -463,10 +472,13 @@ bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input, bool RefLayerSupport::IsOutputSupported(const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const { - return IsSupportedForDataTypeRef(reasonIfUnsupported, - output.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); + return IsSupportedForDataTypeGeneric(reasonIfUnsupported, + output.GetDataType(), + &TrueFunc<>, + &TrueFunc<>, + &TrueFunc<>, + &FalseFuncI32<>, + &TrueFunc<>); } bool RefLayerSupport::IsPadSupported(const TensorInfo& input, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index b112e9dd6a..75a9efd70f 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -24,7 +24,8 @@ template <typename F32Workload, typename U8Workload, typename QueueDescriptorTyp std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info) const { - return armnn::MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload, NullWorkload>(descriptor, info); + return armnn::MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload, NullWorkload, NullWorkload>(descriptor, + info); } RefWorkloadFactory::RefWorkloadFactory() @@ -90,7 +91,8 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateOutput(const OutputQueueDes throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: data input and output differ in byte count."); } - return MakeWorkload<CopyMemGenericWorkload, CopyMemGenericWorkload>(descriptor, info); + return MakeWorkloadHelper<CopyMemGenericWorkload, CopyMemGenericWorkload, + CopyMemGenericWorkload, NullWorkload, CopyMemGenericWorkload>(descriptor, info); } std::unique_ptr<IWorkload> RefWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor, @@ -127,7 +129,7 @@ std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePermute(const Permut const WorkloadInfo& info) const { return MakeWorkloadHelper<RefPermuteFloat16Workload, RefPermuteFloat32Workload, RefPermuteUint8Workload, - NullWorkload>(descriptor, info); + NullWorkload, NullWorkload>(descriptor, info); } std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor, @@ -206,7 +208,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConstant(const ConstantQueu const WorkloadInfo& info) const { return MakeWorkloadHelper<NullWorkload, RefConstantFloat32Workload, RefConstantUint8Workload, - RefConstantInt32Workload>(descriptor, info); + RefConstantInt32Workload, NullWorkload>(descriptor, info); } std::unique_ptr<IWorkload> RefWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor, diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index 763f26e18c..3ee07913dc 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -28,6 +28,7 @@ BACKEND_SOURCES := \ workloads/RefBatchNormalizationUint8Workload.cpp \ workloads/RefBatchToSpaceNdFloat32Workload.cpp \ workloads/RefBatchToSpaceNdUint8Workload.cpp \ + workloads/RefComparisonWorkload.cpp \ workloads/RefConstantWorkload.cpp \ workloads/RefConvertFp16ToFp32Workload.cpp \ workloads/RefConvertFp32ToFp16Workload.cpp \ diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp index 330f406265..802167a3a0 100644 --- a/src/backends/reference/test/RefEndToEndTests.cpp +++ b/src/backends/reference/test/RefEndToEndTests.cpp @@ -315,18 +315,22 @@ BOOST_AUTO_TEST_CASE(TrivialMin) BOOST_AUTO_TEST_CASE(RefEqualSimpleEndToEndTest) { - const std::vector<float > expectedOutput({ 1, 1, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 1, 1, 1 }); + const std::vector<uint8_t> expectedOutput({ 1, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 1 }); - ArithmeticSimpleEndToEnd<armnn::DataType::Float32>(defaultBackends, LayerType::Equal, expectedOutput); + ArithmeticSimpleEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends, + LayerType::Equal, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndTest) { - const std::vector<float> expectedOutput({ 0, 0, 0, 0, 1, 1, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0 }); + const std::vector<uint8_t> expectedOutput({ 0, 0, 0, 0, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0 }); - ArithmeticSimpleEndToEnd<armnn::DataType::Float32>(defaultBackends, LayerType::Greater, expectedOutput); + ArithmeticSimpleEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends, + LayerType::Greater, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefEqualSimpleEndToEndUint8Test) @@ -334,7 +338,9 @@ BOOST_AUTO_TEST_CASE(RefEqualSimpleEndToEndUint8Test) const std::vector<uint8_t> expectedOutput({ 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1 }); - ArithmeticSimpleEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends, LayerType::Equal, expectedOutput); + ArithmeticSimpleEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends, + LayerType::Equal, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndUint8Test) @@ -342,23 +348,29 @@ BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndUint8Test) const std::vector<uint8_t> expectedOutput({ 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 }); - ArithmeticSimpleEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends, LayerType::Greater, expectedOutput); + ArithmeticSimpleEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends, + LayerType::Greater, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndTest) { - const std::vector<float > expectedOutput({ 1, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0 }); + const std::vector<uint8_t> expectedOutput({ 1, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0 }); - ArithmeticBroadcastEndToEnd<armnn::DataType::Float32>(defaultBackends, LayerType::Equal, expectedOutput); + ArithmeticBroadcastEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends, + LayerType::Equal, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndTest) { - const std::vector<float> expectedOutput({ 0, 1, 0, 0, 0, 1, - 1, 1, 1, 1, 1, 1 }); + const std::vector<uint8_t> expectedOutput({ 0, 1, 0, 0, 0, 1, + 1, 1, 1, 1, 1, 1 }); - ArithmeticBroadcastEndToEnd<armnn::DataType::Float32>(defaultBackends, LayerType::Greater, expectedOutput); + ArithmeticBroadcastEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends, + LayerType::Greater, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndUint8Test) @@ -366,7 +378,9 @@ BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndUint8Test) const std::vector<uint8_t > expectedOutput({ 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 }); - ArithmeticBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends, LayerType::Equal, expectedOutput); + ArithmeticBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends, + LayerType::Equal, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndUint8Test) @@ -374,7 +388,9 @@ BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndUint8Test) const std::vector<uint8_t> expectedOutput({ 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1 }); - ArithmeticBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends, LayerType::Greater, expectedOutput); + ArithmeticBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends, + LayerType::Greater, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefMergerEndToEndDim0Test) diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index f95fda08d1..57e89fa456 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -40,6 +40,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefBatchToSpaceNdFloat32Workload.hpp RefBatchToSpaceNdUint8Workload.cpp RefBatchToSpaceNdUint8Workload.hpp + RefComparisonWorkload.cpp + RefComparisonWorkload.hpp RefConstantWorkload.cpp RefConstantWorkload.hpp RefConvertFp16ToFp32Workload.cpp diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp index cb8aa7089c..c8c25ef9e9 100644 --- a/src/backends/reference/workloads/ElementwiseFunction.cpp +++ b/src/backends/reference/workloads/ElementwiseFunction.cpp @@ -13,24 +13,26 @@ namespace armnn { -template <typename Functor> -ElementwiseFunction<Functor>::ElementwiseFunction(const TensorShape& inShape0, - const TensorShape& inShape1, - const TensorShape& outShape, - const float* inData0, - const float* inData1, - float* outData) +template <typename Functor, typename dataTypeInput, typename dataTypeOutput> +ElementwiseFunction<Functor, dataTypeInput, dataTypeOutput>::ElementwiseFunction(const TensorShape& inShape0, + const TensorShape& inShape1, + const TensorShape& outShape, + const dataTypeInput* inData0, + const dataTypeInput* inData1, + dataTypeOutput* outData) { BroadcastLoop(inShape0, inShape1, outShape).Unroll(Functor(), 0, inData0, inData1, outData); } } //namespace armnn -template struct armnn::ElementwiseFunction<std::plus<float>>; -template struct armnn::ElementwiseFunction<std::minus<float>>; -template struct armnn::ElementwiseFunction<std::multiplies<float>>; -template struct armnn::ElementwiseFunction<std::divides<float>>; -template struct armnn::ElementwiseFunction<armnn::maximum<float>>; -template struct armnn::ElementwiseFunction<armnn::minimum<float>>; -template struct armnn::ElementwiseFunction<std::equal_to<float>>; -template struct armnn::ElementwiseFunction<std::greater<float>>; +template struct armnn::ElementwiseFunction<std::plus<float>, float, float>; +template struct armnn::ElementwiseFunction<std::minus<float>, float, float>; +template struct armnn::ElementwiseFunction<std::multiplies<float>, float, float>; +template struct armnn::ElementwiseFunction<std::divides<float>, float, float>; +template struct armnn::ElementwiseFunction<armnn::maximum<float>, float, float>; +template struct armnn::ElementwiseFunction<armnn::minimum<float>, float, float>; +template struct armnn::ElementwiseFunction<std::equal_to<float>, float ,uint8_t>; +template struct armnn::ElementwiseFunction<std::equal_to<uint8_t>, uint8_t, uint8_t>; +template struct armnn::ElementwiseFunction<std::greater<float>, float, uint8_t>; +template struct armnn::ElementwiseFunction<std::greater<uint8_t>, uint8_t, uint8_t>; diff --git a/src/backends/reference/workloads/ElementwiseFunction.hpp b/src/backends/reference/workloads/ElementwiseFunction.hpp index 0ac136466c..8099f3279a 100644 --- a/src/backends/reference/workloads/ElementwiseFunction.hpp +++ b/src/backends/reference/workloads/ElementwiseFunction.hpp @@ -10,15 +10,15 @@ namespace armnn { -template <typename Functor> +template <typename Functor, typename dataTypeInput, typename dataTypeOutput> struct ElementwiseFunction { ElementwiseFunction(const TensorShape& inShape0, const TensorShape& inShape1, const TensorShape& outShape, - const float* inData0, - const float* inData1, - float* outData); + const dataTypeInput* inData0, + const dataTypeInput* inData1, + dataTypeOutput* outData); }; } //namespace armnn diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp new file mode 100644 index 0000000000..fe517ff51a --- /dev/null +++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp @@ -0,0 +1,65 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefComparisonWorkload.hpp" +#include "ElementwiseFunction.hpp" +#include "RefWorkloadUtils.hpp" +#include "Profiling.hpp" +#include <vector> + +namespace armnn { + +template<typename ParentDescriptor, typename Functor> +void RefFloat32ComparisonWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char* debugString) const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); + + auto data = BaseFloat32ComparisonWorkload<ParentDescriptor>::GetData(); + const TensorShape& inShape0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); + const TensorShape& inShape1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); + const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); + + const float* inData0 = GetInputTensorDataFloat(0, data); + const float* inData1 = GetInputTensorDataFloat(1, data); + uint8_t* outData = GetOutputTensorData<uint8_t>(0, data); + + ElementwiseFunction<Functor, float, uint8_t>(inShape0, + inShape1, + outputShape, + inData0, + inData1, + outData); + +} + +template<typename ParentDescriptor, typename Functor> +void RefUint8ComparisonWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char* debugString) const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); + + auto data = BaseUint8ComparisonWorkload<ParentDescriptor>::GetData(); + const TensorShape& inputInfo0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); + const TensorShape& inputInfo1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); + const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); + + const uint8_t* inData0 = GetInputTensorData<uint8_t>(0, data); + const uint8_t* inData1 = GetInputTensorData<uint8_t>(1, data); + uint8_t* outData = GetOutputTensorData<uint8_t>(0, data); + + ElementwiseFunction<Functor, uint8_t, uint8_t>(inputInfo0, + inputInfo1, + outputShape, + inData0, + inData1, + outData); +} + +} + +template class armnn::RefFloat32ComparisonWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>; +template class armnn::RefUint8ComparisonWorkload<armnn::EqualQueueDescriptor, std::equal_to<uint8_t>>; + +template class armnn::RefFloat32ComparisonWorkload<armnn::GreaterQueueDescriptor, std::greater<float>>; +template class armnn::RefUint8ComparisonWorkload<armnn::GreaterQueueDescriptor, std::greater<uint8_t>>; diff --git a/src/backends/reference/workloads/RefComparisonWorkload.hpp b/src/backends/reference/workloads/RefComparisonWorkload.hpp new file mode 100644 index 0000000000..524d20625a --- /dev/null +++ b/src/backends/reference/workloads/RefComparisonWorkload.hpp @@ -0,0 +1,92 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <armnn/Types.hpp> +#include <backendsCommon/Workload.hpp> +#include <backendsCommon/WorkloadData.hpp> +#include "StringMapping.hpp" + +namespace armnn +{ + +template <typename Functor, + typename armnn::DataType DataType, + typename ParentDescriptor, + typename armnn::StringMapping::Id DebugString> +class RefComparisonWorkload +{ + // Needs specialization. The default is empty on purpose. +}; + +template <typename ParentDescriptor, typename Functor> +class RefFloat32ComparisonWorkload : public BaseFloat32ComparisonWorkload<ParentDescriptor> +{ +public: + using BaseFloat32ComparisonWorkload<ParentDescriptor>::BaseFloat32ComparisonWorkload; + void ExecuteImpl(const char * debugString) const; +}; + +template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> +class RefComparisonWorkload<Functor, armnn::DataType::Float32, ParentDescriptor, DebugString> + : public RefFloat32ComparisonWorkload<ParentDescriptor, Functor> +{ +public: + using RefFloat32ComparisonWorkload<ParentDescriptor, Functor>::RefFloat32ComparisonWorkload; + + virtual void Execute() const override + { + using Parent = RefFloat32ComparisonWorkload<ParentDescriptor, Functor>; + Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString)); + } +}; + +template <typename ParentDescriptor, typename Functor> +class RefUint8ComparisonWorkload : public BaseUint8ComparisonWorkload<ParentDescriptor> +{ +public: + using BaseUint8ComparisonWorkload<ParentDescriptor>::BaseUint8ComparisonWorkload; + void ExecuteImpl(const char * debugString) const; +}; + +template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> +class RefComparisonWorkload<Functor, armnn::DataType::QuantisedAsymm8, ParentDescriptor, DebugString> + : public RefUint8ComparisonWorkload<ParentDescriptor, Functor> +{ +public: + using RefUint8ComparisonWorkload<ParentDescriptor, Functor>::RefUint8ComparisonWorkload; + + virtual void Execute() const override + { + using Parent = RefUint8ComparisonWorkload<ParentDescriptor, Functor>; + Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString)); + } +}; + +using RefEqualFloat32Workload = + RefComparisonWorkload<std::equal_to<float>, + DataType::Float32, + EqualQueueDescriptor, + StringMapping::RefEqualWorkload_Execute>; + +using RefEqualUint8Workload = + RefComparisonWorkload<std::equal_to<uint8_t>, + DataType::QuantisedAsymm8, + EqualQueueDescriptor, + StringMapping::RefEqualWorkload_Execute>; + +using RefGreaterFloat32Workload = + RefComparisonWorkload<std::greater<float>, + DataType::Float32, + GreaterQueueDescriptor, + StringMapping::RefGreaterWorkload_Execute>; + +using RefGreaterUint8Workload = + RefComparisonWorkload<std::greater<uint8_t>, + DataType::QuantisedAsymm8, + GreaterQueueDescriptor, + StringMapping::RefGreaterWorkload_Execute>; +} // armnn diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp index 13d6e70a96..c9b93c8524 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp @@ -26,7 +26,7 @@ void BaseFloat32ElementwiseWorkload<ParentDescriptor, Functor>::ExecuteImpl(cons const float* inData1 = GetInputTensorDataFloat(1, data); float* outData = GetOutputTensorDataFloat(0, data); - ElementwiseFunction<Functor>(inShape0, inShape1, outShape, inData0, inData1, outData); + ElementwiseFunction<Functor, float, float>(inShape0, inShape1, outShape, inData0, inData1, outData); } template <typename ParentDescriptor, typename Functor> @@ -44,12 +44,12 @@ void BaseUint8ElementwiseWorkload<ParentDescriptor, Functor>::ExecuteImpl(const std::vector<float> results(outputInfo.GetNumElements()); - ElementwiseFunction<Functor>(inputInfo0.GetShape(), - inputInfo1.GetShape(), - outputInfo.GetShape(), - dequant0.data(), - dequant1.data(), - results.data()); + ElementwiseFunction<Functor, float, float>(inputInfo0.GetShape(), + inputInfo1.GetShape(), + outputInfo.GetShape(), + dequant0.data(), + dequant1.data(), + results.data()); Quantize(GetOutputTensorDataU8(0, data), results.data(), outputInfo); } @@ -73,9 +73,3 @@ template class armnn::BaseUint8ElementwiseWorkload<armnn::MaximumQueueDescriptor template class armnn::BaseFloat32ElementwiseWorkload<armnn::MinimumQueueDescriptor, armnn::minimum<float>>; template class armnn::BaseUint8ElementwiseWorkload<armnn::MinimumQueueDescriptor, armnn::minimum<float>>; - -template class armnn::BaseFloat32ElementwiseWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>; -template class armnn::BaseUint8ElementwiseWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>; - -template class armnn::BaseFloat32ElementwiseWorkload<armnn::GreaterQueueDescriptor, std::greater<float>>; -template class armnn::BaseUint8ElementwiseWorkload<armnn::GreaterQueueDescriptor, std::greater<float>>; diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp index 6dd6865f53..a5ff376673 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp @@ -144,28 +144,4 @@ using RefMinimumUint8Workload = DataType::QuantisedAsymm8, MinimumQueueDescriptor, StringMapping::RefMinimumWorkload_Execute>; - -using RefEqualFloat32Workload = - RefElementwiseWorkload<std::equal_to<float>, - DataType::Float32, - EqualQueueDescriptor, - StringMapping::RefEqualWorkload_Execute>; - -using RefEqualUint8Workload = - RefElementwiseWorkload<std::equal_to<float>, - DataType::QuantisedAsymm8, - EqualQueueDescriptor, - StringMapping::RefEqualWorkload_Execute>; - -using RefGreaterFloat32Workload = - RefElementwiseWorkload<std::greater<float>, - DataType::Float32, - GreaterQueueDescriptor, - StringMapping::RefGreaterWorkload_Execute>; - -using RefGreaterUint8Workload = - RefElementwiseWorkload<std::greater<float>, - DataType::QuantisedAsymm8, - GreaterQueueDescriptor, - StringMapping::RefGreaterWorkload_Execute>; } // armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 1cbceb366b..d9f4dbb342 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -60,3 +60,4 @@ #include "RefBatchToSpaceNdFloat32Workload.hpp" #include "RefDebugWorkload.hpp" #include "RefRsqrtFloat32Workload.hpp" +#include "RefComparisonWorkload.hpp" |