aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2024-03-13 16:10:32 +0000
committerTeresaARM <teresa.charlinreyes@arm.com>2024-05-08 14:22:08 +0000
commit21bda1405d2cb49fc873583b41a48836b33d285e (patch)
tree6d57debcf6be6aeb28a3e3951757c73ae5636250 /src
parent8208e2b8b1d09d0e89394ae134eb61e390dfd93c (diff)
downloadarmnn-21bda1405d2cb49fc873583b41a48836b33d285e.tar.gz
IVGCVSW-8235 ScatterNd Operator Implementation (CL)
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: I59fe96b0a272fa6984bfc172bf3e110476f3ce7b
Diffstat (limited to 'src')
-rw-r--r--src/backends/aclCommon/ArmComputeTensorUtils.cpp26
-rw-r--r--src/backends/aclCommon/ArmComputeTensorUtils.hpp6
-rw-r--r--src/backends/cl/ClLayerSupport.cpp24
-rw-r--r--src/backends/cl/ClLayerSupport.hpp9
-rw-r--r--src/backends/cl/ClWorkloadFactory.cpp7
-rw-r--r--src/backends/cl/backend.mk3
-rw-r--r--src/backends/cl/test/ClEndToEndTests.cpp22
-rw-r--r--src/backends/cl/test/ClLayerTests.cpp89
-rw-r--r--src/backends/cl/workloads/CMakeLists.txt4
-rw-r--r--src/backends/cl/workloads/ClScatterNdWorkload.cpp77
-rw-r--r--src/backends/cl/workloads/ClScatterNdWorkload.hpp35
-rw-r--r--src/backends/cl/workloads/ClWorkloads.hpp3
-rw-r--r--src/backends/reference/test/RefEndToEndTests.cpp2
13 files changed, 298 insertions, 9 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
index c5b4fa157e..cfd2e0e110 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
@@ -459,5 +459,31 @@ unsigned int ComputeDepthwiseConv2dDepthMultiplier(armnn::DataLayout layout,
return depthMultiplier;
}
+arm_compute::ScatterInfo BuildArmComputeScatterInfo(const ScatterNdDescriptor& descriptor)
+{
+ arm_compute::ScatterFunction scatterFunction;
+ switch(descriptor.m_Function)
+ {
+ case ScatterNdFunction::Update:
+ scatterFunction = arm_compute::ScatterFunction::Update;
+ break;
+ case ScatterNdFunction::Add:
+ scatterFunction = arm_compute::ScatterFunction::Add;
+ break;
+ case ScatterNdFunction::Sub:
+ scatterFunction = arm_compute::ScatterFunction::Sub;
+ break;
+ case ScatterNdFunction::Max:
+ scatterFunction = arm_compute::ScatterFunction::Max;
+ break;
+ case ScatterNdFunction::Min:
+ scatterFunction = arm_compute::ScatterFunction::Min;
+ break;
+ default: throw InvalidArgumentException("Unknown ArmNN::ScatterNd Function: [" +
+ std::to_string(static_cast<int>(descriptor.m_Function)) + "]");
+ }
+
+ return arm_compute::ScatterInfo(scatterFunction, !descriptor.m_InputEnabled);
+}
} // namespace armcomputetensorutils
} // namespace armnn
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.hpp b/src/backends/aclCommon/ArmComputeTensorUtils.hpp
index d8a41fe41f..63c70c7092 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.hpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -12,6 +12,7 @@
#include <arm_compute/core/ITensor.h>
#include <arm_compute/core/TensorInfo.h>
#include <arm_compute/core/Types.h>
+#include <arm_compute/function_info/ScatterInfo.h>
#include <Half.hpp>
@@ -108,6 +109,9 @@ unsigned int ComputeDepthwiseConv2dDepthMultiplier(armnn::DataLayout layout,
const arm_compute::TensorShape& weightsShape,
const arm_compute::TensorShape& inputShape);
+/// Utility function used to setup an arm_compute::ScatterInfo from ArmNN ScatterNd descriptor
+arm_compute::ScatterInfo BuildArmComputeScatterInfo(const ScatterNdDescriptor& descriptor);
+
/// Utility function used to setup an arm_compute::PadStrideInfo object from an ArmNN layer descriptor.
template <typename Descriptor>
arm_compute::PadStrideInfo BuildArmComputePadStrideInfo(const Descriptor& descriptor)
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index 9f7d562df6..030b4c2d09 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -73,6 +73,7 @@
#include "workloads/ClResizeWorkload.hpp"
#include "workloads/ClReverseV2Workload.hpp"
#include "workloads/ClRsqrtWorkload.hpp"
+#include "workloads/ClScatterNdWorkload.hpp"
#include "workloads/ClSinWorkload.hpp"
#include "workloads/ClSliceWorkload.hpp"
#include "workloads/ClSoftmaxWorkload.hpp"
@@ -578,6 +579,13 @@ bool ClLayerSupport::IsLayerSupported(const LayerType& type,
infos[1],
infos[2],
reasonIfUnsupported);
+ case LayerType::ScatterNd:
+ return IsScatterNdSupported(infos[0], // input/shape
+ infos[1], // indices
+ infos[2], // updates
+ infos[3], // output
+ *(PolymorphicDowncast<const ScatterNdDescriptor*>(&descriptor)),
+ reasonIfUnsupported);
case LayerType::Shape:
return LayerSupportBase::IsShapeSupported(infos[0],
infos[1],
@@ -1442,6 +1450,22 @@ bool ClLayerSupport::IsReverseV2Supported(const TensorInfo& input,
output);
}
+bool ClLayerSupport::IsScatterNdSupported(const TensorInfo& input,
+ const TensorInfo& indices,
+ const TensorInfo& updates,
+ const TensorInfo& output,
+ const ScatterNdDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ FORWARD_WORKLOAD_VALIDATE_FUNC(ClScatterNdWorkloadValidate,
+ reasonIfUnsupported,
+ input,
+ indices,
+ updates,
+ output,
+ descriptor);
+}
+
bool ClLayerSupport::IsSliceSupported(const TensorInfo& input,
const TensorInfo& output,
const SliceDescriptor& descriptor,
diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp
index 907db01b89..8e9c0be7f8 100644
--- a/src/backends/cl/ClLayerSupport.hpp
+++ b/src/backends/cl/ClLayerSupport.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -300,6 +300,13 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const;
+ bool IsScatterNdSupported(const TensorInfo& input,
+ const TensorInfo& indices,
+ const TensorInfo& updates,
+ const TensorInfo& output,
+ const ScatterNdDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
+
bool IsSliceSupported(const TensorInfo& input,
const TensorInfo& output,
const SliceDescriptor& descriptor,
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp
index 6fe42644c2..6a7b0e64ae 100644
--- a/src/backends/cl/ClWorkloadFactory.cpp
+++ b/src/backends/cl/ClWorkloadFactory.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "ClWorkloadFactory.hpp"
@@ -716,6 +716,11 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateWorkload(LayerType type,
auto reverseV2QueueDescriptor = PolymorphicDowncast<const ReverseV2QueueDescriptor*>(&descriptor);
return MakeWorkload<ClReverseV2Workload>(*reverseV2QueueDescriptor, info, m_CLCompileContext);
}
+ case LayerType::ScatterNd :
+ {
+ auto scatterNdQueueDescriptor = PolymorphicDowncast<const ScatterNdQueueDescriptor*>(&descriptor);
+ return MakeWorkload<ClScatterNdWorkload>(*scatterNdQueueDescriptor, info, m_CLCompileContext);
+ }
case LayerType::Slice :
{
auto sliceQueueDescriptor = PolymorphicDowncast<const SliceQueueDescriptor*>(&descriptor);
diff --git a/src/backends/cl/backend.mk b/src/backends/cl/backend.mk
index 2143c30309..f233ffc5e1 100644
--- a/src/backends/cl/backend.mk
+++ b/src/backends/cl/backend.mk
@@ -1,5 +1,5 @@
#
-# Copyright © 2017-2023 ARM Ltd and Contributors. All rights reserved.
+# Copyright © 2017-2024 ARM Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
#
@@ -81,6 +81,7 @@ BACKEND_SOURCES := \
workloads/ClResizeWorkload.cpp \
workloads/ClReverseV2Workload.cpp \
workloads/ClRsqrtWorkload.cpp \
+ workloads/ClScatterNdWorkload.cpp \
workloads/ClSinWorkload.cpp \
workloads/ClSliceWorkload.cpp \
workloads/ClSoftmaxWorkload.cpp \
diff --git a/src/backends/cl/test/ClEndToEndTests.cpp b/src/backends/cl/test/ClEndToEndTests.cpp
index 9e60843177..fa5d545547 100644
--- a/src/backends/cl/test/ClEndToEndTests.cpp
+++ b/src/backends/cl/test/ClEndToEndTests.cpp
@@ -25,6 +25,7 @@
#include <backendsCommon/test/ReshapeEndToEndTestImpl.hpp>
#include <backendsCommon/test/ResizeEndToEndTestImpl.hpp>
#include <backendsCommon/test/ReverseV2EndToEndTestImpl.hpp>
+#include <backendsCommon/test/ScatterNdEndToEndTestImpl.hpp>
#include <backendsCommon/test/SliceEndToEndTestImpl.hpp>
#include <backendsCommon/test/SpaceToDepthEndToEndTestImpl.hpp>
#include <backendsCommon/test/SplitterEndToEndTestImpl.hpp>
@@ -322,6 +323,27 @@ TEST_CASE("DequantizeEndToEndOffsetTest")
DequantizeEndToEndOffset<armnn::DataType::QAsymmU8>(clDefaultBackends);
}
+// ScatterNd
+TEST_CASE("ClScatterNd1DInputEndToEndFloat32Test")
+{
+ ScatterNd1DimUpdateWithInputEndToEnd<armnn::DataType::Float32>(clDefaultBackends);
+}
+
+TEST_CASE("ClScatterNd1DNoInputEndToEndFloat32Test")
+{
+ ScatterNd1DimUpdateNoInputEndToEnd<armnn::DataType::Float32>(clDefaultBackends);
+}
+
+TEST_CASE("ClScatterNd2DInputEndToEndFloat32Test")
+{
+ ScatterNd2DimUpdateWithInputEndToEnd<armnn::DataType::Float32>(clDefaultBackends);
+}
+
+TEST_CASE("ClScatterNd2DNoInputEndToEndFloat32Test")
+{
+ ScatterNd2DimUpdateNoInputEndToEnd<armnn::DataType::Float32>(clDefaultBackends);
+}
+
// Slice
TEST_CASE("ClSliceEndtoEndTestFloat32")
{
diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp
index da2b967fcb..e193ca24ea 100644
--- a/src/backends/cl/test/ClLayerTests.cpp
+++ b/src/backends/cl/test/ClLayerTests.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -1531,6 +1531,93 @@ ARMNN_AUTO_TEST_FIXTURE_WITH_THF(SimpleSoftmaxBeta2Uint8, ClContextControlFixtur
// LogSoftmax
ARMNN_AUTO_TEST_FIXTURE_WITH_THF(LogSoftmaxFloat32_1, ClContextControlFixture, LogSoftmaxTest1<DataType::Float32>)
+// ScatterNd
+// With Input tensor
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd1DUpdateTestWithInputFloat32,
+ ClContextControlFixture,
+ ScatterNd1DimUpdateWithInput<DataType::Float32>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DUpdateTestWithInputFloat32,
+ ClContextControlFixture,
+ ScatterNd2DimUpdateWithInput<DataType::Float32>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2Dim1Outter1InnerUpdateWithInputFloat32,
+ ClContextControlFixture,
+ ScatterNd2Dim1Outter1InnerUpdateWithInput<DataType::Float32>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3DimUpdateWithInputFloat32,
+ ClContextControlFixture,
+ ScatterNd3DimUpdateWithInput<DataType::Float32>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3Dim1Outter2InnerUpdateWithInputFloat32,
+ ClContextControlFixture,
+ ScatterNd3Dim1Outter2InnerUpdateWithInput<DataType::Float32>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3Dim2Outter1InnerUpdateWithInputFloat32,
+ ClContextControlFixture,
+ ScatterNd3Dim2Outter1InnerUpdateWithInput<DataType::Float32>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd4DimUpdateWithInputFloat32,
+ ClContextControlFixture,
+ ScatterNd4DimUpdateWithInput<DataType::Float32>)
+
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DimAddWithInputFloat32,
+ ClContextControlFixture,
+ ScatterNd2DimAddWithInput<DataType::Float32>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DimSubWithInputFloat32,
+ ClContextControlFixture,
+ ScatterNd2DimSubWithInput<DataType::Float32>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DimMaxWithInputFloat32,
+ ClContextControlFixture,
+ ScatterNd2DimMaxWithInput<DataType::Float32>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DimMinWithInputFloat32,
+ ClContextControlFixture,
+ ScatterNd2DimMinWithInput<DataType::Float32>)
+
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3DimUpdateWithInputFloat16,
+ ClContextControlFixture,
+ ScatterNd3DimUpdateWithInput<DataType::Float16>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3DimUpdateWithInputSigned32,
+ ClContextControlFixture,
+ ScatterNd3DimUpdateWithInput<DataType::Signed32>)
+
+// No input tensor, only shape provided
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd1DUpdateTestNoInputFloat32,
+ ClContextControlFixture,
+ ScatterNd1DimUpdateNoInput<DataType::Float32>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DimUpdateTestNoInputFloat32,
+ ClContextControlFixture,
+ ScatterNd2DimUpdateNoInput<DataType::Float32>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2Dim1Outter1InnerUpdateNoInputFloat32,
+ ClContextControlFixture,
+ ScatterNd2Dim1Outter1InnerUpdateNoInput<DataType::Float32>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3DimUpdateNoInputFloat32,
+ ClContextControlFixture,
+ ScatterNd3DimUpdateNoInput<DataType::Float32>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3Dim1Outter2InnerUpdateNoInputFloat32,
+ ClContextControlFixture,
+ ScatterNd3Dim1Outter2InnerUpdateNoInput<DataType::Float32>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3Dim2Outter1InnerUpdateNoInputFloat32,
+ ClContextControlFixture,
+ ScatterNd3Dim2Outter1InnerUpdateNoInput<DataType::Float32>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd4DimUpdateNoInputFloat32,
+ ClContextControlFixture,
+ ScatterNd4DimUpdateNoInput<DataType::Float32>)
+
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DimAddNoInputFloat32,
+ ClContextControlFixture,
+ ScatterNd2DimAddNoInput<DataType::Float32>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DimSubNoInputFloat32,
+ ClContextControlFixture,
+ ScatterNd2DimSubNoInput<DataType::Float32>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DimMaxNoInputFloat32,
+ ClContextControlFixture,
+ ScatterNd2DimMaxNoInput<DataType::Float32>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DimMinNoInputFloat32,
+ ClContextControlFixture,
+ ScatterNd2DimMinNoInput<DataType::Float32>)
+
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3DimUpdateNoInputFloat16,
+ ClContextControlFixture,
+ ScatterNd3DimUpdateNoInput<DataType::Float16>)
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3DimUpdateNoInputSigned32,
+ ClContextControlFixture,
+ ScatterNd3DimUpdateNoInput<DataType::Signed32>)
+
// Space To Batch Nd
ARMNN_AUTO_TEST_FIXTURE_WITH_THF(SpaceToBatchNdSimpleFloat32, ClContextControlFixture, SpaceToBatchNdSimpleFloat32Test)
ARMNN_AUTO_TEST_FIXTURE_WITH_THF(
diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt
index f38366fa57..7db602b46b 100644
--- a/src/backends/cl/workloads/CMakeLists.txt
+++ b/src/backends/cl/workloads/CMakeLists.txt
@@ -1,5 +1,5 @@
#
-# Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
+# Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
#
@@ -113,6 +113,8 @@ list(APPEND armnnClBackendWorkloads_sources
ClReverseV2Workload.hpp
ClRsqrtWorkload.cpp
ClRsqrtWorkload.hpp
+ ClScatterNdWorkload.cpp
+ ClScatterNdWorkload.hpp
ClSinWorkload.cpp
ClSinWorkload.hpp
ClSliceWorkload.cpp
diff --git a/src/backends/cl/workloads/ClScatterNdWorkload.cpp b/src/backends/cl/workloads/ClScatterNdWorkload.cpp
new file mode 100644
index 0000000000..e75edf12c9
--- /dev/null
+++ b/src/backends/cl/workloads/ClScatterNdWorkload.cpp
@@ -0,0 +1,77 @@
+//
+// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ClScatterNdWorkload.hpp"
+
+#include "ClWorkloadUtils.hpp"
+
+#include <aclCommon/ArmComputeTensorUtils.hpp>
+#include <cl/ClTensorHandle.hpp>
+
+#include <arm_compute/function_info/ScatterInfo.h>
+
+namespace armnn
+{
+
+using namespace armcomputetensorutils;
+
+arm_compute::Status ClScatterNdWorkloadValidate(const TensorInfo& inputInfo,
+ const TensorInfo& indicesInfo,
+ const TensorInfo& updatesInfo,
+ const TensorInfo& outputInfo,
+ const ScatterNdDescriptor& descriptor)
+{
+ const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(inputInfo);
+ const arm_compute::TensorInfo aclIndicesInfo = BuildArmComputeTensorInfo(indicesInfo);
+ const arm_compute::TensorInfo aclUpdatesInfo = BuildArmComputeTensorInfo(updatesInfo);
+ const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(outputInfo);
+
+ arm_compute::ScatterInfo scatterInfo = BuildArmComputeScatterInfo(descriptor);
+
+ return arm_compute::CLScatter::validate(descriptor.m_InputEnabled ? &aclInputInfo : nullptr,
+ &aclUpdatesInfo,
+ &aclIndicesInfo,
+ &aclOutputInfo,
+ scatterInfo);
+}
+
+ClScatterNdWorkload::ClScatterNdWorkload(const ScatterNdQueueDescriptor& descriptor,
+ const WorkloadInfo& info,
+ const arm_compute::CLCompileContext& clCompileContext)
+ : ClBaseWorkload<ScatterNdQueueDescriptor>(descriptor, info)
+{
+ // Report Profiling Details
+ ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClScatterNdWorkload_Construct",
+ descriptor.m_Parameters,
+ info,
+ this->GetGuid());
+
+ m_Data.ValidateInputsOutputs("ClScatterNdWorkload", 3, 1);
+
+ arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
+ arm_compute::ICLTensor& updates = static_cast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
+ arm_compute::ICLTensor& indices = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
+ arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
+
+ arm_compute::ScatterInfo scatterInfo = BuildArmComputeScatterInfo(descriptor.m_Parameters);
+
+ {
+ ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClScatterNdWorkload_configure");
+ m_ScatterNdLayer.configure(clCompileContext,
+ descriptor.m_Parameters.m_InputEnabled ? &input : nullptr,
+ &updates,
+ &indices,
+ &output,
+ scatterInfo);
+ }
+}
+
+void ClScatterNdWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClScatterNdWorkload_Execute");
+ RunClFunction(m_ScatterNdLayer, CHECK_LOCATION());
+}
+
+} //namespace armnn
diff --git a/src/backends/cl/workloads/ClScatterNdWorkload.hpp b/src/backends/cl/workloads/ClScatterNdWorkload.hpp
new file mode 100644
index 0000000000..070dac440d
--- /dev/null
+++ b/src/backends/cl/workloads/ClScatterNdWorkload.hpp
@@ -0,0 +1,35 @@
+//
+// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/Descriptors.hpp>
+
+#include <arm_compute/runtime/CL/functions/CLScatter.h>
+
+#include "ClBaseWorkload.hpp"
+
+namespace armnn
+{
+
+arm_compute::Status ClScatterNdWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& indices,
+ const TensorInfo& updates,
+ const TensorInfo& output,
+ const ScatterNdDescriptor& descriptor);
+
+class ClScatterNdWorkload : public ClBaseWorkload<ScatterNdQueueDescriptor>
+{
+public:
+ ClScatterNdWorkload(const ScatterNdQueueDescriptor& descriptor,
+ const WorkloadInfo& info,
+ const arm_compute::CLCompileContext& clCompileContext);
+ void Execute() const override;
+
+private:
+ mutable arm_compute::CLScatter m_ScatterNdLayer;
+};
+
+} //namespace armnn
diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp
index 40b3e99258..3178f6420d 100644
--- a/src/backends/cl/workloads/ClWorkloads.hpp
+++ b/src/backends/cl/workloads/ClWorkloads.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017,2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -57,6 +57,7 @@
#include "ClResizeWorkload.hpp"
#include "ClReverseV2Workload.hpp"
#include "ClRsqrtWorkload.hpp"
+#include "ClScatterNdWorkload.hpp"
#include "ClSinWorkload.hpp"
#include "ClSliceWorkload.hpp"
#include "ClSoftmaxWorkload.hpp"
diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp
index 68b7fbff90..6f57236dd5 100644
--- a/src/backends/reference/test/RefEndToEndTests.cpp
+++ b/src/backends/reference/test/RefEndToEndTests.cpp
@@ -1354,7 +1354,6 @@ TEST_CASE("QuantizationEndToEndFloat16_S16Test")
}
// ScatterNd
-
TEST_CASE("RefScatterNd1DInputEndToEndFloat32Test")
{
ScatterNd1DimUpdateWithInputEndToEnd<armnn::DataType::Float32>(defaultBackends);
@@ -1395,7 +1394,6 @@ TEST_CASE("RefScatterNd2DNoInputEndToEndInt8Test")
ScatterNd2DimUpdateNoInputEndToEnd<armnn::DataType::QAsymmS8>(defaultBackends);
}
-
// SpaceToDepth
TEST_CASE("RefSpaceToDepthNhwcEndToEndTest1")
{