From 2b030d9e6d24cfba615f8803047e914b56cb79b5 Mon Sep 17 00:00:00 2001 From: Teresa Charlin Date: Fri, 27 Mar 2020 16:40:56 +0000 Subject: IVGCVSW-4603 Support comparison operators in CL * Deprecate ClGreaterWorkload * Add ClComparisonWorkload to encompass all comparison operators Signed-off-by: Teresa Charlin Change-Id: Ida0ed7f59899d75b0fe7de1e7433b1ade018c6f1 --- src/backends/cl/workloads/CMakeLists.txt | 4 +- src/backends/cl/workloads/ClComparisonWorkload.cpp | 62 ++++++++++++++++++++++ src/backends/cl/workloads/ClComparisonWorkload.hpp | 30 +++++++++++ src/backends/cl/workloads/ClGreaterWorkload.cpp | 59 -------------------- src/backends/cl/workloads/ClGreaterWorkload.hpp | 29 ---------- src/backends/cl/workloads/ClWorkloads.hpp | 2 +- 6 files changed, 95 insertions(+), 91 deletions(-) create mode 100644 src/backends/cl/workloads/ClComparisonWorkload.cpp create mode 100644 src/backends/cl/workloads/ClComparisonWorkload.hpp delete mode 100644 src/backends/cl/workloads/ClGreaterWorkload.cpp delete mode 100644 src/backends/cl/workloads/ClGreaterWorkload.hpp (limited to 'src/backends/cl/workloads') diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt index 3f964eb1a6..161ad96361 100644 --- a/src/backends/cl/workloads/CMakeLists.txt +++ b/src/backends/cl/workloads/CMakeLists.txt @@ -16,6 +16,8 @@ list(APPEND armnnClBackendWorkloads_sources ClBatchNormalizationFloatWorkload.hpp ClBatchToSpaceNdWorkload.cpp ClBatchToSpaceNdWorkload.hpp + ClComparisonWorkload.cpp + ClComparisonWorkload.hpp ClConcatWorkload.cpp ClConcatWorkload.hpp ClConstantWorkload.cpp @@ -38,8 +40,6 @@ list(APPEND armnnClBackendWorkloads_sources ClFloorFloatWorkload.hpp ClFullyConnectedWorkload.cpp ClFullyConnectedWorkload.hpp - ClGreaterWorkload.cpp - ClGreaterWorkload.hpp ClInstanceNormalizationWorkload.cpp ClInstanceNormalizationWorkload.hpp ClL2NormalizationFloatWorkload.cpp diff --git a/src/backends/cl/workloads/ClComparisonWorkload.cpp b/src/backends/cl/workloads/ClComparisonWorkload.cpp new file mode 100644 index 0000000000..30b336dd94 --- /dev/null +++ b/src/backends/cl/workloads/ClComparisonWorkload.cpp @@ -0,0 +1,62 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ClComparisonWorkload.hpp" + +#include "ClWorkloadUtils.hpp" + +#include +#include + +#include + +#include +#include +#include + +namespace armnn +{ + +using namespace armcomputetensorutils; + +arm_compute::Status ClComparisonWorkloadValidate(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + const ComparisonDescriptor& descriptor) +{ + const arm_compute::TensorInfo aclInput0Info = BuildArmComputeTensorInfo(input0); + const arm_compute::TensorInfo aclInput1Info = BuildArmComputeTensorInfo(input1); + const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output); + + const arm_compute::ComparisonOperation comparisonOperation = ConvertComparisonOperationToAcl(descriptor); + + const arm_compute::Status aclStatus = arm_compute::CLComparison::validate(&aclInput0Info, + &aclInput1Info, + &aclOutputInfo, + comparisonOperation); + return aclStatus; +} + +ClComparisonWorkload::ClComparisonWorkload(const ComparisonQueueDescriptor& descriptor, const WorkloadInfo& info) + : BaseWorkload(descriptor, info) +{ + m_Data.ValidateInputsOutputs("ClComparisonWorkload", 2, 1); + + arm_compute::ICLTensor& input0 = static_cast(m_Data.m_Inputs[0])->GetTensor(); + arm_compute::ICLTensor& input1 = static_cast(m_Data.m_Inputs[1])->GetTensor(); + arm_compute::ICLTensor& output = static_cast(m_Data.m_Outputs[0])->GetTensor(); + + const arm_compute::ComparisonOperation comparisonOperation = ConvertComparisonOperationToAcl(m_Data.m_Parameters); + + m_ComparisonLayer.configure(&input0, &input1, &output, comparisonOperation); +} + +void ClComparisonWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT_CL("ClComparisonWorkload_Execute"); + RunClFunction(m_ComparisonLayer, CHECK_LOCATION()); +} + +} //namespace armnn diff --git a/src/backends/cl/workloads/ClComparisonWorkload.hpp b/src/backends/cl/workloads/ClComparisonWorkload.hpp new file mode 100644 index 0000000000..e842152fed --- /dev/null +++ b/src/backends/cl/workloads/ClComparisonWorkload.hpp @@ -0,0 +1,30 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +#include + +namespace armnn +{ + +arm_compute::Status ClComparisonWorkloadValidate(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + const ComparisonDescriptor& descriptor); + +class ClComparisonWorkload : public BaseWorkload +{ +public: + ClComparisonWorkload(const ComparisonQueueDescriptor& descriptor, const WorkloadInfo& info); + void Execute() const override; + +private: + mutable arm_compute::CLComparison m_ComparisonLayer; +}; + +} //namespace armnn diff --git a/src/backends/cl/workloads/ClGreaterWorkload.cpp b/src/backends/cl/workloads/ClGreaterWorkload.cpp deleted file mode 100644 index 2051cc3aa3..0000000000 --- a/src/backends/cl/workloads/ClGreaterWorkload.cpp +++ /dev/null @@ -1,59 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "ClGreaterWorkload.hpp" - -#include "ClWorkloadUtils.hpp" - -#include -#include - -#include - -#include -#include -#include - -namespace armnn -{ - -using namespace armcomputetensorutils; - -arm_compute::Status ClGreaterWorkloadValidate(const TensorInfo& input0, - const TensorInfo& input1, - const TensorInfo& output) -{ - const arm_compute::TensorInfo aclInput0Info = BuildArmComputeTensorInfo(input0); - const arm_compute::TensorInfo aclInput1Info = BuildArmComputeTensorInfo(input1); - const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output); - - const arm_compute::Status aclStatus = arm_compute::CLComparison::validate( - &aclInput0Info, - &aclInput1Info, - &aclOutputInfo, - arm_compute::ComparisonOperation::Greater); - - return aclStatus; -} - -ClGreaterWorkload::ClGreaterWorkload(const GreaterQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload(descriptor, info) -{ - m_Data.ValidateInputsOutputs("ClGreaterWorkload", 2, 1); - - arm_compute::ICLTensor& input0 = static_cast(m_Data.m_Inputs[0])->GetTensor(); - arm_compute::ICLTensor& input1 = static_cast(m_Data.m_Inputs[1])->GetTensor(); - arm_compute::ICLTensor& output = static_cast(m_Data.m_Outputs[0])->GetTensor(); - - m_GreaterLayer.configure(&input0, &input1, &output, arm_compute::ComparisonOperation::Greater); -} - -void ClGreaterWorkload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT_CL("ClGreaterWorkload_Execute"); - RunClFunction(m_GreaterLayer, CHECK_LOCATION()); -} - -} //namespace armnn diff --git a/src/backends/cl/workloads/ClGreaterWorkload.hpp b/src/backends/cl/workloads/ClGreaterWorkload.hpp deleted file mode 100644 index 9b2a1710bc..0000000000 --- a/src/backends/cl/workloads/ClGreaterWorkload.hpp +++ /dev/null @@ -1,29 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include - -#include - -namespace armnn -{ - -arm_compute::Status ClGreaterWorkloadValidate(const TensorInfo& input0, - const TensorInfo& input1, - const TensorInfo& output); - -class ClGreaterWorkload : public BaseWorkload -{ -public: - ClGreaterWorkload(const GreaterQueueDescriptor& descriptor, const WorkloadInfo& info); - void Execute() const override; - -private: - mutable arm_compute::CLComparison m_GreaterLayer; -}; - -} //namespace armnn diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp index c7c016379e..ffe66a0716 100644 --- a/src/backends/cl/workloads/ClWorkloads.hpp +++ b/src/backends/cl/workloads/ClWorkloads.hpp @@ -8,6 +8,7 @@ #include "ClActivationWorkload.hpp" #include "ClAdditionWorkload.hpp" #include "ClArgMinMaxWorkload.hpp" +#include "ClComparisonWorkload.hpp" #include "ClConstantWorkload.hpp" #include "ClBatchNormalizationFloatWorkload.hpp" #include "ClBatchToSpaceNdWorkload.hpp" @@ -18,7 +19,6 @@ #include "ClDivisionFloatWorkload.hpp" #include "ClFloorFloatWorkload.hpp" #include "ClFullyConnectedWorkload.hpp" -#include "ClGreaterWorkload.hpp" #include "ClInstanceNormalizationWorkload.hpp" #include "ClL2NormalizationFloatWorkload.hpp" #include "ClLstmFloatWorkload.hpp" -- cgit v1.2.1