diff options
Diffstat (limited to 'src/backends/cl/workloads')
-rw-r--r-- | src/backends/cl/workloads/CMakeLists.txt | 4 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClScatterNdWorkload.cpp | 77 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClScatterNdWorkload.hpp | 35 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClSplitterWorkload.cpp | 3 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClWorkloads.hpp | 3 |
5 files changed, 119 insertions, 3 deletions
diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt index f38366fa57..7db602b46b 100644 --- a/src/backends/cl/workloads/CMakeLists.txt +++ b/src/backends/cl/workloads/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved. +# Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved. # SPDX-License-Identifier: MIT # @@ -113,6 +113,8 @@ list(APPEND armnnClBackendWorkloads_sources ClReverseV2Workload.hpp ClRsqrtWorkload.cpp ClRsqrtWorkload.hpp + ClScatterNdWorkload.cpp + ClScatterNdWorkload.hpp ClSinWorkload.cpp ClSinWorkload.hpp ClSliceWorkload.cpp diff --git a/src/backends/cl/workloads/ClScatterNdWorkload.cpp b/src/backends/cl/workloads/ClScatterNdWorkload.cpp new file mode 100644 index 0000000000..e75edf12c9 --- /dev/null +++ b/src/backends/cl/workloads/ClScatterNdWorkload.cpp @@ -0,0 +1,77 @@ +// +// Copyright © 2024 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ClScatterNdWorkload.hpp" + +#include "ClWorkloadUtils.hpp" + +#include <aclCommon/ArmComputeTensorUtils.hpp> +#include <cl/ClTensorHandle.hpp> + +#include <arm_compute/function_info/ScatterInfo.h> + +namespace armnn +{ + +using namespace armcomputetensorutils; + +arm_compute::Status ClScatterNdWorkloadValidate(const TensorInfo& inputInfo, + const TensorInfo& indicesInfo, + const TensorInfo& updatesInfo, + const TensorInfo& outputInfo, + const ScatterNdDescriptor& descriptor) +{ + const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(inputInfo); + const arm_compute::TensorInfo aclIndicesInfo = BuildArmComputeTensorInfo(indicesInfo); + const arm_compute::TensorInfo aclUpdatesInfo = BuildArmComputeTensorInfo(updatesInfo); + const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(outputInfo); + + arm_compute::ScatterInfo scatterInfo = BuildArmComputeScatterInfo(descriptor); + + return arm_compute::CLScatter::validate(descriptor.m_InputEnabled ? &aclInputInfo : nullptr, + &aclUpdatesInfo, + &aclIndicesInfo, + &aclOutputInfo, + scatterInfo); +} + +ClScatterNdWorkload::ClScatterNdWorkload(const ScatterNdQueueDescriptor& descriptor, + const WorkloadInfo& info, + const arm_compute::CLCompileContext& clCompileContext) + : ClBaseWorkload<ScatterNdQueueDescriptor>(descriptor, info) +{ + // Report Profiling Details + ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClScatterNdWorkload_Construct", + descriptor.m_Parameters, + info, + this->GetGuid()); + + m_Data.ValidateInputsOutputs("ClScatterNdWorkload", 3, 1); + + arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor(); + arm_compute::ICLTensor& updates = static_cast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetTensor(); + arm_compute::ICLTensor& indices = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor(); + arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor(); + + arm_compute::ScatterInfo scatterInfo = BuildArmComputeScatterInfo(descriptor.m_Parameters); + + { + ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClScatterNdWorkload_configure"); + m_ScatterNdLayer.configure(clCompileContext, + descriptor.m_Parameters.m_InputEnabled ? &input : nullptr, + &updates, + &indices, + &output, + scatterInfo); + } +} + +void ClScatterNdWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClScatterNdWorkload_Execute"); + RunClFunction(m_ScatterNdLayer, CHECK_LOCATION()); +} + +} //namespace armnn diff --git a/src/backends/cl/workloads/ClScatterNdWorkload.hpp b/src/backends/cl/workloads/ClScatterNdWorkload.hpp new file mode 100644 index 0000000000..070dac440d --- /dev/null +++ b/src/backends/cl/workloads/ClScatterNdWorkload.hpp @@ -0,0 +1,35 @@ +// +// Copyright © 2024 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <armnn/Descriptors.hpp> + +#include <arm_compute/runtime/CL/functions/CLScatter.h> + +#include "ClBaseWorkload.hpp" + +namespace armnn +{ + +arm_compute::Status ClScatterNdWorkloadValidate(const TensorInfo& input, + const TensorInfo& indices, + const TensorInfo& updates, + const TensorInfo& output, + const ScatterNdDescriptor& descriptor); + +class ClScatterNdWorkload : public ClBaseWorkload<ScatterNdQueueDescriptor> +{ +public: + ClScatterNdWorkload(const ScatterNdQueueDescriptor& descriptor, + const WorkloadInfo& info, + const arm_compute::CLCompileContext& clCompileContext); + void Execute() const override; + +private: + mutable arm_compute::CLScatter m_ScatterNdLayer; +}; + +} //namespace armnn diff --git a/src/backends/cl/workloads/ClSplitterWorkload.cpp b/src/backends/cl/workloads/ClSplitterWorkload.cpp index ec904eb51b..074ce5db72 100644 --- a/src/backends/cl/workloads/ClSplitterWorkload.cpp +++ b/src/backends/cl/workloads/ClSplitterWorkload.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2019-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2019-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -11,6 +11,7 @@ #include <aclCommon/ArmComputeUtils.hpp> #include <armnn/utility/PolymorphicDowncast.hpp> #include <armnn/backends/TensorHandle.hpp> +#include <backendsCommon/WorkloadUtils.hpp> #include <cl/ClTensorHandle.hpp> diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp index 40b3e99258..3178f6420d 100644 --- a/src/backends/cl/workloads/ClWorkloads.hpp +++ b/src/backends/cl/workloads/ClWorkloads.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017,2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -57,6 +57,7 @@ #include "ClResizeWorkload.hpp" #include "ClReverseV2Workload.hpp" #include "ClRsqrtWorkload.hpp" +#include "ClScatterNdWorkload.hpp" #include "ClSinWorkload.hpp" #include "ClSliceWorkload.hpp" #include "ClSoftmaxWorkload.hpp" |