diff options
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r-- | src/backends/reference/workloads/CMakeLists.txt | 6 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefScatterNdWorkload.cpp | 100 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefScatterNdWorkload.hpp | 30 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefWorkloads.hpp | 3 | ||||
-rw-r--r-- | src/backends/reference/workloads/ScatterNd.cpp | 336 | ||||
-rw-r--r-- | src/backends/reference/workloads/ScatterNd.hpp | 34 |
6 files changed, 507 insertions, 2 deletions
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 42f92aec1d..0f70cb0022 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. +# Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved. # SPDX-License-Identifier: MIT # @@ -85,6 +85,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefCastWorkload.hpp RefChannelShuffleWorkload.cpp RefChannelShuffleWorkload.hpp + RefScatterNdWorkload.cpp + RefScatterNdWorkload.hpp RefShapeWorkload.hpp RefComparisonWorkload.cpp RefComparisonWorkload.hpp @@ -195,6 +197,8 @@ list(APPEND armnnRefBackendWorkloads_sources Resize.cpp Resize.hpp Rsqrt.hpp + ScatterNd.cpp + ScatterNd.hpp Sin.hpp Slice.cpp Slice.hpp diff --git a/src/backends/reference/workloads/RefScatterNdWorkload.cpp b/src/backends/reference/workloads/RefScatterNdWorkload.cpp new file mode 100644 index 0000000000..4713add0e9 --- /dev/null +++ b/src/backends/reference/workloads/RefScatterNdWorkload.cpp @@ -0,0 +1,100 @@ +// +// Copyright © 2024 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include <fmt/format.h> +#include "RefScatterNdWorkload.hpp" +#include "RefWorkloadUtils.hpp" +#include "ScatterNd.hpp" +#include "Profiling.hpp" + +namespace armnn +{ + + RefScatterNdWorkload::RefScatterNdWorkload(const ScatterNdQueueDescriptor& descriptor, const WorkloadInfo& info) + : RefBaseWorkload(descriptor, info) + {} + + void RefScatterNdWorkload::Execute() const + { + Execute(m_Data.m_Inputs, m_Data.m_Outputs); + } + + void RefScatterNdWorkload::ExecuteAsync(ExecutionData& executionData) + { + WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data); + Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs); + } + + void RefScatterNdWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const + { + ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefScatterNdWorkload_Execute"); + + if (m_Data.m_Parameters.m_InputEnabled) + { + // Getting TensorInfos for three inputs slots + const TensorInfo& inputInfo = GetTensorInfo(inputs[0]); + const TensorInfo& indicesInfo = GetTensorInfo(inputs[1]); + const TensorInfo& updatesInfo = GetTensorInfo(inputs[2]); + + // Getting Decoder for input + std::unique_ptr<Decoder<float>> inputDecoder = MakeDecoder<float>(GetTensorInfo(inputs[0]), + inputs[0]->Map()); + + // Getting Decoder for indices + std::unique_ptr<Decoder<int>> indicesDecoder = MakeDecoder<int>(GetTensorInfo(inputs[1]), + inputs[1]->Map()); + + // Getting Decoder for updates + std::unique_ptr<Decoder<float>> updatesDecoder = MakeDecoder<float>(GetTensorInfo(inputs[2]), + inputs[2]->Map()); + + // Getting Encoder for output + std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]), + outputs[0]->Map()); + + ScatterNd(inputInfo, + indicesInfo, + updatesInfo, + *inputDecoder, + *indicesDecoder, + *updatesDecoder, + *outputEncoder, + m_Data.m_Parameters); + } + else + { + // Getting TensorInfos for three inputs slots + const TensorInfo& shapeInfo = GetTensorInfo(inputs[0]); + const TensorInfo& indicesInfo = GetTensorInfo(inputs[1]); + const TensorInfo& updatesInfo = GetTensorInfo(inputs[2]); + + // Getting Decoder for shape + std::unique_ptr<Decoder<int>> shapeDecoder = MakeDecoder<int>(GetTensorInfo(inputs[0]), + inputs[0]->Map()); + + // Getting Decoder for indices + std::unique_ptr<Decoder<int>> indicesDecoder = MakeDecoder<int>(GetTensorInfo(inputs[1]), + inputs[1]->Map()); + + // Getting Decoder for updates + std::unique_ptr<Decoder<float>> updatesDecoder = MakeDecoder<float>(GetTensorInfo(inputs[2]), + inputs[2]->Map()); + + // Getting Encoder for output + std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]), + outputs[0]->Map()); + + ScatterNd(indicesInfo, + updatesInfo, + shapeInfo, + *indicesDecoder, + *updatesDecoder, + *shapeDecoder, + *outputEncoder, + m_Data.m_Parameters); + } + } + +} // namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/workloads/RefScatterNdWorkload.hpp b/src/backends/reference/workloads/RefScatterNdWorkload.hpp new file mode 100644 index 0000000000..c9cf5a3af3 --- /dev/null +++ b/src/backends/reference/workloads/RefScatterNdWorkload.hpp @@ -0,0 +1,30 @@ +// +// Copyright © 2024 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "RefBaseWorkload.hpp" +#include <armnn/backends/WorkloadData.hpp> + +#include "ScatterNd.hpp" + +namespace armnn +{ + + class RefScatterNdWorkload : public RefBaseWorkload<ScatterNdQueueDescriptor> + { + public: + explicit RefScatterNdWorkload(const ScatterNdQueueDescriptor& descriptor, + const WorkloadInfo& info); + + void Execute() const override; + void ExecuteAsync(ExecutionData& executionData) override; + + private: + void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const; + + }; + +} // namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 98aa27b8a9..92b178c3d5 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -55,6 +55,7 @@ #include "RefReshapeWorkload.hpp" #include "RefResizeWorkload.hpp" #include "RefReverseV2Workload.hpp" +#include "RefScatterNdWorkload.hpp" #include "RefShapeWorkload.hpp" #include "RefSliceWorkload.hpp" #include "RefSplitterWorkload.hpp" diff --git a/src/backends/reference/workloads/ScatterNd.cpp b/src/backends/reference/workloads/ScatterNd.cpp new file mode 100644 index 0000000000..8eb53b00a8 --- /dev/null +++ b/src/backends/reference/workloads/ScatterNd.cpp @@ -0,0 +1,336 @@ +// +// Copyright © 2024 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ScatterNd.hpp" +#include "Encoders.hpp" +#include <armnn/backends/WorkloadData.hpp> +#include <armnn/Logging.hpp> + +#include <fmt/format.h> + +#include <numeric> + +namespace armnn +{ + +float ScatterOperation(ScatterNdFunction operation, + float input, + float update) +{ + switch (operation) + { + case ScatterNdFunction::Update: + return update; + case ScatterNdFunction::Add: + return input + update; + case ScatterNdFunction::Sub: + return input - update; + case ScatterNdFunction::Max: + return std::max(input, update); + case ScatterNdFunction::Min: + return std::min(input, update); + case ScatterNdFunction::Mul: + return input * update; + default: + throw InvalidArgumentException("ScatterNd: cannot execute this operation."); + } +} + +void ScatterNd(const TensorInfo& inputInfo, + const TensorInfo& indicesInfo, + const TensorInfo& updatesInfo, + Decoder<float>& input, + Decoder<int>& indices, + Decoder<float>& updates, + Encoder<float>& output, + ScatterNdDescriptor descriptor) +{ + // Axis Unsupported + if (descriptor.m_AxisEnabled) + { + throw InvalidArgumentException("ScatterNd: axis param not supported."); + } + + // Get the shape for indices, updates, and input + TensorShape indicesShape = indicesInfo.GetShape(); + TensorShape updatesShape = updatesInfo.GetShape(); + TensorShape inputShape = inputInfo.GetShape(); + + // Get the dimensions for indices and updates + unsigned int dimension = inputInfo.GetNumDimensions(); + unsigned int indicesDim = indicesInfo.GetNumDimensions(); + unsigned int updatesDim = updatesInfo.GetNumDimensions(); + + // Calculate the outter and inner dimensions + unsigned int outterDim = indicesShape[indicesDim - 1]; + unsigned int innerDim = dimension - outterDim; + + // Calculate the number of elements in each dimension + unsigned int numElementsCount = 1; + std::vector<unsigned int> elementInDim(dimension); + for (unsigned int dimIndex = dimension; dimIndex > 0; --dimIndex) + { + elementInDim[dimIndex - 1] = numElementsCount; + numElementsCount *= inputShape[dimIndex - 1]; + } + + // Number of updates per index + unsigned int numUpdatesPerIndex = elementInDim[dimension - innerDim - 1]; + + // Number of indices to update + unsigned int numIndices = indicesShape[0]; + + // Check Input Requirements + // Requirement 1: Indices and Updates must have rank at least 1 + if (indicesDim < 1 || updatesDim < 1) + { + throw InvalidArgumentException("ScatterNd: indices and updates must have rank >= 1."); + } + + // Requirement 2: Input, Indices and Updates must have values + if (inputInfo.GetNumElements() == 0 || + indicesInfo.GetNumElements() == 0 || + updatesInfo.GetNumElements() == 0) + { + throw InvalidArgumentException("ScatterNd: input, indices and updates tensor must have values."); + } + + // Requirement 3: Indices and Updates must match in shape + // The updates dimension should equals to 1 + inner dimension + if (updatesDim != 1 + innerDim) + { + throw InvalidArgumentException("ScatterNd: updates dimension should equal to 1 + inner dimension."); + } + // The inner dimension of updates has to match with shape of input + for (unsigned int dimBackIndex = 0; dimBackIndex < innerDim; ++dimBackIndex) + { + if (updatesShape[updatesDim - dimBackIndex - 1] != inputShape[dimension - dimBackIndex - 1]) + { + throw InvalidArgumentException( + fmt::format("ScatterNd: input and updates shape not match on dimension {}", + dimension - dimBackIndex)); + } + } + + // Requirement 4: Check duplicate indices and out of bound indices + std::set<int> indicesSet; + std::vector<int> flattenIndices(numIndices); + for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx) + { + // Get the index + int flattenIndex = 0; + + for (unsigned int outterIdx = 0; outterIdx < outterDim; ++outterIdx) { + + int outterIndexValue = indices.Get(); + + // Check bounds + if (outterIndexValue < 0 || outterIndexValue >= int(inputShape[outterIdx])) + { + throw InvalidArgumentException( + fmt::format("ScatterNd: indices {} out of bound [0, {})", + outterIndexValue, inputShape[outterIdx])); + } + + flattenIndex += int(elementInDim[outterIdx]) * outterIndexValue; + ++indices; + } + + // Check duplicates when executing ScatterNd::Update + if (descriptor.m_Function == ScatterNdFunction::Update && + indicesSet.find(flattenIndex) != indicesSet.end()) + { + throw InvalidArgumentException( + fmt::format("ScatterNd: duplicate indices occurs {}", flattenIndex)); + } + + flattenIndices[indicesIdx] = flattenIndex; + indicesSet.insert(flattenIndex); + } + + // Set the input data to output + for (unsigned int idx = 0; idx < inputInfo.GetNumElements(); ++idx) + { + float inputValue = input.Get(); + ++input; + output.Set(inputValue); + ++output; + } + + // Iterate through all indices to scatter updates + for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx) + { + // Get the index and calculate the flatten index + int flattenIndex = flattenIndices[indicesIdx]; + + // FlattenIndex is the place that we are going to update the elements + unsigned int updatesStartIdx = indicesIdx * numUpdatesPerIndex; + for (unsigned int updatesIdx = 0; updatesIdx < numUpdatesPerIndex; ++updatesIdx) + { + updates[updatesStartIdx + updatesIdx]; + input[static_cast<unsigned int>(flattenIndex) + updatesIdx]; + float updateValue = ScatterOperation(descriptor.m_Function, input.Get(), updates.Get()); + output[static_cast<unsigned int>(flattenIndex) + updatesIdx]; + output.Set(updateValue); + } + } +} + +void ScatterNd(const TensorInfo& indicesInfo, + const TensorInfo& updatesInfo, + const TensorInfo& shapeInfo, + Decoder<int>& indices, + Decoder<float>& updates, + Decoder<int>& shape, + Encoder<float>& output, + ScatterNdDescriptor descriptor) +{ + // Axis Unsupported + if (descriptor.m_AxisEnabled) + { + throw InvalidArgumentException("ScatterNd: axis param not supported."); + } + + // Get the shape for indices, updates, and input + TensorShape indicesShape = indicesInfo.GetShape(); + TensorShape updatesShape = updatesInfo.GetShape(); + + // Get the shape values + std::vector<float> shapeValues = shape.DecodeTensor(shapeInfo.GetShape()); + // Check the shape + if (shapeInfo.GetNumElements() == 0) + { + throw InvalidArgumentException("ScatterNd: shape must have values."); + } + for (auto shapeValue : shapeValues) + { + if (shapeValue <= 0) + { + throw InvalidArgumentException("ScatterNd: shape values must >= 0."); + } + } + // Get the input shape + std::vector<unsigned int> inputShape (shapeValues.begin(), shapeValues.end()); + unsigned int inputElementsNum = static_cast<unsigned int>( + std::accumulate(inputShape.begin(), inputShape.end(), 1, std::multiplies<unsigned int>())); + + // Get the dimensions for indices and updates + unsigned int dimension = shapeInfo.GetNumElements(); + unsigned int indicesDim = indicesInfo.GetNumDimensions(); + unsigned int updatesDim = updatesInfo.GetNumDimensions(); + + // Calculate the outter and inner dimensions + unsigned int outterDim = indicesShape[indicesDim - 1]; + unsigned int innerDim = dimension - outterDim; + + // Calculate the number of elements in each dimension + unsigned int numElementsCount = 1; + std::vector<unsigned int> elementInDim(dimension); + for (unsigned int dimIndex = dimension; dimIndex > 0; --dimIndex) + { + elementInDim[dimIndex - 1] = numElementsCount; + numElementsCount *= inputShape[dimIndex - 1]; + } + + // Number of updates per index + unsigned int numUpdatesPerIndex = elementInDim[dimension - innerDim - 1]; + + // Number of indices to update + unsigned int numIndices = indicesShape[0]; + + // Check Input Requirements + // Requirement 1: Indices and Updates must have rank at least 1 + if (indicesDim < 1 || updatesDim < 1) + { + throw InvalidArgumentException("ScatterNd: indices and updates must have rank >= 1."); + } + + // Requirement 2: shape, Indices and Updates must have values + if (indicesInfo.GetNumElements() == 0 || + updatesInfo.GetNumElements() == 0) + { + throw InvalidArgumentException("ScatterNd: indices and updates tensor must have values."); + } + + // Requirement 3: Indices and Updates must match in shape + // The updates dimension should equals to 1 + inner dimension + if (updatesDim != 1 + innerDim) + { + throw InvalidArgumentException("ScatterNd: updates dimension should equal to 1 + inner dimension."); + } + // The inner dimension of updates has to match with shape of input + for (unsigned int dimBackIndex = 0; dimBackIndex < innerDim; ++dimBackIndex) + { + if (updatesShape[updatesDim - dimBackIndex - 1] != inputShape[dimension - dimBackIndex - 1]) + { + throw InvalidArgumentException( + fmt::format("ScatterNd: input and updates shape not match on dimension {}", + dimension - dimBackIndex)); + } + } + + // Requirement 4: Check duplicate indices and out of bound indices + std::set<int> indicesSet; + std::vector<int> flattenIndices(numIndices); + for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx) + { + // Get the index + int flattenIndex = 0; + + for (unsigned int outterIdx = 0; outterIdx < outterDim; ++outterIdx) { + + int outterIndexValue = indices.Get(); + + // Check bounds + if (outterIndexValue < 0 || outterIndexValue >= int(inputShape[outterIdx])) + { + throw InvalidArgumentException( + fmt::format("ScatterNd: indices {} out of bound [0, {})", + outterIndexValue, inputShape[outterIdx])); + } + + flattenIndex += int(elementInDim[outterIdx]) * outterIndexValue; + ++indices; + } + + // Check duplicates when executing ScatterNd::Update + if (descriptor.m_Function == ScatterNdFunction::Update && + indicesSet.find(flattenIndex) != indicesSet.end()) + { + throw InvalidArgumentException( + fmt::format("ScatterNd: duplicate indices {} occurs when executing ScatterNd::Update.", + flattenIndex)); + } + + flattenIndices[indicesIdx] = flattenIndex; + indicesSet.insert(flattenIndex); + } + + // Set zeros to output + for (unsigned int idx = 0; idx < inputElementsNum; ++idx) + { + output.Set(0.0f); + ++output; + } + + // Iterate through all indices to scatter updates + for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx) + { + // Get the index and calculate the flatten index + int flattenIndex = flattenIndices[indicesIdx]; + + // FlattenIndex is the place that we are going to update the elements + unsigned int updatesStartIdx = indicesIdx * numUpdatesPerIndex; + for (unsigned int updatesIdx = 0; updatesIdx < numUpdatesPerIndex; ++updatesIdx) + { + updates[updatesStartIdx + updatesIdx]; + float updateValue = ScatterOperation(descriptor.m_Function, 0.0f, updates.Get()); + output[static_cast<unsigned int>(flattenIndex) + updatesIdx]; + output.Set(updateValue); + } + } +} + +} // namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/workloads/ScatterNd.hpp b/src/backends/reference/workloads/ScatterNd.hpp new file mode 100644 index 0000000000..e40d3640a7 --- /dev/null +++ b/src/backends/reference/workloads/ScatterNd.hpp @@ -0,0 +1,34 @@ +// +// Copyright © 2024 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <armnn/Tensor.hpp> +#include "Encoders.hpp" +#include "Decoders.hpp" +#include "armnn/Descriptors.hpp" + +namespace armnn +{ +// ScatterNd with input tensor +void ScatterNd(const TensorInfo& inputInfo, + const TensorInfo& indicesInfo, + const TensorInfo& updatesInfo, + Decoder<float>& input, + Decoder<int>& indices, + Decoder<float>& updates, + Encoder<float>& output, + ScatterNdDescriptor descriptor); + +// ScatterNd without input tensor, only shape provided +void ScatterNd(const TensorInfo& indicesInfo, + const TensorInfo& updatesInfo, + const TensorInfo& shapeInfo, + Decoder<int>& indices, + Decoder<float>& updates, + Decoder<int>& shape, + Encoder<float>& output, + ScatterNdDescriptor descriptor); +} // namespace armnn
\ No newline at end of file |