diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-10-16 17:45:38 +0100 |
---|---|---|
committer | Áron Virginás-Tar <aron.virginas-tar@arm.com> | 2019-10-21 08:52:04 +0000 |
commit | 77bfb5e32faadb1383d48364a6f54adbff84ad80 (patch) | |
tree | 0bf5dfb48cb8d5c248baf716f02b9f481400316e /src/backends/reference | |
parent | 5884708e650a80e355398532bc320bbabdbb53f4 (diff) | |
download | armnn-77bfb5e32faadb1383d48364a6f54adbff84ad80.tar.gz |
IVGCVSW-3993 Add frontend and reference workload for ComparisonLayer
* Added frontend for ComparisonLayer
* Added RefComparisonWorkload
* Deprecated and removed Equal and Greater layers and workloads
* Updated tests to ensure backward compatibility
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: Id50c880be1b567c531efff919c0c366d0a71cbe9
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 85 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.hpp | 8 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 16 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.hpp | 5 | ||||
-rw-r--r-- | src/backends/reference/backend.mk | 1 | ||||
-rw-r--r-- | src/backends/reference/test/RefEndToEndTests.cpp | 50 | ||||
-rw-r--r-- | src/backends/reference/workloads/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/backends/reference/workloads/ElementwiseFunction.cpp | 7 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefComparisonWorkload.cpp | 102 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefComparisonWorkload.hpp | 34 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefElementwiseWorkload.cpp | 8 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefElementwiseWorkload.hpp | 9 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefWorkloads.hpp | 1 | ||||
-rw-r--r-- | src/backends/reference/workloads/StringMapping.hpp | 4 |
14 files changed, 237 insertions, 95 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 9342b29f47..c65886ba4d 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -308,6 +308,35 @@ bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input, return supported; } +bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + const ComparisonDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported) const +{ + boost::ignore_unused(descriptor); + + std::array<DataType, 4> supportedInputTypes = + { + DataType::Float32, + DataType::Float16, + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 + }; + + bool supported = true; + supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported, + "Reference comparison: input 0 is not a supported type"); + + supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported, + "Reference comparison: input 0 and Input 1 types are mismatched"); + + supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported, + "Reference comparison: output is not of type Boolean"); + + return supported; +} + bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs, const TensorInfo& output, const ConcatDescriptor& descriptor, @@ -644,29 +673,11 @@ bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0, const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const { - bool supported = true; - - std::array<DataType,4> supportedTypes = - { - DataType::Float32, - DataType::Float16, - DataType::QuantisedAsymm8, - DataType::QuantisedSymm16 - }; - - supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported, - "Reference equal: input 0 is not a supported type."); - - supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported, - "Reference equal: input 1 is not a supported type."); - - supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported, - "Reference equal: input 0 and Input 1 types are mismatched"); - - supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported, - "Reference equal: shapes are not suitable for implicit broadcast."); - - return supported; + return IsComparisonSupported(input0, + input1, + output, + ComparisonDescriptor(ComparisonOperation::Equal), + reasonIfUnsupported); } bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input, @@ -802,29 +813,11 @@ bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0, const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const { - bool supported = true; - - std::array<DataType,4> supportedTypes = - { - DataType::Float32, - DataType::Float16, - DataType::QuantisedAsymm8, - DataType::QuantisedSymm16 - }; - - supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported, - "Reference greater: input 0 is not a supported type."); - - supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported, - "Reference greater: input 1 is not a supported type."); - - supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported, - "Reference greater: input 0 and Input 1 types are mismatched"); - - supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported, - "Reference greater: shapes are not suitable for implicit broadcast."); - - return supported; + return IsComparisonSupported(input0, + input1, + output, + ComparisonDescriptor(ComparisonOperation::Greater), + reasonIfUnsupported); } bool RefLayerSupport::IsInputSupported(const TensorInfo& input, diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index 5c71e8d337..04b355ee0a 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -45,6 +45,12 @@ public: const BatchToSpaceNdDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsComparisonSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + const ComparisonDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsConcatSupported(const std::vector<const TensorInfo*> inputs, const TensorInfo& output, const ConcatDescriptor& descriptor, @@ -106,6 +112,7 @@ public: const TensorInfo& output, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead") bool IsEqualSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, @@ -131,6 +138,7 @@ public: const TensorInfo& output, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead") bool IsGreaterSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 1f6d1d7e8b..c2cb51abf3 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -131,6 +131,12 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateBatchToSpaceNd(const BatchT return std::make_unique<RefBatchToSpaceNdWorkload>(descriptor, info); } +std::unique_ptr<IWorkload> RefWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::make_unique<RefComparisonWorkload>(descriptor, info); +} + std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor, const WorkloadInfo& info) const { @@ -208,7 +214,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDivision(const DivisionQueu std::unique_ptr<IWorkload> RefWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return std::make_unique<RefEqualWorkload>(descriptor, info); + ComparisonQueueDescriptor comparisonDescriptor; + comparisonDescriptor.m_Parameters.m_Operation = ComparisonOperation::Equal; + + return CreateComparison(comparisonDescriptor, info); } std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFakeQuantization( @@ -240,7 +249,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGather(const GatherQueueDes std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return std::make_unique<RefGreaterWorkload>(descriptor, info); + ComparisonQueueDescriptor comparisonDescriptor; + comparisonDescriptor.m_Parameters.m_Operation = ComparisonOperation::Greater; + + return CreateComparison(comparisonDescriptor, info); } std::unique_ptr<IWorkload> RefWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor, diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp index 41e9b28ea2..7b73d5b21f 100644 --- a/src/backends/reference/RefWorkloadFactory.hpp +++ b/src/backends/reference/RefWorkloadFactory.hpp @@ -78,6 +78,9 @@ public: std::unique_ptr<IWorkload> CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr<IWorkload> CreateComparison(const ComparisonQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + std::unique_ptr<IWorkload> CreateConcat(const ConcatQueueDescriptor& descriptor, const WorkloadInfo& info) const override; @@ -111,6 +114,7 @@ public: std::unique_ptr<IWorkload> CreateDivision(const DivisionQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + ARMNN_DEPRECATED_MSG("Use CreateComparison instead") std::unique_ptr<IWorkload> CreateEqual(const EqualQueueDescriptor& descriptor, const WorkloadInfo& info) const override; @@ -126,6 +130,7 @@ public: std::unique_ptr<IWorkload> CreateGather(const GatherQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + ARMNN_DEPRECATED_MSG("Use CreateComparison instead") std::unique_ptr<IWorkload> CreateGreater(const GreaterQueueDescriptor& descriptor, const WorkloadInfo& info) const override; diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index 49b07a41d2..7e97acdee2 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -47,6 +47,7 @@ BACKEND_SOURCES := \ workloads/RefArgMinMaxWorkload.cpp \ workloads/RefBatchNormalizationWorkload.cpp \ workloads/RefBatchToSpaceNdWorkload.cpp \ + workloads/RefComparisonWorkload.cpp \ workloads/RefConcatWorkload.cpp \ workloads/RefConstantWorkload.cpp \ workloads/RefConvertFp16ToFp32Workload.cpp \ diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp index 370ef6599b..1968e4da7e 100644 --- a/src/backends/reference/test/RefEndToEndTests.cpp +++ b/src/backends/reference/test/RefEndToEndTests.cpp @@ -6,8 +6,8 @@ #include <backendsCommon/test/EndToEndTestImpl.hpp> #include <backendsCommon/test/AbsEndToEndTestImpl.hpp> -#include <backendsCommon/test/ArithmeticTestImpl.hpp> #include <backendsCommon/test/BatchToSpaceNdEndToEndTestImpl.hpp> +#include <backendsCommon/test/ComparisonEndToEndTestImpl.hpp> #include <backendsCommon/test/ConcatEndToEndTestImpl.hpp> #include <backendsCommon/test/DepthToSpaceEndToEndTestImpl.hpp> #include <backendsCommon/test/DequantizeEndToEndTestImpl.hpp> @@ -348,9 +348,9 @@ BOOST_AUTO_TEST_CASE(RefEqualSimpleEndToEndTest) const std::vector<uint8_t> expectedOutput({ 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1 }); - ArithmeticSimpleEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends, - LayerType::Equal, - expectedOutput); + ComparisonSimpleEndToEnd<armnn::DataType::Float32>(defaultBackends, + ComparisonOperation::Equal, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndTest) @@ -358,9 +358,9 @@ BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndTest) const std::vector<uint8_t> expectedOutput({ 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 }); - ArithmeticSimpleEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends, - LayerType::Greater, - expectedOutput); + ComparisonSimpleEndToEnd<armnn::DataType::Float32>(defaultBackends, + ComparisonOperation::Greater, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefEqualSimpleEndToEndUint8Test) @@ -368,9 +368,9 @@ BOOST_AUTO_TEST_CASE(RefEqualSimpleEndToEndUint8Test) const std::vector<uint8_t> expectedOutput({ 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1 }); - ArithmeticSimpleEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends, - LayerType::Equal, - expectedOutput); + ComparisonSimpleEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends, + ComparisonOperation::Equal, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndUint8Test) @@ -378,9 +378,9 @@ BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndUint8Test) const std::vector<uint8_t> expectedOutput({ 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 }); - ArithmeticSimpleEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends, - LayerType::Greater, - expectedOutput); + ComparisonSimpleEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends, + ComparisonOperation::Greater, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndTest) @@ -388,9 +388,9 @@ BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndTest) const std::vector<uint8_t> expectedOutput({ 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 }); - ArithmeticBroadcastEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends, - LayerType::Equal, - expectedOutput); + ComparisonBroadcastEndToEnd<armnn::DataType::Float32>(defaultBackends, + ComparisonOperation::Equal, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndTest) @@ -398,9 +398,9 @@ BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndTest) const std::vector<uint8_t> expectedOutput({ 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1 }); - ArithmeticBroadcastEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends, - LayerType::Greater, - expectedOutput); + ComparisonBroadcastEndToEnd<armnn::DataType::Float32>(defaultBackends, + ComparisonOperation::Greater, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndUint8Test) @@ -408,9 +408,9 @@ BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndUint8Test) const std::vector<uint8_t > expectedOutput({ 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 }); - ArithmeticBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends, - LayerType::Equal, - expectedOutput); + ComparisonBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends, + ComparisonOperation::Equal, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndUint8Test) @@ -418,9 +418,9 @@ BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndUint8Test) const std::vector<uint8_t> expectedOutput({ 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1 }); - ArithmeticBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends, - LayerType::Greater, - expectedOutput); + ComparisonBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends, + ComparisonOperation::Greater, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefBatchToSpaceNdEndToEndFloat32NHWCTest) diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index b8eb95c729..7844518620 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -63,6 +63,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefBatchNormalizationWorkload.hpp RefBatchToSpaceNdWorkload.cpp RefBatchToSpaceNdWorkload.hpp + RefComparisonWorkload.cpp + RefComparisonWorkload.hpp RefConcatWorkload.cpp RefConcatWorkload.hpp RefConstantWorkload.cpp diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp index 7a5c071f70..888037f9a6 100644 --- a/src/backends/reference/workloads/ElementwiseFunction.cpp +++ b/src/backends/reference/workloads/ElementwiseFunction.cpp @@ -32,6 +32,11 @@ template struct armnn::ElementwiseFunction<std::multiplies<float>>; template struct armnn::ElementwiseFunction<std::divides<float>>; template struct armnn::ElementwiseFunction<armnn::maximum<float>>; template struct armnn::ElementwiseFunction<armnn::minimum<float>>; + +// Comparison template struct armnn::ElementwiseFunction<std::equal_to<float>>; template struct armnn::ElementwiseFunction<std::greater<float>>; - +template struct armnn::ElementwiseFunction<std::greater_equal<float>>; +template struct armnn::ElementwiseFunction<std::less<float>>; +template struct armnn::ElementwiseFunction<std::less_equal<float>>; +template struct armnn::ElementwiseFunction<std::not_equal_to<float>>; diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp new file mode 100644 index 0000000000..60446226be --- /dev/null +++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp @@ -0,0 +1,102 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefComparisonWorkload.hpp" + +#include "Decoders.hpp" +#include "ElementwiseFunction.hpp" +#include "Encoders.hpp" +#include "RefWorkloadUtils.hpp" + +#include <Profiling.hpp> + +#include <armnn/TypesUtils.hpp> + +#include <functional> + +namespace armnn +{ + +RefComparisonWorkload::RefComparisonWorkload(const ComparisonQueueDescriptor& desc, + const WorkloadInfo& info) + : BaseWorkload<ComparisonQueueDescriptor>(desc, info) +{} + +void RefComparisonWorkload::PostAllocationConfigure() +{ + const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + + m_Input0 = MakeDecoder<InType>(inputInfo0); + m_Input1 = MakeDecoder<InType>(inputInfo1); + + m_Output = MakeEncoder<OutType>(outputInfo); +} + +void RefComparisonWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefComparisonWorkload_Execute"); + + const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + + const TensorShape& inShape0 = inputInfo0.GetShape(); + const TensorShape& inShape1 = inputInfo1.GetShape(); + const TensorShape& outShape = outputInfo.GetShape(); + + m_Input0->Reset(m_Data.m_Inputs[0]->Map()); + m_Input1->Reset(m_Data.m_Inputs[1]->Map()); + m_Output->Reset(m_Data.m_Outputs[0]->Map()); + + using EqualFunction = ElementwiseFunction<std::equal_to<InType>>; + using GreaterFunction = ElementwiseFunction<std::greater<InType>>; + using GreaterOrEqualFunction = ElementwiseFunction<std::greater_equal<InType>>; + using LessFunction = ElementwiseFunction<std::less<InType>>; + using LessOrEqualFunction = ElementwiseFunction<std::less_equal<InType>>; + using NotEqualFunction = ElementwiseFunction<std::not_equal_to<InType>>; + + switch (m_Data.m_Parameters.m_Operation) + { + case ComparisonOperation::Equal: + { + EqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output); + break; + } + case ComparisonOperation::Greater: + { + GreaterFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output); + break; + } + case ComparisonOperation::GreaterOrEqual: + { + GreaterOrEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output); + break; + } + case ComparisonOperation::Less: + { + LessFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output); + break; + } + case ComparisonOperation::LessOrEqual: + { + LessOrEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output); + break; + } + case ComparisonOperation::NotEqual: + { + NotEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output); + break; + } + default: + { + throw InvalidArgumentException(std::string("Unsupported comparison operation ") + + GetComparisonOperationAsCString(m_Data.m_Parameters.m_Operation), CHECK_LOCATION()); + } + } +} + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefComparisonWorkload.hpp b/src/backends/reference/workloads/RefComparisonWorkload.hpp new file mode 100644 index 0000000000..a19e4a0540 --- /dev/null +++ b/src/backends/reference/workloads/RefComparisonWorkload.hpp @@ -0,0 +1,34 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "BaseIterator.hpp" + +#include <backendsCommon/Workload.hpp> +#include <backendsCommon/WorkloadData.hpp> + +namespace armnn +{ + +class RefComparisonWorkload : public BaseWorkload<ComparisonQueueDescriptor> +{ +public: + using BaseWorkload<ComparisonQueueDescriptor>::m_Data; + + RefComparisonWorkload(const ComparisonQueueDescriptor& descriptor, const WorkloadInfo& info); + void PostAllocationConfigure() override; + void Execute() const override; + +private: + using InType = float; + using OutType = bool; + + std::unique_ptr<Decoder<InType>> m_Input0; + std::unique_ptr<Decoder<InType>> m_Input1; + std::unique_ptr<Encoder<OutType>> m_Output; +}; + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp index 6431348bc2..7e02f032ef 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp @@ -86,11 +86,3 @@ template class armnn::RefElementwiseWorkload<armnn::maximum<float>, template class armnn::RefElementwiseWorkload<armnn::minimum<float>, armnn::MinimumQueueDescriptor, armnn::StringMapping::RefMinimumWorkload_Execute>; - -template class armnn::RefElementwiseWorkload<std::equal_to<float>, - armnn::EqualQueueDescriptor, - armnn::StringMapping::RefEqualWorkload_Execute>; - -template class armnn::RefElementwiseWorkload<std::greater<float>, - armnn::GreaterQueueDescriptor, - armnn::StringMapping::RefGreaterWorkload_Execute>; diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp index 651942e9e5..ee0d80b172 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp @@ -65,13 +65,4 @@ using RefMinimumWorkload = MinimumQueueDescriptor, StringMapping::RefMinimumWorkload_Execute>; -using RefEqualWorkload = - RefElementwiseWorkload<std::equal_to<float>, - armnn::EqualQueueDescriptor, - armnn::StringMapping::RefEqualWorkload_Execute>; - -using RefGreaterWorkload = - RefElementwiseWorkload<std::greater<float>, - armnn::GreaterQueueDescriptor, - armnn::StringMapping::RefGreaterWorkload_Execute>; } // armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 79d1935823..1f9ad4a19a 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -20,6 +20,7 @@ #include "RefArgMinMaxWorkload.hpp" #include "RefBatchNormalizationWorkload.hpp" #include "RefBatchToSpaceNdWorkload.hpp" +#include "RefComparisonWorkload.hpp" #include "RefConvolution2dWorkload.hpp" #include "RefConstantWorkload.hpp" #include "RefConcatWorkload.hpp" diff --git a/src/backends/reference/workloads/StringMapping.hpp b/src/backends/reference/workloads/StringMapping.hpp index 073a5a6833..1654b78088 100644 --- a/src/backends/reference/workloads/StringMapping.hpp +++ b/src/backends/reference/workloads/StringMapping.hpp @@ -18,9 +18,7 @@ struct StringMapping public: enum Id { RefAdditionWorkload_Execute, - RefEqualWorkload_Execute, RefDivisionWorkload_Execute, - RefGreaterWorkload_Execute, RefMaximumWorkload_Execute, RefMinimumWorkload_Execute, RefMultiplicationWorkload_Execute, @@ -40,8 +38,6 @@ private: { m_Strings[RefAdditionWorkload_Execute] = "RefAdditionWorkload_Execute"; m_Strings[RefDivisionWorkload_Execute] = "RefDivisionWorkload_Execute"; - m_Strings[RefEqualWorkload_Execute] = "RefEqualWorkload_Execute"; - m_Strings[RefGreaterWorkload_Execute] = "RefGreaterWorkload_Execute"; m_Strings[RefMaximumWorkload_Execute] = "RefMaximumWorkload_Execute"; m_Strings[RefMinimumWorkload_Execute] = "RefMinimumWorkload_Execute"; m_Strings[RefMultiplicationWorkload_Execute] = "RefMultiplicationWorkload_Execute"; |