aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/workloads
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl/workloads')
-rw-r--r--src/backends/cl/workloads/CMakeLists.txt4
-rw-r--r--src/backends/cl/workloads/ClScatterNdWorkload.cpp77
-rw-r--r--src/backends/cl/workloads/ClScatterNdWorkload.hpp35
-rw-r--r--src/backends/cl/workloads/ClSplitterWorkload.cpp3
-rw-r--r--src/backends/cl/workloads/ClWorkloads.hpp3
5 files changed, 119 insertions, 3 deletions
diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt
index f38366fa57..7db602b46b 100644
--- a/src/backends/cl/workloads/CMakeLists.txt
+++ b/src/backends/cl/workloads/CMakeLists.txt
@@ -1,5 +1,5 @@
#
-# Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
+# Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
#
@@ -113,6 +113,8 @@ list(APPEND armnnClBackendWorkloads_sources
ClReverseV2Workload.hpp
ClRsqrtWorkload.cpp
ClRsqrtWorkload.hpp
+ ClScatterNdWorkload.cpp
+ ClScatterNdWorkload.hpp
ClSinWorkload.cpp
ClSinWorkload.hpp
ClSliceWorkload.cpp
diff --git a/src/backends/cl/workloads/ClScatterNdWorkload.cpp b/src/backends/cl/workloads/ClScatterNdWorkload.cpp
new file mode 100644
index 0000000000..e75edf12c9
--- /dev/null
+++ b/src/backends/cl/workloads/ClScatterNdWorkload.cpp
@@ -0,0 +1,77 @@
+//
+// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ClScatterNdWorkload.hpp"
+
+#include "ClWorkloadUtils.hpp"
+
+#include <aclCommon/ArmComputeTensorUtils.hpp>
+#include <cl/ClTensorHandle.hpp>
+
+#include <arm_compute/function_info/ScatterInfo.h>
+
+namespace armnn
+{
+
+using namespace armcomputetensorutils;
+
+arm_compute::Status ClScatterNdWorkloadValidate(const TensorInfo& inputInfo,
+ const TensorInfo& indicesInfo,
+ const TensorInfo& updatesInfo,
+ const TensorInfo& outputInfo,
+ const ScatterNdDescriptor& descriptor)
+{
+ const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(inputInfo);
+ const arm_compute::TensorInfo aclIndicesInfo = BuildArmComputeTensorInfo(indicesInfo);
+ const arm_compute::TensorInfo aclUpdatesInfo = BuildArmComputeTensorInfo(updatesInfo);
+ const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(outputInfo);
+
+ arm_compute::ScatterInfo scatterInfo = BuildArmComputeScatterInfo(descriptor);
+
+ return arm_compute::CLScatter::validate(descriptor.m_InputEnabled ? &aclInputInfo : nullptr,
+ &aclUpdatesInfo,
+ &aclIndicesInfo,
+ &aclOutputInfo,
+ scatterInfo);
+}
+
+ClScatterNdWorkload::ClScatterNdWorkload(const ScatterNdQueueDescriptor& descriptor,
+ const WorkloadInfo& info,
+ const arm_compute::CLCompileContext& clCompileContext)
+ : ClBaseWorkload<ScatterNdQueueDescriptor>(descriptor, info)
+{
+ // Report Profiling Details
+ ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClScatterNdWorkload_Construct",
+ descriptor.m_Parameters,
+ info,
+ this->GetGuid());
+
+ m_Data.ValidateInputsOutputs("ClScatterNdWorkload", 3, 1);
+
+ arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
+ arm_compute::ICLTensor& updates = static_cast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
+ arm_compute::ICLTensor& indices = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
+ arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
+
+ arm_compute::ScatterInfo scatterInfo = BuildArmComputeScatterInfo(descriptor.m_Parameters);
+
+ {
+ ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClScatterNdWorkload_configure");
+ m_ScatterNdLayer.configure(clCompileContext,
+ descriptor.m_Parameters.m_InputEnabled ? &input : nullptr,
+ &updates,
+ &indices,
+ &output,
+ scatterInfo);
+ }
+}
+
+void ClScatterNdWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClScatterNdWorkload_Execute");
+ RunClFunction(m_ScatterNdLayer, CHECK_LOCATION());
+}
+
+} //namespace armnn
diff --git a/src/backends/cl/workloads/ClScatterNdWorkload.hpp b/src/backends/cl/workloads/ClScatterNdWorkload.hpp
new file mode 100644
index 0000000000..070dac440d
--- /dev/null
+++ b/src/backends/cl/workloads/ClScatterNdWorkload.hpp
@@ -0,0 +1,35 @@
+//
+// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/Descriptors.hpp>
+
+#include <arm_compute/runtime/CL/functions/CLScatter.h>
+
+#include "ClBaseWorkload.hpp"
+
+namespace armnn
+{
+
+arm_compute::Status ClScatterNdWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& indices,
+ const TensorInfo& updates,
+ const TensorInfo& output,
+ const ScatterNdDescriptor& descriptor);
+
+class ClScatterNdWorkload : public ClBaseWorkload<ScatterNdQueueDescriptor>
+{
+public:
+ ClScatterNdWorkload(const ScatterNdQueueDescriptor& descriptor,
+ const WorkloadInfo& info,
+ const arm_compute::CLCompileContext& clCompileContext);
+ void Execute() const override;
+
+private:
+ mutable arm_compute::CLScatter m_ScatterNdLayer;
+};
+
+} //namespace armnn
diff --git a/src/backends/cl/workloads/ClSplitterWorkload.cpp b/src/backends/cl/workloads/ClSplitterWorkload.cpp
index ec904eb51b..074ce5db72 100644
--- a/src/backends/cl/workloads/ClSplitterWorkload.cpp
+++ b/src/backends/cl/workloads/ClSplitterWorkload.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2019-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2019-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -11,6 +11,7 @@
#include <aclCommon/ArmComputeUtils.hpp>
#include <armnn/utility/PolymorphicDowncast.hpp>
#include <armnn/backends/TensorHandle.hpp>
+#include <backendsCommon/WorkloadUtils.hpp>
#include <cl/ClTensorHandle.hpp>
diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp
index 40b3e99258..3178f6420d 100644
--- a/src/backends/cl/workloads/ClWorkloads.hpp
+++ b/src/backends/cl/workloads/ClWorkloads.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017,2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -57,6 +57,7 @@
#include "ClResizeWorkload.hpp"
#include "ClReverseV2Workload.hpp"
#include "ClRsqrtWorkload.hpp"
+#include "ClScatterNdWorkload.hpp"
#include "ClSinWorkload.hpp"
#include "ClSliceWorkload.hpp"
#include "ClSoftmaxWorkload.hpp"