diff options
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 65 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.hpp | 9 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 7 | ||||
-rw-r--r-- | src/backends/reference/backend.mk | 4 | ||||
-rw-r--r-- | src/backends/reference/test/RefLayerTests.cpp | 54 | ||||
-rw-r--r-- | src/backends/reference/workloads/CMakeLists.txt | 6 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefScatterNdWorkload.cpp | 100 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefScatterNdWorkload.hpp | 30 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefWorkloads.hpp | 3 | ||||
-rw-r--r-- | src/backends/reference/workloads/ScatterNd.cpp | 336 | ||||
-rw-r--r-- | src/backends/reference/workloads/ScatterNd.hpp | 34 |
11 files changed, 642 insertions, 6 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index f97d03a26e..654aeb55dc 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -356,6 +356,13 @@ bool RefLayerSupport::IsLayerSupported(const LayerType& type, infos[1], *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)), reasonIfUnsupported); + case LayerType::ScatterNd: + return IsScatterNdSupported(infos[0], + infos[1], + infos[2], + infos[3], + *(PolymorphicDowncast<const ScatterNdDescriptor*>(&descriptor)), + reasonIfUnsupported); case LayerType::Slice: return IsSliceSupported(infos[0], infos[1], @@ -2442,6 +2449,64 @@ bool RefLayerSupport::IsReverseV2Supported(const TensorInfo& input0, return supported; } +bool RefLayerSupport::IsScatterNdSupported(const TensorInfo& input, + const TensorInfo& indices, + const TensorInfo& updates, + const TensorInfo& output, + const ScatterNdDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported) const +{ + IgnoreUnused(descriptor); + + bool supported = true; + + std::array<DataType, 7> supportedTypes + { + DataType::Float32, + DataType::Float16, + DataType::QAsymmS8, + DataType::QAsymmU8, + DataType::QSymmS8, + DataType::QSymmS16, + DataType::Signed32 + }; + + std::array<DataType, 1> indicesSupportedTypes = + { + DataType::Signed32 + }; + + supported &= CheckSupportRule(TypeAnyOf(indices, indicesSupportedTypes), reasonIfUnsupported, + "ScatterNd: indices type not supported."); + + supported &= CheckSupportRule(TypeAnyOf(updates, supportedTypes), reasonIfUnsupported, + "ScatterNd: updates type not supported."); + + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "ScatterNd: output type not supported"); + + supported &= CheckSupportRule(TypesAreEqual(updates, output), reasonIfUnsupported, + "ScatterNd: input and updates types are mismatched"); + + if (descriptor.m_InputEnabled) + { + // If the input slot is enabled, we have the input tensor in this slot + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "ScatterNd: input type not supported."); + + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, + "ScatterNd: input and output types are mismatched"); + } + else + { + // If the input slot is not enabled, we have the shape tensor in this slot + supported &= CheckSupportRule(TypeAnyOf(input, indicesSupportedTypes), reasonIfUnsupported, + "ScatterNd: shape type not supported."); + } + + return supported; +} + bool RefLayerSupport::IsShapeSupported(const TensorInfo& input, const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index 9e7175389a..1b0f4a2bb5 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -309,6 +309,13 @@ public: const TensorInfo& output, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const; + bool IsScatterNdSupported(const TensorInfo& input, + const TensorInfo& indices, + const TensorInfo& updates, + const TensorInfo& output, + const ScatterNdDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const; + bool IsShapeSupported(const TensorInfo& input, const TensorInfo& output, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const; diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index ad6ec9a792..df458c1a6d 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include <Layer.hpp> @@ -567,6 +567,11 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateWorkload(LayerType type, auto reverseV2QueueDescriptor = PolymorphicDowncast<const ReverseV2QueueDescriptor*>(&descriptor); return std::make_unique<RefReverseV2Workload>(*reverseV2QueueDescriptor, info); } + case LayerType::ScatterNd: + { + auto scatterQueueDescriptor = PolymorphicDowncast<const ScatterNdQueueDescriptor*>(&descriptor); + return std::make_unique<RefScatterNdWorkload>(*scatterQueueDescriptor, info); + } case LayerType::Shape: { auto shapeQueueDescriptor = PolymorphicDowncast<const ShapeQueueDescriptor*>(&descriptor); diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index 27ca8f607a..752255607a 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -1,5 +1,5 @@ # -# Copyright © 2017-2023 ARM Ltd and Contributors. All rights reserved. +# Copyright © 2017-2024 ARM Ltd and Contributors. All rights reserved. # SPDX-License-Identifier: MIT # @@ -96,6 +96,7 @@ BACKEND_SOURCES := \ workloads/RefReshapeWorkload.cpp \ workloads/RefResizeWorkload.cpp \ workloads/RefReverseV2Workload.cpp \ + workloads/RefScatterNdWorkload.cpp \ workloads/RefSliceWorkload.cpp \ workloads/RefSoftmaxWorkload.cpp \ workloads/RefSpaceToBatchNdWorkload.cpp \ @@ -109,6 +110,7 @@ BACKEND_SOURCES := \ workloads/RefUnidirectionalSequenceLstmWorkload.cpp \ workloads/Resize.cpp \ workloads/ReverseV2Impl.cpp \ + workloads/ScatterNd.cpp \ workloads/Slice.cpp \ workloads/SpaceToBatchNd.cpp \ workloads/SpaceToDepth.cpp \ diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index cfe85594b3..078338163f 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017,2022-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017,2022-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -2885,4 +2885,56 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis0TestSigned32, BroadcastTo3dAxis0 ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis1TestSigned32, BroadcastTo3dAxis1Test<DataType::Signed32>) ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo3dAxis2TestSigned32, BroadcastTo3dAxis2Test<DataType::Signed32>) ARMNN_AUTO_TEST_CASE_WITH_THF(BroadcastTo4dTestSigned32, BroadcastTo4dTest<DataType::Signed32>) + +// ScatterNd +// With Input tensor +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd1DUpdateTestWithInputFloat32, ScatterNd1DimUpdateWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd2DUpdateTestWithInputFloat32, ScatterNd2DimUpdateWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd2Dim1Outter1InnerUpdateWithInputFloat32, + ScatterNd2Dim1Outter1InnerUpdateWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd3DimUpdateWithInputFloat32, ScatterNd3DimUpdateWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd3Dim1Outter2InnerUpdateWithInputFloat32, + ScatterNd3Dim1Outter2InnerUpdateWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd3Dim2Outter1InnerUpdateWithInputFloat32, + ScatterNd3Dim2Outter1InnerUpdateWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd4DimUpdateWithInputFloat32, ScatterNd4DimUpdateWithInput<DataType::Float32>) + +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd2DimAddWithInputFloat32, ScatterNd2DimAddWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd2DimSubWithInputFloat32, ScatterNd2DimSubWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd2DimMaxWithInputFloat32, ScatterNd2DimMaxWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd2DimMinWithInputFloat32, ScatterNd2DimMinWithInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd2DimMulWithInputFloat32, ScatterNd2DimMulWithInput<DataType::Float32>) + +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd3DimUpdateWithInputFloat16, ScatterNd3DimUpdateWithInput<DataType::Float16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd3DimUpdateWithInputQAsymmS8, ScatterNd3DimUpdateWithInput<DataType::QAsymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd3DimUpdateWithInputQAsymmU8, ScatterNd3DimUpdateWithInput<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd3DimUpdateWithInputQSymmS8, ScatterNd3DimUpdateWithInput<DataType::QSymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd3DimUpdateWithInputQSymmS16, ScatterNd3DimUpdateWithInput<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd3DimUpdateWithInputSigned32, ScatterNd3DimUpdateWithInput<DataType::Signed32>) + +// No input tensor, only shape provided +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd1DUpdateTestNoInputFloat32, ScatterNd1DimUpdateNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd2DimUpdateTestNoInputFloat32, ScatterNd2DimUpdateNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd2Dim1Outter1InnerUpdateNoInputFloat32, + ScatterNd2Dim1Outter1InnerUpdateNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd3DimUpdateNoInputFloat32, ScatterNd3DimUpdateNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd3Dim1Outter2InnerUpdateNoInputFloat32, + ScatterNd3Dim1Outter2InnerUpdateNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd3Dim2Outter1InnerUpdateNoInputFloat32, + ScatterNd3Dim2Outter1InnerUpdateNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd4DimUpdateNoInputFloat32, ScatterNd4DimUpdateNoInput<DataType::Float32>) + +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd2DimAddNoInputFloat32, ScatterNd2DimAddNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd2DimSubNoInputFloat32, ScatterNd2DimSubNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd2DimMaxNoInputFloat32, ScatterNd2DimMaxNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd2DimMinNoInputFloat32, ScatterNd2DimMinNoInput<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd2DimMulNoInputFloat32, ScatterNd2DimMulNoInput<DataType::Float32>) + +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd3DimUpdateNoInputFloat16, ScatterNd3DimUpdateNoInput<DataType::Float16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd3DimUpdateNoInputQAsymmS8, ScatterNd3DimUpdateNoInput<DataType::QAsymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd3DimUpdateNoInputQAsymmU8, ScatterNd3DimUpdateNoInput<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd3DimUpdateNoInputQSymmS8, ScatterNd3DimUpdateNoInput<DataType::QSymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd3DimUpdateNoInputQSymmS16, ScatterNd3DimUpdateNoInput<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(ScatterNd3DimUpdateNoInputSigned32, ScatterNd3DimUpdateNoInput<DataType::Signed32>) + }
\ No newline at end of file diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 42f92aec1d..0f70cb0022 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. +# Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved. # SPDX-License-Identifier: MIT # @@ -85,6 +85,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefCastWorkload.hpp RefChannelShuffleWorkload.cpp RefChannelShuffleWorkload.hpp + RefScatterNdWorkload.cpp + RefScatterNdWorkload.hpp RefShapeWorkload.hpp RefComparisonWorkload.cpp RefComparisonWorkload.hpp @@ -195,6 +197,8 @@ list(APPEND armnnRefBackendWorkloads_sources Resize.cpp Resize.hpp Rsqrt.hpp + ScatterNd.cpp + ScatterNd.hpp Sin.hpp Slice.cpp Slice.hpp diff --git a/src/backends/reference/workloads/RefScatterNdWorkload.cpp b/src/backends/reference/workloads/RefScatterNdWorkload.cpp new file mode 100644 index 0000000000..4713add0e9 --- /dev/null +++ b/src/backends/reference/workloads/RefScatterNdWorkload.cpp @@ -0,0 +1,100 @@ +// +// Copyright © 2024 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include <fmt/format.h> +#include "RefScatterNdWorkload.hpp" +#include "RefWorkloadUtils.hpp" +#include "ScatterNd.hpp" +#include "Profiling.hpp" + +namespace armnn +{ + + RefScatterNdWorkload::RefScatterNdWorkload(const ScatterNdQueueDescriptor& descriptor, const WorkloadInfo& info) + : RefBaseWorkload(descriptor, info) + {} + + void RefScatterNdWorkload::Execute() const + { + Execute(m_Data.m_Inputs, m_Data.m_Outputs); + } + + void RefScatterNdWorkload::ExecuteAsync(ExecutionData& executionData) + { + WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data); + Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs); + } + + void RefScatterNdWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const + { + ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefScatterNdWorkload_Execute"); + + if (m_Data.m_Parameters.m_InputEnabled) + { + // Getting TensorInfos for three inputs slots + const TensorInfo& inputInfo = GetTensorInfo(inputs[0]); + const TensorInfo& indicesInfo = GetTensorInfo(inputs[1]); + const TensorInfo& updatesInfo = GetTensorInfo(inputs[2]); + + // Getting Decoder for input + std::unique_ptr<Decoder<float>> inputDecoder = MakeDecoder<float>(GetTensorInfo(inputs[0]), + inputs[0]->Map()); + + // Getting Decoder for indices + std::unique_ptr<Decoder<int>> indicesDecoder = MakeDecoder<int>(GetTensorInfo(inputs[1]), + inputs[1]->Map()); + + // Getting Decoder for updates + std::unique_ptr<Decoder<float>> updatesDecoder = MakeDecoder<float>(GetTensorInfo(inputs[2]), + inputs[2]->Map()); + + // Getting Encoder for output + std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]), + outputs[0]->Map()); + + ScatterNd(inputInfo, + indicesInfo, + updatesInfo, + *inputDecoder, + *indicesDecoder, + *updatesDecoder, + *outputEncoder, + m_Data.m_Parameters); + } + else + { + // Getting TensorInfos for three inputs slots + const TensorInfo& shapeInfo = GetTensorInfo(inputs[0]); + const TensorInfo& indicesInfo = GetTensorInfo(inputs[1]); + const TensorInfo& updatesInfo = GetTensorInfo(inputs[2]); + + // Getting Decoder for shape + std::unique_ptr<Decoder<int>> shapeDecoder = MakeDecoder<int>(GetTensorInfo(inputs[0]), + inputs[0]->Map()); + + // Getting Decoder for indices + std::unique_ptr<Decoder<int>> indicesDecoder = MakeDecoder<int>(GetTensorInfo(inputs[1]), + inputs[1]->Map()); + + // Getting Decoder for updates + std::unique_ptr<Decoder<float>> updatesDecoder = MakeDecoder<float>(GetTensorInfo(inputs[2]), + inputs[2]->Map()); + + // Getting Encoder for output + std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]), + outputs[0]->Map()); + + ScatterNd(indicesInfo, + updatesInfo, + shapeInfo, + *indicesDecoder, + *updatesDecoder, + *shapeDecoder, + *outputEncoder, + m_Data.m_Parameters); + } + } + +} // namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/workloads/RefScatterNdWorkload.hpp b/src/backends/reference/workloads/RefScatterNdWorkload.hpp new file mode 100644 index 0000000000..c9cf5a3af3 --- /dev/null +++ b/src/backends/reference/workloads/RefScatterNdWorkload.hpp @@ -0,0 +1,30 @@ +// +// Copyright © 2024 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "RefBaseWorkload.hpp" +#include <armnn/backends/WorkloadData.hpp> + +#include "ScatterNd.hpp" + +namespace armnn +{ + + class RefScatterNdWorkload : public RefBaseWorkload<ScatterNdQueueDescriptor> + { + public: + explicit RefScatterNdWorkload(const ScatterNdQueueDescriptor& descriptor, + const WorkloadInfo& info); + + void Execute() const override; + void ExecuteAsync(ExecutionData& executionData) override; + + private: + void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const; + + }; + +} // namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 98aa27b8a9..92b178c3d5 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -55,6 +55,7 @@ #include "RefReshapeWorkload.hpp" #include "RefResizeWorkload.hpp" #include "RefReverseV2Workload.hpp" +#include "RefScatterNdWorkload.hpp" #include "RefShapeWorkload.hpp" #include "RefSliceWorkload.hpp" #include "RefSplitterWorkload.hpp" diff --git a/src/backends/reference/workloads/ScatterNd.cpp b/src/backends/reference/workloads/ScatterNd.cpp new file mode 100644 index 0000000000..8eb53b00a8 --- /dev/null +++ b/src/backends/reference/workloads/ScatterNd.cpp @@ -0,0 +1,336 @@ +// +// Copyright © 2024 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ScatterNd.hpp" +#include "Encoders.hpp" +#include <armnn/backends/WorkloadData.hpp> +#include <armnn/Logging.hpp> + +#include <fmt/format.h> + +#include <numeric> + +namespace armnn +{ + +float ScatterOperation(ScatterNdFunction operation, + float input, + float update) +{ + switch (operation) + { + case ScatterNdFunction::Update: + return update; + case ScatterNdFunction::Add: + return input + update; + case ScatterNdFunction::Sub: + return input - update; + case ScatterNdFunction::Max: + return std::max(input, update); + case ScatterNdFunction::Min: + return std::min(input, update); + case ScatterNdFunction::Mul: + return input * update; + default: + throw InvalidArgumentException("ScatterNd: cannot execute this operation."); + } +} + +void ScatterNd(const TensorInfo& inputInfo, + const TensorInfo& indicesInfo, + const TensorInfo& updatesInfo, + Decoder<float>& input, + Decoder<int>& indices, + Decoder<float>& updates, + Encoder<float>& output, + ScatterNdDescriptor descriptor) +{ + // Axis Unsupported + if (descriptor.m_AxisEnabled) + { + throw InvalidArgumentException("ScatterNd: axis param not supported."); + } + + // Get the shape for indices, updates, and input + TensorShape indicesShape = indicesInfo.GetShape(); + TensorShape updatesShape = updatesInfo.GetShape(); + TensorShape inputShape = inputInfo.GetShape(); + + // Get the dimensions for indices and updates + unsigned int dimension = inputInfo.GetNumDimensions(); + unsigned int indicesDim = indicesInfo.GetNumDimensions(); + unsigned int updatesDim = updatesInfo.GetNumDimensions(); + + // Calculate the outter and inner dimensions + unsigned int outterDim = indicesShape[indicesDim - 1]; + unsigned int innerDim = dimension - outterDim; + + // Calculate the number of elements in each dimension + unsigned int numElementsCount = 1; + std::vector<unsigned int> elementInDim(dimension); + for (unsigned int dimIndex = dimension; dimIndex > 0; --dimIndex) + { + elementInDim[dimIndex - 1] = numElementsCount; + numElementsCount *= inputShape[dimIndex - 1]; + } + + // Number of updates per index + unsigned int numUpdatesPerIndex = elementInDim[dimension - innerDim - 1]; + + // Number of indices to update + unsigned int numIndices = indicesShape[0]; + + // Check Input Requirements + // Requirement 1: Indices and Updates must have rank at least 1 + if (indicesDim < 1 || updatesDim < 1) + { + throw InvalidArgumentException("ScatterNd: indices and updates must have rank >= 1."); + } + + // Requirement 2: Input, Indices and Updates must have values + if (inputInfo.GetNumElements() == 0 || + indicesInfo.GetNumElements() == 0 || + updatesInfo.GetNumElements() == 0) + { + throw InvalidArgumentException("ScatterNd: input, indices and updates tensor must have values."); + } + + // Requirement 3: Indices and Updates must match in shape + // The updates dimension should equals to 1 + inner dimension + if (updatesDim != 1 + innerDim) + { + throw InvalidArgumentException("ScatterNd: updates dimension should equal to 1 + inner dimension."); + } + // The inner dimension of updates has to match with shape of input + for (unsigned int dimBackIndex = 0; dimBackIndex < innerDim; ++dimBackIndex) + { + if (updatesShape[updatesDim - dimBackIndex - 1] != inputShape[dimension - dimBackIndex - 1]) + { + throw InvalidArgumentException( + fmt::format("ScatterNd: input and updates shape not match on dimension {}", + dimension - dimBackIndex)); + } + } + + // Requirement 4: Check duplicate indices and out of bound indices + std::set<int> indicesSet; + std::vector<int> flattenIndices(numIndices); + for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx) + { + // Get the index + int flattenIndex = 0; + + for (unsigned int outterIdx = 0; outterIdx < outterDim; ++outterIdx) { + + int outterIndexValue = indices.Get(); + + // Check bounds + if (outterIndexValue < 0 || outterIndexValue >= int(inputShape[outterIdx])) + { + throw InvalidArgumentException( + fmt::format("ScatterNd: indices {} out of bound [0, {})", + outterIndexValue, inputShape[outterIdx])); + } + + flattenIndex += int(elementInDim[outterIdx]) * outterIndexValue; + ++indices; + } + + // Check duplicates when executing ScatterNd::Update + if (descriptor.m_Function == ScatterNdFunction::Update && + indicesSet.find(flattenIndex) != indicesSet.end()) + { + throw InvalidArgumentException( + fmt::format("ScatterNd: duplicate indices occurs {}", flattenIndex)); + } + + flattenIndices[indicesIdx] = flattenIndex; + indicesSet.insert(flattenIndex); + } + + // Set the input data to output + for (unsigned int idx = 0; idx < inputInfo.GetNumElements(); ++idx) + { + float inputValue = input.Get(); + ++input; + output.Set(inputValue); + ++output; + } + + // Iterate through all indices to scatter updates + for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx) + { + // Get the index and calculate the flatten index + int flattenIndex = flattenIndices[indicesIdx]; + + // FlattenIndex is the place that we are going to update the elements + unsigned int updatesStartIdx = indicesIdx * numUpdatesPerIndex; + for (unsigned int updatesIdx = 0; updatesIdx < numUpdatesPerIndex; ++updatesIdx) + { + updates[updatesStartIdx + updatesIdx]; + input[static_cast<unsigned int>(flattenIndex) + updatesIdx]; + float updateValue = ScatterOperation(descriptor.m_Function, input.Get(), updates.Get()); + output[static_cast<unsigned int>(flattenIndex) + updatesIdx]; + output.Set(updateValue); + } + } +} + +void ScatterNd(const TensorInfo& indicesInfo, + const TensorInfo& updatesInfo, + const TensorInfo& shapeInfo, + Decoder<int>& indices, + Decoder<float>& updates, + Decoder<int>& shape, + Encoder<float>& output, + ScatterNdDescriptor descriptor) +{ + // Axis Unsupported + if (descriptor.m_AxisEnabled) + { + throw InvalidArgumentException("ScatterNd: axis param not supported."); + } + + // Get the shape for indices, updates, and input + TensorShape indicesShape = indicesInfo.GetShape(); + TensorShape updatesShape = updatesInfo.GetShape(); + + // Get the shape values + std::vector<float> shapeValues = shape.DecodeTensor(shapeInfo.GetShape()); + // Check the shape + if (shapeInfo.GetNumElements() == 0) + { + throw InvalidArgumentException("ScatterNd: shape must have values."); + } + for (auto shapeValue : shapeValues) + { + if (shapeValue <= 0) + { + throw InvalidArgumentException("ScatterNd: shape values must >= 0."); + } + } + // Get the input shape + std::vector<unsigned int> inputShape (shapeValues.begin(), shapeValues.end()); + unsigned int inputElementsNum = static_cast<unsigned int>( + std::accumulate(inputShape.begin(), inputShape.end(), 1, std::multiplies<unsigned int>())); + + // Get the dimensions for indices and updates + unsigned int dimension = shapeInfo.GetNumElements(); + unsigned int indicesDim = indicesInfo.GetNumDimensions(); + unsigned int updatesDim = updatesInfo.GetNumDimensions(); + + // Calculate the outter and inner dimensions + unsigned int outterDim = indicesShape[indicesDim - 1]; + unsigned int innerDim = dimension - outterDim; + + // Calculate the number of elements in each dimension + unsigned int numElementsCount = 1; + std::vector<unsigned int> elementInDim(dimension); + for (unsigned int dimIndex = dimension; dimIndex > 0; --dimIndex) + { + elementInDim[dimIndex - 1] = numElementsCount; + numElementsCount *= inputShape[dimIndex - 1]; + } + + // Number of updates per index + unsigned int numUpdatesPerIndex = elementInDim[dimension - innerDim - 1]; + + // Number of indices to update + unsigned int numIndices = indicesShape[0]; + + // Check Input Requirements + // Requirement 1: Indices and Updates must have rank at least 1 + if (indicesDim < 1 || updatesDim < 1) + { + throw InvalidArgumentException("ScatterNd: indices and updates must have rank >= 1."); + } + + // Requirement 2: shape, Indices and Updates must have values + if (indicesInfo.GetNumElements() == 0 || + updatesInfo.GetNumElements() == 0) + { + throw InvalidArgumentException("ScatterNd: indices and updates tensor must have values."); + } + + // Requirement 3: Indices and Updates must match in shape + // The updates dimension should equals to 1 + inner dimension + if (updatesDim != 1 + innerDim) + { + throw InvalidArgumentException("ScatterNd: updates dimension should equal to 1 + inner dimension."); + } + // The inner dimension of updates has to match with shape of input + for (unsigned int dimBackIndex = 0; dimBackIndex < innerDim; ++dimBackIndex) + { + if (updatesShape[updatesDim - dimBackIndex - 1] != inputShape[dimension - dimBackIndex - 1]) + { + throw InvalidArgumentException( + fmt::format("ScatterNd: input and updates shape not match on dimension {}", + dimension - dimBackIndex)); + } + } + + // Requirement 4: Check duplicate indices and out of bound indices + std::set<int> indicesSet; + std::vector<int> flattenIndices(numIndices); + for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx) + { + // Get the index + int flattenIndex = 0; + + for (unsigned int outterIdx = 0; outterIdx < outterDim; ++outterIdx) { + + int outterIndexValue = indices.Get(); + + // Check bounds + if (outterIndexValue < 0 || outterIndexValue >= int(inputShape[outterIdx])) + { + throw InvalidArgumentException( + fmt::format("ScatterNd: indices {} out of bound [0, {})", + outterIndexValue, inputShape[outterIdx])); + } + + flattenIndex += int(elementInDim[outterIdx]) * outterIndexValue; + ++indices; + } + + // Check duplicates when executing ScatterNd::Update + if (descriptor.m_Function == ScatterNdFunction::Update && + indicesSet.find(flattenIndex) != indicesSet.end()) + { + throw InvalidArgumentException( + fmt::format("ScatterNd: duplicate indices {} occurs when executing ScatterNd::Update.", + flattenIndex)); + } + + flattenIndices[indicesIdx] = flattenIndex; + indicesSet.insert(flattenIndex); + } + + // Set zeros to output + for (unsigned int idx = 0; idx < inputElementsNum; ++idx) + { + output.Set(0.0f); + ++output; + } + + // Iterate through all indices to scatter updates + for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx) + { + // Get the index and calculate the flatten index + int flattenIndex = flattenIndices[indicesIdx]; + + // FlattenIndex is the place that we are going to update the elements + unsigned int updatesStartIdx = indicesIdx * numUpdatesPerIndex; + for (unsigned int updatesIdx = 0; updatesIdx < numUpdatesPerIndex; ++updatesIdx) + { + updates[updatesStartIdx + updatesIdx]; + float updateValue = ScatterOperation(descriptor.m_Function, 0.0f, updates.Get()); + output[static_cast<unsigned int>(flattenIndex) + updatesIdx]; + output.Set(updateValue); + } + } +} + +} // namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/workloads/ScatterNd.hpp b/src/backends/reference/workloads/ScatterNd.hpp new file mode 100644 index 0000000000..e40d3640a7 --- /dev/null +++ b/src/backends/reference/workloads/ScatterNd.hpp @@ -0,0 +1,34 @@ +// +// Copyright © 2024 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <armnn/Tensor.hpp> +#include "Encoders.hpp" +#include "Decoders.hpp" +#include "armnn/Descriptors.hpp" + +namespace armnn +{ +// ScatterNd with input tensor +void ScatterNd(const TensorInfo& inputInfo, + const TensorInfo& indicesInfo, + const TensorInfo& updatesInfo, + Decoder<float>& input, + Decoder<int>& indices, + Decoder<float>& updates, + Encoder<float>& output, + ScatterNdDescriptor descriptor); + +// ScatterNd without input tensor, only shape provided +void ScatterNd(const TensorInfo& indicesInfo, + const TensorInfo& updatesInfo, + const TensorInfo& shapeInfo, + Decoder<int>& indices, + Decoder<float>& updates, + Decoder<int>& shape, + Encoder<float>& output, + ScatterNdDescriptor descriptor); +} // namespace armnn
\ No newline at end of file |