diff options
author | Teresa Charlin <teresa.charlinreyes@arm.com> | 2024-03-13 16:10:32 +0000 |
---|---|---|
committer | TeresaARM <teresa.charlinreyes@arm.com> | 2024-05-08 14:22:08 +0000 |
commit | 21bda1405d2cb49fc873583b41a48836b33d285e (patch) | |
tree | 6d57debcf6be6aeb28a3e3951757c73ae5636250 /src | |
parent | 8208e2b8b1d09d0e89394ae134eb61e390dfd93c (diff) | |
download | armnn-21bda1405d2cb49fc873583b41a48836b33d285e.tar.gz |
IVGCVSW-8235 ScatterNd Operator Implementation (CL)
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: I59fe96b0a272fa6984bfc172bf3e110476f3ce7b
Diffstat (limited to 'src')
-rw-r--r-- | src/backends/aclCommon/ArmComputeTensorUtils.cpp | 26 | ||||
-rw-r--r-- | src/backends/aclCommon/ArmComputeTensorUtils.hpp | 6 | ||||
-rw-r--r-- | src/backends/cl/ClLayerSupport.cpp | 24 | ||||
-rw-r--r-- | src/backends/cl/ClLayerSupport.hpp | 9 | ||||
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.cpp | 7 | ||||
-rw-r--r-- | src/backends/cl/backend.mk | 3 | ||||
-rw-r--r-- | src/backends/cl/test/ClEndToEndTests.cpp | 22 | ||||
-rw-r--r-- | src/backends/cl/test/ClLayerTests.cpp | 89 | ||||
-rw-r--r-- | src/backends/cl/workloads/CMakeLists.txt | 4 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClScatterNdWorkload.cpp | 77 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClScatterNdWorkload.hpp | 35 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClWorkloads.hpp | 3 | ||||
-rw-r--r-- | src/backends/reference/test/RefEndToEndTests.cpp | 2 |
13 files changed, 298 insertions, 9 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp index c5b4fa157e..cfd2e0e110 100644 --- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp +++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp @@ -459,5 +459,31 @@ unsigned int ComputeDepthwiseConv2dDepthMultiplier(armnn::DataLayout layout, return depthMultiplier; } +arm_compute::ScatterInfo BuildArmComputeScatterInfo(const ScatterNdDescriptor& descriptor) +{ + arm_compute::ScatterFunction scatterFunction; + switch(descriptor.m_Function) + { + case ScatterNdFunction::Update: + scatterFunction = arm_compute::ScatterFunction::Update; + break; + case ScatterNdFunction::Add: + scatterFunction = arm_compute::ScatterFunction::Add; + break; + case ScatterNdFunction::Sub: + scatterFunction = arm_compute::ScatterFunction::Sub; + break; + case ScatterNdFunction::Max: + scatterFunction = arm_compute::ScatterFunction::Max; + break; + case ScatterNdFunction::Min: + scatterFunction = arm_compute::ScatterFunction::Min; + break; + default: throw InvalidArgumentException("Unknown ArmNN::ScatterNd Function: [" + + std::to_string(static_cast<int>(descriptor.m_Function)) + "]"); + } + + return arm_compute::ScatterInfo(scatterFunction, !descriptor.m_InputEnabled); +} } // namespace armcomputetensorutils } // namespace armnn diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.hpp b/src/backends/aclCommon/ArmComputeTensorUtils.hpp index d8a41fe41f..63c70c7092 100644 --- a/src/backends/aclCommon/ArmComputeTensorUtils.hpp +++ b/src/backends/aclCommon/ArmComputeTensorUtils.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -12,6 +12,7 @@ #include <arm_compute/core/ITensor.h> #include <arm_compute/core/TensorInfo.h> #include <arm_compute/core/Types.h> +#include <arm_compute/function_info/ScatterInfo.h> #include <Half.hpp> @@ -108,6 +109,9 @@ unsigned int ComputeDepthwiseConv2dDepthMultiplier(armnn::DataLayout layout, const arm_compute::TensorShape& weightsShape, const arm_compute::TensorShape& inputShape); +/// Utility function used to setup an arm_compute::ScatterInfo from ArmNN ScatterNd descriptor +arm_compute::ScatterInfo BuildArmComputeScatterInfo(const ScatterNdDescriptor& descriptor); + /// Utility function used to setup an arm_compute::PadStrideInfo object from an ArmNN layer descriptor. template <typename Descriptor> arm_compute::PadStrideInfo BuildArmComputePadStrideInfo(const Descriptor& descriptor) diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index 9f7d562df6..030b4c2d09 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -73,6 +73,7 @@ #include "workloads/ClResizeWorkload.hpp" #include "workloads/ClReverseV2Workload.hpp" #include "workloads/ClRsqrtWorkload.hpp" +#include "workloads/ClScatterNdWorkload.hpp" #include "workloads/ClSinWorkload.hpp" #include "workloads/ClSliceWorkload.hpp" #include "workloads/ClSoftmaxWorkload.hpp" @@ -578,6 +579,13 @@ bool ClLayerSupport::IsLayerSupported(const LayerType& type, infos[1], infos[2], reasonIfUnsupported); + case LayerType::ScatterNd: + return IsScatterNdSupported(infos[0], // input/shape + infos[1], // indices + infos[2], // updates + infos[3], // output + *(PolymorphicDowncast<const ScatterNdDescriptor*>(&descriptor)), + reasonIfUnsupported); case LayerType::Shape: return LayerSupportBase::IsShapeSupported(infos[0], infos[1], @@ -1442,6 +1450,22 @@ bool ClLayerSupport::IsReverseV2Supported(const TensorInfo& input, output); } +bool ClLayerSupport::IsScatterNdSupported(const TensorInfo& input, + const TensorInfo& indices, + const TensorInfo& updates, + const TensorInfo& output, + const ScatterNdDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported) const +{ + FORWARD_WORKLOAD_VALIDATE_FUNC(ClScatterNdWorkloadValidate, + reasonIfUnsupported, + input, + indices, + updates, + output, + descriptor); +} + bool ClLayerSupport::IsSliceSupported(const TensorInfo& input, const TensorInfo& output, const SliceDescriptor& descriptor, diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp index 907db01b89..8e9c0be7f8 100644 --- a/src/backends/cl/ClLayerSupport.hpp +++ b/src/backends/cl/ClLayerSupport.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -300,6 +300,13 @@ public: const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const; + bool IsScatterNdSupported(const TensorInfo& input, + const TensorInfo& indices, + const TensorInfo& updates, + const TensorInfo& output, + const ScatterNdDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const; + bool IsSliceSupported(const TensorInfo& input, const TensorInfo& output, const SliceDescriptor& descriptor, diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index 6fe42644c2..6a7b0e64ae 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "ClWorkloadFactory.hpp" @@ -716,6 +716,11 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateWorkload(LayerType type, auto reverseV2QueueDescriptor = PolymorphicDowncast<const ReverseV2QueueDescriptor*>(&descriptor); return MakeWorkload<ClReverseV2Workload>(*reverseV2QueueDescriptor, info, m_CLCompileContext); } + case LayerType::ScatterNd : + { + auto scatterNdQueueDescriptor = PolymorphicDowncast<const ScatterNdQueueDescriptor*>(&descriptor); + return MakeWorkload<ClScatterNdWorkload>(*scatterNdQueueDescriptor, info, m_CLCompileContext); + } case LayerType::Slice : { auto sliceQueueDescriptor = PolymorphicDowncast<const SliceQueueDescriptor*>(&descriptor); diff --git a/src/backends/cl/backend.mk b/src/backends/cl/backend.mk index 2143c30309..f233ffc5e1 100644 --- a/src/backends/cl/backend.mk +++ b/src/backends/cl/backend.mk @@ -1,5 +1,5 @@ # -# Copyright © 2017-2023 ARM Ltd and Contributors. All rights reserved. +# Copyright © 2017-2024 ARM Ltd and Contributors. All rights reserved. # SPDX-License-Identifier: MIT # @@ -81,6 +81,7 @@ BACKEND_SOURCES := \ workloads/ClResizeWorkload.cpp \ workloads/ClReverseV2Workload.cpp \ workloads/ClRsqrtWorkload.cpp \ + workloads/ClScatterNdWorkload.cpp \ workloads/ClSinWorkload.cpp \ workloads/ClSliceWorkload.cpp \ workloads/ClSoftmaxWorkload.cpp \ diff --git a/src/backends/cl/test/ClEndToEndTests.cpp b/src/backends/cl/test/ClEndToEndTests.cpp index 9e60843177..fa5d545547 100644 --- a/src/backends/cl/test/ClEndToEndTests.cpp +++ b/src/backends/cl/test/ClEndToEndTests.cpp @@ -25,6 +25,7 @@ #include <backendsCommon/test/ReshapeEndToEndTestImpl.hpp> #include <backendsCommon/test/ResizeEndToEndTestImpl.hpp> #include <backendsCommon/test/ReverseV2EndToEndTestImpl.hpp> +#include <backendsCommon/test/ScatterNdEndToEndTestImpl.hpp> #include <backendsCommon/test/SliceEndToEndTestImpl.hpp> #include <backendsCommon/test/SpaceToDepthEndToEndTestImpl.hpp> #include <backendsCommon/test/SplitterEndToEndTestImpl.hpp> @@ -322,6 +323,27 @@ TEST_CASE("DequantizeEndToEndOffsetTest") DequantizeEndToEndOffset<armnn::DataType::QAsymmU8>(clDefaultBackends); } +// ScatterNd +TEST_CASE("ClScatterNd1DInputEndToEndFloat32Test") +{ + ScatterNd1DimUpdateWithInputEndToEnd<armnn::DataType::Float32>(clDefaultBackends); +} + +TEST_CASE("ClScatterNd1DNoInputEndToEndFloat32Test") +{ + ScatterNd1DimUpdateNoInputEndToEnd<armnn::DataType::Float32>(clDefaultBackends); +} + +TEST_CASE("ClScatterNd2DInputEndToEndFloat32Test") +{ + ScatterNd2DimUpdateWithInputEndToEnd<armnn::DataType::Float32>(clDefaultBackends); +} + +TEST_CASE("ClScatterNd2DNoInputEndToEndFloat32Test") +{ + ScatterNd2DimUpdateNoInputEndToEnd<armnn::DataType::Float32>(clDefaultBackends); +} + // Slice TEST_CASE("ClSliceEndtoEndTestFloat32") { diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp index da2b967fcb..e193ca24ea 100644 --- a/src/backends/cl/test/ClLayerTests.cpp +++ b/src/backends/cl/test/ClLayerTests.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -1531,6 +1531,93 @@ ARMNN_AUTO_TEST_FIXTURE_WITH_THF(SimpleSoftmaxBeta2Uint8, ClContextControlFixtur // LogSoftmax ARMNN_AUTO_TEST_FIXTURE_WITH_THF(LogSoftmaxFloat32_1, ClContextControlFixture, LogSoftmaxTest1<DataType::Float32>) +// ScatterNd +// With Input tensor +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd1DUpdateTestWithInputFloat32, + ClContextControlFixture, + ScatterNd1DimUpdateWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DUpdateTestWithInputFloat32, + ClContextControlFixture, + ScatterNd2DimUpdateWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2Dim1Outter1InnerUpdateWithInputFloat32, + ClContextControlFixture, + ScatterNd2Dim1Outter1InnerUpdateWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3DimUpdateWithInputFloat32, + ClContextControlFixture, + ScatterNd3DimUpdateWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3Dim1Outter2InnerUpdateWithInputFloat32, + ClContextControlFixture, + ScatterNd3Dim1Outter2InnerUpdateWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3Dim2Outter1InnerUpdateWithInputFloat32, + ClContextControlFixture, + ScatterNd3Dim2Outter1InnerUpdateWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd4DimUpdateWithInputFloat32, + ClContextControlFixture, + ScatterNd4DimUpdateWithInput<DataType::Float32>) + +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DimAddWithInputFloat32, + ClContextControlFixture, + ScatterNd2DimAddWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DimSubWithInputFloat32, + ClContextControlFixture, + ScatterNd2DimSubWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DimMaxWithInputFloat32, + ClContextControlFixture, + ScatterNd2DimMaxWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DimMinWithInputFloat32, + ClContextControlFixture, + ScatterNd2DimMinWithInput<DataType::Float32>) + +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3DimUpdateWithInputFloat16, + ClContextControlFixture, + ScatterNd3DimUpdateWithInput<DataType::Float16>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3DimUpdateWithInputSigned32, + ClContextControlFixture, + ScatterNd3DimUpdateWithInput<DataType::Signed32>) + +// No input tensor, only shape provided +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd1DUpdateTestNoInputFloat32, + ClContextControlFixture, + ScatterNd1DimUpdateNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DimUpdateTestNoInputFloat32, + ClContextControlFixture, + ScatterNd2DimUpdateNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2Dim1Outter1InnerUpdateNoInputFloat32, + ClContextControlFixture, + ScatterNd2Dim1Outter1InnerUpdateNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3DimUpdateNoInputFloat32, + ClContextControlFixture, + ScatterNd3DimUpdateNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3Dim1Outter2InnerUpdateNoInputFloat32, + ClContextControlFixture, + ScatterNd3Dim1Outter2InnerUpdateNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3Dim2Outter1InnerUpdateNoInputFloat32, + ClContextControlFixture, + ScatterNd3Dim2Outter1InnerUpdateNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd4DimUpdateNoInputFloat32, + ClContextControlFixture, + ScatterNd4DimUpdateNoInput<DataType::Float32>) + +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DimAddNoInputFloat32, + ClContextControlFixture, + ScatterNd2DimAddNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DimSubNoInputFloat32, + ClContextControlFixture, + ScatterNd2DimSubNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DimMaxNoInputFloat32, + ClContextControlFixture, + ScatterNd2DimMaxNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd2DimMinNoInputFloat32, + ClContextControlFixture, + ScatterNd2DimMinNoInput<DataType::Float32>) + +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3DimUpdateNoInputFloat16, + ClContextControlFixture, + ScatterNd3DimUpdateNoInput<DataType::Float16>) +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(ScatterNd3DimUpdateNoInputSigned32, + ClContextControlFixture, + ScatterNd3DimUpdateNoInput<DataType::Signed32>) + // Space To Batch Nd ARMNN_AUTO_TEST_FIXTURE_WITH_THF(SpaceToBatchNdSimpleFloat32, ClContextControlFixture, SpaceToBatchNdSimpleFloat32Test) ARMNN_AUTO_TEST_FIXTURE_WITH_THF( diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt index f38366fa57..7db602b46b 100644 --- a/src/backends/cl/workloads/CMakeLists.txt +++ b/src/backends/cl/workloads/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved. +# Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved. # SPDX-License-Identifier: MIT # @@ -113,6 +113,8 @@ list(APPEND armnnClBackendWorkloads_sources ClReverseV2Workload.hpp ClRsqrtWorkload.cpp ClRsqrtWorkload.hpp + ClScatterNdWorkload.cpp + ClScatterNdWorkload.hpp ClSinWorkload.cpp ClSinWorkload.hpp ClSliceWorkload.cpp diff --git a/src/backends/cl/workloads/ClScatterNdWorkload.cpp b/src/backends/cl/workloads/ClScatterNdWorkload.cpp new file mode 100644 index 0000000000..e75edf12c9 --- /dev/null +++ b/src/backends/cl/workloads/ClScatterNdWorkload.cpp @@ -0,0 +1,77 @@ +// +// Copyright © 2024 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ClScatterNdWorkload.hpp" + +#include "ClWorkloadUtils.hpp" + +#include <aclCommon/ArmComputeTensorUtils.hpp> +#include <cl/ClTensorHandle.hpp> + +#include <arm_compute/function_info/ScatterInfo.h> + +namespace armnn +{ + +using namespace armcomputetensorutils; + +arm_compute::Status ClScatterNdWorkloadValidate(const TensorInfo& inputInfo, + const TensorInfo& indicesInfo, + const TensorInfo& updatesInfo, + const TensorInfo& outputInfo, + const ScatterNdDescriptor& descriptor) +{ + const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(inputInfo); + const arm_compute::TensorInfo aclIndicesInfo = BuildArmComputeTensorInfo(indicesInfo); + const arm_compute::TensorInfo aclUpdatesInfo = BuildArmComputeTensorInfo(updatesInfo); + const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(outputInfo); + + arm_compute::ScatterInfo scatterInfo = BuildArmComputeScatterInfo(descriptor); + + return arm_compute::CLScatter::validate(descriptor.m_InputEnabled ? &aclInputInfo : nullptr, + &aclUpdatesInfo, + &aclIndicesInfo, + &aclOutputInfo, + scatterInfo); +} + +ClScatterNdWorkload::ClScatterNdWorkload(const ScatterNdQueueDescriptor& descriptor, + const WorkloadInfo& info, + const arm_compute::CLCompileContext& clCompileContext) + : ClBaseWorkload<ScatterNdQueueDescriptor>(descriptor, info) +{ + // Report Profiling Details + ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClScatterNdWorkload_Construct", + descriptor.m_Parameters, + info, + this->GetGuid()); + + m_Data.ValidateInputsOutputs("ClScatterNdWorkload", 3, 1); + + arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor(); + arm_compute::ICLTensor& updates = static_cast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetTensor(); + arm_compute::ICLTensor& indices = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor(); + arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor(); + + arm_compute::ScatterInfo scatterInfo = BuildArmComputeScatterInfo(descriptor.m_Parameters); + + { + ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClScatterNdWorkload_configure"); + m_ScatterNdLayer.configure(clCompileContext, + descriptor.m_Parameters.m_InputEnabled ? &input : nullptr, + &updates, + &indices, + &output, + scatterInfo); + } +} + +void ClScatterNdWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClScatterNdWorkload_Execute"); + RunClFunction(m_ScatterNdLayer, CHECK_LOCATION()); +} + +} //namespace armnn diff --git a/src/backends/cl/workloads/ClScatterNdWorkload.hpp b/src/backends/cl/workloads/ClScatterNdWorkload.hpp new file mode 100644 index 0000000000..070dac440d --- /dev/null +++ b/src/backends/cl/workloads/ClScatterNdWorkload.hpp @@ -0,0 +1,35 @@ +// +// Copyright © 2024 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <armnn/Descriptors.hpp> + +#include <arm_compute/runtime/CL/functions/CLScatter.h> + +#include "ClBaseWorkload.hpp" + +namespace armnn +{ + +arm_compute::Status ClScatterNdWorkloadValidate(const TensorInfo& input, + const TensorInfo& indices, + const TensorInfo& updates, + const TensorInfo& output, + const ScatterNdDescriptor& descriptor); + +class ClScatterNdWorkload : public ClBaseWorkload<ScatterNdQueueDescriptor> +{ +public: + ClScatterNdWorkload(const ScatterNdQueueDescriptor& descriptor, + const WorkloadInfo& info, + const arm_compute::CLCompileContext& clCompileContext); + void Execute() const override; + +private: + mutable arm_compute::CLScatter m_ScatterNdLayer; +}; + +} //namespace armnn diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp index 40b3e99258..3178f6420d 100644 --- a/src/backends/cl/workloads/ClWorkloads.hpp +++ b/src/backends/cl/workloads/ClWorkloads.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017,2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -57,6 +57,7 @@ #include "ClResizeWorkload.hpp" #include "ClReverseV2Workload.hpp" #include "ClRsqrtWorkload.hpp" +#include "ClScatterNdWorkload.hpp" #include "ClSinWorkload.hpp" #include "ClSliceWorkload.hpp" #include "ClSoftmaxWorkload.hpp" diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp index 68b7fbff90..6f57236dd5 100644 --- a/src/backends/reference/test/RefEndToEndTests.cpp +++ b/src/backends/reference/test/RefEndToEndTests.cpp @@ -1354,7 +1354,6 @@ TEST_CASE("QuantizationEndToEndFloat16_S16Test") } // ScatterNd - TEST_CASE("RefScatterNd1DInputEndToEndFloat32Test") { ScatterNd1DimUpdateWithInputEndToEnd<armnn::DataType::Float32>(defaultBackends); @@ -1395,7 +1394,6 @@ TEST_CASE("RefScatterNd2DNoInputEndToEndInt8Test") ScatterNd2DimUpdateNoInputEndToEnd<armnn::DataType::QAsymmS8>(defaultBackends); } - // SpaceToDepth TEST_CASE("RefSpaceToDepthNhwcEndToEndTest1") { |