From 2b030d9e6d24cfba615f8803047e914b56cb79b5 Mon Sep 17 00:00:00 2001 From: Teresa Charlin Date: Fri, 27 Mar 2020 16:40:56 +0000 Subject: IVGCVSW-4603 Support comparison operators in CL * Deprecate ClGreaterWorkload * Add ClComparisonWorkload to encompass all comparison operators Signed-off-by: Teresa Charlin Change-Id: Ida0ed7f59899d75b0fe7de1e7433b1ade018c6f1 --- src/backends/aclCommon/ArmComputeUtils.hpp | 14 +++++ src/backends/backendsCommon/WorkloadData.hpp | 1 + src/backends/cl/ClLayerSupport.cpp | 18 +++--- src/backends/cl/ClWorkloadFactory.cpp | 9 +-- src/backends/cl/ClWorkloadFactory.hpp | 2 + src/backends/cl/backend.mk | 2 +- src/backends/cl/test/ClLayerTests.cpp | 69 ++++++++++++++++++++++ src/backends/cl/workloads/CMakeLists.txt | 4 +- src/backends/cl/workloads/ClComparisonWorkload.cpp | 62 +++++++++++++++++++ src/backends/cl/workloads/ClComparisonWorkload.hpp | 30 ++++++++++ src/backends/cl/workloads/ClGreaterWorkload.cpp | 59 ------------------ src/backends/cl/workloads/ClGreaterWorkload.hpp | 29 --------- src/backends/cl/workloads/ClWorkloads.hpp | 2 +- src/backends/reference/test/RefLayerTests.cpp | 2 +- 14 files changed, 191 insertions(+), 112 deletions(-) create mode 100644 src/backends/cl/workloads/ClComparisonWorkload.cpp create mode 100644 src/backends/cl/workloads/ClComparisonWorkload.hpp delete mode 100644 src/backends/cl/workloads/ClGreaterWorkload.cpp delete mode 100644 src/backends/cl/workloads/ClGreaterWorkload.hpp diff --git a/src/backends/aclCommon/ArmComputeUtils.hpp b/src/backends/aclCommon/ArmComputeUtils.hpp index c3cfb5cb78..9c6f46462e 100644 --- a/src/backends/aclCommon/ArmComputeUtils.hpp +++ b/src/backends/aclCommon/ArmComputeUtils.hpp @@ -78,6 +78,20 @@ ConvertActivationDescriptorToAclActivationLayerInfo(const ActivationDescriptor& actDesc.m_A, actDesc.m_B); } +inline arm_compute::ComparisonOperation ConvertComparisonOperationToAcl(const ComparisonDescriptor& descriptor) +{ + switch (descriptor.m_Operation) + { + case ComparisonOperation::Greater: return arm_compute::ComparisonOperation::Greater; + case ComparisonOperation::GreaterOrEqual: return arm_compute::ComparisonOperation::GreaterEqual; + case ComparisonOperation::Less: return arm_compute::ComparisonOperation::Less; + case ComparisonOperation::LessOrEqual: return arm_compute::ComparisonOperation::LessEqual; + case ComparisonOperation::Equal: return arm_compute::ComparisonOperation::Equal; + case ComparisonOperation::NotEqual: return arm_compute::ComparisonOperation::NotEqual; + default: throw InvalidArgumentException("Unsupported comparison function"); + } +} + inline arm_compute::PoolingType ConvertPoolingAlgorithmToAclPoolingType(PoolingAlgorithm poolingAlgorithm) { using arm_compute::PoolingType; diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index 448de6a1ee..ad71e3c0a9 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -443,6 +443,7 @@ struct MinimumQueueDescriptor : QueueDescriptor void Validate(const WorkloadInfo& workloadInfo) const; }; +// Deprecated use ComparisonQueueDescriptor instead struct GreaterQueueDescriptor : QueueDescriptor { void Validate(const WorkloadInfo& workloadInfo) const; diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index 7f7554ab54..12c71c0f70 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -22,6 +22,7 @@ #include "workloads/ClArgMinMaxWorkload.hpp" #include "workloads/ClBatchNormalizationFloatWorkload.hpp" #include "workloads/ClBatchToSpaceNdWorkload.hpp" +#include "workloads/ClComparisonWorkload.hpp" #include "workloads/ClConvertFp16ToFp32Workload.hpp" #include "workloads/ClConvertFp32ToFp16Workload.hpp" #include "workloads/ClConvolution2dWorkload.hpp" @@ -31,7 +32,6 @@ #include "workloads/ClDivisionFloatWorkload.hpp" #include "workloads/ClFloorFloatWorkload.hpp" #include "workloads/ClFullyConnectedWorkload.hpp" -#include "workloads/ClGreaterWorkload.hpp" #include "workloads/ClInstanceNormalizationWorkload.hpp" #include "workloads/ClL2NormalizationFloatWorkload.hpp" #include "workloads/ClLstmFloatWorkload.hpp" @@ -232,16 +232,12 @@ bool ClLayerSupport::IsComparisonSupported(const TensorInfo& input0, const ComparisonDescriptor& descriptor, Optional reasonIfUnsupported) const { - if (descriptor.m_Operation == ComparisonOperation::Greater) - { - FORWARD_WORKLOAD_VALIDATE_FUNC(ClGreaterWorkloadValidate, - reasonIfUnsupported, - input0, - input1, - output); - } - - return false; + FORWARD_WORKLOAD_VALIDATE_FUNC(ClComparisonWorkloadValidate, + reasonIfUnsupported, + input0, + input1, + output, + descriptor); } bool ClLayerSupport::IsConcatSupported(const std::vector inputs, diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index ead0bc36a4..b1bd46c4d7 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -173,14 +173,7 @@ std::unique_ptr ClWorkloadFactory::CreateBatchToSpaceNd(const BatchTo std::unique_ptr ClWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& descriptor, const WorkloadInfo& info) const { - if (descriptor.m_Parameters.m_Operation == ComparisonOperation::Greater) - { - GreaterQueueDescriptor greaterQueueDescriptor; - greaterQueueDescriptor.m_Inputs = descriptor.m_Inputs; - greaterQueueDescriptor.m_Outputs = descriptor.m_Outputs; - return MakeWorkload(greaterQueueDescriptor, info); - } - return MakeWorkload(descriptor, info); + return MakeWorkload(descriptor, info); } std::unique_ptr ClWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor, diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp index a7168010f2..b1c63211d6 100644 --- a/src/backends/cl/ClWorkloadFactory.hpp +++ b/src/backends/cl/ClWorkloadFactory.hpp @@ -96,6 +96,7 @@ public: std::unique_ptr CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + ARMNN_DEPRECATED_MSG("Use CreateComparison instead") std::unique_ptr CreateEqual(const EqualQueueDescriptor& descriptor, const WorkloadInfo& info) const override; @@ -108,6 +109,7 @@ public: std::unique_ptr CreateGather(const GatherQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + ARMNN_DEPRECATED_MSG("Use CreateComparison instead") std::unique_ptr CreateGreater(const GreaterQueueDescriptor& descriptor, const WorkloadInfo& info) const override; diff --git a/src/backends/cl/backend.mk b/src/backends/cl/backend.mk index c8da9b714b..3f2e80824d 100644 --- a/src/backends/cl/backend.mk +++ b/src/backends/cl/backend.mk @@ -28,6 +28,7 @@ BACKEND_SOURCES := \ workloads/ClArgMinMaxWorkload.cpp \ workloads/ClBatchNormalizationFloatWorkload.cpp \ workloads/ClBatchToSpaceNdWorkload.cpp \ + workloads/ClComparisonWorkload.cpp \ workloads/ClConcatWorkload.cpp \ workloads/ClConstantWorkload.cpp \ workloads/ClConvertFp16ToFp32Workload.cpp \ @@ -39,7 +40,6 @@ BACKEND_SOURCES := \ workloads/ClDivisionFloatWorkload.cpp \ workloads/ClFloorFloatWorkload.cpp \ workloads/ClFullyConnectedWorkload.cpp \ - workloads/ClGreaterWorkload.cpp \ workloads/ClInstanceNormalizationWorkload.cpp \ workloads/ClL2NormalizationFloatWorkload.cpp \ workloads/ClLstmFloatWorkload.cpp \ diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp index df80da215e..509da41f81 100644 --- a/src/backends/cl/test/ClLayerTests.cpp +++ b/src/backends/cl/test/ClLayerTests.cpp @@ -532,15 +532,84 @@ ARMNN_AUTO_TEST_CASE(MinimumBroadcast1Element1, MinimumBroadcast1ElementTest1) ARMNN_AUTO_TEST_CASE(MinimumBroadcast1Element2, MinimumBroadcast1ElementTest2) ARMNN_AUTO_TEST_CASE(MinimumBroadcast1DVectorUint8, MinimumBroadcast1DVectorUint8Test) +// Equal +ARMNN_AUTO_TEST_CASE(EqualSimple, EqualSimpleTest) +ARMNN_AUTO_TEST_CASE(EqualBroadcast1Element, EqualBroadcast1ElementTest) +ARMNN_AUTO_TEST_CASE(EqualBroadcast1dVector, EqualBroadcast1dVectorTest) + +ARMNN_AUTO_TEST_CASE(EqualSimpleFloat16, EqualSimpleFloat16Test) +ARMNN_AUTO_TEST_CASE(EqualBroadcast1ElementFloat16, EqualBroadcast1ElementFloat16Test) +ARMNN_AUTO_TEST_CASE(EqualBroadcast1dVectorFloat16, EqualBroadcast1dVectorFloat16Test) + +ARMNN_AUTO_TEST_CASE(EqualSimpleUint8, EqualSimpleUint8Test) +ARMNN_AUTO_TEST_CASE(EqualBroadcast1ElementUint8, EqualBroadcast1ElementUint8Test) +ARMNN_AUTO_TEST_CASE(EqualBroadcast1dVectorUint8, EqualBroadcast1dVectorUint8Test) + // Greater ARMNN_AUTO_TEST_CASE(GreaterSimple, GreaterSimpleTest) ARMNN_AUTO_TEST_CASE(GreaterBroadcast1Element, GreaterBroadcast1ElementTest) ARMNN_AUTO_TEST_CASE(GreaterBroadcast1dVector, GreaterBroadcast1dVectorTest) +ARMNN_AUTO_TEST_CASE(GreaterSimpleFloat16, GreaterSimpleFloat16Test) +ARMNN_AUTO_TEST_CASE(GreaterBroadcast1ElementFloat16, GreaterBroadcast1ElementFloat16Test) +ARMNN_AUTO_TEST_CASE(GreaterBroadcast1dVectorFloat16, GreaterBroadcast1dVectorFloat16Test) + ARMNN_AUTO_TEST_CASE(GreaterSimpleUint8, GreaterSimpleUint8Test) ARMNN_AUTO_TEST_CASE(GreaterBroadcast1ElementUint8, GreaterBroadcast1ElementUint8Test) ARMNN_AUTO_TEST_CASE(GreaterBroadcast1dVectorUint8, GreaterBroadcast1dVectorUint8Test) +// GreaterOrEqual +ARMNN_AUTO_TEST_CASE(GreaterOrEqualSimple, GreaterOrEqualSimpleTest) +ARMNN_AUTO_TEST_CASE(GreaterOrEqualBroadcast1Element, GreaterOrEqualBroadcast1ElementTest) +ARMNN_AUTO_TEST_CASE(GreaterOrEqualBroadcast1dVector, GreaterOrEqualBroadcast1dVectorTest) + +ARMNN_AUTO_TEST_CASE(GreaterOrEqualSimpleFloat16, GreaterOrEqualSimpleFloat16Test) +ARMNN_AUTO_TEST_CASE(GreaterOrEqualBroadcast1ElementFloat16, GreaterOrEqualBroadcast1ElementFloat16Test) +ARMNN_AUTO_TEST_CASE(GreaterOrEqualBroadcast1dVectorFloat16, GreaterOrEqualBroadcast1dVectorFloat16Test) + +ARMNN_AUTO_TEST_CASE(GreaterOrEqualSimpleUint8, GreaterOrEqualSimpleUint8Test) +ARMNN_AUTO_TEST_CASE(GreaterOrEqualBroadcast1ElementUint8, GreaterOrEqualBroadcast1ElementUint8Test) +ARMNN_AUTO_TEST_CASE(GreaterOrEqualBroadcast1dVectorUint8, GreaterOrEqualBroadcast1dVectorUint8Test) + +// Less +ARMNN_AUTO_TEST_CASE(LessSimple, LessSimpleTest) +ARMNN_AUTO_TEST_CASE(LessBroadcast1Element, LessBroadcast1ElementTest) +ARMNN_AUTO_TEST_CASE(LessBroadcast1dVector, LessBroadcast1dVectorTest) + +ARMNN_AUTO_TEST_CASE(LessSimpleFloat16, LessSimpleFloat16Test) +ARMNN_AUTO_TEST_CASE(LessBroadcast1ElementFloat16, LessBroadcast1ElementFloat16Test) +ARMNN_AUTO_TEST_CASE(LessBroadcast1dVectorFloat16, LessBroadcast1dVectorFloat16Test) + +ARMNN_AUTO_TEST_CASE(LessSimpleUint8, LessSimpleUint8Test) +ARMNN_AUTO_TEST_CASE(LessBroadcast1ElementUint8, LessBroadcast1ElementUint8Test) +ARMNN_AUTO_TEST_CASE(LessBroadcast1dVectorUint8, LessBroadcast1dVectorUint8Test) + +// LessOrEqual +ARMNN_AUTO_TEST_CASE(LessOrEqualSimple, LessOrEqualSimpleTest) +ARMNN_AUTO_TEST_CASE(LessOrEqualBroadcast1Element, LessOrEqualBroadcast1ElementTest) +ARMNN_AUTO_TEST_CASE(LessOrEqualBroadcast1dVector, LessOrEqualBroadcast1dVectorTest) + +ARMNN_AUTO_TEST_CASE(LessOrEqualSimpleFloat16, LessOrEqualSimpleFloat16Test) +ARMNN_AUTO_TEST_CASE(LessOrEqualBroadcast1ElementFloat16, LessOrEqualBroadcast1ElementFloat16Test) +ARMNN_AUTO_TEST_CASE(LessOrEqualBroadcast1dVectorFloat16, LessOrEqualBroadcast1dVectorFloat16Test) + +ARMNN_AUTO_TEST_CASE(LessOrEqualSimpleUint8, LessOrEqualSimpleUint8Test) +ARMNN_AUTO_TEST_CASE(LessOrEqualBroadcast1ElementUint8, LessOrEqualBroadcast1ElementUint8Test) +ARMNN_AUTO_TEST_CASE(LessOrEqualBroadcast1dVectorUint8, LessOrEqualBroadcast1dVectorUint8Test) + +// NotEqual +ARMNN_AUTO_TEST_CASE(NotEqualSimple, NotEqualSimpleTest) +ARMNN_AUTO_TEST_CASE(NotEqualBroadcast1Element, NotEqualBroadcast1ElementTest) +ARMNN_AUTO_TEST_CASE(NotEqualBroadcast1dVector, NotEqualBroadcast1dVectorTest) + +ARMNN_AUTO_TEST_CASE(NotEqualSimpleFloat16, NotEqualSimpleFloat16Test) +ARMNN_AUTO_TEST_CASE(NotEqualBroadcast1ElementFloat16, NotEqualBroadcast1ElementFloat16Test) +ARMNN_AUTO_TEST_CASE(NotEqualBroadcast1dVectorFloat16, NotEqualBroadcast1dVectorFloat16Test) + +ARMNN_AUTO_TEST_CASE(NotEqualSimpleUint8, NotEqualSimpleUint8Test) +ARMNN_AUTO_TEST_CASE(NotEqualBroadcast1ElementUint8, NotEqualBroadcast1ElementUint8Test) +ARMNN_AUTO_TEST_CASE(NotEqualBroadcast1dVectorUint8, NotEqualBroadcast1dVectorUint8Test) + // Softmax ARMNN_AUTO_TEST_CASE(SimpleSoftmaxBeta1, SimpleSoftmaxTest, 1.0f) ARMNN_AUTO_TEST_CASE(SimpleSoftmaxBeta2, SimpleSoftmaxTest, 2.0f) diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt index 3f964eb1a6..161ad96361 100644 --- a/src/backends/cl/workloads/CMakeLists.txt +++ b/src/backends/cl/workloads/CMakeLists.txt @@ -16,6 +16,8 @@ list(APPEND armnnClBackendWorkloads_sources ClBatchNormalizationFloatWorkload.hpp ClBatchToSpaceNdWorkload.cpp ClBatchToSpaceNdWorkload.hpp + ClComparisonWorkload.cpp + ClComparisonWorkload.hpp ClConcatWorkload.cpp ClConcatWorkload.hpp ClConstantWorkload.cpp @@ -38,8 +40,6 @@ list(APPEND armnnClBackendWorkloads_sources ClFloorFloatWorkload.hpp ClFullyConnectedWorkload.cpp ClFullyConnectedWorkload.hpp - ClGreaterWorkload.cpp - ClGreaterWorkload.hpp ClInstanceNormalizationWorkload.cpp ClInstanceNormalizationWorkload.hpp ClL2NormalizationFloatWorkload.cpp diff --git a/src/backends/cl/workloads/ClComparisonWorkload.cpp b/src/backends/cl/workloads/ClComparisonWorkload.cpp new file mode 100644 index 0000000000..30b336dd94 --- /dev/null +++ b/src/backends/cl/workloads/ClComparisonWorkload.cpp @@ -0,0 +1,62 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ClComparisonWorkload.hpp" + +#include "ClWorkloadUtils.hpp" + +#include +#include + +#include + +#include +#include +#include + +namespace armnn +{ + +using namespace armcomputetensorutils; + +arm_compute::Status ClComparisonWorkloadValidate(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + const ComparisonDescriptor& descriptor) +{ + const arm_compute::TensorInfo aclInput0Info = BuildArmComputeTensorInfo(input0); + const arm_compute::TensorInfo aclInput1Info = BuildArmComputeTensorInfo(input1); + const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output); + + const arm_compute::ComparisonOperation comparisonOperation = ConvertComparisonOperationToAcl(descriptor); + + const arm_compute::Status aclStatus = arm_compute::CLComparison::validate(&aclInput0Info, + &aclInput1Info, + &aclOutputInfo, + comparisonOperation); + return aclStatus; +} + +ClComparisonWorkload::ClComparisonWorkload(const ComparisonQueueDescriptor& descriptor, const WorkloadInfo& info) + : BaseWorkload(descriptor, info) +{ + m_Data.ValidateInputsOutputs("ClComparisonWorkload", 2, 1); + + arm_compute::ICLTensor& input0 = static_cast(m_Data.m_Inputs[0])->GetTensor(); + arm_compute::ICLTensor& input1 = static_cast(m_Data.m_Inputs[1])->GetTensor(); + arm_compute::ICLTensor& output = static_cast(m_Data.m_Outputs[0])->GetTensor(); + + const arm_compute::ComparisonOperation comparisonOperation = ConvertComparisonOperationToAcl(m_Data.m_Parameters); + + m_ComparisonLayer.configure(&input0, &input1, &output, comparisonOperation); +} + +void ClComparisonWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT_CL("ClComparisonWorkload_Execute"); + RunClFunction(m_ComparisonLayer, CHECK_LOCATION()); +} + +} //namespace armnn diff --git a/src/backends/cl/workloads/ClComparisonWorkload.hpp b/src/backends/cl/workloads/ClComparisonWorkload.hpp new file mode 100644 index 0000000000..e842152fed --- /dev/null +++ b/src/backends/cl/workloads/ClComparisonWorkload.hpp @@ -0,0 +1,30 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +#include + +namespace armnn +{ + +arm_compute::Status ClComparisonWorkloadValidate(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + const ComparisonDescriptor& descriptor); + +class ClComparisonWorkload : public BaseWorkload +{ +public: + ClComparisonWorkload(const ComparisonQueueDescriptor& descriptor, const WorkloadInfo& info); + void Execute() const override; + +private: + mutable arm_compute::CLComparison m_ComparisonLayer; +}; + +} //namespace armnn diff --git a/src/backends/cl/workloads/ClGreaterWorkload.cpp b/src/backends/cl/workloads/ClGreaterWorkload.cpp deleted file mode 100644 index 2051cc3aa3..0000000000 --- a/src/backends/cl/workloads/ClGreaterWorkload.cpp +++ /dev/null @@ -1,59 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "ClGreaterWorkload.hpp" - -#include "ClWorkloadUtils.hpp" - -#include -#include - -#include - -#include -#include -#include - -namespace armnn -{ - -using namespace armcomputetensorutils; - -arm_compute::Status ClGreaterWorkloadValidate(const TensorInfo& input0, - const TensorInfo& input1, - const TensorInfo& output) -{ - const arm_compute::TensorInfo aclInput0Info = BuildArmComputeTensorInfo(input0); - const arm_compute::TensorInfo aclInput1Info = BuildArmComputeTensorInfo(input1); - const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output); - - const arm_compute::Status aclStatus = arm_compute::CLComparison::validate( - &aclInput0Info, - &aclInput1Info, - &aclOutputInfo, - arm_compute::ComparisonOperation::Greater); - - return aclStatus; -} - -ClGreaterWorkload::ClGreaterWorkload(const GreaterQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload(descriptor, info) -{ - m_Data.ValidateInputsOutputs("ClGreaterWorkload", 2, 1); - - arm_compute::ICLTensor& input0 = static_cast(m_Data.m_Inputs[0])->GetTensor(); - arm_compute::ICLTensor& input1 = static_cast(m_Data.m_Inputs[1])->GetTensor(); - arm_compute::ICLTensor& output = static_cast(m_Data.m_Outputs[0])->GetTensor(); - - m_GreaterLayer.configure(&input0, &input1, &output, arm_compute::ComparisonOperation::Greater); -} - -void ClGreaterWorkload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT_CL("ClGreaterWorkload_Execute"); - RunClFunction(m_GreaterLayer, CHECK_LOCATION()); -} - -} //namespace armnn diff --git a/src/backends/cl/workloads/ClGreaterWorkload.hpp b/src/backends/cl/workloads/ClGreaterWorkload.hpp deleted file mode 100644 index 9b2a1710bc..0000000000 --- a/src/backends/cl/workloads/ClGreaterWorkload.hpp +++ /dev/null @@ -1,29 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include - -#include - -namespace armnn -{ - -arm_compute::Status ClGreaterWorkloadValidate(const TensorInfo& input0, - const TensorInfo& input1, - const TensorInfo& output); - -class ClGreaterWorkload : public BaseWorkload -{ -public: - ClGreaterWorkload(const GreaterQueueDescriptor& descriptor, const WorkloadInfo& info); - void Execute() const override; - -private: - mutable arm_compute::CLComparison m_GreaterLayer; -}; - -} //namespace armnn diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp index c7c016379e..ffe66a0716 100644 --- a/src/backends/cl/workloads/ClWorkloads.hpp +++ b/src/backends/cl/workloads/ClWorkloads.hpp @@ -8,6 +8,7 @@ #include "ClActivationWorkload.hpp" #include "ClAdditionWorkload.hpp" #include "ClArgMinMaxWorkload.hpp" +#include "ClComparisonWorkload.hpp" #include "ClConstantWorkload.hpp" #include "ClBatchNormalizationFloatWorkload.hpp" #include "ClBatchToSpaceNdWorkload.hpp" @@ -18,7 +19,6 @@ #include "ClDivisionFloatWorkload.hpp" #include "ClFloorFloatWorkload.hpp" #include "ClFullyConnectedWorkload.hpp" -#include "ClGreaterWorkload.hpp" #include "ClInstanceNormalizationWorkload.hpp" #include "ClL2NormalizationFloatWorkload.hpp" #include "ClLstmFloatWorkload.hpp" diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 13b3ecfd4e..1c96db1c37 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -657,7 +657,7 @@ ARMNN_AUTO_TEST_CASE(LessSimpleUint8, LessSimpleUint8Test) ARMNN_AUTO_TEST_CASE(LessBroadcast1ElementUint8, LessBroadcast1ElementUint8Test) ARMNN_AUTO_TEST_CASE(LessBroadcast1dVectorUint8, LessBroadcast1dVectorUint8Test) -// GreaterOrEqual +// LessOrEqual ARMNN_AUTO_TEST_CASE(LessOrEqualSimple, LessOrEqualSimpleTest) ARMNN_AUTO_TEST_CASE(LessOrEqualBroadcast1Element, LessOrEqualBroadcast1ElementTest) ARMNN_AUTO_TEST_CASE(LessOrEqualBroadcast1dVector, LessOrEqualBroadcast1dVectorTest) -- cgit v1.2.1