diff options
Diffstat (limited to 'src/backends/reference/workloads/ScatterNd.cpp')
-rw-r--r-- | src/backends/reference/workloads/ScatterNd.cpp | 336 |
1 files changed, 336 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/ScatterNd.cpp b/src/backends/reference/workloads/ScatterNd.cpp new file mode 100644 index 0000000000..8eb53b00a8 --- /dev/null +++ b/src/backends/reference/workloads/ScatterNd.cpp @@ -0,0 +1,336 @@ +// +// Copyright © 2024 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ScatterNd.hpp" +#include "Encoders.hpp" +#include <armnn/backends/WorkloadData.hpp> +#include <armnn/Logging.hpp> + +#include <fmt/format.h> + +#include <numeric> + +namespace armnn +{ + +float ScatterOperation(ScatterNdFunction operation, + float input, + float update) +{ + switch (operation) + { + case ScatterNdFunction::Update: + return update; + case ScatterNdFunction::Add: + return input + update; + case ScatterNdFunction::Sub: + return input - update; + case ScatterNdFunction::Max: + return std::max(input, update); + case ScatterNdFunction::Min: + return std::min(input, update); + case ScatterNdFunction::Mul: + return input * update; + default: + throw InvalidArgumentException("ScatterNd: cannot execute this operation."); + } +} + +void ScatterNd(const TensorInfo& inputInfo, + const TensorInfo& indicesInfo, + const TensorInfo& updatesInfo, + Decoder<float>& input, + Decoder<int>& indices, + Decoder<float>& updates, + Encoder<float>& output, + ScatterNdDescriptor descriptor) +{ + // Axis Unsupported + if (descriptor.m_AxisEnabled) + { + throw InvalidArgumentException("ScatterNd: axis param not supported."); + } + + // Get the shape for indices, updates, and input + TensorShape indicesShape = indicesInfo.GetShape(); + TensorShape updatesShape = updatesInfo.GetShape(); + TensorShape inputShape = inputInfo.GetShape(); + + // Get the dimensions for indices and updates + unsigned int dimension = inputInfo.GetNumDimensions(); + unsigned int indicesDim = indicesInfo.GetNumDimensions(); + unsigned int updatesDim = updatesInfo.GetNumDimensions(); + + // Calculate the outter and inner dimensions + unsigned int outterDim = indicesShape[indicesDim - 1]; + unsigned int innerDim = dimension - outterDim; + + // Calculate the number of elements in each dimension + unsigned int numElementsCount = 1; + std::vector<unsigned int> elementInDim(dimension); + for (unsigned int dimIndex = dimension; dimIndex > 0; --dimIndex) + { + elementInDim[dimIndex - 1] = numElementsCount; + numElementsCount *= inputShape[dimIndex - 1]; + } + + // Number of updates per index + unsigned int numUpdatesPerIndex = elementInDim[dimension - innerDim - 1]; + + // Number of indices to update + unsigned int numIndices = indicesShape[0]; + + // Check Input Requirements + // Requirement 1: Indices and Updates must have rank at least 1 + if (indicesDim < 1 || updatesDim < 1) + { + throw InvalidArgumentException("ScatterNd: indices and updates must have rank >= 1."); + } + + // Requirement 2: Input, Indices and Updates must have values + if (inputInfo.GetNumElements() == 0 || + indicesInfo.GetNumElements() == 0 || + updatesInfo.GetNumElements() == 0) + { + throw InvalidArgumentException("ScatterNd: input, indices and updates tensor must have values."); + } + + // Requirement 3: Indices and Updates must match in shape + // The updates dimension should equals to 1 + inner dimension + if (updatesDim != 1 + innerDim) + { + throw InvalidArgumentException("ScatterNd: updates dimension should equal to 1 + inner dimension."); + } + // The inner dimension of updates has to match with shape of input + for (unsigned int dimBackIndex = 0; dimBackIndex < innerDim; ++dimBackIndex) + { + if (updatesShape[updatesDim - dimBackIndex - 1] != inputShape[dimension - dimBackIndex - 1]) + { + throw InvalidArgumentException( + fmt::format("ScatterNd: input and updates shape not match on dimension {}", + dimension - dimBackIndex)); + } + } + + // Requirement 4: Check duplicate indices and out of bound indices + std::set<int> indicesSet; + std::vector<int> flattenIndices(numIndices); + for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx) + { + // Get the index + int flattenIndex = 0; + + for (unsigned int outterIdx = 0; outterIdx < outterDim; ++outterIdx) { + + int outterIndexValue = indices.Get(); + + // Check bounds + if (outterIndexValue < 0 || outterIndexValue >= int(inputShape[outterIdx])) + { + throw InvalidArgumentException( + fmt::format("ScatterNd: indices {} out of bound [0, {})", + outterIndexValue, inputShape[outterIdx])); + } + + flattenIndex += int(elementInDim[outterIdx]) * outterIndexValue; + ++indices; + } + + // Check duplicates when executing ScatterNd::Update + if (descriptor.m_Function == ScatterNdFunction::Update && + indicesSet.find(flattenIndex) != indicesSet.end()) + { + throw InvalidArgumentException( + fmt::format("ScatterNd: duplicate indices occurs {}", flattenIndex)); + } + + flattenIndices[indicesIdx] = flattenIndex; + indicesSet.insert(flattenIndex); + } + + // Set the input data to output + for (unsigned int idx = 0; idx < inputInfo.GetNumElements(); ++idx) + { + float inputValue = input.Get(); + ++input; + output.Set(inputValue); + ++output; + } + + // Iterate through all indices to scatter updates + for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx) + { + // Get the index and calculate the flatten index + int flattenIndex = flattenIndices[indicesIdx]; + + // FlattenIndex is the place that we are going to update the elements + unsigned int updatesStartIdx = indicesIdx * numUpdatesPerIndex; + for (unsigned int updatesIdx = 0; updatesIdx < numUpdatesPerIndex; ++updatesIdx) + { + updates[updatesStartIdx + updatesIdx]; + input[static_cast<unsigned int>(flattenIndex) + updatesIdx]; + float updateValue = ScatterOperation(descriptor.m_Function, input.Get(), updates.Get()); + output[static_cast<unsigned int>(flattenIndex) + updatesIdx]; + output.Set(updateValue); + } + } +} + +void ScatterNd(const TensorInfo& indicesInfo, + const TensorInfo& updatesInfo, + const TensorInfo& shapeInfo, + Decoder<int>& indices, + Decoder<float>& updates, + Decoder<int>& shape, + Encoder<float>& output, + ScatterNdDescriptor descriptor) +{ + // Axis Unsupported + if (descriptor.m_AxisEnabled) + { + throw InvalidArgumentException("ScatterNd: axis param not supported."); + } + + // Get the shape for indices, updates, and input + TensorShape indicesShape = indicesInfo.GetShape(); + TensorShape updatesShape = updatesInfo.GetShape(); + + // Get the shape values + std::vector<float> shapeValues = shape.DecodeTensor(shapeInfo.GetShape()); + // Check the shape + if (shapeInfo.GetNumElements() == 0) + { + throw InvalidArgumentException("ScatterNd: shape must have values."); + } + for (auto shapeValue : shapeValues) + { + if (shapeValue <= 0) + { + throw InvalidArgumentException("ScatterNd: shape values must >= 0."); + } + } + // Get the input shape + std::vector<unsigned int> inputShape (shapeValues.begin(), shapeValues.end()); + unsigned int inputElementsNum = static_cast<unsigned int>( + std::accumulate(inputShape.begin(), inputShape.end(), 1, std::multiplies<unsigned int>())); + + // Get the dimensions for indices and updates + unsigned int dimension = shapeInfo.GetNumElements(); + unsigned int indicesDim = indicesInfo.GetNumDimensions(); + unsigned int updatesDim = updatesInfo.GetNumDimensions(); + + // Calculate the outter and inner dimensions + unsigned int outterDim = indicesShape[indicesDim - 1]; + unsigned int innerDim = dimension - outterDim; + + // Calculate the number of elements in each dimension + unsigned int numElementsCount = 1; + std::vector<unsigned int> elementInDim(dimension); + for (unsigned int dimIndex = dimension; dimIndex > 0; --dimIndex) + { + elementInDim[dimIndex - 1] = numElementsCount; + numElementsCount *= inputShape[dimIndex - 1]; + } + + // Number of updates per index + unsigned int numUpdatesPerIndex = elementInDim[dimension - innerDim - 1]; + + // Number of indices to update + unsigned int numIndices = indicesShape[0]; + + // Check Input Requirements + // Requirement 1: Indices and Updates must have rank at least 1 + if (indicesDim < 1 || updatesDim < 1) + { + throw InvalidArgumentException("ScatterNd: indices and updates must have rank >= 1."); + } + + // Requirement 2: shape, Indices and Updates must have values + if (indicesInfo.GetNumElements() == 0 || + updatesInfo.GetNumElements() == 0) + { + throw InvalidArgumentException("ScatterNd: indices and updates tensor must have values."); + } + + // Requirement 3: Indices and Updates must match in shape + // The updates dimension should equals to 1 + inner dimension + if (updatesDim != 1 + innerDim) + { + throw InvalidArgumentException("ScatterNd: updates dimension should equal to 1 + inner dimension."); + } + // The inner dimension of updates has to match with shape of input + for (unsigned int dimBackIndex = 0; dimBackIndex < innerDim; ++dimBackIndex) + { + if (updatesShape[updatesDim - dimBackIndex - 1] != inputShape[dimension - dimBackIndex - 1]) + { + throw InvalidArgumentException( + fmt::format("ScatterNd: input and updates shape not match on dimension {}", + dimension - dimBackIndex)); + } + } + + // Requirement 4: Check duplicate indices and out of bound indices + std::set<int> indicesSet; + std::vector<int> flattenIndices(numIndices); + for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx) + { + // Get the index + int flattenIndex = 0; + + for (unsigned int outterIdx = 0; outterIdx < outterDim; ++outterIdx) { + + int outterIndexValue = indices.Get(); + + // Check bounds + if (outterIndexValue < 0 || outterIndexValue >= int(inputShape[outterIdx])) + { + throw InvalidArgumentException( + fmt::format("ScatterNd: indices {} out of bound [0, {})", + outterIndexValue, inputShape[outterIdx])); + } + + flattenIndex += int(elementInDim[outterIdx]) * outterIndexValue; + ++indices; + } + + // Check duplicates when executing ScatterNd::Update + if (descriptor.m_Function == ScatterNdFunction::Update && + indicesSet.find(flattenIndex) != indicesSet.end()) + { + throw InvalidArgumentException( + fmt::format("ScatterNd: duplicate indices {} occurs when executing ScatterNd::Update.", + flattenIndex)); + } + + flattenIndices[indicesIdx] = flattenIndex; + indicesSet.insert(flattenIndex); + } + + // Set zeros to output + for (unsigned int idx = 0; idx < inputElementsNum; ++idx) + { + output.Set(0.0f); + ++output; + } + + // Iterate through all indices to scatter updates + for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx) + { + // Get the index and calculate the flatten index + int flattenIndex = flattenIndices[indicesIdx]; + + // FlattenIndex is the place that we are going to update the elements + unsigned int updatesStartIdx = indicesIdx * numUpdatesPerIndex; + for (unsigned int updatesIdx = 0; updatesIdx < numUpdatesPerIndex; ++updatesIdx) + { + updates[updatesStartIdx + updatesIdx]; + float updateValue = ScatterOperation(descriptor.m_Function, 0.0f, updates.Get()); + output[static_cast<unsigned int>(flattenIndex) + updatesIdx]; + output.Set(updateValue); + } + } +} + +} // namespace armnn
\ No newline at end of file |